From 95f97501b8071839bce565019ae1c108dacaf15a Mon Sep 17 00:00:00 2001 From: FredyH Date: Tue, 10 Sep 2024 00:42:30 +0200 Subject: [PATCH] Added onDisconnected callback called after calling db:disconnect. --- .gitignore | 3 ++- README.md | 6 ++++++ src/lua/LuaDatabase.cpp | 32 ++++++++++++++++++++++++++++---- src/lua/LuaDatabase.h | 1 + src/lua/LuaObject.cpp | 5 +++++ src/mysql/Database.cpp | 9 ++++++++- src/mysql/Database.h | 3 ++- 7 files changed, 52 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 0d6fece..ada10e0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ *.zip cmake-build-debug .idea -.vs \ No newline at end of file +.vs +.cache \ No newline at end of file diff --git a/README.md b/README.md index 3181a0a..abc5f69 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ Database:connect() Database:disconnect(shouldWait) -- Returns nothing -- disconnects from the database and waits for all queries to finish if shouldWait is true +-- This function calls the onDisconnected callback if it existed on the database before the database was connected. Database:query( sql ) -- Returns [Query] @@ -149,6 +150,11 @@ Database.onConnected( db ) Database.onConnectionFailed( db, err ) -- Called when the connection to the MySQL server fails, [String] err is why. +Database.onDisconnected( db ) +-- Called after Database.disconnect has been called and all queries have finished executing +-- Note: You have to set this callback before calling Database:connect() or it will not be called. + + -- Query/PreparedQuery object (transactions also inherit all functions, some have no effect though) -- Functions diff --git a/src/lua/LuaDatabase.cpp b/src/lua/LuaDatabase.cpp index 518426e..52ba493 100644 --- a/src/lua/LuaDatabase.cpp +++ b/src/lua/LuaDatabase.cpp @@ -103,6 +103,12 @@ MYSQLOO_LUA_FUNCTION(connect) { LUA->Push(1); database->m_tableReference = LuaReferenceCreate(LUA); } + + LUA->ReferencePush(database->m_tableReference); + LUA->GetField(-1, "onDisconnected"); + database->m_hasOnDisconnected = LUA->IsType(-1, GarrysMod::Lua::Type::Function); + LUA->Pop(2); // callback, table + database->m_database->connect(); return 0; } @@ -351,7 +357,6 @@ void LuaDatabase::think(ILuaBase *LUA) { LUA->ReferencePush(this->m_tableReference); pcallWithErrorReporter(LUA, 1); } - LUA->Pop(); //Callback function } else { LUA->GetField(-1, "onConnectionFailed"); if (LUA->GetType(-1) == GarrysMod::Lua::Type::Function) { @@ -360,11 +365,15 @@ void LuaDatabase::think(ILuaBase *LUA) { LUA->PushString(error.c_str()); pcallWithErrorReporter(LUA, 2); } - LUA->Pop(); //Callback function } + LUA->Pop(); // DB Table - LuaReferenceFree(LUA, this->m_tableReference); - this->m_tableReference = 0; + if (!this->m_hasOnDisconnected) { + // Only free the table reference if we do not have an onDisconnected callback. + // Otherwise, it will be freed after the onDisconnected callback was called. + LuaReferenceFree(LUA, this->m_tableReference); + this->m_tableReference = 0; + } } //Run callbacks of finished queries @@ -372,6 +381,21 @@ void LuaDatabase::think(ILuaBase *LUA) { for (auto &pair: finishedQueries) { LuaQuery::runCallback(LUA, pair.first, pair.second); } + + if (database->wasDisconnected() && this->m_hasOnDisconnected && this->m_tableReference != 0) { + this->m_hasOnDisconnected = false; + + LUA->ReferencePush(this->m_tableReference); + + LUA->GetField(-1, "onDisconnected"); + if (LUA->GetType(-1) == GarrysMod::Lua::Type::Function) { + LUA->ReferencePush(this->m_tableReference); + pcallWithErrorReporter(LUA, 1); + } + LUA->Pop(1); // DB Table + + LuaReferenceFree(LUA, this->m_tableReference); + } } void LuaDatabase::onDestroyedByLua(ILuaBase *LUA) { diff --git a/src/lua/LuaDatabase.h b/src/lua/LuaDatabase.h index d06062d..6d8f636 100644 --- a/src/lua/LuaDatabase.h +++ b/src/lua/LuaDatabase.h @@ -17,6 +17,7 @@ public: void think(ILuaBase *LUA); int m_tableReference = 0; + bool m_hasOnDisconnected = false; std::shared_ptr m_database; bool m_dbCallbackRan = false; diff --git a/src/lua/LuaObject.cpp b/src/lua/LuaObject.cpp index f68ee22..652ecaf 100644 --- a/src/lua/LuaObject.cpp +++ b/src/lua/LuaObject.cpp @@ -68,6 +68,11 @@ LUA_FUNCTION(errorReporter) { return 1; } +/** + * Similar to LUA->PCall but also uses an error reporter and prints the + * error to the console using ErrorNoHalt (if it exists). + * Consumes the function and all nargs arguments on the stack, does not return any values. + */ void LuaObject::pcallWithErrorReporter(ILuaBase *LUA, int nargs) { LUA->PushCFunction(errorReporter); int errorHandlerIndex = LUA->Top() - nargs - 1; diff --git a/src/mysql/Database.cpp b/src/mysql/Database.cpp index fcfaa62..ae660db 100644 --- a/src/mysql/Database.cpp +++ b/src/mysql/Database.cpp @@ -215,7 +215,13 @@ void Database::disconnect(bool wait) { if (wait && m_thread.joinable()) { m_thread.join(); } - disconnected = true; +} + +/* + * Returns true after the database has been fully disconnected and no more queries are in the queue. + */ +bool Database::wasDisconnected() { + return disconnected; } /* Returns the status of the database, constants can be found in GMModule @@ -361,6 +367,7 @@ void Database::connectRun() { if (m_status == DATABASE_CONNECTED) { m_status = DATABASE_NOT_CONNECTED; } + disconnected = true; }); { auto connectionSignaler = finally([&] { m_connectWakeupVariable.notify_one(); }); diff --git a/src/mysql/Database.h b/src/mysql/Database.h index ff0b3e2..44ea35f 100644 --- a/src/mysql/Database.h +++ b/src/mysql/Database.h @@ -115,6 +115,7 @@ public: return finishedQueries.clear(); } + bool wasDisconnected(); private: Database(std::string host, std::string username, std::string pw, std::string database, unsigned int port, std::string unixSocket); @@ -158,10 +159,10 @@ private: bool shouldAutoReconnect = true; bool useMultiStatements = true; bool startedConnecting = false; - bool disconnected = false; bool m_canWait = false; std::pair, std::shared_ptr> m_waitingQuery = {nullptr, nullptr}; std::atomic m_success{true}; + std::atomic disconnected { false }; std::atomic m_connectionDone{false}; std::atomic cachePreparedStatements{true}; std::condition_variable m_queryWakeupVariable{};