diff --git a/GmodLUA/GarrysMod/Lua/Interface.h b/GmodLUA/GarrysMod/Lua/Interface.h index 7152362..bb3b8ca 100644 --- a/GmodLUA/GarrysMod/Lua/Interface.h +++ b/GmodLUA/GarrysMod/Lua/Interface.h @@ -3,8 +3,7 @@ #include "LuaBase.h" -struct lua_State -{ +struct lua_State { #if defined( _WIN32 ) && !defined( _M_X64 ) // Win32 unsigned char _ignore_this_common_lua_header_[48 + 22]; @@ -24,27 +23,27 @@ struct lua_State // macOS64 unsigned char _ignore_this_common_lua_header_[92 + 22]; #else - #error agh +#error agh #endif - GarrysMod::Lua::ILuaBase* luabase; + GarrysMod::Lua::ILuaBase *luabase; }; #ifndef GMOD - #ifdef _WIN32 - #define DLL_EXPORT extern "C" __declspec( dllexport ) - #else - #define DLL_EXPORT extern "C" __attribute__((visibility("default"))) - #endif +#ifdef _WIN32 +#define DLL_EXPORT extern "C" __declspec( dllexport ) +#else +#define DLL_EXPORT extern "C" __attribute__((visibility("default"))) +#endif - #ifdef GMOD_ALLOW_DEPRECATED - // Stop using this and use LUA_FUNCTION! - #define LUA ( state->luabase ) +#ifdef GMOD_ALLOW_DEPRECATED +// Stop using this and use LUA_FUNCTION! +#define LUA ( state->luabase ) - #define GMOD_MODULE_OPEN() DLL_EXPORT int gmod13_open( lua_State* state ) - #define GMOD_MODULE_CLOSE() DLL_EXPORT int gmod13_close( lua_State* state ) - #else - #define GMOD_MODULE_OPEN() \ +#define GMOD_MODULE_OPEN() DLL_EXPORT int gmod13_open( lua_State* state ) +#define GMOD_MODULE_CLOSE() DLL_EXPORT int gmod13_close( lua_State* state ) +#else +#define GMOD_MODULE_OPEN() \ int gmod13_open__Imp( GarrysMod::Lua::ILuaBase* LUA ); \ DLL_EXPORT int gmod13_open( lua_State* L ) \ { \ @@ -52,7 +51,7 @@ struct lua_State } \ int gmod13_open__Imp( GarrysMod::Lua::ILuaBase* LUA ) - #define GMOD_MODULE_CLOSE() \ +#define GMOD_MODULE_CLOSE() \ int gmod13_close__Imp( GarrysMod::Lua::ILuaBase* LUA ); \ DLL_EXPORT int gmod13_close( lua_State* L ) \ { \ @@ -60,16 +59,16 @@ struct lua_State } \ int gmod13_close__Imp( GarrysMod::Lua::ILuaBase* LUA ) - #define LUA_FUNCTION( FUNC ) \ - int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ); \ - int FUNC( lua_State* L ) \ +#define LUA_FUNCTION(FUNC) \ + static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ); \ + static int FUNC( lua_State* L ) \ { \ GarrysMod::Lua::ILuaBase* LUA = L->luabase; \ LUA->SetState(L); \ return FUNC##__Imp( LUA ); \ } \ - int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ) - #endif + static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ) +#endif #endif #endif diff --git a/MySQLOO/include/lua/LuaDatabase.h b/MySQLOO/include/lua/LuaDatabase.h deleted file mode 100644 index 0a7b952..0000000 --- a/MySQLOO/include/lua/LuaDatabase.h +++ /dev/null @@ -1,114 +0,0 @@ -#ifndef DATABASE_ -#define DATABASE_ - -#include "MySQLHeader.h" -#include -#include -#include -#include -#include -#include -#include -#include "GarrysMod/Lua/Interface.h" -#include "BlockingQueue.h" -#include "LuaObjectBase.h" -#include "Query.h" -#include "IQuery.h" - -class SSLSettings { -public: - bool customSSLSettings = false; - std::string key = ""; - std::string cert = ""; - std::string ca = ""; - std::string capath = ""; - std::string cipher = ""; - void applySSLSettings(MYSQL* m_sql); -}; - -class DatabaseThread; -class ConnectThread; -enum DatabaseStatus { - DATABASE_CONNECTED = 0, - DATABASE_CONNECTING = 1, - DATABASE_NOT_CONNECTED = 2, - DATABASE_CONNECTION_FAILED = 3 -}; - -class Database : public LuaObjectBase { - friend class IQuery; -public: - enum { - INTEGER = 0, - BIT, - FLOATING_POINT, - STRING, - }; - Database(GarrysMod::Lua::ILuaBase* LUA, std::string host, std::string username, std::string pw, std::string database, unsigned int port, std::string unixSocket); - ~Database(void); - void cacheStatement(MYSQL_STMT* stmt); - void freeStatement(MYSQL_STMT* stmt); - void enqueueQuery(IQuery* query, std::shared_ptr data); - void think(GarrysMod::Lua::ILuaBase*); - void setAutoReconnect(bool autoReconnect); - bool getAutoReconnect(); - bool shouldCachePreparedStatements() { - return cachePreparedStatements; - } -private: - void shutdown(); - void freeCachedStatements(); - void freeUnusedStatements(); - void run(); - void connectRun(); - static int query(lua_State* state); - static int prepare(lua_State* state); - static int createTransaction(lua_State* state); - static int escape(lua_State* state); - static int setCharacterSet(lua_State* state); - static int connect(lua_State* state); - static int wait(lua_State* state); - static int abortAllQueries(lua_State* state); - static int status(lua_State* state); - static int serverVersion(lua_State* state); - static int serverInfo(lua_State* state); - static int hostInfo(lua_State* state); - static int queueSize(lua_State* state); - static int setAutoReconnect(lua_State* state); - static int setMultiStatements(lua_State* state); - static int ping(lua_State* state); - static int setCachePreparedStatements(lua_State* state); - static int disconnect(lua_State* state); - static int setSSLSettings(lua_State* state); - BlockingQueue, std::shared_ptr>> finishedQueries; - BlockingQueue, std::shared_ptr>> queryQueue; - BlockingQueue cachedStatements; - BlockingQueue freedStatements; - MYSQL* m_sql = nullptr; - std::thread m_thread; - std::mutex m_connectMutex; //Mutex used during connection - std::mutex m_queryMutex; //Mutex that is locked while query thread operates on m_sql object - std::condition_variable m_connectWakeupVariable; - unsigned int m_serverVersion = 0; - std::string m_serverInfo = ""; - std::string m_hostInfo = ""; - bool shouldAutoReconnect = true; - bool useMultiStatements = true; - bool dbCallbackRan = false; - bool startedConnecting = false; - bool disconnected = false; - std::atomic m_success{ true }; - std::atomic m_connectionDone{ false }; - std::atomic cachePreparedStatements{ true }; - std::atomic m_status{ DATABASE_NOT_CONNECTED }; - std::string m_connection_err; - std::condition_variable m_queryWakupVariable; - std::string database = ""; - std::string host = ""; - std::string username = ""; - std::string pw = ""; - std::string socket = ""; - unsigned int port; - SSLSettings customSSLSettings { }; -}; -#endif \ No newline at end of file diff --git a/src/Main.cpp b/src/Main.cpp index 9ef9a9b..a5d9329 100644 --- a/src/Main.cpp +++ b/src/Main.cpp @@ -5,6 +5,7 @@ static std::shared_ptr db; int main() { + mysql_library_init(0, nullptr, nullptr); std::cout << "Test" << std::endl; db = Database::createDatabase("127.0.0.1", "root", "test", "mysql", 3306, ""); db->connect(); @@ -19,9 +20,27 @@ int main() { query->start(queryData); query->wait(true); auto firstResultSet = queryData->getResult(); - auto& firstRow = firstResultSet.getRows().front(); - auto& firstValue = firstRow.getValues().front(); + auto &firstRow = firstResultSet.getRows().front(); + auto &firstValue = firstRow.getValues().front(); std::cout << "Result: " << firstValue << std::endl; } - //mysql_library_end(); + + auto transaction = db->transaction(); + auto transactionQuery1 = db->prepare("SELECT ?"); + transactionQuery1->setNumber(1, 3.0); + auto transactionData1 = transactionQuery1->buildQueryData(); + auto transactionQuery2 = db->query("SELECT 12"); + auto transactionData2 = transactionQuery1->buildQueryData(); + std::deque, std::shared_ptr>> transactionQueries; + transactionQueries.emplace_back(transactionQuery1, transactionData1); + transactionQueries.emplace_back(transactionQuery2, transactionData2); + auto transactionData = transaction->buildQueryData(transactionQueries); + transaction->start(transactionData); + transaction->wait(true); + auto firstResultSet = transactionData2->getResult(); + auto &firstRow = firstResultSet.getRows().front(); + auto &firstValue = firstRow.getValues().front(); + std::cout << "Transaction Result: " << firstValue << std::endl; + + mysql_library_end(); } \ No newline at end of file diff --git a/src/lua/GMModule.cpp b/src/lua/GMModule.cpp new file mode 100644 index 0000000..81d8a07 --- /dev/null +++ b/src/lua/GMModule.cpp @@ -0,0 +1,218 @@ +#include "GarrysMod/Lua/Interface.h" +#include "../mysql/Database.h" +#include +#include +#include "LuaObject.h" +#include "LuaDatabase.h" +#define MYSQLOO_VERSION "9" +#define MYSQLOO_MINOR_VERSION "7" + +// Variable to hold the reference to the version check ConVar object +static int versionCheckConVar = 0; + +GMOD_MODULE_CLOSE() { + // Free the version check ConVar object reference + if (versionCheckConVar != 0) { + LUA->ReferenceFree(versionCheckConVar); + } + + /* Deletes all the remaining luaobjects when the server changes map + */ + /* + for (auto query : LuaObjectBase::luaRemovalObjects) { + query->onDestroyed(nullptr); + }*/ + //LuaObjectBase::luaRemovalObjects.clear(); + LuaObject::luaObjects.clear(); + LuaObject::luaThinkObjects.clear(); + mysql_thread_end(); + mysql_library_end(); + return 0; +} + +/* Connects to the database and returns a Database instance that can be used + * as an interface to the mysql server. + */ +LUA_FUNCTION(mysqlooConnect) { + LUA->CheckType(1, GarrysMod::Lua::Type::String); + LUA->CheckType(2, GarrysMod::Lua::Type::String); + LUA->CheckType(3, GarrysMod::Lua::Type::String); + LUA->CheckType(4, GarrysMod::Lua::Type::String); + std::string host = LUA->GetString(1); + std::string username = LUA->GetString(2); + std::string pw = LUA->GetString(3); + std::string database = LUA->GetString(4); + unsigned int port = 3306; + std::string unixSocket; + if (LUA->IsType(5, GarrysMod::Lua::Type::Number)) { + port = (int)LUA->GetNumber(5); + } + if (LUA->IsType(6, GarrysMod::Lua::Type::String)) { + unixSocket = LUA->GetString(6); + } + auto object = Database::createDatabase(host, username, pw, database, port, unixSocket); + return 1; +} + +/* Returns the amount of LuaObjectBase objects that are currently in use + * This includes Database and Query instances + */ +LUA_FUNCTION(objectCount) { + LUA->PushNumber((double) LuaObject::luaObjects.size()); + return 1; +} + +static void runInTimer(GarrysMod::Lua::ILuaBase* LUA, double delay, GarrysMod::Lua::CFunc func) { + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); + LUA->GetField(-1, "timer"); + //In case someone removes the timer library + if (LUA->IsType(-1, GarrysMod::Lua::Type::Nil)) { + LUA->Pop(2); + return; + } + LUA->GetField(-1, "Simple"); + LUA->PushNumber(delay); + LUA->PushCFunction(func); + LUA->Call(2, 0); + LUA->Pop(2); +} + +static void printMessage(GarrysMod::Lua::ILuaBase* LUA, const char* str, int r, int g, int b) { + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); + LUA->GetField(-1, "Color"); + LUA->PushNumber(r); + LUA->PushNumber(g); + LUA->PushNumber(b); + LUA->Call(3, 1); + int ref = LUA->ReferenceCreate(); + LUA->GetField(-1, "MsgC"); + LUA->ReferencePush(ref); + LUA->PushString(str); + LUA->Call(2, 0); + LUA->Pop(); + LUA->ReferenceFree(ref); +} + +static int printOutdatatedVersion(lua_State* state) { + GarrysMod::Lua::ILuaBase* LUA = state->luabase; + LUA->SetState(state); + printMessage(LUA, "Your server is using an outdated mysqloo9 version\n", 255, 0, 0); + printMessage(LUA, "Download the latest version from here:\n", 255, 0, 0); + printMessage(LUA, "https://github.com/FredyH/MySQLOO/releases\n", 86, 156, 214); + runInTimer(LUA, 300, printOutdatatedVersion); + return 0; +} + +static int fetchSuccessful(lua_State* state) { + GarrysMod::Lua::ILuaBase* LUA = state->luabase; + LUA->SetState(state); + std::string version = LUA->GetString(1); + //version.size() < 3 so that the 404 response gets ignored + if (version != MYSQLOO_MINOR_VERSION && version.size() <= 3) { + printOutdatatedVersion(state); + } else { + printMessage(LUA, "Your server is using the latest mysqloo9 version\n", 0, 255, 0); + } + return 0; +} + +static int fetchFailed(lua_State* state) { + GarrysMod::Lua::ILuaBase* LUA = state->luabase; + LUA->SetState(state); + printMessage(LUA, "Failed to retrieve latest version of mysqloo9\n", 255, 0, 0); + return 0; +} + +static int doVersionCheck(lua_State* state) { + GarrysMod::Lua::ILuaBase* LUA = state->luabase; + LUA->SetState(state); + + // Check if the reference to the ConVar object is set + if (versionCheckConVar != 0) { + // Retrieve the value of the ConVar + LUA->ReferencePush(versionCheckConVar); // Push the ConVar object + LUA->GetField(-1, "GetInt"); // Push the name of the function + LUA->ReferencePush(versionCheckConVar); // Push the ConVar object as the first self argument + LUA->Call(1, 1); // Call with 1 argument and 1 return + int versionCheckEnabled = (int)LUA->GetNumber(-1); // Retrieve the returned value + + // Check if the version check convar is set to 1 + if (versionCheckEnabled == 1) { + // Execute the HTTP request + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); + LUA->GetField(-1, "http"); + LUA->GetField(-1, "Fetch"); + LUA->PushString("https://raw.githubusercontent.com/FredyH/MySQLOO/master/minorversion.txt"); + LUA->PushCFunction(fetchSuccessful); + LUA->PushCFunction(fetchFailed); + LUA->PCall(3, 0, 0); + LUA->Pop(2); + } + } + + return 0; +} + +GMOD_MODULE_OPEN() { + if (mysql_library_init(0, nullptr, nullptr)) { + LUA->ThrowError("Could not initialize mysql library."); + } + + //Creating MetaTables + LuaDatabase::createMetaTable(LUA); + + //LuaObjectBase::createMetatables(LUA); + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); + LUA->GetField(-1, "hook"); + LUA->GetField(-1, "Add"); + LUA->PushString("Think"); + LUA->PushString("__MySQLOOThinkHook"); + LUA->PushCFunction(LuaObject::luaObjectThink); + LUA->Call(3, 0); + LUA->Pop(); + LUA->Pop(); + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); + LUA->CreateTable(); + + LUA->PushString(MYSQLOO_VERSION); LUA->SetField(-2, "VERSION"); + LUA->PushString(MYSQLOO_MINOR_VERSION); LUA->SetField(-2, "MINOR_VERSION"); + + LUA->PushNumber(DATABASE_CONNECTED); LUA->SetField(-2, "DATABASE_CONNECTED"); + LUA->PushNumber(DATABASE_CONNECTING); LUA->SetField(-2, "DATABASE_CONNECTING"); + LUA->PushNumber(DATABASE_NOT_CONNECTED); LUA->SetField(-2, "DATABASE_NOT_CONNECTED"); + LUA->PushNumber(DATABASE_CONNECTION_FAILED); LUA->SetField(-2, "DATABASE_CONNECTION_FAILED"); + + LUA->PushNumber(QUERY_NOT_RUNNING); LUA->SetField(-2, "QUERY_NOT_RUNNING"); + LUA->PushNumber(QUERY_RUNNING); LUA->SetField(-2, "QUERY_RUNNING"); + LUA->PushNumber(QUERY_COMPLETE); LUA->SetField(-2, "QUERY_COMPLETE"); + LUA->PushNumber(QUERY_ABORTED); LUA->SetField(-2, "QUERY_ABORTED"); + LUA->PushNumber(QUERY_WAITING); LUA->SetField(-2, "QUERY_WAITING"); + + LUA->PushNumber(OPTION_NUMERIC_FIELDS); LUA->SetField(-2, "OPTION_NUMERIC_FIELDS"); + LUA->PushNumber(OPTION_INTERPRET_DATA); LUA->SetField(-2, "OPTION_INTERPRET_DATA"); //Not used anymore + LUA->PushNumber(OPTION_NAMED_FIELDS); LUA->SetField(-2, "OPTION_NAMED_FIELDS"); //Not used anymore + LUA->PushNumber(OPTION_CACHE); LUA->SetField(-2, "OPTION_CACHE"); //Not used anymore + + LUA->PushCFunction(mysqlooConnect); LUA->SetField(-2, "connect"); + LUA->PushCFunction(objectCount); LUA->SetField(-2, "objectCount"); + + LUA->SetField(-2, "mysqloo"); + LUA->Pop(); + + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); // Push the global table + // Create the version check ConVar + LUA->GetField(-1, "CreateConVar"); + LUA->PushString("sv_mysqloo_versioncheck"); // Name + LUA->PushString("1"); // Default value + LUA->PushNumber(128); // FCVAR flags + LUA->PushString("Enable or disable the MySQLOO update checker."); // Help text + LUA->PushNumber(0); // Min value + LUA->PushNumber(1); // Max value + LUA->Call(6, 1); // Call with 6 arguments and 1 result + versionCheckConVar = LUA->ReferenceCreate(); // Store the created ConVar object as a global variable + LUA->Pop(); // Pop the global table + + runInTimer(LUA, 5, doVersionCheck); + + return 1; +} diff --git a/src/lua/LuaDatabase.cpp b/src/lua/LuaDatabase.cpp new file mode 100644 index 0000000..71671e9 --- /dev/null +++ b/src/lua/LuaDatabase.cpp @@ -0,0 +1,197 @@ +// +// Created by Fredy on 28/10/2021. +// + +#include "LuaDatabase.h" + +MYSQLOO_LUA_FUNCTION(connect) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + if (database->m_tableReference == 0) { + LUA->Push(-1); + database->m_tableReference = LUA->ReferenceCreate(); + } + database->m_database->connect(); + return 0; +} + +MYSQLOO_LUA_FUNCTION(escape) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + unsigned int nQueryLength; + const char *sQuery = LUA->GetString(2, &nQueryLength); + auto escaped = database->m_database->escape(std::string(sQuery, nQueryLength)); + LUA->PushString(escaped.c_str(), escaped.size()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(setCharacterSet) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->CheckType(2, GarrysMod::Lua::Type::String); + const char *charset = LUA->GetString(2); + bool success = database->m_database->setCharacterSet(charset); + LUA->PushBool(success); + LUA->PushString(""); + return 2; +} + +MYSQLOO_LUA_FUNCTION(setSSLSettings) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + SSLSettings sslSettings; + if (LUA->IsType(2, GarrysMod::Lua::Type::String)) { + sslSettings.key = LUA->GetString(2); + } + if (LUA->IsType(3, GarrysMod::Lua::Type::String)) { + sslSettings.cert = LUA->GetString(3); + } + if (LUA->IsType(4, GarrysMod::Lua::Type::String)) { + sslSettings.ca = LUA->GetString(4); + } + if (LUA->IsType(5, GarrysMod::Lua::Type::String)) { + sslSettings.capath = LUA->GetString(5); + } + if (LUA->IsType(6, GarrysMod::Lua::Type::String)) { + sslSettings.cipher = LUA->GetString(6); + } + database->m_database->setSSLSettings(sslSettings); + return 0; +} + +MYSQLOO_LUA_FUNCTION(disconnect) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + bool wait = false; + if (LUA->IsType(2, GarrysMod::Lua::Type::Bool)) { + wait = LUA->GetBool(2); + } + database->m_database->disconnect(wait); + return 0; +} + +MYSQLOO_LUA_FUNCTION(status) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->PushNumber(database->m_database->m_status); + return 1; +} + +MYSQLOO_LUA_FUNCTION(serverVersion) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->PushNumber(database->m_database->serverVersion()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(serverInfo) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->PushString(database->m_database->serverInfo().c_str()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(hostInfo) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->PushString(database->m_database->hostInfo().c_str()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(setAutoReconnect) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->CheckType(2, GarrysMod::Lua::Type::Bool); + database->m_database->setAutoReconnect(LUA->GetBool(2)); + return 0; +} + +MYSQLOO_LUA_FUNCTION(setMultiStatements) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->CheckType(2, GarrysMod::Lua::Type::Bool); + database->m_database->setMultiStatements(LUA->GetBool(2)); + return 0; +} + +MYSQLOO_LUA_FUNCTION(setCachePreparedStatements) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->CheckType(2, GarrysMod::Lua::Type::Bool); + database->m_database->setCachePreparedStatements(LUA->GetBool(2)); + return 0; +} + +MYSQLOO_LUA_FUNCTION(abortAllQueries) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + auto abortedQueries = database->m_database->abortAllQueries(); + for (auto pair: abortedQueries) { + //TODO: + //query->onQueryDataFinished(LUA, data); + } + LUA->PushNumber((double) abortedQueries.size()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(queueSize) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->PushNumber((double) database->m_database->queueSize()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(ping) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + LUA->PushBool(database->m_database->ping()); + return 1; +} + +MYSQLOO_LUA_FUNCTION(wait) { + auto database = getLuaObject(LUA, LuaObject::TYPE_DATABASE); + database->m_database->wait(); + return 0; +} + +void LuaDatabase::createMetaTable(ILuaBase *LUA) { + LuaObject::TYPE_DATABASE = LUA->CreateMetaTable("MySQLOO Database"); + LuaObject::addMetaTableFunctions(LUA); + + LUA->PushCFunction(connect); + LUA->SetField(-2, "connect"); + + LUA->PushCFunction(escape); + LUA->SetField(-2, "escape"); + + LUA->PushCFunction(setCharacterSet); + LUA->SetField(-2, "setCharacterSet"); + + LUA->PushCFunction(setSSLSettings); + LUA->SetField(-2, "setSSLSettings"); + + LUA->PushCFunction(disconnect); + LUA->SetField(-2, "disconnect"); + + LUA->PushCFunction(status); + LUA->SetField(-2, "status"); + + LUA->PushCFunction(serverVersion); + LUA->SetField(-2, "serverVersion"); + + LUA->PushCFunction(serverInfo); + LUA->SetField(-2, "serverInfo"); + + LUA->PushCFunction(hostInfo); + LUA->SetField(-2, "hostInfo"); + + LUA->PushCFunction(setAutoReconnect); + LUA->SetField(-2, "setAutoReconnect"); + + LUA->PushCFunction(setMultiStatements); + LUA->SetField(-2, "setMultiStatements"); + + LUA->PushCFunction(setCachePreparedStatements); + LUA->SetField(-2, "setCachePreparedStatements"); + + LUA->PushCFunction(abortAllQueries); + LUA->SetField(-2, "abortAllQueries"); + + LUA->PushCFunction(queueSize); + LUA->SetField(-2, "queueSize"); + + LUA->PushCFunction(ping); + LUA->SetField(-2, "ping"); + + LUA->PushCFunction(wait); + LUA->SetField(-2, "wait"); +} + +void LuaDatabase::think(ILuaBase *lua) { + LuaObject::think(lua); +} \ No newline at end of file diff --git a/src/lua/LuaDatabase.h b/src/lua/LuaDatabase.h new file mode 100644 index 0000000..30d7bf2 --- /dev/null +++ b/src/lua/LuaDatabase.h @@ -0,0 +1,30 @@ +// +// Created by Fredy on 28/10/2021. +// + +#ifndef MYSQLOO_LUADATABASE_H +#define MYSQLOO_LUADATABASE_H + + +#include "../mysql/Database.h" + +#include +#include "LuaObject.h" + +class LuaDatabase : public LuaObject { +public: + static void createMetaTable(ILuaBase *LUA); + + void think(ILuaBase *lua) override; + + int m_tableReference = 0; + std::shared_ptr m_database; +protected: + explicit LuaDatabase(std::shared_ptr database) : LuaObject("Database"), + m_database(std::move(database)) { + + } +}; + + +#endif //MYSQLOO_LUADATABASE_H diff --git a/src/lua/LuaObject.cpp b/src/lua/LuaObject.cpp new file mode 100644 index 0000000..5949947 --- /dev/null +++ b/src/lua/LuaObject.cpp @@ -0,0 +1,61 @@ +#include +#include "LuaObject.h" + +std::string LuaObject::toString() { + std::stringstream ss; + ss << s_className << " " << this; + return ss.str(); +} + +std::shared_ptr LuaObject::checkMySQLOOType(ILuaBase *lua, int position) { + int type = lua->GetType(position); + if (type != LuaObject::TYPE_DATABASE && type != LuaObject::TYPE_PREPARED_QUERY && + type != LuaObject::TYPE_QUERY && type != LuaObject::TYPE_TRANSACTION) { + lua->ThrowError("Provided argument is not a valid MySQLOO object"); + } + return lua->GetUserType(position, type)->shared_from_this(); +} + +LUA_FUNCTION(luaObjectGc) { + auto luaObject = LuaObject::checkMySQLOOType(LUA); + LuaObject::luaThinkObjects.push_back(luaObject); + LuaObject::luaThinkObjects.erase( + std::remove(LuaObject::luaThinkObjects.begin(), LuaObject::luaThinkObjects.end(), luaObject), + LuaObject::luaThinkObjects.end() + ); + LuaObject::luaObjects.erase( + std::remove(LuaObject::luaObjects.begin(), LuaObject::luaObjects.end(), luaObject), + LuaObject::luaObjects.end() + ); + //After this function this object should be deleted. + //For the Database this might cause the database thread to join + //TODO: Figure out if this is wise. + return 0; +} + +LUA_FUNCTION(luaObjectToString) { + auto luaObject = LuaObject::checkMySQLOOType(LUA); + auto str = luaObject->toString(); + LUA->PushString(str.c_str()); + return 1; +} + +LUA_CLASS_FUNCTION(LuaObject, luaObjectThink) { + std::deque> thinkObjectsCopy = LuaObject::luaThinkObjects; + for (auto &query: thinkObjectsCopy) { + query->think(LUA); + } + return 0; +} + + +void LuaObject::addMetaTableFunctions(GarrysMod::Lua::ILuaBase *lua) { + lua->CreateTable(); + lua->SetField(-2, "__index"); + + lua->PushCFunction(luaObjectGc); + lua->SetField(-2, "__gc"); + + lua->PushCFunction(luaObjectToString); + lua->SetField(-2, "__tostring"); +} \ No newline at end of file diff --git a/src/lua/LuaObject.h b/src/lua/LuaObject.h new file mode 100644 index 0000000..4c82248 --- /dev/null +++ b/src/lua/LuaObject.h @@ -0,0 +1,82 @@ + + +#ifndef MYSQLOO_LUAOBJECT_H +#define MYSQLOO_LUAOBJECT_H + +#include +#include +#include +#include +#include "GarrysMod/Lua/Interface.h" +#include "../mysql/MySQLOOException.h" + +using namespace GarrysMod::Lua; + +class LuaObject : public std::enable_shared_from_this { +public: + virtual void think(ILuaBase *lua) {}; + + std::string toString(); + + + static std::deque> luaObjects; + static std::deque> luaThinkObjects; + + static std::shared_ptr checkMySQLOOType(ILuaBase *lua, int position = 1); + + static int TYPE_DATABASE; + static int TYPE_QUERY; + static int TYPE_PREPARED_QUERY; + static int TYPE_TRANSACTION; + + static void addMetaTableFunctions(ILuaBase *lua); + + //Lua functions + static int luaObjectThink(lua_State *L); + +protected: + + explicit LuaObject(std::string className) : s_className(std::move(className)) { + + } + + std::string s_className; +}; + +template +T *getLuaObject(ILuaBase *LUA, int type, int stackPos = 1) { + T *returnValue = LUA->GetUserType(-1, type); + if (returnValue == nullptr) { + LUA->ThrowError("[MySQLOO] Expected MySQLOO table"); + } + return returnValue; +} + + +#define MYSQLOO_LUA_FUNCTION(FUNC) \ + static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ); \ + static int FUNC( lua_State* L ) \ + { \ + GarrysMod::Lua::ILuaBase* LUA = L->luabase; \ + LUA->SetState(L); \ + try { \ + return FUNC##__Imp( LUA ); \ + } catch (const MySQLOOException& error) { \ + LUA->ThrowError(error.message.c_str()); \ + return 0; \ + } \ + } \ + static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ) + +#define LUA_CLASS_FUNCTION(CLASS, FUNC) \ + static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ); \ + int CLASS::FUNC( lua_State* L ) \ + { \ + GarrysMod::Lua::ILuaBase* LUA = L->luabase; \ + LUA->SetState(L); \ + return FUNC##__Imp( LUA ); \ + } \ + static int FUNC##__Imp( GarrysMod::Lua::ILuaBase* LUA ) + + +#endif //MYSQLOO_LUAOBJECT_H diff --git a/src/mysql/Database.cpp b/src/mysql/Database.cpp index eb07b87..c267f50 100644 --- a/src/mysql/Database.cpp +++ b/src/mysql/Database.cpp @@ -2,6 +2,7 @@ #include "Query.h" #include "IQuery.h" #include "MySQLOOException.h" +#include "Transaction.h" #include #include #include @@ -95,16 +96,14 @@ size_t Database::queueSize() { /* Aborts all queries that are in the queue of started queries and returns the number of successfully aborted queries. * Does not abort queries that are already taken from the queue and being processed. */ -size_t Database::abortAllQueries() { +std::deque, std::shared_ptr>> Database::abortAllQueries() { auto canceledQueries = queryQueue.clear(); for (auto &pair: canceledQueries) { auto query = pair.first; auto data = pair.second; data->setStatus(QUERY_ABORTED); - //TODO: - //query->onQueryDataFinished(LUA, data); } - return canceledQueries.size(); + return canceledQueries; } /* Waits for the connection of the database to finish by blocking the current thread until the connect thread finished. @@ -277,6 +276,10 @@ std::shared_ptr Database::prepare(const std::string &query) { return std::shared_ptr(new PreparedQuery(shared_from_this(), query)); } +std::shared_ptr Database::transaction() { + return std::shared_ptr(new Transaction(shared_from_this())); +} + bool Database::ping() { auto query = std::shared_ptr(new PingQuery(shared_from_this())); auto queryData = query->buildQueryData(); diff --git a/src/mysql/Database.h b/src/mysql/Database.h index f7f8f62..d42d333 100644 --- a/src/mysql/Database.h +++ b/src/mysql/Database.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -15,6 +16,7 @@ #include "PreparedQuery.h" #include "IQuery.h" #include "PingQuery.h" +#include "Transaction.h" struct SSLSettings { bool customSSLSettings = false; @@ -76,9 +78,9 @@ public: std::shared_ptr prepare(const std::string &query); - bool ping(); + std::shared_ptr transaction(); - //std::shared_ptr createTransaction(lua_State* state); + bool ping(); std::string escape(const std::string &str); @@ -88,7 +90,7 @@ public: void wait(); - size_t abortAllQueries(); + std::deque, std::shared_ptr>> abortAllQueries(); DatabaseStatus status(); @@ -102,12 +104,13 @@ public: void setMultiStatements(bool multiStatement); - //std::shared_ptr ping(); void setCachePreparedStatements(bool cachePreparedStatements); void disconnect(bool wait); void setSSLSettings(const SSLSettings &settings); + std::atomic m_status{DATABASE_NOT_CONNECTED}; + std::string m_connection_err; private: Database(std::string host, std::string username, std::string pw, std::string database, unsigned int port, @@ -143,8 +146,6 @@ private: std::atomic m_success{true}; std::atomic m_connectionDone{false}; std::atomic cachePreparedStatements{true}; - std::atomic m_status{DATABASE_NOT_CONNECTED}; - std::string m_connection_err; std::condition_variable m_queryWakupVariable; std::string database; std::string host; diff --git a/src/mysql/IQuery.h b/src/mysql/IQuery.h index 7266108..6981f59 100644 --- a/src/mysql/IQuery.h +++ b/src/mysql/IQuery.h @@ -110,7 +110,7 @@ public: } void setError(std::string err) { - m_errorText = err; + m_errorText = std::move(err); } bool isFinished() { @@ -166,6 +166,7 @@ protected: int m_errorReference = 0; int m_abortReference = 0; int m_onDataReference = 0; + int m_tableReference = 0; bool m_wasFirstData = false; }; diff --git a/src/mysql/MySQLOOException.h b/src/mysql/MySQLOOException.h index b91fa5c..c46d721 100644 --- a/src/mysql/MySQLOOException.h +++ b/src/mysql/MySQLOOException.h @@ -1,8 +1,6 @@ #ifndef MYSQLOO_MYSQLERROR_H #define MYSQLOO_MYSQLERROR_H -#endif //MYSQLOO_MYSQLERROR_H - #include //When called from the outside @@ -12,4 +10,6 @@ public: } std::string message{}; -}; \ No newline at end of file +}; + +#endif //MYSQLOO_MYSQLERROR_H diff --git a/src/mysql/Transaction.h b/src/mysql/Transaction.h index 7f6aa0a..13df8b1 100644 --- a/src/mysql/Transaction.h +++ b/src/mysql/Transaction.h @@ -25,15 +25,15 @@ class Transaction : public IQuery { friend class Database; public: - explicit Transaction(const std::weak_ptr &database) : IQuery(database) { - - } static std::shared_ptr buildQueryData(const std::deque, std::shared_ptr>>& queries); protected: bool executeStatement(Database &database, MYSQL *connection, std::shared_ptr data) override; + explicit Transaction(const std::weak_ptr &database) : IQuery(database) { + + } private: std::vector> addedQueryData;