Improved several mysqloo functions

Queries can now be started multiple times
PreparedQueries now cache and reuse the prepared statements allocated on the server
PreparedQueries can now return multiple result sets
Database:ping() now always attempts to reconnect to the database regardless of autoReconnect status
Added automatic update check that checks if an update is available
Added mysqloo.MINOR_VERSION
Added database:setCachePreparedStatements
slightly changed behavior of query:hasMoreResults()
Added lua library for ease of use
Added tmysql4 wrapper
Added connection pool lua library
Removed out folder
Removed logger
This commit is contained in:
syl0r 2017-01-12 04:42:56 +01:00
parent c2db04fe3f
commit a86d2e792f
29 changed files with 1605 additions and 1078 deletions

6
.gitignore vendored
View File

@ -1 +1,7 @@
/solutions
/out/windows/gmsv_mysqloo_win32.iobj
/out/windows/gmsv_mysqloo_win32.ipdb
/out/windows/gmsv_mysqloo_win32.pdb
/premake4
/premake4.exe
/out

View File

@ -14,13 +14,12 @@ preparedQuery:setString(2, "''``--lel") -- you don't need to escape the string i
preparedQuery:setNull(3)
preparedQuery:setNumber(4, 100)
preparedQuery:setBoolean(5, true)
preparedQuery:start()
preparedQuery:putNewParameters()
-- you can now reuse prepared queries
preparedQuery:setString(1, "STEAM_0:0:654321")
preparedQuery:setString(2, "HufflePuffle")
preparedQuery:setString(3, "owner")
preparedQuery:setNumber(4, 100000)
preparedQuery:setBoolean(5, false)
preparedQuery:start() -- this will insert 2 rows into the table users
preparedQuery:start()

View File

@ -2,15 +2,12 @@ include("database.lua") -- check database.lua for an example of how to create da
local query = db:query("SELECT 1, 2, 3; SELECT 4, 5, 6; SELECT 7, 8, 9") -- In mysqloo 9 a query can be started before the database is connected
function query:onSuccess(data)
local row = data[1]
for k,v in pairs(row) do
print(v) -- should print 1, 2, 3 in any order
end
while(self:hasMoreResults()) do -- should be true twice
row = self:getNextResults()[1]
while(self:hasMoreResults()) do -- should be true three times
row = query:getData()[1]
for k,v in pairs(row) do
print(v) -- should print 4, 5, 6, 7, 8, 9 in any order
print(v) -- should print 1, 2, 3, 4, 5, 6, 7, 8, 9 in any order
end
self:getNextResults()
end
end

View File

@ -16,20 +16,17 @@
class DatabaseThread;
class ConnectThread;
enum DatabaseStatus
{
enum DatabaseStatus {
DATABASE_CONNECTED = 0,
DATABASE_CONNECTING = 1,
DATABASE_NOT_CONNECTED = 2,
DATABASE_CONNECTION_FAILED = 3
};
class Database : LuaObjectBase
{
class Database : LuaObjectBase {
friend class IQuery;
public:
enum
{
enum {
INTEGER = 0,
BIT,
FLOATING_POINT,
@ -37,11 +34,19 @@ public:
};
Database(lua_State* state, std::string host, std::string username, std::string pw, std::string database, unsigned int port, std::string unixSocket);
~Database(void);
void enqueueQuery(IQuery* query);
void enqueueQuery(IQuery* query, std::shared_ptr<IQueryData> data);
void think(lua_State*);
void cacheStatement(MYSQL_STMT* stmt);
void freeStatement(MYSQL_STMT* stmt);
void setAutoReconnect(my_bool autoReconnect);
my_bool getAutoReconnect();
bool shouldCachePreparedStatements() {
return cachePreparedStatements;
}
private:
void run();
void connectRun();
void freeUnusedStatements();
static int query(lua_State* state);
static int prepare(lua_State* state);
static int createTransaction(lua_State* state);
@ -57,13 +62,17 @@ private:
static int setAutoReconnect(lua_State* state);
static int setMultiStatements(lua_State* state);
static int ping(lua_State* state);
std::deque<std::shared_ptr<IQuery>> finishedQueries;
std::deque<std::shared_ptr<IQuery>> queryQueue;
static int setCachePreparedStatements(lua_State* state);
std::deque<std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>>> finishedQueries;
std::deque<std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>>> queryQueue;
std::vector<MYSQL_STMT*> cachedStatements;
std::vector<MYSQL_STMT*> freedStatements;
MYSQL* m_sql;
std::thread m_thread;
std::mutex m_queryQueueMutex;
std::mutex m_finishedQueueMutex;
std::mutex m_connectMutex;
std::mutex m_stmtMutex;
std::condition_variable m_connectWakeupVariable;
unsigned int m_serverVersion = 0;
std::string m_serverInfo = "";
@ -75,6 +84,7 @@ private:
std::atomic<bool> destroyed{ false };
std::atomic<bool> m_success{ true };
std::atomic<bool> m_connectionDone{ false };
std::atomic<bool> cachePreparedStatements{ true };
std::atomic<DatabaseStatus> m_status{ DATABASE_NOT_CONNECTED };
std::string m_connection_err;
std::condition_variable m_queryWakupVariable;

View File

@ -5,47 +5,35 @@
#include <string>
#include <mutex>
#include <atomic>
#include <vector>
#include <condition_variable>
#include "ResultData.h"
class DataRow;
class Database;
struct QueryResultData
{
std::string error = "";
bool finished = false;
bool failed = false;
unsigned int errno;
unsigned int affectedRows;
};
enum QueryStatus
{
enum QueryStatus {
QUERY_NOT_RUNNING = 0,
QUERY_RUNNING = 1,
QUERY_COMPLETE = 3, //Query is complete right before callback is run
QUERY_ABORTED = 4,
QUERY_WAITING = 5,
};
enum QueryResultStatus
{
enum QueryResultStatus {
QUERY_NONE = 0,
QUERY_ERROR,
QUERY_SUCCESS
};
enum
{
enum {
OPTION_NUMERIC_FIELDS = 1,
OPTION_NAMED_FIELDS = 2,
OPTION_INTERPRET_DATA = 4,
OPTION_CACHE = 8,
};
class MySQLException : public std::runtime_error
{
class IQueryData;
class MySQLException : public std::runtime_error {
public:
MySQLException(int errorCode, const char* message) : runtime_error(message)
{
MySQLException(int errorCode, const char* message) : runtime_error(message) {
this->m_errorCode = errorCode;
}
int getErrorCode() const { return m_errorCode; }
@ -53,35 +41,33 @@ private:
int m_errorCode = 0;
};
class IQuery : public LuaObjectBase
{
friend class Database;
class IQuery : public LuaObjectBase {
friend class Database;
public:
IQuery(Database* dbase, lua_State* state);
virtual ~IQuery();
virtual void doCallback(lua_State* state) = 0;
void onDestroyed(lua_State* state);
void setResultStatus(QueryResultStatus);
void setStatus(QueryStatus);
virtual void doCallback(lua_State* state, std::shared_ptr<IQueryData> queryData) = 0;
virtual void onDestroyed(lua_State* state) {};
virtual std::shared_ptr<IQueryData> buildQueryData(lua_State* state) = 0;
void addQueryData(lua_State* state, std::shared_ptr<IQueryData> data, bool shouldRefCallbacks = true);
void onQueryDataFinished(lua_State* state, std::shared_ptr<IQueryData> data);
void setCallbackData(std::shared_ptr<IQueryData> data) {
callbackQueryData = data;
}
protected:
//methods
QueryResultStatus m_resultStatus = QUERY_NONE;
QueryResultStatus getResultStatus();
virtual bool executeStatement(MYSQL* m_sql) = 0;
void dataToLua(lua_State* state, int rowReference, unsigned int column, std::string &columnValue, const char* columnName, int columnType, bool isNull);
virtual bool executeStatement(MYSQL* m_sql, std::shared_ptr<IQueryData> data) = 0;
virtual void think(lua_State* state) {};
static int start(lua_State* state);
static int isRunning(lua_State* state);
static int lastInsert(lua_State* state);
static int affectedRows(lua_State* state);
static int getData_Wrapper(lua_State* state);
static int hasMoreResults(lua_State* state);
static int getNextResults(lua_State* state);
static int setOption(lua_State* state);
static int wait(lua_State* state);
static int error(lua_State* state);
static int abort(lua_State* state);
int getData(lua_State* state);
static int wait(lua_State* state);
bool hasCallbackData() {
return callbackQueryData.get() != nullptr;
}
//Wrapper functions for c api that throw exceptions
void mysqlQuery(MYSQL* sql, std::string &query);
void mysqlAutocommit(MYSQL* sql, bool auto_mode);
@ -89,14 +75,76 @@ protected:
bool mysqlNextResult(MYSQL* sql);
//fields
Database* m_database = nullptr;
std::atomic<bool> finished{ false };
std::atomic<QueryStatus> m_status{ QUERY_NOT_RUNNING };
std::string m_errorText = "";
std::deque<my_ulonglong> m_affectedRows;
std::deque<my_ulonglong> m_insertIds;
std::deque<ResultData> results;
std::condition_variable m_waitWakeupVariable;
int m_options = 0;
int dataReference = 0;
std::vector<std::shared_ptr<IQueryData>> runningQueryData;
std::shared_ptr<IQueryData> callbackQueryData;
bool hasBeenStarted = false;
};
class IQueryData {
friend class IQuery;
public:
std::string getError() {
return m_errorText;
}
void setError(std::string err) {
m_errorText = err;
}
bool isFinished() {
return finished;
}
void setFinished(bool isFinished) {
finished = isFinished;
}
QueryStatus getStatus() {
return m_status;
}
void setStatus(QueryStatus status) {
this->m_status = status;
}
QueryResultStatus getResultStatus() {
return m_resultStatus;
}
void setResultStatus(QueryResultStatus status) {
m_resultStatus = status;
}
int getErrorReference() {
return m_errorReference;
}
int getSuccessReference() {
return m_successReference;
}
int getOnDataReference() {
return m_onDataReference;
}
int getAbortReference() {
return m_abortReference;
}
bool isFirstData() {
return m_wasFirstData;
}
protected:
std::string m_errorText = "";
std::atomic<bool> finished{ false };
std::atomic<QueryStatus> m_status{ QUERY_NOT_RUNNING };
std::atomic<QueryResultStatus> m_resultStatus{ QUERY_NONE };
int m_successReference = 0;
int m_errorReference = 0;
int m_abortReference = 0;
int m_onDataReference = 0;
bool m_wasFirstData = false;
};
#endif

View File

@ -1,18 +0,0 @@
#ifndef LOGGER_
#define LOGGER_
#define LOGGER_ENABLED 0
#ifdef LINUX
#define __FUNCSIG__ __PRETTY_FUNCTION__
#endif
#if LOGGER_ENABLED == 1
#define LOG_CURRENT_FUNCTIONCALL Logger::Log("Calling function %s in file %s line:%d\n", __FUNCSIG__, __FILE__, __LINE__);
#else
#define LOG_CURRENT_FUNCTIONCALL ;
#endif
namespace Logger
{
void Log(const char* format, ...);
};
#endif

View File

@ -10,14 +10,12 @@
#include <map>
#include <algorithm>
#include <string>
enum
{
enum {
TYPE_DATABASE = 1,
TYPE_QUERY = 2
};
class LuaObjectBase : public std::enable_shared_from_this<LuaObjectBase>
{
class LuaObjectBase : public std::enable_shared_from_this<LuaObjectBase> {
public:
LuaObjectBase(lua_State* state, bool shouldthink, unsigned char type);
LuaObjectBase(lua_State* state, unsigned char type);
@ -32,6 +30,8 @@ public:
static LuaObjectBase* unpackLuaObject(lua_State* state, int index, int type, bool shouldReference);
int pushTableReference(lua_State* state);
bool hasCallback(lua_State* state, const char* functionName);
int getCallbackReference(lua_State* state, const char* functionName);
void runFunction(lua_State* state, int funcRef, const char* sig = 0, ...);
void runCallback(lua_State* state, const char* functionName, const char* sig = 0, ...);
static std::deque<std::shared_ptr<LuaObjectBase>> luaObjects;
static std::deque<std::shared_ptr<LuaObjectBase>> luaThinkObjects;
@ -40,6 +40,7 @@ public:
std::shared_ptr<LuaObjectBase> getSharedPointerInstance();
void unreference(lua_State* state);
protected:
void runFunctionVaList(lua_State* state, int funcRef, const char* sig, va_list list);
bool scheduledForRemoval = false;
bool shouldthink = false;
int m_tableReference = 0;
@ -47,6 +48,7 @@ protected:
bool canbedestroyed = true;
const char* classname = "LuaObject";
unsigned char type = 255;
static void referenceTable(lua_State* state, LuaObjectBase* object, int index);
private:
std::map<std::string, GarrysMod::Lua::CFunc> m_callbackFunctions;
static int tableMetaTable;

View File

@ -8,14 +8,13 @@
#include <string.h>
class PingQuery : Query
{
class PingQuery : Query {
friend class Database;
public:
PingQuery(Database* dbase, lua_State* state);
virtual ~PingQuery(void);
protected:
void executeQuery(MYSQL* m_sql);
void executeQuery(MYSQL* m_sql, std::shared_ptr<IQueryData>);
bool pingSuccess = false;
};
#endif

View File

@ -7,8 +7,7 @@
#include <string.h>
class PreparedQueryField
{
class PreparedQueryField {
friend class PreparedQuery;
public:
PreparedQueryField(unsigned int index, int type) : m_index(index), m_type(type) {}
@ -20,39 +19,49 @@ private:
};
template< typename T >
class TypedQueryField : public PreparedQueryField
{
class TypedQueryField : public PreparedQueryField {
friend class PreparedQuery;
public:
TypedQueryField(unsigned int index, int type, const T& data)
: PreparedQueryField(index, type), m_data(data){};
: PreparedQueryField(index, type), m_data(data) {};
virtual ~TypedQueryField() {}
private:
T m_data;
};
class PreparedQuery : Query
{
class PreparedQuery : public Query {
friend class Database;
public:
PreparedQuery(Database* dbase, lua_State* state);
virtual ~PreparedQuery(void);
bool executeStatement(MYSQL* connection);
bool executeStatement(MYSQL* connection, std::shared_ptr<IQueryData> data);
virtual void onDestroyed(lua_State* state);
protected:
void executeQuery(MYSQL* m_sql);
virtual std::shared_ptr<IQueryData> buildQueryData(lua_State* state);
void executeQuery(MYSQL* m_sql, std::shared_ptr<IQueryData> data);
private:
std::deque<std::unordered_map<unsigned int, std::unique_ptr<PreparedQueryField>>> parameters;
std::deque<std::unordered_map<unsigned int, std::shared_ptr<PreparedQueryField>>> m_parameters;
static int setNumber(lua_State* state);
static int setString(lua_State* state);
static int setBoolean(lua_State* state);
static int setNull(lua_State* state);
static int putNewParameters(lua_State* state);
MYSQL_STMT *mysqlStmtInit(MYSQL* sql);
void generateMysqlBinds(MYSQL_BIND* binds, std::unordered_map<unsigned int, std::unique_ptr<PreparedQueryField>> *map, unsigned int parameterCount);
void generateMysqlBinds(MYSQL_BIND* binds, std::unordered_map<unsigned int, std::shared_ptr<PreparedQueryField>> &map, unsigned int parameterCount);
void mysqlStmtBindParameter(MYSQL_STMT* sql, MYSQL_BIND* bind);
void mysqlStmtPrepare(MYSQL_STMT* sql, const char* str);
void mysqlStmtExecute(MYSQL_STMT* sql);
void mysqlStmtStoreResult(MYSQL_STMT* sql);
bool mysqlStmtNextResult(MYSQL_STMT* sql);
//This is atomic to prevent visibility issues
std::atomic<MYSQL_STMT*> cachedStatement{ nullptr };
};
class PreparedQueryData : public QueryData {
friend class PreparedQuery;
protected:
std::deque<std::unordered_map<unsigned int, std::shared_ptr<PreparedQueryField>>> m_parameters;
bool firstAttempt = true;
};
#endif

View File

@ -5,18 +5,63 @@
#include <condition_variable>
#include "IQuery.h"
class Query : public IQuery
{
class QueryData;
class Query : public IQuery {
friend class Database;
public:
Query(Database* dbase, lua_State* state);
virtual ~Query(void);
void setQuery(std::string query);
virtual bool executeStatement(MYSQL* m_sql);
virtual void executeQuery(MYSQL* m_sql);
void doCallback(lua_State* state);
virtual bool executeStatement(MYSQL* m_sql, std::shared_ptr<IQueryData> data);
virtual void executeQuery(MYSQL* m_sql, std::shared_ptr<IQueryData> data);
virtual void onDestroyed(lua_State* state);
virtual void doCallback(lua_State* state, std::shared_ptr<IQueryData> queryData);
virtual std::shared_ptr<IQueryData> buildQueryData(lua_State* state);
protected:
void dataToLua(lua_State* state, int rowReference, unsigned int column, std::string &columnValue, const char* columnName, int columnType, bool isNull);
static int lastInsert(lua_State* state);
static int affectedRows(lua_State* state);
static int getData_Wrapper(lua_State* state);
static int hasMoreResults(lua_State* state);
static int getNextResults(lua_State* state);
int getData(lua_State* state);
int dataReference = 0;
std::string m_query;
};
class QueryData : public IQueryData {
friend class Query;
public:
my_ulonglong getLastInsertID() {
return (m_insertIds.size() == 0) ? 0 : m_insertIds.front();
}
my_ulonglong getAffectedRows() {
return (m_affectedRows.size() == 0) ? 0 : m_affectedRows.front();
}
bool hasMoreResults() {
return m_insertIds.size() > 0 && m_affectedRows.size() > 0 && m_results.size() > 0;
}
bool getNextResults() {
if (!hasMoreResults()) return false;
m_results.pop_front();
m_insertIds.pop_front();
m_affectedRows.pop_front();
return true;
}
ResultData& getResult() {
return m_results.front();
}
std::deque<ResultData> getResults() {
return m_results;
}
protected:
std::deque<my_ulonglong> m_affectedRows;
std::deque<my_ulonglong> m_insertIds;
std::deque<ResultData> m_results;
};
#endif

View File

@ -0,0 +1,11 @@
#ifndef QUERY_DATA_
#define QUERY_DATA_
class QueryData
{
int lastInsert;
std::deque<my_ulonglong> m_affectedRows;
std::deque<my_ulonglong> m_insertIds;
std::deque<ResultData> results;
std::condition_variable m_waitWakeupVariable;
}
#endif

View File

@ -5,17 +5,14 @@
#include <memory>
#include "MySQLHeader.h"
class ResultDataRow
{
class ResultDataRow {
public:
ResultDataRow(unsigned long *lengths, MYSQL_ROW row, unsigned int columnCount);
ResultDataRow(MYSQL_STMT* statement, MYSQL_BIND* bind, unsigned int columnCount);
std::vector<std::string> & getValues()
{
std::vector<std::string> & getValues() {
return values;
}
bool isFieldNull(int index)
{
bool isFieldNull(int index) {
return nullFields[index];
}
private:
@ -25,16 +22,15 @@ private:
std::vector<std::string> values;
};
class ResultData
{
class ResultData {
public:
ResultData(MYSQL_RES* result);
ResultData(MYSQL_STMT* result);
ResultData();
~ResultData();
std::vector<std::string> & getColumns(){ return columns; }
std::vector<ResultDataRow> & getRows(){ return rows; }
std::vector<int> & getColumnTypes(){ return columnTypes; }
std::vector<std::string> & getColumns() { return columns; }
std::vector<ResultDataRow> & getRows() { return rows; }
std::vector<int> & getColumnTypes() { return columnTypes; }
private:
ResultData(unsigned int columns, unsigned int rows);
unsigned int columnCount = 0;

View File

@ -7,20 +7,24 @@
#include "Query.h"
class Transaction : public IQuery
{
class Transaction : public IQuery {
friend class Database;
public:
Transaction(Database* dbase, lua_State* state);
void doCallback(lua_State* state);
void doCallback(lua_State* state, std::shared_ptr<IQueryData> data);
virtual std::shared_ptr<IQueryData> buildQueryData(lua_State* state);
protected:
static int clearQueries(lua_State* state);
static int addQuery(lua_State* state);
static int getQueries(lua_State* state);
bool executeStatement(MYSQL * connection);
bool executeStatement(MYSQL * connection, std::shared_ptr<IQueryData> data);
void onDestroyed(lua_State* state);
private:
std::deque<Query*> queries;
};
class TransactionData : public IQueryData {
friend class Transaction;
protected:
std::deque<std::pair<std::shared_ptr<Query>, std::shared_ptr<IQueryData>>> m_queries;
bool retried = false;
std::mutex m_queryMutex;
};
#endif

View File

@ -3,7 +3,6 @@
#include "IQuery.h"
#include "PreparedQuery.h"
#include "PingQuery.h"
#include "Logger.h"
#include "Transaction.h"
#include <string>
#include <cstring>
@ -11,9 +10,7 @@
#include <chrono>
Database::Database(lua_State* state, std::string host, std::string username, std::string pw, std::string database, unsigned int port, std::string unixSocket) :
LuaObjectBase(state, TYPE_DATABASE), database(database), host(host), username(username), pw(pw), socket(unixSocket), port(port)
{
LOG_CURRENT_FUNCTIONCALL
LuaObjectBase(state, TYPE_DATABASE), database(database), host(host), username(username), pw(pw), socket(unixSocket), port(port) {
classname = "Database";
registerFunction(state, "prepare", Database::prepare);
registerFunction(state, "escape", Database::escape);
@ -28,29 +25,41 @@ LuaObjectBase(state, TYPE_DATABASE), database(database), host(host), username(us
registerFunction(state, "status", Database::status);
registerFunction(state, "queueSize", Database::queueSize);
registerFunction(state, "setAutoReconnect", Database::setAutoReconnect);
registerFunction(state, "setCachePreparedStatements", Database::setCachePreparedStatements);
registerFunction(state, "setMultiStatements", Database::setMultiStatements);
registerFunction(state, "ping", Database::ping);
}
Database::~Database()
{
LOG_CURRENT_FUNCTIONCALL
Database::~Database() {
this->destroyed = true;
{
std::unique_lock<std::mutex> lck(m_queryQueueMutex);
this->m_queryWakupVariable.notify_all();
}
if (this->m_thread.joinable())
{
if (this->m_thread.joinable()) {
this->m_thread.join();
}
}
//This makes sure that all stmts always get freed
void Database::cacheStatement(MYSQL_STMT* stmt) {
if (stmt == nullptr) return;
std::unique_lock<std::mutex> lck(m_stmtMutex);
cachedStatements.push_back(stmt);
}
//This notifies the database thread to free this statement some time in the future
void Database::freeStatement(MYSQL_STMT* stmt) {
if (stmt == nullptr) return;
std::unique_lock<std::mutex> lck(m_stmtMutex);
cachedStatements.erase(std::remove(cachedStatements.begin(), cachedStatements.end(), stmt));
freedStatements.push_back(stmt);
}
/* Creates and returns a query instance and enqueues it into the queue of accepted queries.
*/
int Database::query(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::query(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
LUA->CheckType(2, GarrysMod::Lua::Type::STRING);
unsigned int outLen = 0;
@ -63,9 +72,7 @@ int Database::query(lua_State* state)
/* Creates and returns a PreparedQuery instance and enqueues it into the queue of accepted queries.
*/
int Database::prepare(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::prepare(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
LUA->CheckType(2, GarrysMod::Lua::Type::STRING);
unsigned int outLen = 0;
@ -78,9 +85,7 @@ int Database::prepare(lua_State* state)
/* Creates and returns a PreparedQuery instance and enqueues it into the queue of accepted queries.
*/
int Database::createTransaction(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::createTransaction(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
Transaction* transactionObject = new Transaction(object, state);
transactionObject->pushTableReference(state);
@ -89,14 +94,12 @@ int Database::createTransaction(lua_State* state)
/* Enqueues a query into the queue of accepted queries.
*/
void Database::enqueueQuery(IQuery* query)
{
LOG_CURRENT_FUNCTIONCALL
void Database::enqueueQuery(IQuery* query, std::shared_ptr<IQueryData> queryData) {
std::unique_lock<std::mutex> qlck(m_queryQueueMutex);
query->canbedestroyed = false;
//std::shared_ptr<IQuery> sharedPtr = query->getSharedPointerInstance();
queryQueue.push_back(std::dynamic_pointer_cast<IQuery>(query->getSharedPointerInstance()));
query->m_status = QUERY_WAITING;
queryQueue.push_back(std::make_pair(std::dynamic_pointer_cast<IQuery>(query->getSharedPointerInstance()), queryData));
queryData->setStatus(QUERY_WAITING);
this->m_queryWakupVariable.notify_one();
}
@ -104,28 +107,25 @@ void Database::enqueueQuery(IQuery* query)
/* Returns the amount of queued querys in the database instance
* If a query is currently being processed, it does not count towards the queue size
*/
int Database::queueSize(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::queueSize(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
std::unique_lock<std::mutex> qlck(object->m_queryQueueMutex);
LUA->PushNumber(object->queryQueue.size());
return 1;
}
/* Aborts all queries that are in the queue of started queries and returns the number of successfully aborted queries.
* Does not abort queries that are already taken from the queue and being processed.
*/
int Database::abortAllQueries(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::abortAllQueries(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
std::lock_guard<std::mutex> lock(object->m_queryQueueMutex);
for (auto& query : object->queryQueue)
{
query->m_status = QUERY_ABORTED;
for (auto& pair : object->queryQueue) {
auto query = pair.first;
auto data = pair.second;
data->setStatus(QUERY_ABORTED);
query->unreference(state);
}
LUA->PushNumber((double)object->queryQueue.size());
@ -136,12 +136,9 @@ int Database::abortAllQueries(lua_State* state)
/* Waits for the connection of the database to finish by blocking the current thread until the connect thread finished.
* Callbacks are going to be called before this function returns
*/
int Database::wait(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::wait(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
if (!object->startedConnecting)
{
if (!object->startedConnecting) {
LUA->ThrowError("Tried to wait for database connection to finish without starting the connection!");
}
std::unique_lock<std::mutex> lck(object->m_connectMutex);
@ -153,9 +150,7 @@ int Database::wait(lua_State* state)
/* Escapes an unescaped string using the database taking into account the characterset of the database.
* This might break if the characterset of the database is changed after the connection was done
*/
int Database::escape(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::escape(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
std::lock_guard<std::mutex>(object->m_connectMutex);
//No query mutex needed since this doesn't use the connection at all
@ -173,12 +168,9 @@ int Database::escape(lua_State* state)
/* Starts the thread that connects to the database and then handles queries.
*/
int Database::connect(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::connect(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE, true);
if (object->m_status != DATABASE_NOT_CONNECTED || object->startedConnecting)
{
if (object->m_status != DATABASE_NOT_CONNECTED || object->startedConnecting) {
LUA->ThrowError("Database already connected.");
}
object->startedConnecting = true;
@ -189,9 +181,7 @@ int Database::connect(lua_State* state)
/* Returns the status of the database, constants can be found in GMModule
*/
int Database::status(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::status(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
LUA->PushNumber(object->m_status);
return 1;
@ -203,12 +193,9 @@ int Database::status(lua_State* state)
/* Returns the server version as a formatted integer (XYYZZ, X= major-, Y=minor, Z=sub-version)
* Only works as soon as the connection has been established
*/
int Database::serverVersion(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::serverVersion(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
if (!object->m_connectionDone)
{
if (!object->m_connectionDone) {
LUA->ThrowError("Tried to get server version when client is not connected to server yet!");
}
LUA->PushNumber(object->m_serverVersion);
@ -218,12 +205,9 @@ int Database::serverVersion(lua_State* state)
/* Returns the server version as a string (for example 5.0.96)
* Only works as soon as the connection has been established
*/
int Database::serverInfo(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::serverInfo(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
if (!object->m_connectionDone)
{
if (!object->m_connectionDone) {
LUA->ThrowError("Tried to get server info when client is not connected to server yet!");
}
LUA->PushString(object->m_serverInfo.c_str());
@ -233,24 +217,18 @@ int Database::serverInfo(lua_State* state)
/* Returns a string of the hostname connected to as well as the connection type
* Only works as soon as the connection has been established
*/
int Database::hostInfo(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::hostInfo(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
if (!object->m_connectionDone)
{
if (!object->m_connectionDone) {
LUA->ThrowError("Tried to get server info when client is not connected to server yet!");
}
LUA->PushString(object->m_hostInfo.c_str());
return 1;
}
int Database::setAutoReconnect(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::setAutoReconnect(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
if (object->m_status != DATABASE_NOT_CONNECTED || object->startedConnecting)
{
if (object->m_status != DATABASE_NOT_CONNECTED || object->startedConnecting) {
LUA->ThrowError("Database already connected.");
}
LUA->CheckType(2, GarrysMod::Lua::Type::BOOL);
@ -258,12 +236,9 @@ int Database::setAutoReconnect(lua_State* state)
return 0;
}
int Database::setMultiStatements(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::setMultiStatements(lua_State* state) {
Database* object = (Database*)unpackSelf(state, TYPE_DATABASE);
if (object->m_status != DATABASE_NOT_CONNECTED || object->startedConnecting)
{
if (object->m_status != DATABASE_NOT_CONNECTED || object->startedConnecting) {
LUA->ThrowError("Database already connected.");
}
LUA->CheckType(2, GarrysMod::Lua::Type::BOOL);
@ -271,12 +246,9 @@ int Database::setMultiStatements(lua_State* state)
return 0;
}
int Database::ping(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int Database::ping(lua_State* state) {
Database* database = (Database*)unpackSelf(state, TYPE_DATABASE);
if (database->m_status != DATABASE_CONNECTED)
{
if (database->m_status != DATABASE_CONNECTED) {
LUA->PushBool(false);
return 1;
}
@ -296,36 +268,55 @@ int Database::ping(lua_State* state)
return 1;
}
//Set this to false if your database server imposes a low prepared statements limit
//Or if you might create a very high amount of prepared queries in a short period of time
int Database::setCachePreparedStatements(lua_State* state) {
Database* database = (Database*)unpackSelf(state, TYPE_DATABASE);
if (database->m_status != DATABASE_NOT_CONNECTED) {
LUA->ThrowError("setCachePreparedStatements has to be called before db:start()");
return 0;
}
LUA->CheckType(2, GarrysMod::Lua::Type::BOOL);
database->cachePreparedStatements = LUA->GetBool();
}
//Should only be called from the db thread
//While the mysql documentation says that mysql_options should only be called
//before the connection is done it appears to work after just fine (at least for reconnect)
void Database::setAutoReconnect(my_bool autoReconnect) {
mysql_options(m_sql, MYSQL_OPT_RECONNECT, &autoReconnect);
}
//Should only be called from the db thread
my_bool Database::getAutoReconnect() {
return m_sql->reconnect;
}
/* Thread that connects to the database, on success it continues to handle queries in the run method.
*/
void Database::connectRun()
{
LOG_CURRENT_FUNCTIONCALL
void Database::connectRun() {
mysql_thread_init();
auto threadEnd = finally([&] { mysql_thread_end(); });
{
auto connectionSignaliser = finally([&] { m_connectWakeupVariable.notify_one(); });
std::lock_guard<std::mutex>(this->m_connectMutex);
this->m_sql = mysql_init(nullptr);
if (this->m_sql == nullptr)
{
if (this->m_sql == nullptr) {
m_success = false;
m_connection_err = "Out of memory";
m_connectionDone = true;
m_status = DATABASE_CONNECTION_FAILED;
return;
}
if (this->shouldAutoReconnect)
{
my_bool reconnect = 1;
mysql_options(this->m_sql, MYSQL_OPT_RECONNECT, &reconnect);
if (this->shouldAutoReconnect) {
setAutoReconnect((my_bool)1);
}
const char* socket = (this->socket.length() == 0) ? nullptr : this->socket.c_str();
unsigned long clientFlag = (this->useMultiStatements) ? CLIENT_MULTI_STATEMENTS : 0;
if (mysql_real_connect(this->m_sql, this->host.c_str(), this->username.c_str(), this->pw.c_str(),
this->database.c_str(), this->port, socket, clientFlag) != this->m_sql)
{
clientFlag |= CLIENT_MULTI_RESULTS;
if (mysql_real_connect(this->m_sql, this->host.c_str(), this->username.c_str(), this->pw.c_str(),
this->database.c_str(), this->port, socket, clientFlag) != this->m_sql) {
m_success = false;
m_connection_err = mysql_error(this->m_sql);
m_connectionDone = true;
@ -341,8 +332,7 @@ void Database::connectRun()
m_hostInfo = mysql_get_host_info(this->m_sql);
}
auto closeConnection = finally([&] { mysql_close(this->m_sql); this->m_sql = nullptr; });
if (m_success)
{
if (m_success) {
run();
}
}
@ -351,16 +341,11 @@ void Database::connectRun()
* In case the database connection was established or failed for the first time the connection callbacks are being run.
* Takes all the queries from the finished queries queue and runs the callback for them.
*/
void Database::think(lua_State* state)
{
if (m_connectionDone && !dbCallbackRan)
{
if (m_success)
{
void Database::think(lua_State* state) {
if (m_connectionDone && !dbCallbackRan) {
if (m_success) {
runCallback(state, "onConnected");
}
else
{
} else {
runCallback(state, "onConnectionFailed", "s", m_connection_err.c_str());
}
this->unreference(state);
@ -368,49 +353,74 @@ void Database::think(lua_State* state)
}
//Needs to lock for condition check to prevent race conditions
std::unique_lock<std::mutex> lock(m_finishedQueueMutex);
while (!finishedQueries.empty())
{
std::shared_ptr<IQuery> curquery = finishedQueries.front();
while (!finishedQueries.empty()) {
auto pair = finishedQueries.front();
auto query = pair.first;
auto data = pair.second;
finishedQueries.pop_front();
//Unlocking here because the lock isn't needed for the callbacks
//Allows the database thread to add more finished queries
lock.unlock();
curquery->doCallback(state);
curquery->canbedestroyed = true;
curquery->unreference(state);
query->setCallbackData(data);
data->setStatus(QUERY_COMPLETE);
query->doCallback(state, data);
query->onQueryDataFinished(state, data);
lock.lock();
}
}
void Database::freeUnusedStatements() {
std::lock_guard<std::mutex> stmtLock(m_stmtMutex);
for (auto& stmt : freedStatements) {
mysql_stmt_close(stmt);
}
freedStatements.clear();
}
/* The run method of the thread of the database instance.
*/
void Database::run()
{
LOG_CURRENT_FUNCTIONCALL
while (true)
{
void Database::run() {
while (true) {
std::unique_lock<std::mutex> lock(m_queryQueueMutex);
//Passively waiting for new queries to arrive
while (this->queryQueue.empty() && !this->destroyed) this->m_queryWakupVariable.wait(lock);
uint64_t counter = 0;
//While there are new queries, execute them
while (!this->queryQueue.empty())
{
std::shared_ptr<IQuery> curquery = this->queryQueue.front();
while (!this->queryQueue.empty()) {
auto pair = this->queryQueue.front();
auto curquery = pair.first;
auto data = pair.second;
this->queryQueue.pop_front();
//The lock isn't needed for this section anymore, since it is not operating on the query queue
lock.unlock();
curquery->executeStatement(this->m_sql);
curquery->executeStatement(this->m_sql, data);
{
//New scope so no nested locking occurs
std::lock_guard<std::mutex> lock(m_finishedQueueMutex);
curquery->finished = true;
finishedQueries.push_back(curquery);
finishedQueries.push_back(pair);
data->setFinished(true);
curquery->m_waitWakeupVariable.notify_one();
}
//So that statements get freed sometimes even if the queue is constantly full
if (counter++ % 200 == 0) {
freeUnusedStatements();
}
lock.lock();
}
if (this->destroyed && this->queryQueue.empty())
lock.unlock();
freeUnusedStatements();
lock.lock();
if (this->destroyed && this->queryQueue.empty()) {
std::lock_guard<std::mutex> lock(m_stmtMutex);
for (auto& stmt : cachedStatements) {
mysql_stmt_close(stmt);
}
cachedStatements.clear();
for (auto& stmt : freedStatements) {
mysql_stmt_close(stmt);
}
freedStatements.clear();
return;
}
}
}

View File

@ -1,18 +1,14 @@
#include "GarrysMod/Lua/Interface.h"
#include "LuaObjectBase.h"
#include "Database.h"
#include "Logger.h"
#include "MySQLHeader.h"
#include <iostream>
#include <iostream>
#include <fstream>
#define MYSQLOO_VERSION "9"
#define MYSQLOO_MINOR_VERSION "1"
GMOD_MODULE_CLOSE()
{
Logger::Log("-----------------------------\n");
Logger::Log("MySQLOO closing\n");
Logger::Log("-----------------------------\n");
GMOD_MODULE_CLOSE() {
/* Deletes all the remaining luaobjects when the server changes map
*/
for (auto query : LuaObjectBase::luaRemovalObjects) {
@ -28,9 +24,7 @@ GMOD_MODULE_CLOSE()
/* Connects to the database and returns a Database instance that can be used
* as an interface to the mysql server.
*/
static int connect(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
static int connect(lua_State* state) {
LUA->CheckType(1, GarrysMod::Lua::Type::STRING);
LUA->CheckType(2, GarrysMod::Lua::Type::STRING);
LUA->CheckType(3, GarrysMod::Lua::Type::STRING);
@ -41,34 +35,93 @@ static int connect(lua_State* state)
std::string database = LUA->GetString(4);
unsigned int port = 3306;
std::string unixSocket = "";
if (LUA->IsType(5, GarrysMod::Lua::Type::NUMBER))
{
port = (int) LUA->GetNumber(5);
if (LUA->IsType(5, GarrysMod::Lua::Type::NUMBER)) {
port = (int)LUA->GetNumber(5);
}
if (LUA->IsType(6, GarrysMod::Lua::Type::STRING))
{
if (LUA->IsType(6, GarrysMod::Lua::Type::STRING)) {
unixSocket = LUA->GetString(6);
}
Database* object = new Database(state, host, username, pw, database, port, unixSocket);
((LuaObjectBase*) object)->pushTableReference(state);
((LuaObjectBase*)object)->pushTableReference(state);
return 1;
}
/* Returns the amount of LuaObjectBase objects that are currently in use
* This includes Database and Query instances
*/
static int objectCount(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
static int objectCount(lua_State* state) {
LUA->PushNumber(LuaObjectBase::luaObjects.size());
return 1;
}
GMOD_MODULE_OPEN()
{
Logger::Log("-----------------------------\n");
Logger::Log("MySQLOO starting\n");
Logger::Log("-----------------------------\n");
static void runInTimer(lua_State* state, double delay, GarrysMod::Lua::CFunc func) {
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->GetField(-1, "timer");
//In case someone removes the timer library
if (LUA->IsType(-1, GarrysMod::Lua::Type::NIL)) {
LUA->Pop(2);
return;
}
LUA->GetField(-1, "Simple");
LUA->PushNumber(delay);
LUA->PushCFunction(func);
LUA->Call(2, 0);
LUA->Pop(2);
}
static void printMessage(lua_State* state, const char* str, int r, int g, int b) {
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->GetField(-1, "Color");
LUA->PushNumber(r);
LUA->PushNumber(g);
LUA->PushNumber(b);
LUA->Call(3, 1);
int ref = LUA->ReferenceCreate();
LUA->GetField(-1, "MsgC");
LUA->ReferencePush(ref);
LUA->PushString(str);
LUA->Call(2, 0);
LUA->Pop();
LUA->ReferenceFree(ref);
}
static int printOutdatatedVersion(lua_State* state) {
printMessage(state, "Your server is using an outdated mysqloo9 version\n", 255, 0, 0);
printMessage(state, "Download the latest version from here:\n", 255, 0, 0);
printMessage(state, "https://github.com/syl0r/MySQLOO/releases\n", 86, 156, 214);
runInTimer(state, 300, printOutdatatedVersion);
return 0;
}
static int fetchSuccessful(lua_State* state) {
std::string version = LUA->GetString(1);
//version.size() < 3 so that the 404 response gets ignored
if (version != MYSQLOO_MINOR_VERSION && version.size() <= 3) {
printOutdatatedVersion(state);
} else {
printMessage(state, "Your server is using the latest mysqloo9 version\n", 0, 255, 0);
}
return 0;
}
static int fetchFailed(lua_State* state) {
printMessage(state, "Failed to retrieve latest version of mysqloo9\n", 255, 0, 0);
return 0;
}
static int doVersionCheck(lua_State* state) {
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->GetField(-1, "http");
LUA->GetField(-1, "Fetch");
LUA->PushString("https://raw.githubusercontent.com/syl0r/MySQLOO/master/minorversion.txt");
LUA->PushCFunction(fetchSuccessful);
LUA->PushCFunction(fetchFailed);
LUA->Call(3, 0);
LUA->Pop(2);
return 0;
}
GMOD_MODULE_OPEN() {
if (mysql_library_init(0, nullptr, nullptr)) {
LUA->ThrowError("Could not initialize mysql library.");
}
@ -83,30 +136,32 @@ GMOD_MODULE_OPEN()
LUA->Pop();
LUA->Pop();
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->CreateTable();
LUA->CreateTable();
LUA->PushString(MYSQLOO_VERSION); LUA->SetField(-2, "VERSION");
LUA->PushString(MYSQLOO_VERSION); LUA->SetField(-2, "VERSION");
LUA->PushString(MYSQLOO_MINOR_VERSION); LUA->SetField(-2, "MINOR_VERSION");
LUA->PushNumber(DATABASE_CONNECTED); LUA->SetField(-2, "DATABASE_CONNECTED");
LUA->PushNumber(DATABASE_CONNECTING); LUA->SetField(-2, "DATABASE_CONNECTING");
LUA->PushNumber(DATABASE_NOT_CONNECTED); LUA->SetField(-2, "DATABASE_NOT_CONNECTED");
LUA->PushNumber(DATABASE_CONNECTION_FAILED); LUA->SetField(-2, "DATABASE_CONNECTION_FAILED");
LUA->PushNumber(DATABASE_CONNECTED); LUA->SetField(-2, "DATABASE_CONNECTED");
LUA->PushNumber(DATABASE_CONNECTING); LUA->SetField(-2, "DATABASE_CONNECTING");
LUA->PushNumber(DATABASE_NOT_CONNECTED); LUA->SetField(-2, "DATABASE_NOT_CONNECTED");
LUA->PushNumber(DATABASE_CONNECTION_FAILED); LUA->SetField(-2, "DATABASE_CONNECTION_FAILED");
LUA->PushNumber(QUERY_NOT_RUNNING); LUA->SetField(-2, "QUERY_NOT_RUNNING");
LUA->PushNumber(QUERY_RUNNING); LUA->SetField(-2, "QUERY_RUNNING");
LUA->PushNumber(QUERY_COMPLETE); LUA->SetField(-2, "QUERY_COMPLETE");
LUA->PushNumber(QUERY_ABORTED); LUA->SetField(-2, "QUERY_ABORTED");
LUA->PushNumber(QUERY_WAITING); LUA->SetField(-2, "QUERY_WAITING");
LUA->PushNumber(QUERY_NOT_RUNNING); LUA->SetField(-2, "QUERY_NOT_RUNNING");
LUA->PushNumber(QUERY_RUNNING); LUA->SetField(-2, "QUERY_RUNNING");
LUA->PushNumber(QUERY_COMPLETE); LUA->SetField(-2, "QUERY_COMPLETE");
LUA->PushNumber(QUERY_ABORTED); LUA->SetField(-2, "QUERY_ABORTED");
LUA->PushNumber(QUERY_WAITING); LUA->SetField(-2, "QUERY_WAITING");
LUA->PushNumber(OPTION_NUMERIC_FIELDS); LUA->SetField(-2, "OPTION_NUMERIC_FIELDS");
LUA->PushNumber(OPTION_INTERPRET_DATA); LUA->SetField(-2, "OPTION_INTERPRET_DATA"); //Not used anymore
LUA->PushNumber(OPTION_NAMED_FIELDS); LUA->SetField(-2, "OPTION_NAMED_FIELDS"); //Not used anymore
LUA->PushNumber(OPTION_CACHE); LUA->SetField(-2, "OPTION_CACHE"); //Not used anymore
LUA->PushNumber(OPTION_NUMERIC_FIELDS); LUA->SetField(-2, "OPTION_NUMERIC_FIELDS");
LUA->PushNumber(OPTION_INTERPRET_DATA); LUA->SetField(-2, "OPTION_INTERPRET_DATA"); //Not used anymore
LUA->PushNumber(OPTION_NAMED_FIELDS); LUA->SetField(-2, "OPTION_NAMED_FIELDS"); //Not used anymore
LUA->PushNumber(OPTION_CACHE); LUA->SetField(-2, "OPTION_CACHE"); //Not used anymore
LUA->PushCFunction(connect); LUA->SetField(-2, "connect");
LUA->PushCFunction(objectCount); LUA->SetField(-2, "objectCount");
LUA->PushCFunction(connect); LUA->SetField(-2, "connect");
LUA->PushCFunction(objectCount); LUA->SetField(-2, "objectCount");
LUA->SetField(-2, "mysqloo");
LUA->SetField(-2, "mysqloo");
LUA->Pop();
runInTimer(state, 5, doVersionCheck);
return 1;
}

View File

@ -1,370 +1,158 @@
#include "IQuery.h"
#include "Database.h"
#include "ResultData.h"
#include "Logger.h"
//Important:
//Calling any query functions that rely on data from the query thread
//before the callback is called can result in race conditions.
//Always check for QUERY_COMPLETE!!!
IQuery::IQuery(Database* dbase, lua_State* state) : LuaObjectBase(state, false, TYPE_QUERY), m_database(dbase)
{
LOG_CURRENT_FUNCTIONCALL
IQuery::IQuery(Database* dbase, lua_State* state) : LuaObjectBase(state, false, TYPE_QUERY), m_database(dbase) {
m_options = OPTION_NAMED_FIELDS | OPTION_INTERPRET_DATA | OPTION_CACHE;
m_status = QUERY_NOT_RUNNING;
registerFunction(state, "start", IQuery::start);
registerFunction(state, "affectedRows", IQuery::affectedRows);
registerFunction(state, "lastInsert", IQuery::lastInsert);
registerFunction(state, "getData", IQuery::getData_Wrapper);
registerFunction(state, "error", IQuery::error);
registerFunction(state, "wait", IQuery::wait);
registerFunction(state, "setOption", IQuery::setOption);
registerFunction(state, "isRunning", IQuery::isRunning);
registerFunction(state, "abort", IQuery::abort);
registerFunction(state, "hasMoreResults", IQuery::hasMoreResults);
registerFunction(state, "getNextResults", IQuery::getNextResults);
}
IQuery::~IQuery()
{
}
void IQuery::setResultStatus(QueryResultStatus status) {
this->m_resultStatus = status;
}
void IQuery::setStatus(QueryStatus status) {
this->m_status = status;
}
IQuery::~IQuery() {}
QueryResultStatus IQuery::getResultStatus() {
return this->m_resultStatus;
}
//When the query is destroyed by lua
void IQuery::onDestroyed(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
if (this->dataReference != 0 && state != nullptr)
{
//Make sure data associated with this query can be freed as well
LUA->ReferenceFree(this->dataReference);
this->dataReference = 0;
if (!hasCallbackData()) {
return QUERY_NONE;
}
}
//This function just returns the data associated with the query
//Data is only created once (and then the reference to that data is returned)
int IQuery::getData_Wrapper(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_resultStatus != QUERY_SUCCESS)
{
LUA->PushNil();
}
else
{
LUA->ReferencePush(object->getData(state));
}
return 1;
}
//Returns true if a query has at least one additional ResultSet left
int IQuery::hasMoreResults(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_COMPLETE)
{
LUA->ThrowError("Query not completed yet");
}
LUA->PushBool(object->results.size() > 0);
return 1;
}
//Unreferences the current result set and uses the next result set
int IQuery::getNextResults(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_COMPLETE)
{
LUA->ThrowError("Query not completed yet");
}
else if (object->results.size() == 0)
{
LUA->ThrowError("Query doesn't have any more results");
}
if (object->dataReference != 0)
{
LUA->ReferenceFree(object->dataReference);
object->dataReference = 0;
}
if (object->m_insertIds.size() > 0)
{
object->m_insertIds.pop_front();
}
if (object->m_affectedRows.size() > 0)
{
object->m_affectedRows.pop_front();
}
LUA->ReferencePush(object->getData(state));
return 1;
return callbackQueryData->getResultStatus();
}
//Wrapper for c api calls
//Just throws an exception if anything goes wrong for ease of use
void IQuery::mysqlAutocommit(MYSQL* sql, bool auto_mode)
{
LOG_CURRENT_FUNCTIONCALL
int result = mysql_autocommit(sql, auto_mode);
if (result != 0)
{
void IQuery::mysqlAutocommit(MYSQL* sql, bool auto_mode) {
int result = mysql_autocommit(sql, auto_mode);
if (result != 0) {
const char* errorMessage = mysql_error(sql);
int errorCode = mysql_errno(sql);
throw MySQLException(errorCode, errorMessage);
}
}
//Stores the data associated with the current result set of the query
//Only called once per result set (and then cached)
int IQuery::getData(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
if (this->dataReference != 0)
return this->dataReference;
LUA->CreateTable();
if (this->results.size() > 0)
{
ResultData& currentData = this->results.front();
for (unsigned int i = 0; i < currentData.getRows().size(); i++)
{
ResultDataRow& row = currentData.getRows()[i];
LUA->CreateTable();
int rowObject = LUA->ReferenceCreate();
for (unsigned int j = 0; j < row.getValues().size(); j++)
{
dataToLua( state, rowObject, j + 1, row.getValues()[j], currentData.getColumns()[j].c_str(),
currentData.getColumnTypes()[j], row.isFieldNull(j));
}
LUA->PushNumber(i+1);
LUA->ReferencePush(rowObject);
LUA->SetTable(-3);
LUA->ReferenceFree(rowObject);
}
//ResultSet is not needed anymore since we stored it in lua tables.
this->results.pop_front();
}
this->dataReference = LUA->ReferenceCreate();
return this->dataReference;
};
//Function that converts the data stored in a mysql field into a lua type
void IQuery::dataToLua(lua_State* state, int rowReference, unsigned int column, std::string &columnValue, const char* columnName, int columnType, bool isNull)
{
LUA->ReferencePush(rowReference);
if (this->m_options & OPTION_NUMERIC_FIELDS)
{
LUA->PushNumber(column);
}
if (isNull)
{
LUA->PushNil();
}
else
{
switch (columnType)
{
case MYSQL_TYPE_FLOAT:
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_TINY:
case MYSQL_TYPE_SHORT:
LUA->PushNumber(atof(columnValue.c_str()));
break;
case MYSQL_TYPE_BIT:
LUA->PushNumber(static_cast<int>(columnValue[0]));
break;
case MYSQL_TYPE_NULL:
LUA->PushNil();
break;
default:
LUA->PushString(columnValue.c_str(), columnValue.length());
break;
}
}
if (this->m_options & OPTION_NUMERIC_FIELDS)
{
LUA->SetTable(-3);
}
else
{
LUA->SetField(-2, columnName);
}
LUA->Pop();
}
//Queues the query into the queue of the database instance associated with it
int IQuery::start(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY, true);
if (object->m_status != QUERY_NONE)
{
LUA->ThrowError("Query already started");
}
else
{
object->m_database->enqueueQuery(object);
int IQuery::start(lua_State* state) {
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
if (object->runningQueryData.size() == 0) {
referenceTable(state, object, 1);
}
std::shared_ptr<IQueryData> ptr = object->buildQueryData(state);
object->addQueryData(state, ptr);
object->m_database->enqueueQuery(object, ptr);
object->hasBeenStarted = true;
return 0;
}
//Returns if the query has been queued with the database instance
int IQuery::isRunning(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int IQuery::isRunning(lua_State* state) {
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
LUA->PushBool(object->m_status == QUERY_RUNNING || object->m_status == QUERY_WAITING);
return 1;
}
//Returns the last insert id produced by INSERT INTO statements (or 0 if there is none)
int IQuery::lastInsert(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
//Calling lastInsert() after query was executed but before the callback is run can cause race conditions
if (object->m_status != QUERY_COMPLETE || object->m_insertIds.size() == 0)
LUA->PushNumber(0);
else
LUA->PushNumber((double) object->m_insertIds.front());
return 1;
}
//Returns the last affected rows produced by INSERT/DELETE/UPDATE (0 for none, -1 for errors)
//For a SELECT statement this returns the amount of rows returned
int IQuery::affectedRows(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
//Calling affectedRows() after query was executed but before the callback is run can cause race conditions
if (object->m_status != QUERY_COMPLETE || object->m_affectedRows.size() == 0)
LUA->PushNumber(0);
else
LUA->PushNumber((double) object->m_affectedRows.front());
LUA->PushBool(object->runningQueryData.size() > 0);
return 1;
}
//Blocks the current Thread until the query has finished processing
//Possibly dangerous (dead lock when database goes down while waiting)
//If the second argument is set to true, the query is going to be swapped to the front of the query queue
int IQuery::wait(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int IQuery::wait(lua_State* state) {
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
bool shouldSwap = false;
if (LUA->IsType(2, GarrysMod::Lua::Type::BOOL))
{
if (LUA->IsType(2, GarrysMod::Lua::Type::BOOL)) {
shouldSwap = LUA->GetBool(2);
}
if (object->m_status == QUERY_NOT_RUNNING)
{
if (object->runningQueryData.size() == 0) {
LUA->ThrowError("Query not started.");
}
std::shared_ptr<IQueryData> lastInsertedQuery = object->runningQueryData.back();
//Changing the order of the query might have unwanted side effects, so this is disabled by default
if(shouldSwap)
{
if (shouldSwap) {
std::lock_guard<std::mutex> lck(object->m_database->m_queryQueueMutex);
auto pos = std::find_if(object->m_database->queryQueue.begin(), object->m_database->queryQueue.end(), [&](std::shared_ptr<IQuery> const& p) {
return p.get() == object;
auto pos = std::find_if(object->m_database->queryQueue.begin(), object->m_database->queryQueue.end(),
[&](std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>> const& p) {
return p.second.get() == lastInsertedQuery.get();
});
if (pos != object->m_database->queryQueue.begin() && pos != object->m_database->queryQueue.end())
{
if (pos != object->m_database->queryQueue.begin() && pos != object->m_database->queryQueue.end()) {
std::iter_swap(pos, object->m_database->queryQueue.begin());
}
}
{
std::unique_lock<std::mutex> lck(object->m_database->m_finishedQueueMutex);
while (!object->finished) object->m_waitWakeupVariable.wait(lck);
while (!lastInsertedQuery->isFinished()) object->m_waitWakeupVariable.wait(lck);
}
object->m_database->think(state);
return 0;
}
//Returns the error message produced by the mysql query or 0 if there is none
int IQuery::error(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int IQuery::error(lua_State* state) {
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_COMPLETE)
if (object->hasCallbackData()) {
return 0;
else
{
LUA->PushString(object->m_errorText.c_str());
return 1;
}
//Calling affectedRows() after query was executed but before the callback is run can cause race conditions
LUA->PushString(object->callbackQueryData->getError().c_str());
return 1;
}
//Attempts to abort the query, returns true if it was able to stop the query in time, false otherwise
int IQuery::abort(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
//Attempts to abort the query, returns true if it was able to stop at least one query in time, false otherwise
int IQuery::abort(lua_State* state) {
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
std::lock_guard<std::mutex> lock(object->m_database->m_queryQueueMutex);
auto it = std::find_if(object->m_database->queryQueue.begin(), object->m_database->queryQueue.end(), [&](std::shared_ptr<IQuery> const& p) {
return p.get() == object;
});
if (it != object->m_database->queryQueue.end())
{
object->m_database->queryQueue.erase(it);
object->m_status = QUERY_ABORTED;
object->unreference(state);
LUA->PushBool(true);
object->runCallback(state, "onAborted");
}
else
{
LUA->PushBool(false);
bool wasAborted = false;
//This is copied so that I can remove entries from that vector
auto vec = object->runningQueryData;
for (auto& data : vec) {
//It doesn't really matter if any of them are in a transaction since in that case they
//aren't in the query queue
auto it = std::find_if(object->m_database->queryQueue.begin(), object->m_database->queryQueue.end(),
[&](std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>> const& p) {
return p.second.get() == data.get();
});
if (it != object->m_database->queryQueue.end()) {
object->m_database->queryQueue.erase(it);
data->setStatus(QUERY_ABORTED);
wasAborted = true;
if (data->getAbortReference() != 0) {
object->runFunction(state, data->getAbortReference());
} else if (data->isFirstData()) {
object->runCallback(state, "onAborted");
}
object->onQueryDataFinished(state, data);
}
}
LUA->PushBool(wasAborted);
return 1;
}
//Sets several query options
int IQuery::setOption(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int IQuery::setOption(lua_State* state) {
IQuery* object = (IQuery*)unpackSelf(state, TYPE_QUERY);
LUA->CheckType(2, GarrysMod::Lua::Type::NUMBER);
bool set = true;
int option = (int) LUA->GetNumber(2);
int option = (int)LUA->GetNumber(2);
if (option != OPTION_NUMERIC_FIELDS &&
option != OPTION_NAMED_FIELDS &&
option != OPTION_INTERPRET_DATA &&
option != OPTION_CACHE)
{
option != OPTION_CACHE) {
LUA->ThrowError("Invalid option");
return 0;
}
if (LUA->Top() >= 3)
{
if (LUA->Top() >= 3) {
LUA->CheckType(3, GarrysMod::Lua::Type::BOOL);
set = LUA->GetBool(3);
}
if (set)
{
if (set) {
object->m_options |= option;
}
else
{
} else {
object->m_options &= ~option;
}
return 0;
@ -373,26 +161,20 @@ int IQuery::setOption(lua_State* state)
//Wrapper for c api calls
//Just throws an exception if anything goes wrong for ease of use
void IQuery::mysqlQuery(MYSQL* sql, std::string &query)
{
LOG_CURRENT_FUNCTIONCALL
void IQuery::mysqlQuery(MYSQL* sql, std::string &query) {
int result = mysql_real_query(sql, query.c_str(), query.length());
if (result != 0)
{
if (result != 0) {
const char* errorMessage = mysql_error(sql);
int errorCode = mysql_errno(sql);
throw MySQLException(errorCode, errorMessage);
}
}
MYSQL_RES* IQuery::mysqlStoreResults(MYSQL* sql)
{
MYSQL_RES* IQuery::mysqlStoreResults(MYSQL* sql) {
MYSQL_RES* result = mysql_store_result(sql);
if (result == nullptr)
{
if (result == nullptr) {
int errorCode = mysql_errno(sql);
if (errorCode != 0)
{
if (errorCode != 0) {
const char* errorMessage = mysql_error(sql);
throw MySQLException(errorCode, errorMessage);
}
@ -400,16 +182,46 @@ MYSQL_RES* IQuery::mysqlStoreResults(MYSQL* sql)
return result;
}
bool IQuery::mysqlNextResult(MYSQL* sql)
{
bool IQuery::mysqlNextResult(MYSQL* sql) {
int result = mysql_next_result(sql);
if (result == 0) return true;
if (result == -1) return false;
int errorCode = mysql_errno(sql);
if (errorCode != 0)
{
if (errorCode != 0) {
const char* errorMessage = mysql_error(sql);
throw MySQLException(errorCode, errorMessage);
}
return false;
}
void IQuery::addQueryData(lua_State* state, std::shared_ptr<IQueryData> data, bool shouldRefCallbacks) {
if (!hasBeenStarted) {
data->m_wasFirstData = true;
}
runningQueryData.push_back(data);
if (shouldRefCallbacks) {
data->m_onDataReference = this->getCallbackReference(state, "onData");
data->m_errorReference = this->getCallbackReference(state, "onError");
data->m_abortReference = this->getCallbackReference(state, "onAborted");
data->m_successReference = this->getCallbackReference(state, "onSuccess");
}
}
void IQuery::onQueryDataFinished(lua_State* state, std::shared_ptr<IQueryData> data) {
runningQueryData.erase(std::remove(runningQueryData.begin(), runningQueryData.end(), data));
if (runningQueryData.size() == 0) {
canbedestroyed = true;
unreference(state);
}
if (data->m_onDataReference) {
LUA->ReferenceFree(data->m_onDataReference);
}
if (data->m_errorReference) {
LUA->ReferenceFree(data->m_errorReference);
}
if (data->m_abortReference) {
LUA->ReferenceFree(data->m_abortReference);
}
if (data->m_successReference) {
LUA->ReferenceFree(data->m_successReference);
}
}

View File

@ -1,69 +0,0 @@
#define _CRT_SECURE_NO_DEPRECATE
#include "Logger.h"
#include <string>
#include <mutex>
#include <cstdio>
#include <stdarg.h>
namespace Logger
{
static std::mutex loggerMutex;
static FILE* logFile;
static bool loggingFailed = false;
//Disables logging
static void disableLogging(const char* message)
{
if (logFile != nullptr)
{
fclose(logFile);
logFile = nullptr;
}
loggingFailed = true;
printf("%s\n", message);
}
//Attempts to initialize the logfile stream
//If it fails it disables logging
static bool initFileStream()
{
logFile = fopen("mysqloo.log", "a");
if (logFile == nullptr)
{
disableLogging("Logger failed to open log file, logging disabled");
return false;
}
return true;
}
//Logs a message to mysqloo.log, if it fails logging will be disabled
void Log(const char* format, ...)
{
#ifdef LOGGER_ENABLED
//Double check in case one thread doesn't know about the change to loggingFailed yet,
//but to increase performance in case it already does
if (loggingFailed) return;
std::lock_guard<std::mutex> lock(loggerMutex);
if (loggingFailed) return;
if (logFile == nullptr)
{
if (!initFileStream())
{
return;
}
}
va_list varArgs;
va_start(varArgs, format);
int result = vfprintf(logFile, format, varArgs);
va_end(varArgs);
if (result < 0)
{
disableLogging("Failed to write to log file");
}
else if (fflush(logFile) < 0)
{
disableLogging("Failed to flush to log file");
}
#endif
}
};

View File

@ -1,5 +1,4 @@
#include "LuaObjectBase.h"
#include "Logger.h"
#include <stdarg.h>
#include <iostream>
#include <sstream>
@ -13,87 +12,72 @@ int LuaObjectBase::tableMetaTable = 0;
int LuaObjectBase::userdataMetaTable = 0;
LuaObjectBase::LuaObjectBase(lua_State* state, bool shouldthink, unsigned char type) : type(type)
{
LOG_CURRENT_FUNCTIONCALL
LuaObjectBase::LuaObjectBase(lua_State* state, bool shouldthink, unsigned char type) : type(type) {
classname = "LuaObject";
this->shouldthink = shouldthink;
std::shared_ptr<LuaObjectBase> ptr(this);
if (shouldthink)
{
if (shouldthink) {
luaThinkObjects.push_back(ptr);
}
luaObjects.push_back(ptr);
}
LuaObjectBase::LuaObjectBase(lua_State* state, unsigned char type) : LuaObjectBase::LuaObjectBase(state, true, type)
{
}
LuaObjectBase::LuaObjectBase(lua_State* state, unsigned char type) : LuaObjectBase::LuaObjectBase(state, true, type) {}
//Important!!!!
//LuaObjectBase should never be deleted manually
//Let shared_ptrs handle it
LuaObjectBase::~LuaObjectBase()
{
}
LuaObjectBase::~LuaObjectBase() {}
//Makes C++ functions callable from lua
void LuaObjectBase::registerFunction(lua_State* state, std::string name, GarrysMod::Lua::CFunc func)
{
LOG_CURRENT_FUNCTIONCALL
void LuaObjectBase::registerFunction(lua_State* state, std::string name, GarrysMod::Lua::CFunc func) {
this->m_callbackFunctions[name] = func;
}
//Returns the shared_ptr that exists for this instance
std::shared_ptr<LuaObjectBase> LuaObjectBase::getSharedPointerInstance()
{
std::shared_ptr<LuaObjectBase> LuaObjectBase::getSharedPointerInstance() {
return shared_from_this();
}
//Gets the C++ object associated with a lua table that represents it in LUA
LuaObjectBase* LuaObjectBase::unpackSelf(lua_State* state, int type, bool shouldReference)
{
LuaObjectBase* LuaObjectBase::unpackSelf(lua_State* state, int type, bool shouldReference) {
return unpackLuaObject(state, 1, type, shouldReference);
}
//Gets the C++ object associated with a lua table that represents it in LUA
LuaObjectBase* LuaObjectBase::unpackLuaObject(lua_State* state, int index, int type, bool shouldReference)
{
LOG_CURRENT_FUNCTIONCALL
LuaObjectBase* LuaObjectBase::unpackLuaObject(lua_State* state, int index, int type, bool shouldReference) {
LUA->CheckType(index, GarrysMod::Lua::Type::TABLE);
LUA->GetField(index, "___lua_userdata_object");
GarrysMod::Lua::UserData* ud = (GarrysMod::Lua::UserData*) LUA->GetUserdata(-1);
if (ud->type != type && type != -1)
{
if (ud->type != type && type != -1) {
std::ostringstream oss;
oss << "Wrong type, expected " << type << " got " << ((int)ud->type);
LUA->ThrowError(oss.str().c_str());
}
LuaObjectBase* object = (LuaObjectBase*)ud->data;
if (shouldReference)
{
if (object->m_userdataReference == 0 && object->m_tableReference == 0) {
object->m_userdataReference = LUA->ReferenceCreate();
LUA->Push(index); //Pushes table that needs to be referenced
object->m_tableReference = LUA->ReferenceCreate();
}
else {
LUA->ThrowError("Tried to reference lua object twice (Query started twice?)");
}
}
else
{
LUA->Pop();
if (shouldReference) {
referenceTable(state, object, index);
}
LUA->Pop();
return object;
}
void LuaObjectBase::referenceTable(lua_State* state, LuaObjectBase* object, int index) {
if (object->m_userdataReference != 0 || object->m_tableReference != 0) {
LUA->ThrowError("Tried to reference lua object twice (Query started twice?)");
}
LUA->CheckType(index, GarrysMod::Lua::Type::TABLE);
LUA->Push(index);
object->m_tableReference = LUA->ReferenceCreate();
LUA->Push(index);
LUA->GetField(index, "___lua_userdata_object");
object->m_userdataReference = LUA->ReferenceCreate();
}
//Pushes the table reference of a C++ object that represents it in LUA
int LuaObjectBase::pushTableReference(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
if (m_tableReference != 0)
{
int LuaObjectBase::pushTableReference(lua_State* state) {
if (m_tableReference != 0) {
LUA->ReferencePush(m_tableReference);
return 1;
}
@ -108,8 +92,7 @@ int LuaObjectBase::pushTableReference(lua_State* state)
LUA->CreateTable();
LUA->ReferencePush(userdatareference);
LUA->SetField(-2, "___lua_userdata_object");
for (auto& callback : this->m_callbackFunctions)
{
for (auto& callback : this->m_callbackFunctions) {
LUA->PushCFunction(callback.second);
LUA->SetField(-2, callback.first.c_str());
}
@ -120,25 +103,19 @@ int LuaObjectBase::pushTableReference(lua_State* state)
}
//Unreferences the table that represents this C++ object in lua, so that it can be gc'ed
void LuaObjectBase::unreference(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
if (m_tableReference != 0)
{
void LuaObjectBase::unreference(lua_State* state) {
if (m_tableReference != 0) {
LUA->ReferenceFree(m_tableReference);
m_tableReference = 0;
}
if (m_userdataReference != 0)
{
if (m_userdataReference != 0) {
LUA->ReferenceFree(m_userdataReference);
m_userdataReference = 0;
}
}
//Checks whether or not a callback exists
bool LuaObjectBase::hasCallback(lua_State* state, const char* functionName)
{
LOG_CURRENT_FUNCTIONCALL
bool LuaObjectBase::hasCallback(lua_State* state, const char* functionName) {
if (this->m_tableReference == 0) return false;
LUA->ReferencePush(this->m_tableReference);
LUA->GetField(-1, functionName);
@ -147,120 +124,124 @@ bool LuaObjectBase::hasCallback(lua_State* state, const char* functionName)
return hasCallback;
}
//Runs callbacks associated with the lua object
void LuaObjectBase::runCallback(lua_State* state, const char* functionName, const char* sig, ...)
{
LOG_CURRENT_FUNCTIONCALL
if (this->m_tableReference == 0) return;
int LuaObjectBase::getCallbackReference(lua_State* state, const char* functionName) {
if (this->m_tableReference == 0) return 0;
LUA->ReferencePush(this->m_tableReference);
LUA->GetField(-1, functionName);
if (LUA->GetType(-1) != GarrysMod::Lua::Type::FUNCTION)
{
if (LUA->GetType(-1) != GarrysMod::Lua::Type::FUNCTION) {
LUA->Pop(2);
return;
return 0;
}
//Hacky solution so there isn't too much stuff on the stack after this
int ref = LUA->ReferenceCreate();
return ref;
}
void LuaObjectBase::runFunction(lua_State* state, int funcRef, const char* sig, ...) {
if (funcRef == 0) return;
va_list arguments;
va_start(arguments, sig);
runFunctionVaList(state, funcRef, sig, arguments);
va_end(arguments);
}
void LuaObjectBase::runFunctionVaList(lua_State* state, int funcRef, const char* sig, va_list arguments) {
if (funcRef == 0) return;
if (this->m_tableReference == 0) return;
LUA->ReferencePush(funcRef);
pushTableReference(state);
int numArguments = 1;
if (sig)
{
va_list arguments;
va_start(arguments, sig);
for (unsigned int i = 0; i < std::strlen(sig); i++)
{
if (sig) {
for (unsigned int i = 0; i < std::strlen(sig); i++) {
char option = sig[i];
if (option == 'i')
{
if (option == 'i') {
int value = va_arg(arguments, int);
LUA->PushNumber(value);
numArguments++;
}
else if (option == 'f')
{
} else if (option == 'f') {
float value = static_cast<float>(va_arg(arguments, double));
LUA->PushNumber(value);
numArguments++;
}
else if (option == 'b')
{
} else if (option == 'b') {
bool value = va_arg(arguments, int) != 0;
LUA->PushBool(value);
numArguments++;
}
else if (option == 's')
{
} else if (option == 's') {
char* value = va_arg(arguments, char*);
LUA->PushString(value);
numArguments++;
}
else if (option == 'o')
{
} else if (option == 'o') {
int value = va_arg(arguments, int);
LUA->ReferencePush(value);
numArguments++;
}
else if (option == 'r')
{
} else if (option == 'r') {
int reference = va_arg(arguments, int);
LUA->ReferencePush(reference);
numArguments++;
}
else if (option == 'F')
{
} else if (option == 'F') {
GarrysMod::Lua::CFunc value = va_arg(arguments, GarrysMod::Lua::CFunc);
LUA->PushCFunction(value);
numArguments++;
}
}
va_end(arguments);
}
if (LUA->PCall(numArguments, 0, 0)) {
const char* err = LUA->GetString(-1);
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->GetField(-1, "ErrorNoHalt");
//In case someone removes ErrorNoHalt this doesn't break everything
if (!LUA->IsType(-1, GarrysMod::Lua::Type::FUNCTION)) {
LUA->Pop(2);
return;
}
LUA->PushString(err);
LUA->Call(1, 0);
LUA->Pop(1);
}
}
LUA->Pop(1);
//Runs callbacks associated with the lua object
void LuaObjectBase::runCallback(lua_State* state, const char* functionName, const char* sig, ...) {
if (this->m_tableReference == 0) return;
int funcRef = getCallbackReference(state, functionName);
if (funcRef == 0) {
LUA->ReferenceFree(funcRef);
return;
}
va_list arguments;
va_start(arguments, sig);
runFunctionVaList(state, funcRef, sig, arguments);
va_end(arguments);
LUA->ReferenceFree(funcRef);
}
//Called every tick, checks if the object can be destroyed
int LuaObjectBase::doThink(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int LuaObjectBase::doThink(lua_State* state) {
//Think objects need to be copied because a think call could modify the original thinkObject queue
//which leads to it invalidating the iterator and thus undefined behaviour
std::deque<std::shared_ptr<LuaObjectBase>> thinkObjectsCopy = luaThinkObjects;
for (auto& query : luaThinkObjects) {
query->think(state);
}
if (luaRemovalObjects.size() > 0)
{
for (auto it = luaRemovalObjects.begin(); it != luaRemovalObjects.end(); )
{
if (luaRemovalObjects.size() > 0) {
for (auto it = luaRemovalObjects.begin(); it != luaRemovalObjects.end(); ) {
LuaObjectBase* obj = (*it).get();
if (obj->canbedestroyed)
{
if (obj->canbedestroyed) {
obj->onDestroyed(state);
it = luaRemovalObjects.erase(it);
auto objectIt = std::find_if(luaObjects.begin(), luaObjects.end(), [&](std::shared_ptr<LuaObjectBase> const& p) {
return p.get() == obj;
});
if (objectIt != luaObjects.end())
{
if (objectIt != luaObjects.end()) {
luaObjects.erase(objectIt);
}
auto thinkObjectIt = std::find_if(luaThinkObjects.begin(), luaThinkObjects.end(), [&](std::shared_ptr<LuaObjectBase> const& p) {
return p.get() == obj;
});
if (thinkObjectIt != luaThinkObjects.end())
{
if (thinkObjectIt != luaThinkObjects.end()) {
luaThinkObjects.erase(thinkObjectIt);
}
}
else
{
} else {
++it;
}
}
@ -270,17 +251,13 @@ int LuaObjectBase::doThink(lua_State* state)
//Called when the LUA table representing this C++ object has been gc'ed
//Deletes the associated C++ object
int LuaObjectBase::gcDeleteWrapper(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int LuaObjectBase::gcDeleteWrapper(lua_State* state) {
GarrysMod::Lua::UserData* obj = (GarrysMod::Lua::UserData*) LUA->GetUserdata(1);
if (!obj || (obj->type != TYPE_DATABASE && obj->type != TYPE_QUERY))
return 0;
LuaObjectBase* object = (LuaObjectBase*)obj->data;
if (!object->scheduledForRemoval)
{
if (object->m_userdataReference != 0)
{
if (!object->scheduledForRemoval) {
if (object->m_userdataReference != 0) {
LUA->ReferenceFree(object->m_userdataReference);
object->m_userdataReference = 0;
}
@ -293,9 +270,7 @@ int LuaObjectBase::gcDeleteWrapper(lua_State* state)
}
//Prints the name of the object
int LuaObjectBase::toStringWrapper(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int LuaObjectBase::toStringWrapper(lua_State* state) {
LuaObjectBase* object = unpackSelf(state);
std::stringstream ss;
ss << object->classname << " " << object;
@ -304,9 +279,7 @@ int LuaObjectBase::toStringWrapper(lua_State* state)
}
//Creates metatables used for the LUA representation of the C++ table
int LuaObjectBase::createMetatables(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int LuaObjectBase::createMetatables(lua_State* state) {
LUA->CreateTable();
LUA->PushCFunction(LuaObjectBase::gcDeleteWrapper);

View File

@ -1,24 +1,22 @@
#include "PingQuery.h"
#include "Logger.h"
#ifdef LINUX
#include <stdlib.h>
#endif
#include "Database.h"
//Dummy class just used with the Database::ping function
PingQuery::PingQuery(Database* dbase, lua_State* state) : Query(dbase, state)
{
PingQuery::PingQuery(Database* dbase, lua_State* state) : Query(dbase, state) {
classname = "PingQuery";
}
PingQuery::~PingQuery(void)
{
}
PingQuery::~PingQuery(void) {}
/* Executes the ping query
*/
void PingQuery::executeQuery(MYSQL* connection)
{
LOG_CURRENT_FUNCTIONCALL
void PingQuery::executeQuery(MYSQL* connection, std::shared_ptr<IQueryData> data) {
my_bool oldAutoReconnect = this->m_database->getAutoReconnect();
this->m_database->setAutoReconnect((my_bool) 1);
this->pingSuccess = mysql_ping(connection) == 0;
this->m_database->setAutoReconnect(oldAutoReconnect);
}

View File

@ -1,105 +1,98 @@
#include "PreparedQuery.h"
#include "Logger.h"
#include "Database.h"
#include "errmsg.h"
#ifdef LINUX
#include <stdlib.h>
#endif
//This is dirty but hopefully will be consistent between mysql connector versions
#define ER_MAX_PREPARED_STMT_COUNT_REACHED 1461
PreparedQuery::PreparedQuery(Database* dbase, lua_State* state) : Query(dbase, state)
{
PreparedQuery::PreparedQuery(Database* dbase, lua_State* state) : Query(dbase, state) {
classname = "PreparedQuery";
registerFunction(state, "setNumber", PreparedQuery::setNumber);
registerFunction(state, "setString", PreparedQuery::setString);
registerFunction(state, "setBoolean", PreparedQuery::setBoolean);
registerFunction(state, "setNull", PreparedQuery::setNull);
registerFunction(state, "putNewParameters", PreparedQuery::putNewParameters);
this->parameters.push_back(std::unordered_map<unsigned int, std::unique_ptr<PreparedQueryField>>());
this->m_parameters.push_back(std::unordered_map<unsigned int, std::shared_ptr<PreparedQueryField>>());
}
PreparedQuery::~PreparedQuery(void)
{
PreparedQuery::~PreparedQuery(void) {}
//When the query is destroyed by lua
void PreparedQuery::onDestroyed(lua_State* state) {
{
//There can't be any race conditions here
//This always runs after PreparedQuery::executeQuery() is done
//I am using atomic to prevent visibility issues though
MYSQL_STMT* stmt = this->cachedStatement;
if (stmt != nullptr) {
m_database->freeStatement(cachedStatement);
cachedStatement = nullptr;
}
}
IQuery::onDestroyed(state);
}
int PreparedQuery::setNumber(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int PreparedQuery::setNumber(lua_State* state) {
PreparedQuery* object = (PreparedQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_NONE)
LUA->ThrowError("Query already started");
LUA->CheckType(2, GarrysMod::Lua::Type::NUMBER);
LUA->CheckType(3, GarrysMod::Lua::Type::NUMBER);
double index = LUA->GetNumber(2);
if (index < 1) LUA->ThrowError("Index must be greater than 0");
unsigned int uIndex = (unsigned int) index;
unsigned int uIndex = (unsigned int)index;
double value = LUA->GetNumber(3);
object->parameters.back().insert(std::make_pair(uIndex, std::unique_ptr<PreparedQueryField>(new TypedQueryField<double>(uIndex, MYSQL_TYPE_DOUBLE, value))));
object->m_parameters.back()[uIndex] = std::shared_ptr<PreparedQueryField>(new TypedQueryField<double>(uIndex, MYSQL_TYPE_DOUBLE, value));
return 0;
}
int PreparedQuery::setString(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int PreparedQuery::setString(lua_State* state) {
PreparedQuery* object = (PreparedQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_NONE)
LUA->ThrowError("Query already started");
LUA->CheckType(2, GarrysMod::Lua::Type::NUMBER);
LUA->CheckType(3, GarrysMod::Lua::Type::STRING);
double index = LUA->GetNumber(2);
if (index < 1) LUA->ThrowError("Index must be greater than 0");
unsigned int uIndex = (unsigned int) index;
unsigned int uIndex = (unsigned int)index;
unsigned int length = 0;
const char* string = LUA->GetString(3, &length);
object->parameters.back().insert(std::make_pair(uIndex, std::unique_ptr<PreparedQueryField>(new TypedQueryField<std::string>(uIndex, MYSQL_TYPE_STRING, std::string(string, length)))));
object->m_parameters.back()[uIndex] = std::shared_ptr<PreparedQueryField>(new TypedQueryField<std::string>(uIndex, MYSQL_TYPE_STRING, std::string(string, length)));
return 0;
}
int PreparedQuery::setBoolean(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int PreparedQuery::setBoolean(lua_State* state) {
PreparedQuery* object = (PreparedQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_NONE)
LUA->ThrowError("Query already started");
LUA->CheckType(2, GarrysMod::Lua::Type::NUMBER);
LUA->CheckType(3, GarrysMod::Lua::Type::BOOL);
double index = LUA->GetNumber(2);
if (index < 1) LUA->ThrowError("Index must be greater than 0");
unsigned int uIndex = (unsigned int) index;
unsigned int uIndex = (unsigned int)index;
bool value = LUA->GetBool(3);
object->parameters.back().insert(std::make_pair(uIndex, std::unique_ptr<PreparedQueryField>(new TypedQueryField<bool>(uIndex, MYSQL_TYPE_BIT, value))));
object->m_parameters.back()[uIndex] = std::shared_ptr<PreparedQueryField>(new TypedQueryField<bool>(uIndex, MYSQL_TYPE_BIT, value));
return 0;
}
int PreparedQuery::setNull(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int PreparedQuery::setNull(lua_State* state) {
PreparedQuery* object = (PreparedQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_NONE)
LUA->ThrowError("Query already started");
LUA->CheckType(2, GarrysMod::Lua::Type::NUMBER);
double index = LUA->GetNumber(2);
if (index < 1) LUA->ThrowError("Index must be greater than 0");
unsigned int uIndex = (unsigned int) index;
object->parameters.back().insert(std::make_pair(uIndex, std::unique_ptr<PreparedQueryField>(new PreparedQueryField(uIndex, MYSQL_TYPE_NULL))));
unsigned int uIndex = (unsigned int)index;
object->m_parameters.back()[uIndex] = std::shared_ptr<PreparedQueryField>(new PreparedQueryField(uIndex, MYSQL_TYPE_NULL));
return 0;
}
//Adds an additional set of parameters to the prepared query
//This makes it relatively easy to insert multiple rows at once
int PreparedQuery::putNewParameters(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
int PreparedQuery::putNewParameters(lua_State* state) {
PreparedQuery* object = (PreparedQuery*)unpackSelf(state, TYPE_QUERY);
if (object->m_status != QUERY_NONE)
LUA->ThrowError("Query already started");
object->parameters.push_back(std::unordered_map<unsigned int, std::unique_ptr<PreparedQueryField>>());
object->m_parameters.emplace_back();
return 0;
}
//Wrapper functions that might throw errors
MYSQL_STMT* PreparedQuery::mysqlStmtInit(MYSQL* sql)
{
MYSQL_STMT* PreparedQuery::mysqlStmtInit(MYSQL* sql) {
MYSQL_STMT* stmt = mysql_stmt_init(sql);
if (stmt == nullptr)
{
if (stmt == nullptr) {
const char* errorMessage = mysql_error(sql);
int errorCode = mysql_errno(sql);
throw MySQLException(errorCode, errorMessage);
@ -107,106 +100,103 @@ MYSQL_STMT* PreparedQuery::mysqlStmtInit(MYSQL* sql)
return stmt;
}
void PreparedQuery::mysqlStmtBindParameter(MYSQL_STMT* stmt, MYSQL_BIND* bind)
{
void PreparedQuery::mysqlStmtBindParameter(MYSQL_STMT* stmt, MYSQL_BIND* bind) {
int result = mysql_stmt_bind_param(stmt, bind);
if (result != 0)
{
if (result != 0) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
}
}
void PreparedQuery::mysqlStmtPrepare(MYSQL_STMT* stmt, const char* str)
{
void PreparedQuery::mysqlStmtPrepare(MYSQL_STMT* stmt, const char* str) {
unsigned int length = strlen(str);
int result = mysql_stmt_prepare(stmt, str, length);
if (result != 0)
{
if (result != 0) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
}
}
void PreparedQuery::mysqlStmtExecute(MYSQL_STMT* stmt)
{
void PreparedQuery::mysqlStmtExecute(MYSQL_STMT* stmt) {
int result = mysql_stmt_execute(stmt);
if (result != 0)
{
if (result != 0) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
}
}
void PreparedQuery::mysqlStmtStoreResult(MYSQL_STMT* stmt)
{
void PreparedQuery::mysqlStmtStoreResult(MYSQL_STMT* stmt) {
int result = mysql_stmt_store_result(stmt);
if (result != 0)
{
if (result != 0) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
}
}
bool PreparedQuery::mysqlStmtNextResult(MYSQL_STMT* stmt) {
int result = mysql_stmt_next_result(stmt);
if (result > 0) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
}
return result == 0;
}
static my_bool nullBool = 1;
static int trueValue = 1;
static int falseValue = 0;
//Generates binds for a prepared query. In this case the binds are used to send the parameters to the server
void PreparedQuery::generateMysqlBinds(MYSQL_BIND* binds, std::unordered_map<unsigned int, std::unique_ptr<PreparedQueryField>> *map, unsigned int parameterCount)
{
LOG_CURRENT_FUNCTIONCALL
for (unsigned int i = 1; i <= parameterCount; i++)
{
auto it = map->find(i);
if (it == map->end())
{
MYSQL_BIND* bind = &binds[i-1];
void PreparedQuery::generateMysqlBinds(MYSQL_BIND* binds, std::unordered_map<unsigned int, std::shared_ptr<PreparedQueryField>> &map, unsigned int parameterCount) {
for (unsigned int i = 1; i <= parameterCount; i++) {
auto it = map.find(i);
if (it == map.end()) {
MYSQL_BIND* bind = &binds[i - 1];
bind->buffer_type = MYSQL_TYPE_NULL;
bind->is_null = &nullBool;
continue;
}
unsigned int index = it->second->m_index - 1;
if (index >= parameterCount)
{
if (index >= parameterCount) {
std::stringstream errStream;
errStream << "Invalid parameter index " << index + 1;
throw MySQLException(0, errStream.str().c_str());
}
MYSQL_BIND* bind = &binds[index];
switch (it->second->m_type)
switch (it->second->m_type) {
case MYSQL_TYPE_DOUBLE:
{
case MYSQL_TYPE_DOUBLE:
{
TypedQueryField<double>* doubleField = static_cast<TypedQueryField<double>*>(it->second.get());
bind->buffer_type = MYSQL_TYPE_DOUBLE;
bind->buffer = (char*)&doubleField->m_data;
break;
}
case MYSQL_TYPE_BIT:
{
TypedQueryField<bool>* boolField = static_cast<TypedQueryField<bool>*>(it->second.get());
bind->buffer_type = MYSQL_TYPE_LONG;
bind->buffer = (char*)& ((boolField->m_data) ? trueValue : falseValue);
break;
}
case MYSQL_TYPE_STRING:
{
TypedQueryField<std::string>* textField = static_cast<TypedQueryField<std::string>*>(it->second.get());
bind->buffer_type = MYSQL_TYPE_STRING;
bind->buffer = (char*)textField->m_data.c_str();
bind->buffer_length = textField->m_data.length();
break;
}
case MYSQL_TYPE_NULL:
{
bind->buffer_type = MYSQL_TYPE_NULL;
bind->is_null = &nullBool;
break;
}
TypedQueryField<double>* doubleField = static_cast<TypedQueryField<double>*>(it->second.get());
bind->buffer_type = MYSQL_TYPE_DOUBLE;
bind->buffer = (char*)&doubleField->m_data;
break;
}
case MYSQL_TYPE_BIT:
{
TypedQueryField<bool>* boolField = static_cast<TypedQueryField<bool>*>(it->second.get());
bind->buffer_type = MYSQL_TYPE_LONG;
bind->buffer = (char*)& ((boolField->m_data) ? trueValue : falseValue);
break;
}
case MYSQL_TYPE_STRING:
{
TypedQueryField<std::string>* textField = static_cast<TypedQueryField<std::string>*>(it->second.get());
bind->buffer_type = MYSQL_TYPE_STRING;
bind->buffer = (char*)textField->m_data.c_str();
bind->buffer_length = textField->m_data.length();
break;
}
case MYSQL_TYPE_NULL:
{
bind->buffer_type = MYSQL_TYPE_NULL;
bind->is_null = &nullBool;
break;
}
}
}
}
@ -218,56 +208,100 @@ void PreparedQuery::generateMysqlBinds(MYSQL_BIND* binds, std::unordered_map<uns
* Note: If an error occurs at the nth query all the actions done before
* that nth query won't be reverted even though this query results in an error
*/
void PreparedQuery::executeQuery(MYSQL* connection)
{
MYSQL_STMT* stmt = mysqlStmtInit(connection);
my_bool attrMaxLength = 1;
mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, &attrMaxLength);
mysqlStmtPrepare(stmt, this->m_query.c_str());
auto queryFree = finally([&] {
if (stmt != nullptr) {
mysql_stmt_close(stmt);
stmt = nullptr;
void PreparedQuery::executeQuery(MYSQL* connection, std::shared_ptr<IQueryData> ptr) {
PreparedQueryData* data = (PreparedQueryData*)ptr.get();
my_bool oldReconnectStatus = m_database->getAutoReconnect();
//Autoreconnect has to be disabled for prepared statement since prepared statements
//get reset on the server if the connection fails and auto reconnects
m_database->setAutoReconnect((my_bool) 0);
auto resetReconnectStatus = finally([&] { m_database->setAutoReconnect(oldReconnectStatus); });
try {
MYSQL_STMT* stmt = nullptr;
auto stmtClose = finally([&] {
if (!m_database->shouldCachePreparedStatements() && stmt != nullptr) {
mysql_stmt_close(stmt);
}
});
if (this->cachedStatement.load() != nullptr) {
stmt = this->cachedStatement;
} else {
stmt = mysqlStmtInit(connection);
my_bool attrMaxLength = 1;
mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, &attrMaxLength);
mysqlStmtPrepare(stmt, this->m_query.c_str());
if (m_database->shouldCachePreparedStatements()) {
this->cachedStatement = stmt;
m_database->cacheStatement(stmt);
}
}
this->parameters.clear();
});
unsigned int parameterCount = mysql_stmt_param_count(stmt);
std::vector<MYSQL_BIND> mysqlParameters(parameterCount);
unsigned int parameterCount = mysql_stmt_param_count(stmt);
std::vector<MYSQL_BIND> mysqlParameters(parameterCount);
for (auto& currentMap : parameters)
{
generateMysqlBinds(mysqlParameters.data(), &currentMap, parameterCount);
mysqlStmtBindParameter(stmt, mysqlParameters.data());
mysqlStmtExecute(stmt);
mysqlStmtStoreResult(stmt);
auto resultFree = finally([&] { mysql_stmt_free_result(stmt); });
this->results.emplace_back(stmt);
this->m_affectedRows.push_back(mysql_stmt_affected_rows(stmt));
this->m_insertIds.push_back(mysql_stmt_insert_id(stmt));
this->m_resultStatus = QUERY_SUCCESS;
//This is used to clear the connection in case there are
//more ResultSets from a Procedure
while (this->mysqlNextResult(connection))
{
MYSQL_RES * result = this->mysqlStoreResults(connection);
mysql_free_result(result);
for (auto& currentMap : data->m_parameters) {
generateMysqlBinds(mysqlParameters.data(), currentMap, parameterCount);
mysqlStmtBindParameter(stmt, mysqlParameters.data());
mysqlStmtExecute(stmt);
do {
//There is a potential race condition here. What happens
//when the query executes fine but something goes wrong while storing the result?
mysqlStmtStoreResult(stmt);
auto resultFree = finally([&] { mysql_stmt_free_result(stmt); });
data->m_results.emplace_back(stmt);
data->m_affectedRows.push_back(mysql_stmt_affected_rows(stmt));
data->m_insertIds.push_back(mysql_stmt_insert_id(stmt));
data->m_resultStatus = QUERY_SUCCESS;
} while (mysqlStmtNextResult(stmt));
/*
//This is used to clear the connection in case there are
//more ResultSets from a procedure
while (this->mysqlNextResult(connection)) {
MYSQL_RES * result = this->mysqlStoreResults(connection);
mysql_free_result(result);
}*/
}
} catch (const MySQLException& error) {
int errorCode = error.getErrorCode();
if ((errorCode == CR_SERVER_LOST || errorCode == CR_SERVER_GONE_ERROR || errorCode == ER_MAX_PREPARED_STMT_COUNT_REACHED)) {
m_database->freeStatement(this->cachedStatement);
this->cachedStatement = nullptr;
//Because autoreconnect is disabled we want to try and explicitly execute the prepared query once more
//if we can get the client to reconnect (reconnect is caused by mysql_ping)
//If this fails we just go ahead and error
if (oldReconnectStatus && data->firstAttempt) {
m_database->setAutoReconnect((my_bool)1);
if (mysql_ping(connection) == 0) {
data->firstAttempt = false;
executeQuery(connection, ptr);
return;
}
}
}
//Rethrow error to be handled by executeStatement()
throw error;
}
}
bool PreparedQuery::executeStatement(MYSQL* connection)
{
LOG_CURRENT_FUNCTIONCALL
this->m_status = QUERY_RUNNING;
try
{
this->executeQuery(connection);
this->m_resultStatus = QUERY_SUCCESS;
}
catch (const MySQLException& error)
{
this->m_resultStatus = QUERY_ERROR;
this->m_errorText = error.what();
bool PreparedQuery::executeStatement(MYSQL* connection, std::shared_ptr<IQueryData> ptr) {
PreparedQueryData* data = (PreparedQueryData*)ptr.get();
data->setStatus(QUERY_RUNNING);
try {
this->executeQuery(connection, ptr);
data->setResultStatus(QUERY_SUCCESS);
} catch (const MySQLException& error) {
data->setResultStatus(QUERY_ERROR);
data->setError(error.what());
}
return true;
}
std::shared_ptr<IQueryData> PreparedQuery::buildQueryData(lua_State* state) {
std::shared_ptr<IQueryData> ptr(new PreparedQueryData());
PreparedQueryData* data = (PreparedQueryData*)ptr.get();
data->m_parameters = this->m_parameters;
while (m_parameters.size() > 1) {
//Front so the last used parameters are the ones that are gonna stay
m_parameters.pop_front();
}
return ptr;
}

View File

@ -1,5 +1,4 @@
#include "Query.h"
#include "Logger.h"
#include <iostream>
#include <sstream>
#include <cstring>
@ -8,87 +7,251 @@
#include <stdlib.h>
#endif
Query::Query(Database* dbase, lua_State* state) : IQuery(dbase, state)
{
Query::Query(Database* dbase, lua_State* state) : IQuery(dbase, state) {
classname = "Query";
registerFunction(state, "affectedRows", Query::affectedRows);
registerFunction(state, "lastInsert", Query::lastInsert);
registerFunction(state, "getData", Query::getData_Wrapper);
registerFunction(state, "hasMoreResults", Query::hasMoreResults);
registerFunction(state, "getNextResults", Query::getNextResults);
}
Query::~Query(void)
{
}
Query::~Query(void) {}
//Calls the lua callbacks associated with this query
void Query::doCallback(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
this->m_status = QUERY_COMPLETE;
switch (this->m_resultStatus)
{
void Query::doCallback(lua_State* state, std::shared_ptr<IQueryData> data) {
if (this->dataReference != 0) {
LUA->ReferenceFree(this->dataReference);
this->dataReference = 0;
}
switch (data->getResultStatus()) {
case QUERY_NONE:
break;
case QUERY_ERROR:
this->runCallback(state, "onError", "ss", this->m_errorText.c_str(), this->m_query.c_str());
if (data->getErrorReference() != 0) {
this->runFunction(state, data->getErrorReference(), "ss", data->getError().c_str(), this->m_query.c_str());
} else if (data->isFirstData()) {
//This is to preserve somewhat of a backwards compatibility
//In case people set their callbacks after they start their queries
//If it was the first data this query has been started with then
//it can also search on the object for the callback
//This might break some code under very specific circumstances
//but I doubt this will ever be an issue
this->runCallback(state, "onError", "ss", data->getError().c_str(), this->m_query.c_str());
}
break;
case QUERY_SUCCESS:
int dataref = this->getData(state);
if (this->hasCallback(state, "onData"))
{
if (data->getOnDataReference() != 0 || (this->hasCallback(state, "onData") && data->isFirstData())) {
LUA->ReferencePush(dataref);
LUA->PushNil();
while (LUA->Next(-2))
{
while (LUA->Next(-2)) {
//Top is now the row, top-1 row index
int rowReference = LUA->ReferenceCreate();
this->runCallback(state, "onData", "r", rowReference);
if (data->getOnDataReference() != 0) {
this->runFunction(state, data->getOnDataReference(), "r", rowReference);
} else if (data->isFirstData()) {
this->runCallback(state, "onData", "r", rowReference);
}
LUA->ReferenceFree(rowReference);
//Don't have to pop since reference create consumed the value
}
LUA->Pop();
}
this->runCallback(state, "onSuccess", "r", dataref);
if (data->getSuccessReference() != 0) {
this->runFunction(state, data->getSuccessReference(), "r", dataref);
} else if (data->isFirstData()) {
this->runCallback(state, "onSuccess", "r", dataref);
}
break;
}
}
void Query::executeQuery(MYSQL* connection)
{
void Query::executeQuery(MYSQL* connection, std::shared_ptr<IQueryData> data) {
QueryData* queryData = (QueryData*)data.get();
this->mysqlQuery(connection, this->m_query);
//Stores all result sets
//MySQL result sets shouldn't be accessed from different threads!
do
{
do {
MYSQL_RES * results = this->mysqlStoreResults(connection);
auto resultFree = finally([&] { mysql_free_result(results); });
if (results != nullptr)
this->results.emplace_back(results);
queryData->m_results.emplace_back(results);
else
this->results.emplace_back();
this->m_insertIds.push_back(mysql_insert_id(connection));
this->m_affectedRows.push_back(mysql_affected_rows(connection));
queryData->m_results.emplace_back();
queryData->m_insertIds.push_back(mysql_insert_id(connection));
queryData->m_affectedRows.push_back(mysql_affected_rows(connection));
} while (this->mysqlNextResult(connection));
}
//Executes the raw query
bool Query::executeStatement(MYSQL* connection)
{
LOG_CURRENT_FUNCTIONCALL
this->m_status = QUERY_RUNNING;
try
{
this->executeQuery(connection);
this->m_resultStatus = QUERY_SUCCESS;
}
catch (const MySQLException& error)
{
this->m_resultStatus = QUERY_ERROR;
this->m_errorText = error.what();
bool Query::executeStatement(MYSQL* connection, std::shared_ptr<IQueryData> data) {
QueryData* queryData = (QueryData*)data.get();
queryData->setStatus(QUERY_RUNNING);
try {
this->executeQuery(connection, data);
queryData->m_resultStatus = QUERY_SUCCESS;
} catch (const MySQLException& error) {
queryData->m_resultStatus = QUERY_ERROR;
queryData->m_errorText = error.what();
}
return true;
}
//Sets the mysql query string
void Query::setQuery(std::string query)
{
void Query::setQuery(std::string query) {
m_query = query;
}
//This function just returns the data associated with the query
//Data is only created once (and then the reference to that data is returned)
int Query::getData_Wrapper(lua_State* state) {
Query* object = (Query*)unpackSelf(state, TYPE_QUERY);
if (!object->hasCallbackData() || object->callbackQueryData->getResultStatus() == QUERY_ERROR) {
LUA->PushNil();
} else {
LUA->ReferencePush(object->getData(state));
}
return 1;
}
//Stores the data associated with the current result set of the query
//Only called once per result set (and then cached)
int Query::getData(lua_State* state) {
if (this->dataReference != 0)
return this->dataReference;
LUA->CreateTable();
if (hasCallbackData()) {
QueryData* data = (QueryData*)callbackQueryData.get();
if (data->hasMoreResults()) {
ResultData& currentData = data->getResult();
for (unsigned int i = 0; i < currentData.getRows().size(); i++) {
ResultDataRow& row = currentData.getRows()[i];
LUA->CreateTable();
int rowObject = LUA->ReferenceCreate();
for (unsigned int j = 0; j < row.getValues().size(); j++) {
dataToLua(state, rowObject, j + 1, row.getValues()[j], currentData.getColumns()[j].c_str(),
currentData.getColumnTypes()[j], row.isFieldNull(j));
}
LUA->PushNumber(i + 1);
LUA->ReferencePush(rowObject);
LUA->SetTable(-3);
LUA->ReferenceFree(rowObject);
}
}
}
this->dataReference = LUA->ReferenceCreate();
return this->dataReference;
};
//Function that converts the data stored in a mysql field into a lua type
void Query::dataToLua(lua_State* state, int rowReference, unsigned int column, std::string &columnValue, const char* columnName, int columnType, bool isNull) {
LUA->ReferencePush(rowReference);
if (this->m_options & OPTION_NUMERIC_FIELDS) {
LUA->PushNumber(column);
}
if (isNull) {
LUA->PushNil();
} else {
switch (columnType) {
case MYSQL_TYPE_FLOAT:
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_TINY:
case MYSQL_TYPE_SHORT:
LUA->PushNumber(atof(columnValue.c_str()));
break;
case MYSQL_TYPE_BIT:
LUA->PushNumber(static_cast<int>(columnValue[0]));
break;
case MYSQL_TYPE_NULL:
LUA->PushNil();
break;
default:
LUA->PushString(columnValue.c_str(), columnValue.length());
break;
}
}
if (this->m_options & OPTION_NUMERIC_FIELDS) {
LUA->SetTable(-3);
} else {
LUA->SetField(-2, columnName);
}
LUA->Pop();
}
//Returns true if a query has at least one additional ResultSet left
int Query::hasMoreResults(lua_State* state) {
Query* object = (Query*)unpackSelf(state, TYPE_QUERY);
if (!object->hasCallbackData()) {
LUA->ThrowError("Query not completed yet");
}
QueryData* data = (QueryData*)object->callbackQueryData.get();
LUA->PushBool(data->hasMoreResults());
return 1;
}
//Unreferences the current result set and uses the next result set
int Query::getNextResults(lua_State* state) {
Query* object = (Query*)unpackSelf(state, TYPE_QUERY);
if (!object->hasCallbackData()) {
LUA->ThrowError("Query not completed yet");
}
QueryData* data = (QueryData*)object->callbackQueryData.get();
if (!data->getNextResults()) {
LUA->ThrowError("Query doesn't have any more results");
}
if (object->dataReference != 0) {
LUA->ReferenceFree(object->dataReference);
object->dataReference = 0;
}
LUA->ReferencePush(object->getData(state));
return 1;
}
//Returns the last insert id produced by INSERT INTO statements (or 0 if there is none)
int Query::lastInsert(lua_State* state) {
Query* object = (Query*)unpackSelf(state, TYPE_QUERY);
if (!object->hasCallbackData()) {
LUA->PushNumber(0);
return 1;
}
QueryData* data = (QueryData*)object->callbackQueryData.get();
//Calling lastInsert() after query was executed but before the callback is run can cause race conditions
LUA->PushNumber(data->getLastInsertID());
return 1;
}
//Returns the last affected rows produced by INSERT/DELETE/UPDATE (0 for none, -1 for errors)
//For a SELECT statement this returns the amount of rows returned
int Query::affectedRows(lua_State* state) {
Query* object = (Query*)unpackSelf(state, TYPE_QUERY);
if (!object->hasCallbackData()) {
LUA->PushNumber(0);
return 1;
}
QueryData* data = (QueryData*)object->callbackQueryData.get();
//Calling affectedRows() after query was executed but before the callback is run can cause race conditions
LUA->PushNumber(data->getAffectedRows());
return 1;
}
//When the query is destroyed by lua
void Query::onDestroyed(lua_State* state) {
if (this->dataReference != 0 && state != nullptr) {
//Make sure data associated with this query can be freed as well
LUA->ReferenceFree(this->dataReference);
this->dataReference = 0;
}
}
std::shared_ptr<IQueryData> Query::buildQueryData(lua_State* state) {
std::shared_ptr<IQueryData> ptr(new QueryData());
return ptr;
}

View File

@ -2,64 +2,50 @@
#include "IQuery.h"
#include "Database.h"
#include "string.h"
#include "Logger.h"
#include <iostream>
#include <fstream>
ResultData::ResultData(unsigned int columnCount, unsigned int rows)
{
LOG_CURRENT_FUNCTIONCALL
ResultData::ResultData(unsigned int columnCount, unsigned int rows) {
this->columnCount = columnCount;
this->columns.resize(columnCount);
this->columnTypes.resize(columnCount);
this->rows.reserve(rows);
}
ResultData::ResultData() : ResultData(0,0)
{
}
ResultData::ResultData() : ResultData(0, 0) {}
//Stores all of the rows of a result set
//This is used so the result set can be free'd and doesn't have to be used in
//another thread (which is not safe)
ResultData::ResultData(MYSQL_RES* result) : ResultData((unsigned int) mysql_num_fields(result), (unsigned int) mysql_num_rows(result))
{
ResultData::ResultData(MYSQL_RES* result) : ResultData((unsigned int)mysql_num_fields(result), (unsigned int)mysql_num_rows(result)) {
if (columnCount == 0) return;
for (unsigned int i = 0; i < columnCount; i++)
{
for (unsigned int i = 0; i < columnCount; i++) {
MYSQL_FIELD *field = mysql_fetch_field_direct(result, i);
columnTypes[i] = field->type;
columns[i] = field->name;
}
MYSQL_ROW currentRow;
//This shouldn't error since mysql_store_results stores ALL rows already
while ((currentRow = mysql_fetch_row(result)) != nullptr)
{
while ((currentRow = mysql_fetch_row(result)) != nullptr) {
unsigned long *lengths = mysql_fetch_lengths(result);
this->rows.emplace_back(lengths, currentRow, columnCount);
}
}
static bool mysqlStmtFetch(MYSQL_STMT* stmt)
{
static bool mysqlStmtFetch(MYSQL_STMT* stmt) {
int result = mysql_stmt_fetch(stmt);
if (result == 0) return true;
if (result == 1)
{
if (result == 1) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
}
else
{
} else {
return false;
}
}
static void mysqlStmtBindResult(MYSQL_STMT* stmt, MYSQL_BIND* bind)
{
static void mysqlStmtBindResult(MYSQL_STMT* stmt, MYSQL_BIND* bind) {
int result = mysql_stmt_bind_result(stmt, bind);
if (result != 0)
{
if (result != 0) {
const char* errorMessage = mysql_stmt_error(stmt);
int errorCode = mysql_stmt_errno(stmt);
throw MySQLException(errorCode, errorMessage);
@ -68,60 +54,48 @@ static void mysqlStmtBindResult(MYSQL_STMT* stmt, MYSQL_BIND* bind)
//Stores all of the rows of a prepared query
//This needs to be done because the query shouldn't be accessed from a different thread
ResultData::ResultData(MYSQL_STMT* result) : ResultData((unsigned int) mysql_stmt_field_count(result), (unsigned int) mysql_stmt_num_rows(result))
{
ResultData::ResultData(MYSQL_STMT* result) : ResultData((unsigned int)mysql_stmt_field_count(result), (unsigned int)mysql_stmt_num_rows(result)) {
if (this->columnCount == 0) return;
MYSQL_RES * metaData = mysql_stmt_result_metadata(result);
if (metaData == nullptr){ throw std::runtime_error("mysql_stmt_result_metadata: Unknown Error"); }
if (metaData == nullptr) { throw std::runtime_error("mysql_stmt_result_metadata: Unknown Error"); }
MYSQL_FIELD *fields = mysql_fetch_fields(metaData);
std::vector<MYSQL_BIND> binds(columnCount);
std::vector<my_bool> isFieldNull(columnCount);
std::vector<std::vector<char>> buffers;
std::vector<unsigned long> lengths(columnCount);
for (unsigned int i = 0; i < columnCount; i++)
{
for (unsigned int i = 0; i < columnCount; i++) {
columnTypes[i] = fields[i].type;
columns[i] = fields[i].name;
MYSQL_BIND& bind = binds[i];
bind.buffer_type = MYSQL_TYPE_STRING;
buffers.emplace_back(fields[i].max_length + 2);
bind.buffer = buffers.back().data();
bind.buffer_length = fields[i].max_length+1;
bind.buffer_length = fields[i].max_length + 1;
bind.length = &lengths[i];
bind.is_null = &isFieldNull[i];
bind.is_unsigned = 0;
}
mysqlStmtBindResult(result, binds.data());
while (mysqlStmtFetch(result))
{
while (mysqlStmtFetch(result)) {
this->rows.emplace_back(result, binds.data(), columnCount);
}
}
ResultData::~ResultData()
{
}
ResultData::~ResultData() {}
ResultDataRow::ResultDataRow(unsigned int columnCount)
{
ResultDataRow::ResultDataRow(unsigned int columnCount) {
this->columnCount = columnCount;
this->values.resize(columnCount);
this->nullFields.resize(columnCount);
}
//Datastructure that stores a row of mysql data
ResultDataRow::ResultDataRow(unsigned long *lengths, MYSQL_ROW row, unsigned int columnCount) : ResultDataRow(columnCount)
{
for (unsigned int i = 0; i < columnCount; i++)
{
if (row[i])
{
ResultDataRow::ResultDataRow(unsigned long *lengths, MYSQL_ROW row, unsigned int columnCount) : ResultDataRow(columnCount) {
for (unsigned int i = 0; i < columnCount; i++) {
if (row[i]) {
this->values[i] = std::string(row[i], lengths[i]);
}
else
{
if (lengths[i] == 0)
{
} else {
if (lengths[i] == 0) {
this->nullFields[i] = true;
}
this->values[i] = "";
@ -130,18 +104,12 @@ ResultDataRow::ResultDataRow(unsigned long *lengths, MYSQL_ROW row, unsigned int
}
//Datastructure that stores a row of mysql data of a prepared query
ResultDataRow::ResultDataRow(MYSQL_STMT* statement, MYSQL_BIND* bind, unsigned int columnCount) : ResultDataRow(columnCount)
{
for (unsigned int i = 0; i < columnCount; i++)
{
if (!*(bind[i].is_null) && bind[i].buffer)
{
ResultDataRow::ResultDataRow(MYSQL_STMT* statement, MYSQL_BIND* bind, unsigned int columnCount) : ResultDataRow(columnCount) {
for (unsigned int i = 0; i < columnCount; i++) {
if (!*(bind[i].is_null) && bind[i].buffer) {
this->values[i] = std::string((char*)bind[i].buffer, *bind[i].length);
}
else
{
if (*(bind[i].is_null))
{
} else {
if (*(bind[i].is_null)) {
this->nullFields[i] = true;
}
this->values[i] = "";

View File

@ -1,22 +1,15 @@
#include "Transaction.h"
#include "ResultData.h"
#include "Logger.h"
#include "errmsg.h"
#include "Database.h"
Transaction::Transaction(Database* dbase, lua_State* state) : IQuery(dbase, state) {
registerFunction(state, "addQuery", Transaction::addQuery);
registerFunction(state, "getQueries", Transaction::getQueries);
registerFunction(state, "clearQueries", Transaction::clearQueries);
}
void Transaction::onDestroyed(lua_State* state) {
//This unreferences all queries once the transaction has been gc'ed
if (state != nullptr) {
for (auto& query : queries) {
query->unreference(state);
}
}
this->queries.clear();
}
void Transaction::onDestroyed(lua_State* state) {}
//TODO Fix memory leak if transaction is never started
int Transaction::addQuery(lua_State* state) {
@ -24,15 +17,29 @@ int Transaction::addQuery(lua_State* state) {
if (transaction == nullptr) {
LUA->ThrowError("Tried to pass wrong self");
}
IQuery* iQuery = (IQuery*)unpackLuaObject(state, 2, TYPE_QUERY, true);
IQuery* iQuery = (IQuery*)unpackLuaObject(state, 2, TYPE_QUERY, false);
Query* query = dynamic_cast<Query*>(iQuery);
if (query == nullptr) {
LUA->ThrowError("Tried to pass non query to addQuery()");
}
std::lock_guard<std::mutex> lock(transaction->m_queryMutex);
transaction->queries.push_back(query);
//This is all very ugly
LUA->Push(1);
LUA->GetField(-1, "__queries");
if (LUA->IsType(-1, GarrysMod::Lua::Type::NIL)) {
LUA->Pop();
LUA->CreateTable();
LUA->SetField(-2, "__queries");
LUA->GetField(-1, "__queries");
}
int tblIndex = LUA->Top();
LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB);
LUA->GetField(-1, "table");
LUA->GetField(-1, "insert");
LUA->Push(tblIndex);
LUA->Push(2);
return 1;
LUA->Call(2, 0);
LUA->Push(4);
return 0;
}
int Transaction::getQueries(lua_State* state) {
@ -40,93 +47,143 @@ int Transaction::getQueries(lua_State* state) {
if (transaction == nullptr) {
LUA->ThrowError("Tried to pass wrong self");
}
LUA->CreateTable();
for (size_t i = 0; i < transaction->queries.size(); i++)
{
Query* q = transaction->queries[i];
LUA->PushNumber(i + 1);
q->pushTableReference(state);
LUA->SetTable(-3);
}
LUA->Push(1);
LUA->GetField(-1, "__queries");
return 1;
}
int Transaction::clearQueries(lua_State* state) {
Transaction* transaction = dynamic_cast<Transaction*>(unpackSelf(state, TYPE_QUERY));
if (transaction == nullptr) {
LUA->ThrowError("Tried to pass wrong self");
}
LUA->Push(1);
LUA->PushNil();
LUA->SetField(-2, "__queries");
LUA->Pop();
return 0;
}
//Calls the lua callbacks associated with this query
void Transaction::doCallback(lua_State* state)
{
LOG_CURRENT_FUNCTIONCALL
this->m_status = QUERY_COMPLETE;
switch (this->m_resultStatus)
{
void Transaction::doCallback(lua_State* state, std::shared_ptr<IQueryData> ptr) {
TransactionData* data = (TransactionData*)ptr.get();
data->setStatus(QUERY_COMPLETE);
for (auto& pair : data->m_queries) {
auto query = pair.first;
auto queryData = pair.second;
query->setCallbackData(queryData);
}
switch (data->getResultStatus()) {
case QUERY_NONE:
break;
case QUERY_ERROR:
this->runCallback(state, "onError", "s", this->m_errorText.c_str());
if (data->getErrorReference() != 0) {
this->runFunction(state, data->getErrorReference(), "s", data->getError().c_str());
} else if (data->isFirstData()) {
this->runCallback(state, "onError", "s", data->getError().c_str());
}
break;
case QUERY_SUCCESS:
this->runCallback(state, "onSuccess");
if (data->getSuccessReference() != 0) {
this->runFunction(state, data->getSuccessReference());
} else if (data->isFirstData()) {
this->runCallback(state, "onSuccess");
}
break;
}
for (auto& pair : data->m_queries) {
auto query = pair.first;
auto queryData = pair.second;
query->onQueryDataFinished(state, queryData);
}
}
bool Transaction::executeStatement(MYSQL* connection)
{
LOG_CURRENT_FUNCTIONCALL
this->m_status = QUERY_RUNNING;
bool Transaction::executeStatement(MYSQL* connection, std::shared_ptr<IQueryData> ptr) {
TransactionData* data = (TransactionData*)ptr.get();
data->setStatus(QUERY_RUNNING);
//This temporarily disables reconnect, since a reconnect
//would rollback (and cancel) a transaction
//Which could lead to parts of the transaction being executed outside of a transaction
//If they are being executed after the reconnect
my_bool oldReconnectStatus = connection->reconnect;
connection->reconnect = false;
auto resetReconnectStatus = finally([&] { connection->reconnect = oldReconnectStatus; });
try
{
//TODO autoreconnect fucks things up
this->mysqlAutocommit(connection, false);
my_bool oldReconnectStatus = m_database->getAutoReconnect();
m_database->setAutoReconnect((my_bool)0);
auto resetReconnectStatus = finally([&] { m_database->setAutoReconnect(oldReconnectStatus); });
try {
this->mysqlAutocommit(connection, false);
{
std::lock_guard<std::mutex> lock(this->m_queryMutex);
for (auto& query : queries) {
query->executeQuery(connection);
for (auto& query : data->m_queries) {
query.first->executeQuery(connection, query.second);
}
}
mysql_commit(connection);
this->m_resultStatus = QUERY_SUCCESS;
data->setResultStatus(QUERY_SUCCESS);
this->mysqlAutocommit(connection, true);
}
catch (const MySQLException& error)
{
} catch (const MySQLException& error) {
//This check makes sure that setting mysqlAutocommit back to true doesn't cause the transaction to fail
//Even though the transaction was executed successfully
if (this->m_resultStatus != QUERY_SUCCESS) {
if (data->getResultStatus() != QUERY_SUCCESS) {
int errorCode = error.getErrorCode();
if (oldReconnectStatus && !this->retried &&
if (oldReconnectStatus && !data->retried &&
(errorCode == CR_SERVER_LOST || errorCode == CR_SERVER_GONE_ERROR)) {
//Because autoreconnect is disabled we want to try and explicitly execute the transaction once more
//if we can get the client to reconnect (reconnect is caused by mysql_ping)
//If this fails we just go ahead and error
connection->reconnect = true;
m_database->setAutoReconnect((my_bool)1);
if (mysql_ping(connection) == 0) {
this->retried = true;
return executeStatement(connection);
data->retried = true;
return executeStatement(connection, ptr);
}
}
//If this call fails it means that the connection was (probably) lost
//In that case the mysql server rolls back any transaction anyways so it doesn't
//matter if it fails
mysql_rollback(connection);
this->m_resultStatus = QUERY_ERROR;
data->setResultStatus(QUERY_ERROR);
}
//If this fails it probably means that the connection was lost
//In that case autocommit is turned back on anyways (once the connection is reestablished)
//See: https://dev.mysql.com/doc/refman/5.7/en/auto-reconnect.html
mysql_autocommit(connection, true);
this->m_errorText = error.what();
data->setError(error.what());
}
for (auto& query : queries) {
query->setResultStatus(this->m_resultStatus);
query->setStatus(QUERY_COMPLETE);
for (auto& pair : data->m_queries) {
pair.second->setResultStatus(data->getResultStatus());
pair.second->setStatus(QUERY_COMPLETE);
}
this->m_status = QUERY_COMPLETE;
data->setStatus(QUERY_COMPLETE);
return true;
}
std::shared_ptr<IQueryData> Transaction::buildQueryData(lua_State* state) {
//At this point the transaction is guaranteed to have a referenced table
//since this is always called shortly after transaction:start()
std::shared_ptr<IQueryData> ptr(new TransactionData());
TransactionData* data = (TransactionData*)ptr.get();
this->pushTableReference(state);
LUA->GetField(-1, "__queries");
if (!LUA->IsType(-1, GarrysMod::Lua::Type::TABLE)) {
LUA->Pop(2);
return ptr;
}
int index = 1;
//Stuff could go horribly wrong here if a lua error occurs
//but it really shouldn't
while (true) {
LUA->PushNumber(index++);
LUA->GetTable(-2);
if (!LUA->IsType(-1, GarrysMod::Lua::Type::TABLE)) {
LUA->Pop();
break;
}
//This would error if it's not a query
Query* iQuery = (Query*)unpackLuaObject(state, -1, TYPE_QUERY, false);
auto queryPtr = std::dynamic_pointer_cast<Query>(iQuery->getSharedPointerInstance());
auto queryData = iQuery->buildQueryData(state);
iQuery->addQueryData(state, queryData, false);
data->m_queries.push_back(std::make_pair(queryPtr, queryData));
LUA->Pop();
}
LUA->Pop(2);
return ptr;
}

70
lua/connectionpool.lua Normal file
View File

@ -0,0 +1,70 @@
//A simple connection pool for mysqloo9
//Use mysqloo.CreateConnectionPool(connectionCount, host, username, password [, database, port, socket]) to create a connection pool
//That automatically balances queries between several connections and acts like a regular database instance
require("mysqloo")
if (mysqloo.VERSION != "9") then
error("Using outdated mysqloo version")
end
local pool = {}
local poolMT = {__index = pool}
function mysqloo.CreateConnectionPool(conCount, ...)
if (conCount < 1) then
error("Has to contain at least one connection")
end
local newPool = setmetatable({}, poolMT)
newPool._Connections = {}
local function failCallback(db, err)
print("Failed to connect all db pool connections:")
print(err)
end
for i = 1, conCount do
local db = mysqloo.connect(...)
db.onConnectionFailed = failCallback
db:connect()
table.insert(newPool._Connections, db)
end
return newPool
end
function pool:queueSize()
local count = 0
for k,v in pairs(self._Connections) do
count = count + v:queueSize()
end
return count
end
function pool:abortAllQueries()
for k,v in pairs(self._Connections) do
v:abortAllQueries()
end
end
function pool:getLeastOccupiedDB()
local lowest = nil
local lowestCount = 0
for k, db in pairs(self._Connections) do
local queueSize = db:queueSize()
if (!lowest || queueSize < lowestCount) then
lowest = db
lowestCount = queueSize
end
end
if (!lowest) then
error("failed to find database in the pool")
end
return lowest
end
local overrideFunctions = {"escape", "query", "prepare", "createTransaction", "status", "serverVersion", "hostInfo", "ping"}
for _, name in pairs(overrideFunctions) do
pool[name] = function(pool, ...)
local db = pool:getLeastOccupiedDB()
return db[name](db, ...)
end
end

223
lua/mysqloolib.lua Normal file
View File

@ -0,0 +1,223 @@
--[==[
This library aims to provide an easier and less verbose way to use mysqloo
Function overview:
mysqloo.ConvertDatabase(database)
Returns: the modified database
Modifies an existing database to make use of the extended functionality of this library
mysqloo.CreateDatabase(host, username, password [, database, port, socket])
Returns: the newly created database instance
Does the same as mysqloo.connect() but adds convenient functions provided by this library
Query callbacks are of this structure:
function callback([additionalArgs], query, status, dataOrError) end
additionalArgs are any additional arguments that are passed after the callback in RunQuery and similar
query is the query object that represents the started query
status is true if the query executed successfully, false otherwise
dataOrError is either the results returned by query:getData() if the query executed successfully
or and error message if an error occured (use status to know which one)
Note: dataOrError is nil for transaction if the transaction finished successfully
Database:RunQuery(queryStr, [callback [, additionalArgs]])
Parameters:
queryStr: the query to run
callback: the callback function that is called when the query is done
additionalArgs: any args that will be passed to the callback function on success (see callback structure)
Returns: the query that has been created and started
Description: Creates and runs a mysqloo query using the specified queryStr and runs the provided
callback function when the query finished. If no callback function is provided then an error message is printed if the query errors
Example: database:RunQuery("SELECT 1", function(query, status, data)
PrintTable(data)
end)
Database:PrepareQuery(queryStr, parameterValues, [callback [, additionalArgs])
Parameters:
queryStr: the query string to run with ? representing parameters to be passed in parameterValues
parameterValues: a table containing values that are supposed to replace the ? in the prepared query
additionalArgs: see Database:RunQuery()
Returns: the prepared query that has been created and started
Description: Creates and runs a mysqloo prepared query using the specified queryStr and parameters and runs the provided
callback function when the query finished. If no callback function is provided then an error message is printed if the query errors
Example: database:PrepareQuery("SELECT ?, ?", {1, "a"}, function(query, status, data)
PrintTable(data)
end)
Database:CreateTransaction()
Parameters: none
Returns: a transaction object
Transaction:Query(queryStr)
Parameters:
queryStr: the query to run
Returns: the query that has been added to the transaction
Description: Same as Database:RunQuery() but doesn't take a callback and instead of starting the query
adds the query to the transaction
Transaction:Prepare(queryStr, parameterValues)
Parameters:
queryStr: the query string to run with ? representing parameters to be passed in parameterValues
parameterValues: a table containing values that are supposed to replace the ? in the prepared query
Returns: the prepared query that has been added to the transaction
Description: Same as Database:PrepareQuery() but doesn't take a callback and instead of starting the query
adds the query to the transaction
Transaction:Start(callback [, additionalArgs])
Parameters:
callback: The callback that is called when the transaction is finished
additionalArgs: see Database:RunQuery()
Returns: nothing
Description: starts the transaction and calls the callback when done
If the transaction finishes successfully all queries that belong to it have been
executed successfully. If the transaction fails none of the queries will have had effect
Check https://en.wikipedia.org/wiki/ACID for more information
Transaction example:
local transaction = Database:CreateTransaction()
transaction:Query("SELECT 1")
transaction:Prepare("INSERT INTO `some_tbl` (`some_field`) VALUES(?)", {1})
transaction:Query("SELECT * FROM `some_tbl` WHERE `id` = LAST_INSERT_ID()")
transaction:Start(function(transaction, status, err)
if (!status) then error(err) end
PrintTable(transaction:getQueries()[1]:getData())
PrintTable(transaction:getQueries()[3]:getData())
end)
]==]
require("mysqloo")
if (mysqloo.VERSION != "9" || !mysqloo.MINOR_VERSION || tonumber(mysqloo.MINOR_VERSION) < 1) then
MsgC(Color(255, 0, 0), "You are using an outdated mysqloo version\n")
MsgC(Color(255, 0, 0), "Download the latest mysqloo9 from here\n")
MsgC(Color(86, 156, 214), "https://github.com/syl0r/MySQLOO/releases")
return
end
local db = {}
local dbMetatable = {__index = db}
//This converts an already existing database instance to be able to make use
//of the easier functionality provided by mysqloo.CreateDatabase
function mysqloo.ConvertDatabase(database)
return setmetatable(database, dbMetatable)
end
//The same as mysqloo.connect() but adds easier functionality
function mysqloo.CreateDatabase(...)
local db = mysqloo.connect(...)
db:connect()
return mysqloo.ConvertDatabase(db)
end
local function addQueryFunctions(query, func, ...)
local oldtrace = debug.traceback()
local args = {...}
table.insert(args, query)
function query.onAborted(qu)
table.insert(args, false)
table.insert(args, "aborted")
if (func) then
func(unpack(args))
end
end
function query.onError(qu, err)
table.insert(args, false)
table.insert(args, err)
if (func) then
func(unpack(args))
else
ErrorNoHalt(err .. "\n" .. oldtrace .. "\n")
end
end
function query.onSuccess(qu, data)
table.insert(args, true)
table.insert(args, data)
if (func) then
func(unpack(args))
end
end
end
function db:RunQuery(str, callback, ...)
local query = self:query(str)
addQueryFunctions(query, callback, ...)
query:start()
return query
end
local function setPreparedQueryArguments(query, values)
if (type(values) != "table") then
values = { values }
end
local typeFunctions = {
["string"] = function(query, index, value) query:setString(index, value) end,
["number"] = function(query, index, value) query:setNumber(index, value) end,
["boolean"] = function(query, index, value) query:setBoolean(index, value) end,
}
//This has to be pairs instead of ipairs
//because nil is allowed as value
for k, v in pairs(values) do
local varType = type(v)
if (typeFunctions[varType]) then
typeFunctions[varType](query, k, v)
else
query:setString(k, tostring(v))
end
end
end
function db:PrepareQuery(str, values, callback, ...)
self.CachedStatements = self.CachedStatements or {}
local preparedQuery = self.CachedStatements[str] or self:prepare(str)
addQueryFunctions(preparedQuery, callback, ...)
setPreparedQueryArguments(preparedQuery, values)
preparedQuery:start()
return preparedQuery
end
local transaction = {}
local transactionMT = {__index = transaction}
function transaction:Prepare(str, values)
//TODO: Cache queries
local preparedQuery = self._db:prepare(str)
setPreparedQueryArguments(preparedQuery, values)
self:addQuery(preparedQuery)
return preparedQuery
end
function transaction:Query(str)
local query = self._db:query(str)
self:addQuery(query)
return query
end
function transaction:Start(callback, ...)
local args = {...}
table.insert(args, self)
function self:onSuccess()
table.insert(args, true)
if (callback) then
callback(unpack(args))
end
end
function self:onError(err)
err = err or "aborted"
table.insert(args, false)
table.insert(args, err)
if (callback) then
callback(unpack(args))
else
ErrorNoHalt(err)
end
end
self.onAborted = self.onError
self:start()
end
function db:CreateTransaction()
local transaction = self:createTransaction()
transaction._db = self
setmetatable(transaction, transactionMT)
return transaction
end

124
lua/tmysql4.lua Normal file
View File

@ -0,0 +1,124 @@
//Put this into your server's lua/includes/modules/ folder to replace tmysql4 functionality with mysqloo
//This is only a temporary solution and the best would be to change to mysqloo completely
//A few incompatibilities:
//Mysql Bit fields are returned as numbers instead of a single char
//Mysql Bigint fields are returned as numbers
//This might pose a problem if you have a bigint field for steamid64
//Always make sure to cast that field to a string in the SELECT clause of your query
//Example: SELECT CAST(steamid64 as 'CHAR') as steamid64 FROM ...
require("mysqloo")
if (mysqloo.VERSION != "9") then
error("using outdated mysqloo version")
end
tmysql = tmysql or {}
tmysql.Connections = tmysql.Connections or {}
local database = {}
local databaseMT = {__index = database}
function database:Escape(...)
if (self.Disconnected) then error("database already disconnected") end
return self:escape(...)
end
function database:Connect()
self:connect()
self:wait() //this is dumb
//Unfortunately mysqloo only passes the error message to a callback
//because waiting for the db to connect is really dumb
//so there is no way to retrieve the actual error message here
if (self:status() != mysqloo.DATABASE_CONNECTED) then
return false, "[TMYSQL Wrapper]: Failed to connect to database"
end
table.insert(tmysql.Connections, self)
return true
end
function database:Query(str, callback, ...)
if (self.Disconnected) then error("database already disconnected") end
local additionalArgs = {...}
local qu = self:query(str)
if (!callback) then
qu:start()
return
end
qu.onSuccess = function(qu, result)
local results = {
{
status = true,
error = nil,
affected = qu:affectedRows(),
lastid = qu:lastInsert(),
data = result
}
}
while(qu:hasMoreResults()) do
result = qu:getNextResults()
table.insert(results, {
status = true,
error = nil,
affected = qu:affectedRows(),
lastid = qu:lastInsert(),
data = result
})
end
table.insert(additionalArgs, results)
callback(unpack(additionalArgs))
end
qu.onAborted = function(qu)
local data = {
status = false,
error = "Query aborted"
}
table.insert(additionalArgs, {data})
callback(unpack(additionalArgs))
end
qu.onError = function(qu, err)
local data = {
status = false,
error = err
}
table.insert(additionalArgs, {data})
callback(unpack(additionalArgs))
end
qu:start()
end
function database:Disconnect()
if (self.Disconnected) then error("database already disconnected") end
self:abortAllQueries()
table.RemoveByValue(tmysql.Connections, self)
self.Disconnected = true
end
function database:Option(option, value)
if (self.Disconnected) then error("database already disconnected") end
if (option == bit.lshift(1, 16)) then
self:setMultiStatements(tobool(value))
else
print("[TMYSQL Wrapper]: Unsupported tmysql option")
end
end
function tmysql.GetTable()
return tmysql.Connections
end
//Clientflags are ignored, multistatements are always enabled by default
function tmysql.initialize(host, user, password, database, port, unixSocketPath, clientFlags)
local db = mysqloo.connect(host, user, password, database, port, unixSocketPath)
setmetatable(db, databaseMT)
local status, err = db:Connect()
if (!status) then
return nil, err
end
return db, err
end
function tmysql.Create(host, user, password, database, port, unixSocketPath, clientFlags)
local db = mysqloo.connect(host, user, password, database, port, unixSocketPath)
setmetatable(db, databaseMT)
return db
end
tmysql.Connect = tmysql.initialize

1
minorversion.txt Normal file
View File

@ -0,0 +1 @@
1

Binary file not shown.

Binary file not shown.