diff --git a/.gitignore b/.gitignore index 10af485..0d6fece 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,4 @@ *.zip cmake-build-debug .idea -.vs -CMakeSettings.json \ No newline at end of file +.vs \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a5a804..e11fee9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ if (CMAKE_SIZEOF_VOID_P EQUAL 8) endif () elseif (CMAKE_SIZEOF_VOID_P EQUAL 4) if (WIN32) - find_library(MARIADB_CLIENT_LIB mariadbclient HINTS "${PROJECT_SOURCE_DIR}MySQL/lib/windows") + find_library(MARIADB_CLIENT_LIB mariadbclient HINTS "${PROJECT_SOURCE_DIR}/MySQL/lib/windows") else () find_library(MARIADB_CLIENT_LIB mariadbclient HINTS "${PROJECT_SOURCE_DIR}/MySQL/lib/linux") find_library(CRYPTO_LIB crypto HINTS "${PROJECT_SOURCE_DIR}/MySQL/lib/linux") diff --git a/CMakeSettings.json b/CMakeSettings.json new file mode 100644 index 0000000..b6d078a --- /dev/null +++ b/CMakeSettings.json @@ -0,0 +1,27 @@ +{ + "configurations": [ + { + "name": "x64-RelDebug", + "generator": "Ninja", + "configurationType": "RelWithDebInfo", + "inheritEnvironments": [ "msvc_x64_x64" ], + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "" + }, + { + "name": "x86-RelDebug", + "generator": "Ninja", + "configurationType": "RelWithDebInfo", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "inheritEnvironments": [ "msvc_x86" ], + "variables": [] + } + ] +} \ No newline at end of file diff --git a/IntegrationTest/README.md b/IntegrationTest/README.md new file mode 100644 index 0000000..1e52e03 --- /dev/null +++ b/IntegrationTest/README.md @@ -0,0 +1,12 @@ +# MySQLOO Lua Integration Tests + +This folder contains integration tests for MySQLOO. + +## Running + +- Place this folder into the server's addons folder. +- Adjust the database settings in lua/autorun/server/init.lua + - ensure that the database used is **empty** as it will be filled with test data +- run `mysqloo_start_tests` in the server console + +Each of the tests outputs its result to the console and at the end there is an overview of all tests printed. \ No newline at end of file diff --git a/IntegrationTest/lua/autorun/server/init.lua b/IntegrationTest/lua/autorun/server/init.lua new file mode 100644 index 0000000..ed15fb8 --- /dev/null +++ b/IntegrationTest/lua/autorun/server/init.lua @@ -0,0 +1,10 @@ +DatabaseSettings = { + Host = "localhost", + Port = 3306, + Username = "root", + Password = "", + Database = "test" +} + +print("Loading MySQLOO Testing Framework") +include("mysqloo/init.lua") \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/init.lua b/IntegrationTest/lua/mysqloo/init.lua new file mode 100644 index 0000000..ad0c60b --- /dev/null +++ b/IntegrationTest/lua/mysqloo/init.lua @@ -0,0 +1,8 @@ +include("testframework.lua") +include("setup.lua") + + +local files = file.Find("mysqloo/tests/*.lua", "LUA") +for _,f in pairs(files) do + include("tests/" .. f) +end \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/setup.lua b/IntegrationTest/lua/mysqloo/setup.lua new file mode 100644 index 0000000..e9ff841 --- /dev/null +++ b/IntegrationTest/lua/mysqloo/setup.lua @@ -0,0 +1,18 @@ +require("mysqloo") + +function TestFramework:ConnectToDatabase() + local db = mysqloo.connect(DatabaseSettings.Host, DatabaseSettings.Username, DatabaseSettings.Password, DatabaseSettings.Database, DatabaseSettings.Port) + db:connect() + db:wait() + return db +end + +function TestFramework:RunQuery(db, queryStr) + local query = db:query(queryStr) + query:start() + function query:onError(err) + error(err) + end + query:wait() + return query:getData() +end \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/testframework.lua b/IntegrationTest/lua/mysqloo/testframework.lua new file mode 100644 index 0000000..8791602 --- /dev/null +++ b/IntegrationTest/lua/mysqloo/testframework.lua @@ -0,0 +1,167 @@ + + +TestFramework = TestFramework or {} +TestFramework.RegisteredTests = {} + +local TestMT = {} +TestMT.__index = TestMT + +function TestFramework:RegisterTest(name, f) + local tbl = setmetatable({}, {__index = TestMT}) + tbl.TestFunction = f + tbl.Name = name + table.insert(TestFramework.RegisteredTests, tbl) + print("Registered test ", name) +end + +function TestFramework:RunNextTest() + TestFramework.CurrentIndex = (TestFramework.CurrentIndex or 0) + 1 + TestFramework.TestTimeout = CurTime() + 3 + local test = TestFramework.RegisteredTests[TestFramework.CurrentIndex] + TestFramework.CurrentTest = test + if (!test) then + TestFramework:OnCompleted() + else + test:Run() + end +end + +function TestFramework:CheckTimeout() + if (!TestFramework.CurrentTest) then return end + if (CurTime() > TestFramework.TestTimeout) then + TestFramework.CurrentTest:Fail("TIMEOUT") + end +end + +hook.Add("Think", "TestFrameworkTimeoutCheck", function() + TestFramework:CheckTimeout() +end) + +function TestFramework:ReportResult(success) + TestFramework.TestCount = (TestFramework.TestCount or 0) + 1 + if (success) then + TestFramework.SuccessCount = (TestFramework.SuccessCount or 0) + 1 + else + TestFramework.FailureCount = (TestFramework.FailureCount or 0) + 1 + end +end + +function TestFramework:OnCompleted() + print("[MySQLOO] Tests completed") + MsgC(Color(255, 255, 255), "Completed: ", Color(30, 230, 30), TestFramework.SuccessCount, Color(255, 255, 255), " Failures: ", Color(230, 30, 30), TestFramework.FailureCount, "\n") + + for j = 0, 3 do + timer.Simple(j * 0.5, function() + for i = 1, 100 do + collectgarbage("collect") + end + end) + end + timer.Simple(2, function() + for i = 1, 100 do + collectgarbage("collect") + end + local diffBefore = TestFramework.AllocationCount - TestFramework.DeallocationCount + local diffAfter = mysqloo.allocationCount() - mysqloo.deallocationCount() + if (diffAfter > diffBefore) then + MsgC(Color(255, 255, 255), "Found potential memory leak with ", diffAfter - diffBefore, " new allocations that were not freed\n") + else + MsgC(Color(255, 255, 255), "All allocated objects were freed\n") + end + MsgC(Color(255, 255, 255), "Lua Heap Before: ", TestFramework.LuaMemory, " After: ", collectgarbage("count"), "\n") + end) +end + +function TestFramework:Start() + for i = 1, 5 do + collectgarbage("collect") + end + TestFramework.CurrentIndex = 0 + TestFramework.SuccessCount = 0 + TestFramework.FailureCount = 0 + TestFramework.AllocationCount = mysqloo.allocationCount() + TestFramework.DeallocationCount = mysqloo.deallocationCount() + TestFramework.LuaMemory = collectgarbage("count") + TestFramework:RunNextTest() +end + +function TestMT:Fail(reason) + if (self.Completed) then return end + self.Completed = true + MsgC(Color(230, 30, 30), "FAILED\n") + MsgC(Color(230, 30, 30), "Error: ", reason, "\n") + TestFramework:ReportResult(false) + TestFramework:RunNextTest() +end + +function TestMT:Complete() + if (self.Completed) then return end + self.Completed = true + MsgC(Color(30, 230, 30), "PASSED\n") + TestFramework:ReportResult(true) + TestFramework:RunNextTest() +end + +function TestMT:Run() + MsgC("Test: ", self.Name, " ") + self.Completed = false + local status, err = pcall(function() + self.TestFunction(self) + end) + if (!status) then + self:Fail(err) + end +end + +function TestMT:shouldBeNil(a) + if (a != nil) then + self:Fail(tostring(a) .. " was expected to be nil, but was not nil") + error("Assertion failed") + end +end + +function TestMT:shouldBeGreaterThan(a, num) + if (num >= a) then + self:Fail(tostring(a) .. " was expected to be greater than " .. tostring(num)) + error("Assertion failed") + end +end + +function TestMT:shouldNotBeNil(a) + if (a == nil) then + self:Fail(tostring(a) .. " was expected to not be nil, but was nil") + error("Assertion failed") + end +end + +function TestMT:shouldNotBeEqual(a, b) + if (a == b) then + self:Fail(tostring(a) .. " was equal to " .. tostring(b)) + error("Assertion failed") + end +end + +function TestMT:shouldBeEqual(a, b) + if (a != b) then + self:Fail(tostring(a) .. " was not equal to " .. tostring(b)) + error("Assertion failed") + end +end + +function TestMT:shouldHaveLength(tbl, exactLength) + if (#tbl != exactLength) then + self:Fail("Length of " .. tostring(tbl) .. " was not equal to " .. exactLength) + error("Assertion failed") + end +end + +concommand.Add("mysqloo_start_tests", function(ply) + if (IsValid(ply)) then return end + print("Starting MySQLOO Tests") + if (#player.GetBots() == 0) then + RunConsoleCommand("bot") + end + timer.Simple(0.1, function() + TestFramework:Start() + end) +end) \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/tests/basic_mysql_test.lua b/IntegrationTest/lua/mysqloo/tests/basic_mysql_test.lua new file mode 100644 index 0000000..7da524b --- /dev/null +++ b/IntegrationTest/lua/mysqloo/tests/basic_mysql_test.lua @@ -0,0 +1,14 @@ +TestFramework:RegisterTest("[Basic] selecting 1 should return 1", function(test) + local db = TestFramework:ConnectToDatabase() + local query = db:query("SELECT 3 as test") + function query:onSuccess(data) + test:shouldHaveLength(data, 1) + test:shouldBeEqual(data[1]["test"], 3) + test:Complete() + end + function query:onError(err) + test:Fail(err) + end + query:start() + query:wait() +end) \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/tests/database_tests.lua b/IntegrationTest/lua/mysqloo/tests/database_tests.lua new file mode 100644 index 0000000..38f2930 --- /dev/null +++ b/IntegrationTest/lua/mysqloo/tests/database_tests.lua @@ -0,0 +1,119 @@ +TestFramework:RegisterTest("[Database] should return info correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local serverInfo = db:serverInfo() + local hostInfo = db:hostInfo() + local serverVersion = db:serverVersion() + test:shouldBeGreaterThan(#serverInfo, 0) + test:shouldBeGreaterThan(#hostInfo, 0) + test:shouldBeGreaterThan(serverVersion, 0) + test:Complete() +end) + +TestFramework:RegisterTest("[Database] queue size should return correct size", function(test) + local db = TestFramework:ConnectToDatabase() + test:shouldBeEqual(db:queueSize(), 0) + local query1 = db:query("SELECT SLEEP(0.5)") + local query2 = db:query("SELECT SLEEP(0.5)") + local query3 = db:query("SELECT SLEEP(0.5)") + function query1:onSuccess() + test:shouldBeEqual(db:queueSize(), 1) //1 because the next was already started at this point + end + function query2:onSuccess() + test:shouldBeEqual(db:queueSize(), 0) //0 because the next was already started at this point + test:Complete() + end + function query3:onSuccess() + test:shouldBeEqual(db:queueSize(), 0) + test:Complete() + end + query1:start() + query2:start() + query3:start() + test:shouldBeGreaterThan(db:queueSize(), 1) +end) + +TestFramework:RegisterTest("[Database] should abort all queries correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local query1 = db:query("SELECT SLEEP(0.5)") + local query2 = db:query("SELECT SLEEP(0.5)") + local query3 = db:query("SELECT SLEEP(0.5)") + local abortedCount = 0 + local f = function(q) + abortedCount = abortedCount + 1 + if (abortedCount == 2) then + test:Complete() + end + end + query1.onAborted = f + query2.onAborted = f + query3.onAborted = f + query1:start() + query2:start() + query3:start() + local amountAborted = db:abortAllQueries() + test:shouldBeGreaterThan(amountAborted, 1) //The one already processing might not be aborted +end) + +TestFramework:RegisterTest("[Database] should escape a string correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local escapedStr = db:escape("t'a") + test:shouldBeEqual(escapedStr, "t\\'a") + test:Complete() +end) + +TestFramework:RegisterTest("[Database] should return correct status", function(test) + local db = mysqloo.connect(DatabaseSettings.Host, DatabaseSettings.Username, DatabaseSettings.Password, DatabaseSettings.Database, DatabaseSettings.Port) + test:shouldBeEqual(db:status(), mysqloo.DATABASE_NOT_CONNECTED) + db:connect() + function db:onConnected() + test:shouldBeEqual(db:status(), mysqloo.DATABASE_CONNECTED) + + db:disconnect(true) + test:shouldBeEqual(db:status(), mysqloo.DATABASE_NOT_CONNECTED) + test:Complete() + end +end) + +TestFramework:RegisterTest("[Database] should call onConnected callback correctly", function(test) + local db = mysqloo.connect(DatabaseSettings.Host, DatabaseSettings.Username, DatabaseSettings.Password, DatabaseSettings.Database, DatabaseSettings.Port) + function db:onConnected() + test:Complete() + end + db:connect() +end) + +TestFramework:RegisterTest("[Database] should call onConnectionFailed callback correctly", function(test) + local db = mysqloo.connect(DatabaseSettings.Host, DatabaseSettings.Username, "incorrect_password", DatabaseSettings.Database, DatabaseSettings.Port) + function db:onConnectionFailed(err) + test:shouldBeGreaterThan(#err, 0) + test:Complete() + end + db:connect() +end) + +TestFramework:RegisterTest("[Database] should ping correctly", function(test) + local db = TestFramework:ConnectToDatabase() + test:shouldBeEqual(db:ping(), true) + test:Complete() +end) + +TestFramework:RegisterTest("[Database] allow setting only valid character set", function(test) + local db = TestFramework:ConnectToDatabase() + test:shouldBeEqual(db:setCharacterSet("utf8"), true) + test:shouldBeEqual(db:setCharacterSet("ascii"), true) + test:shouldBeEqual(db:setCharacterSet("invalid_name"), false) + test:Complete() +end) + +TestFramework:RegisterTest("[Database] wait for queries when disconnecting", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SELECT SLEEP(1)") + local wasCalled = false + function qu:onSuccess() + wasCalled = true + end + qu:start() + db:disconnect(true) + test:shouldBeEqual(wasCalled, true) + test:Complete() +end) \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/tests/prepared_query_tests.lua b/IntegrationTest/lua/mysqloo/tests/prepared_query_tests.lua new file mode 100644 index 0000000..8871e8e --- /dev/null +++ b/IntegrationTest/lua/mysqloo/tests/prepared_query_tests.lua @@ -0,0 +1,300 @@ +TestFramework:RegisterTest("[Prepared Query] have correct set... functions", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT ? as a, ? as b, ? as c, ? as d, ? as e") + qu:setNumber(1, 5.5) + qu:setString(2, "test") + qu:setBoolean(3, true) + qu:setBoolean(4, false) + qu:setString(5, "b") + qu:setNull(5) + qu:start() + qu:wait() + local data = qu:getData() + test:shouldHaveLength(data, 1) + test:shouldBeEqual(data[1].a, 5.5) + test:shouldBeEqual(data[1].b, "test") + test:shouldBeEqual(data[1].c, 1) + test:shouldBeEqual(data[1].d, 0) + test:shouldBeNil(data[1].e) + test:Complete() +end) + +TestFramework:RegisterTest("[Prepared Query] allow batching parameters correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT ? as a") + qu:setNumber(1, 1) + qu:putNewParameters() + qu:setNumber(1, 2) + qu:putNewParameters() + qu:setNumber(1, 3) + qu:start() + qu:wait() + test:shouldBeEqual(qu:hasMoreResults(), true) + test:shouldBeEqual(qu:getData()[1].a, 1) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), true) + test:shouldBeEqual(qu:getData()[1].a, 2) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), true) + test:shouldBeEqual(qu:getData()[1].a, 3) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), false) + test:Complete() +end) + +TestFramework:RegisterTest("[Prepared Query] clear parameters correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT ? as a, ? as b") + qu:setNumber(1, 1) + qu:setString(2, "test") + qu:clearParameters() + qu:start() + qu:wait() + local data = qu:getData() + test:shouldHaveLength(data, 1) + test:shouldBeNil(data[1].a) + test:shouldBeNil(data[1].b) + test:Complete() +end) + +TestFramework:RegisterTest("[Prepared Query] last insert should return correct values", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS last_insert_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE last_insert_test(id INT AUTO_INCREMENT PRIMARY KEY)]]) + + local qu = db:prepare("INSERT INTO last_insert_test VALUES()") + function qu:onSuccess() + test:shouldBeEqual(qu:lastInsert(), 1) + end + qu:start() + local qu2 = db:prepare("INSERT INTO last_insert_test VALUES()") + function qu2:onSuccess() + test:shouldBeEqual(qu2:lastInsert(), 2) + end + qu2:start() + local qu3 = db:prepare("INSERT INTO last_insert_test VALUES()") + function qu3:onSuccess() + test:shouldBeEqual(qu3:lastInsert(), 3) + function qu3:onSuccess() + test:shouldBeEqual(qu3:lastInsert(), 4) + qu3.onSuccess = nil + qu3:start() + qu3:wait() + test:shouldBeEqual(qu3:lastInsert(), 5) + test:Complete() + end + qu3:start() + end + qu3:start() +end) + +TestFramework:RegisterTest("[Prepared Query] affected rows should return correct values", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS affected_rows_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE affected_rows_test(id INT AUTO_INCREMENT PRIMARY KEY)]]) + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + local qu = db:prepare("DELETE FROM affected_rows_test WHERE id = ?") + qu:setNumber(1, 4) + qu:start() + qu:wait() + test:shouldBeEqual(qu:affectedRows(), 1) + qu:start() + local qu2 = db:prepare("DELETE FROM affected_rows_test") + function qu2:onSuccess() + test:shouldBeEqual(qu2:affectedRows(), 3) + function qu2:onSuccess() + test:shouldBeEqual(qu2:affectedRows(), 0) + test:Complete() + end + qu2:start() + end + qu2:start() +end) + +TestFramework:RegisterTest("[Prepared Query] isRunning should return the correct value", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT SLEEP(0.1)") + test:shouldBeEqual(qu:isRunning(), false) + function qu:onSuccess() + test:shouldBeEqual(qu:isRunning(), true) + timer.Simple(0.1, function() + test:shouldBeEqual(qu:isRunning(), false) + test:Complete() + end) + end + qu:start() + test:shouldBeEqual(qu:isRunning(), true) +end) + +TestFramework:RegisterTest("[Prepared Query] should return correct data", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS data_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE data_test(id INT PRIMARY KEY, str VARCHAR(10), big BIGINT, bin BLOB, num DOUBLE, bool BIT)]]) + TestFramework:RunQuery(db, [[INSERT INTO data_test VALUES(1, '2', 8589934588, X'470047', 3.3, TRUE)]]) + TestFramework:RunQuery(db, [[INSERT INTO data_test VALUES(2, null, -8589930588, X'00AB', 10.1, FALSE)]]) + + local qu = db:prepare("SELECT * FROM data_test") + function qu:onSuccess(data) + test:shouldBeEqual(data, qu:getData()) //Check that it is cached correctly + test:shouldBeEqual(#data, 2) + local row1 = data[1] + test:shouldBeEqual(row1.id, 1) + test:shouldBeEqual(row1.str, "2") + test:shouldBeEqual(row1.big, 8589934588) + test:shouldBeEqual(row1.bin, string.char(0x47,0x00,0x47)) + test:shouldBeEqual(row1.num, 3.3) + test:shouldBeEqual(row1.bool, 1) + local row2 = data[2] + test:shouldBeEqual(row2.id, 2) + test:shouldBeNil(row2.str) + test:shouldBeEqual(row2.big, -8589930588) + test:shouldBeEqual(row2.bin, string.char(0x00,0xAB)) + test:shouldBeEqual(row2.num, 10.1) + test:shouldBeEqual(row2.bool, 0) + test:Complete() + end + function qu:onError(err) + print(err) + end + qu:start() +end) + +TestFramework:RegisterTest("[Prepared Query] should return correct data if numeric is enabled", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT 1, 2, 4") + qu:setOption(mysqloo.OPTION_NUMERIC_FIELDS) + function qu:onSuccess(data) + test:shouldBeEqual(#data, 1) + local row = data[1] + test:shouldBeEqual(row[1], 1) + test:shouldBeEqual(row[2], 2) + test:shouldBeEqual(row[3], 4) + test:shouldBeNil(row[4]) + test:Complete() + end + qu:start() +end) + +TestFramework:RegisterTest("[Prepared Query] should return correct error", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SEsdg") + function qu:onError(err) + test:shouldBeEqual(qu:error(), err) + test:shouldBeGreaterThan(#qu:error(), 0) + test:Complete() + end + qu:start() +end) + +TestFramework:RegisterTest("[Prepared Query] should return correct error", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SEsdg") + function qu:onError(err) + test:shouldBeEqual(qu:error(), err) + test:shouldBeGreaterThan(#qu:error(), 0) + timer.Simple(0.1, function() + test:shouldBeEqual(qu:error(), err) + test:Complete() + end) + end + qu:start() +end) + +TestFramework:RegisterTest("[Prepared Query] should return correct error if waiting", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SEsdg") + qu:start() + qu:wait() + test:shouldBeGreaterThan(#qu:error(), 0) + test:Complete() +end) + +TestFramework:RegisterTest("[Prepared Query] prevent multiple statements if disabled", function(test) + local db = mysqloo.connect(DatabaseSettings.Host, DatabaseSettings.Username, DatabaseSettings.Password, DatabaseSettings.Database, DatabaseSettings.Port) + db:setMultiStatements(false) + db:connect() + db:wait() + local qu = db:prepare("SELECT 1; SELECT 2;") + function qu:onError() + test:Complete() + end + function qu:onSuccess() + test:Fail("Query should have failed but did not") + end + qu:start() + qu:wait() +end) + +TestFramework:RegisterTest("[Prepared Query] prevent multiple statements even if enabled", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT 1 as a; SELECT 2 as b;") + function qu:onError() + test:Complete() + end + qu:start() +end) + +TestFramework:RegisterTest("[Prepared Query] call onData correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT ? as a UNION ALL SELECT ?") + qu:setNumber(1, 1) + qu:setNumber(2, 2) + local callCount = 0 + local sum = 0 + function qu:onSuccess(data) //onData is called before onSuccess + test:shouldHaveLength(data, 2) + test:shouldBeEqual(callCount, 2) + test:shouldBeEqual(sum, 3) + test:Complete() + end + function qu:onData(row) + callCount = callCount + 1 + sum = sum + row.a + end + qu:start() +end) + +TestFramework:RegisterTest("[Prepared Query] abort query correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:prepare("SELECT SLEEP(1)") //This should block for a bit + qu:start() + local qu2 = db:prepare("SELECT 1") + qu2:start() + function qu2:onAborted() + test:Complete() + end + test:shouldBeEqual(qu2:abort(), true) + test:shouldBeEqual(qu:abort(), false) +end) + +TestFramework:RegisterTest("[Prepared Query] Work with stored procedure correctly", function(test) + local db = TestFramework:ConnectToDatabase() + //db:setMultiStatements(false) + TestFramework:RunQuery(db, "DROP PROCEDURE IF EXISTS test_procedure") + TestFramework:RunQuery(db, [[ + CREATE PROCEDURE test_procedure (IN param INT) + BEGIN + SELECT param as a; + SELECT 999 as b; + END + ]]) + local qu = db:prepare("CALL test_procedure(?)") + qu:setNumber(1, 5) + qu:start() + qu:wait() + test:shouldBeEqual(qu:hasMoreResults(), true) + local first = qu:getData() + test:shouldBeEqual(first[1].a, 5) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), true) + local second = qu:getData() + test:shouldBeEqual(second[1].b, 999) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), true) //For some reason, stored procedures add extra result sets + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), false) + test:Complete() +end) \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/tests/query_tests.lua b/IntegrationTest/lua/mysqloo/tests/query_tests.lua new file mode 100644 index 0000000..a7fa375 --- /dev/null +++ b/IntegrationTest/lua/mysqloo/tests/query_tests.lua @@ -0,0 +1,253 @@ +TestFramework:RegisterTest("[Query] last insert should return correct values", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS last_insert_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE last_insert_test(id INT AUTO_INCREMENT PRIMARY KEY)]]) + + local qu = db:query("INSERT INTO last_insert_test VALUES()") + function qu:onSuccess() + test:shouldBeEqual(qu:lastInsert(), 1) + end + qu:start() + local qu2 = db:query("INSERT INTO last_insert_test VALUES()") + function qu2:onSuccess() + test:shouldBeEqual(qu2:lastInsert(), 2) + end + qu2:start() + local qu3 = db:query("INSERT INTO last_insert_test VALUES()") + function qu3:onSuccess() + test:shouldBeEqual(qu3:lastInsert(), 3) + function qu3:onSuccess() + test:shouldBeEqual(qu3:lastInsert(), 4) + qu3.onSuccess = nil + qu3:start() + qu3:wait() + test:shouldBeEqual(qu3:lastInsert(), 5) + test:Complete() + end + qu3:start() + end + qu3:start() +end) + +TestFramework:RegisterTest("[Query] affected rows should return correct values", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS affected_rows_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE affected_rows_test(id INT AUTO_INCREMENT PRIMARY KEY)]]) + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + TestFramework:RunQuery(db, "INSERT INTO affected_rows_test VALUES()") + local qu = db:query("DELETE FROM affected_rows_test WHERE id = 4") + qu:start() + qu:wait() + test:shouldBeEqual(qu:affectedRows(), 1) + qu:start() + local qu2 = db:query("DELETE FROM affected_rows_test") + function qu2:onSuccess() + test:shouldBeEqual(qu2:affectedRows(), 3) + function qu2:onSuccess() + test:shouldBeEqual(qu2:affectedRows(), 0) + test:Complete() + end + qu2:start() + end + qu2:start() +end) + +TestFramework:RegisterTest("[Query] isRunning should return the correct value", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SELECT SLEEP(0.1)") + test:shouldBeEqual(qu:isRunning(), false) + function qu:onSuccess() + test:shouldBeEqual(qu:isRunning(), true) + timer.Simple(0.1, function() + test:shouldBeEqual(qu:isRunning(), false) + test:Complete() + end) + end + qu:start() + test:shouldBeEqual(qu:isRunning(), true) +end) + +TestFramework:RegisterTest("[Query] should return correct data", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS data_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE data_test(id INT PRIMARY KEY, str VARCHAR(10), big BIGINT, bin BLOB, num DOUBLE, bool BIT)]]) + TestFramework:RunQuery(db, [[INSERT INTO data_test VALUES(1, '2', 8589934588, X'470047', 3.3, TRUE)]]) + TestFramework:RunQuery(db, [[INSERT INTO data_test VALUES(2, null, -8589930588, X'00AB', 10.1, FALSE)]]) + + local qu = db:query("SELECT * FROM data_test") + function qu:onSuccess(data) + test:shouldBeEqual(data, qu:getData()) //Check that it is cached correctly + test:shouldBeEqual(#data, 2) + local row1 = data[1] + test:shouldBeEqual(row1.id, 1) + test:shouldBeEqual(row1.str, "2") + test:shouldBeEqual(row1.big, 8589934588) + test:shouldBeEqual(row1.bin, string.char(0x47,0x00,0x47)) + test:shouldBeEqual(row1.num, 3.3) + test:shouldBeEqual(row1.bool, 1) + local row2 = data[2] + test:shouldBeEqual(row2.id, 2) + test:shouldBeNil(row2.str) + test:shouldBeEqual(row2.big, -8589930588) + test:shouldBeEqual(row2.bin, string.char(0x00,0xAB)) + test:shouldBeEqual(row2.num, 10.1) + test:shouldBeEqual(row2.bool, 0) + test:Complete() + end + function qu:onError(err) + print(err) + end + qu:start() +end) + +TestFramework:RegisterTest("[Query] should return correct data if numeric is enabled", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SELECT 1, 2, 4") + qu:setOption(mysqloo.OPTION_NUMERIC_FIELDS) + function qu:onSuccess(data) + test:shouldBeEqual(#data, 1) + local row = data[1] + test:shouldBeEqual(row[1], 1) + test:shouldBeEqual(row[2], 2) + test:shouldBeEqual(row[3], 4) + test:shouldBeNil(row[4]) + test:Complete() + end + qu:start() +end) + +TestFramework:RegisterTest("[Query] should return correct error", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SEsdg") + function qu:onError(err) + test:shouldBeEqual(qu:error(), err) + test:shouldBeGreaterThan(#qu:error(), 0) + test:Complete() + end + qu:start() +end) + +TestFramework:RegisterTest("[Query] should return correct error", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SEsdg") + function qu:onError(err) + test:shouldBeEqual(qu:error(), err) + test:shouldBeGreaterThan(#qu:error(), 0) + timer.Simple(0.1, function() + test:shouldBeEqual(qu:error(), err) + test:Complete() + end) + end + qu:start() +end) + +TestFramework:RegisterTest("[Query] should return correct error if waiting", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SEsdg") + qu:start() + qu:wait() + test:shouldBeGreaterThan(#qu:error(), 0) + test:Complete() +end) + +TestFramework:RegisterTest("[Query] prevent multiple statements if disabled", function(test) + local db = mysqloo.connect(DatabaseSettings.Host, DatabaseSettings.Username, DatabaseSettings.Password, DatabaseSettings.Database, DatabaseSettings.Port) + db:setMultiStatements(false) + db:connect() + db:wait() + local qu = db:query("SELECT 1; SELECT 2;") + function qu:onError() + test:Complete() + end + function qu:onSuccess() + test:Fail("Query should have failed but did not") + end + qu:start() + qu:wait() +end) + +TestFramework:RegisterTest("[Query] work correctly with multi statements and multiple results", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SELECT 1 as a; SELECT 2 as b;") + qu:start() + qu:wait() + local data = qu:getData() + test:shouldBeEqual(#data, 1) + test:shouldBeEqual(data[1].a, 1) + test:shouldBeEqual(qu:hasMoreResults(), true) + + qu:getNextResults() + local newData = qu:getData() + test:shouldNotBeEqual(newData, data) + test:shouldBeEqual(#newData, 1) + test:shouldBeEqual(newData[1].b, 2) + test:shouldBeEqual(qu:hasMoreResults(), true) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), false) + test:Complete() +end) + +TestFramework:RegisterTest("[Query] work correctly with multi statements and multiple results with affectedRows/lastInserts", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS last_insert_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE last_insert_test(id INT AUTO_INCREMENT PRIMARY KEY)]]) + local qu = db:query("INSERT INTO last_insert_test VALUES(); INSERT INTO last_insert_test VALUES(); INSERT INTO last_insert_test VALUES()") + qu:start() + qu:wait() + test:shouldBeEqual(qu:lastInsert(), 1) + test:shouldBeEqual(qu:hasMoreResults(), true) + qu:getNextResults() + test:shouldBeEqual(qu:lastInsert(), 2) + test:shouldBeEqual(qu:hasMoreResults(), true) + qu:getNextResults() + test:shouldBeEqual(qu:lastInsert(), 3) + test:shouldBeEqual(qu:hasMoreResults(), true) + qu:getNextResults() + test:shouldBeEqual(qu:hasMoreResults(), false) + + local qu2 = db:query("DELETE FROM last_insert_test WHERE id = 1; DELETE FROM last_insert_test") + qu2:start() + qu2:wait() + test:shouldBeEqual(qu2:affectedRows(), 1) + test:shouldBeEqual(qu2:hasMoreResults(), true) + qu2:getNextResults() + test:shouldBeEqual(qu2:affectedRows(), 2) + test:shouldBeEqual(qu2:hasMoreResults(), true) + qu2:getNextResults() + test:shouldBeEqual(qu2:hasMoreResults(), false) + + test:Complete() +end) + +TestFramework:RegisterTest("[Query] call onData correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SELECT 1 as a UNION ALL SELECT 2") + local callCount = 0 + local sum = 0 + function qu:onSuccess(data) //onData is called before onSuccess + test:shouldHaveLength(data, 2) + test:shouldBeEqual(callCount, 2) + test:shouldBeEqual(sum, 3) + test:Complete() + end + function qu:onData(row) + callCount = callCount + 1 + sum = sum + row.a + end + qu:start() +end) + +TestFramework:RegisterTest("[Query] abort query correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local qu = db:query("SELECT SLEEP(1)") //This should block for a bit + qu:start() + local qu2 = db:query("SELECT 1") + qu2:start() + function qu2:onAborted() + test:Complete() + end + test:shouldBeEqual(qu2:abort(), true) + test:shouldBeEqual(qu:abort(), false) +end) \ No newline at end of file diff --git a/IntegrationTest/lua/mysqloo/tests/transaction_test.lua b/IntegrationTest/lua/mysqloo/tests/transaction_test.lua new file mode 100644 index 0000000..6e81465 --- /dev/null +++ b/IntegrationTest/lua/mysqloo/tests/transaction_test.lua @@ -0,0 +1,55 @@ +TestFramework:RegisterTest("[Transaction] should return added queries correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local q1 = db:query("SELECT 1") + local q2 = db:prepare("SELECT ?") + q2:setNumber(2, 2) + local q3 = db:query("SELECT 3") + local transaction = db:createTransaction() + test:shouldHaveLength(transaction:getQueries(), 0) + transaction:addQuery(q1) + transaction:addQuery(q2) + transaction:addQuery(q3) + local queries = transaction:getQueries() + test:shouldHaveLength(transaction:getQueries(), 3) + test:shouldBeEqual(queries[1], q1) + test:shouldBeEqual(queries[2], q2) + test:shouldBeEqual(queries[3], q3) + test:Complete() +end) + +TestFramework:RegisterTest("[Transaction] run transaction with same query correctly", function(test) + local db = TestFramework:ConnectToDatabase() + local transaction = db:createTransaction() + local qu = db:prepare("SELECT ? as a") + qu:setNumber(1, 1) + transaction:addQuery(qu) + qu:setNumber(1, 3) + transaction:addQuery(qu) + function transaction:onSuccess(data) + test:shouldHaveLength(data, 2) + test:shouldBeEqual(data[1][1].a, 1) + test:shouldBeEqual(data[2][1].a, 3) + test:Complete() + end + transaction:start() + transaction:wait() +end) + +TestFramework:RegisterTest("[Transaction] rollback failure correctly", function(test) + local db = TestFramework:ConnectToDatabase() + TestFramework:RunQuery(db, [[DROP TABLE IF EXISTS transaction_test]]) + TestFramework:RunQuery(db, [[CREATE TABLE transaction_test(id INT AUTO_INCREMENT PRIMARY KEY)]]) + local transaction = db:createTransaction() + local qu = db:query("INSERT INTO transaction_test VALUES()") + local qu2 = db:query("gfdgdg") + transaction:addQuery(qu) + transaction:addQuery(qu2) + function transaction:onError() + local qu3 = db:query("SELECT * FROM transaction_test") + qu3:start() + qu3:wait() + test:shouldHaveLength(qu3, 0) + test:Complete() + end + transaction:start() +end) \ No newline at end of file diff --git a/README.md b/README.md index 4776926..8880515 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ This module is an almost entirely rewritten version of MySQLOO 8.1. It supports several new features such as multiple result sets, prepared queries and transactions. The module also fixed the memory leak issues the previous versions of MySQLOO had. -For further information please [visit this forum thread](https://forum.facepunch.com/f/gmodaddon/jjdq/gmsv-mysqloo-v9-Rewritten-MySQL-Module-prepared-statements-transactions/1/). - # Install instructions Download the latest module for your server's operating system and architecture using the links provided below, then place that file within the `garrysmod/lua/bin/` folder on your server. If the `bin` folder doesn't exist, please create it. @@ -132,10 +130,10 @@ Database:ping() -- returns true if the connection is still up, false otherwise Database:setCharacterSet(charSetName) --- Returns [Boolean, String] +-- Returns [Boolean] -- Attempts to set the connection's character set to the one specified. -- Please note that this does block the main server thread if there is a query currently being ran --- Returns true on success, false and an error message on failure +-- Returns true on success, false on failure Database:setSSL(key, cert, ca, capath, cipher) -- Returns nothing @@ -258,7 +256,7 @@ PreparedQuery:clearParameters() PreparedQuery:putNewParameters() -- Returns nothing --- This shouldn't be used anymore, just start the same prepared multiple times with different parameters +-- Deprecated: Start the same prepared statement multiple times instead -- Transaction object @@ -289,17 +287,50 @@ Transaction.onSuccess() # Build instructions: -To build the project you first need to generate the appropriate solution for your system using [premake](https://premake.github.io/download.html). +This project uses [CMake](https://cmake.org/) as a build system. +## Windows + +### Visual Studio +Visual Studio has support for CMake since Visual Studio 2017. To open the project, run Visual Studio and under `File > Open > CMake...` +select the CMakeList.txt from this directory. + +The CMakeSettings.json in this project should already define both a 32 and 64 bit configuration. +You can add new configurations in the combo box that contains the x64 config. Here you can change the build type to Release or RelWithDebInfo and duplicate the config +for a 32 bit build. + +To build the project, you can then simply run `` from the toolbar. The output files are placed in the `out/build/{ConfigurationName}/` subfolder +of this project. + + +### CLion +Simply open the project in CLion and import the CMake project. Assuming you have a [valid toolchain](https://www.jetbrains.com/help/clion/how-to-create-toolchain-in-clion.html) setup, +you can simply build the project using `Build > Build Project` in the toolbar. + +To compile for 32 bit rather than 64 bit, you can select a 32 bit VS toolchain, rather than the 64 bit one. + +The output files are placed within the `cmake-build-debug/` directory of this project. + + +## Linux + + +### Prerequisites +To compile the project, you will need CMake and a functioning c++ compiler. For example, under Ubuntu, the following packages +can be used to compile the module. +```bash +sudo apt install build-essential gcc-multilib cmake ``` -premake5 --os=windows --file=BuildProjects.lua vs2017 -premake5 --os=macosx --file=BuildProjects.lua gmake -premake5 --os=linux --file=BuildProjects.lua gmake -``` -Then building MySQLOO should be as easy as either running make (linux) or pressing the build project button in Visual Studio (windows). -**Note**: To build MySQLOO in 64-bit, run `make config=release_x86_64` +### Compiling +To compile the module, follow the following steps: +- enter the project directory and run `cmake .` in bash. +- in the same directory run `make` in bash. +- The module should be compiled and the resulting binary should be placed directly in the project directory. -**Note**: On Linux you might have to install some additional libraries required in the linking process, but I personally have not experienced any such issues. -**Note:** Mac is currently not supported since the MariaDB connector is not available on mac (at least not precompiled). + +## Mac +Mac is currently not supported since the MariaDB connector is not available on Mac (at least not precompiled). +However, if you are able to compile the connector yourself, building for Mac should broadly follow the same instructions +as for Linux. \ No newline at end of file diff --git a/src/lua/GMModule.cpp b/src/lua/GMModule.cpp index bda48c1..d1d4c4d 100644 --- a/src/lua/GMModule.cpp +++ b/src/lua/GMModule.cpp @@ -72,7 +72,7 @@ static int printOutdatedVersion(lua_State *state) { printMessage(LUA, "Your server is using an outdated mysqloo9 version\n", 255, 0, 0); printMessage(LUA, "Download the latest version from here:\n", 255, 0, 0); printMessage(LUA, "https://github.com/FredyH/MySQLOO/releases\n", 86, 156, 214); - runInTimer(LUA, 300, printOutdatedVersion); + runInTimer(LUA, 3600, printOutdatedVersion); return 0; } diff --git a/src/lua/LuaDatabase.cpp b/src/lua/LuaDatabase.cpp index 2795424..bd0c1aa 100644 --- a/src/lua/LuaDatabase.cpp +++ b/src/lua/LuaDatabase.cpp @@ -142,6 +142,9 @@ MYSQLOO_LUA_FUNCTION(disconnect) { wait = LUA->GetBool(2); } database->m_database->disconnect(wait); + if (wait) { + database->think(LUA); //To set callback data, run callbacks + } return 0; } @@ -288,6 +291,7 @@ void LuaDatabase::think(ILuaBase *LUA) { //Connection callbacks auto database = this->m_database; if (database->isConnectionDone() && !this->m_dbCallbackRan && this->m_tableReference != 0) { + this->m_dbCallbackRan = true; LUA->ReferencePush(this->m_tableReference); if (database->connectionSuccessful()) { LUA->GetField(-1, "onConnected"); @@ -308,7 +312,6 @@ void LuaDatabase::think(ILuaBase *LUA) { } LUA->ReferenceFree(this->m_tableReference); - this->m_dbCallbackRan = true; this->m_tableReference = 0; } diff --git a/src/lua/LuaIQuery.cpp b/src/lua/LuaIQuery.cpp index dfde04d..a1929f1 100644 --- a/src/lua/LuaIQuery.cpp +++ b/src/lua/LuaIQuery.cpp @@ -7,7 +7,7 @@ MYSQLOO_LUA_FUNCTION(start) { auto query = LuaIQuery::getLuaObject(LUA); - auto queryData = query->buildQueryData(LUA, 1); + auto queryData = query->buildQueryData(LUA, 1, true); query->m_query->start(queryData); return 0; } @@ -165,7 +165,7 @@ void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr &iQuery case QUERY_SUCCESS: if (auto query = std::dynamic_pointer_cast(iQuery)) { LuaQuery::runSuccessCallback(LUA, query, std::dynamic_pointer_cast(data)); - } else if (auto transaction = std::dynamic_pointer_cast(query)) { + } else if (auto transaction = std::dynamic_pointer_cast(iQuery)) { LuaTransaction::runSuccessCallback(LUA, transaction, std::dynamic_pointer_cast(data)); } break; diff --git a/src/lua/LuaIQuery.h b/src/lua/LuaIQuery.h index 032a4da..42c5917 100644 --- a/src/lua/LuaIQuery.h +++ b/src/lua/LuaIQuery.h @@ -20,7 +20,7 @@ public: int m_databaseReference = 0; //The table is at the top - virtual std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition) = 0; + virtual std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition, bool shouldRef) = 0; static void referenceCallbacks(ILuaBase *LUA, int stackPosition, IQueryData &data); diff --git a/src/lua/LuaPreparedQuery.cpp b/src/lua/LuaPreparedQuery.cpp index 958686c..589cf1f 100644 --- a/src/lua/LuaPreparedQuery.cpp +++ b/src/lua/LuaPreparedQuery.cpp @@ -82,9 +82,11 @@ void LuaPreparedQuery::createMetaTable(ILuaBase *LUA) { LUA->Pop(); //Metatable } -std::shared_ptr LuaPreparedQuery::buildQueryData(ILuaBase* LUA, int stackPosition) { +std::shared_ptr LuaPreparedQuery::buildQueryData(ILuaBase* LUA, int stackPosition, bool shouldRef) { auto query = (PreparedQuery*) m_query.get(); auto data = query->buildQueryData(); - LuaIQuery::referenceCallbacks(LUA, stackPosition, *data); + if (shouldRef) { + LuaIQuery::referenceCallbacks(LUA, stackPosition, *data); + } return data; } diff --git a/src/lua/LuaPreparedQuery.h b/src/lua/LuaPreparedQuery.h index d441fb5..e4534af 100644 --- a/src/lua/LuaPreparedQuery.h +++ b/src/lua/LuaPreparedQuery.h @@ -8,7 +8,7 @@ class LuaPreparedQuery : public LuaQuery { public: - std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition) override; + std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition, bool shouldRef) override; static void createMetaTable(ILuaBase *LUA); diff --git a/src/lua/LuaQuery.cpp b/src/lua/LuaQuery.cpp index 661db06..4d3ae0c 100644 --- a/src/lua/LuaQuery.cpp +++ b/src/lua/LuaQuery.cpp @@ -180,11 +180,13 @@ void LuaQuery::createMetaTable(ILuaBase *LUA) { LUA->Pop(); //Metatable } -std::shared_ptr LuaQuery::buildQueryData(ILuaBase *LUA, int stackPosition) { +std::shared_ptr LuaQuery::buildQueryData(ILuaBase *LUA, int stackPosition, bool shouldRef) { auto query = std::dynamic_pointer_cast(this->m_query); auto data = query->buildQueryData(); data->setStatus(QUERY_COMPLETE); - LuaIQuery::referenceCallbacks(LUA, stackPosition, *data); + if (shouldRef) { + LuaIQuery::referenceCallbacks(LUA, stackPosition, *data); + } return data; } diff --git a/src/lua/LuaQuery.h b/src/lua/LuaQuery.h index 439557c..58ac2de 100644 --- a/src/lua/LuaQuery.h +++ b/src/lua/LuaQuery.h @@ -17,7 +17,7 @@ public: static void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr& query, const std::shared_ptr &data); - std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition) override; + std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition, bool shouldRef) override; void onDestroyedByLua(ILuaBase *LUA) override; diff --git a/src/lua/LuaTransaction.cpp b/src/lua/LuaTransaction.cpp index 083901c..5ef0d26 100644 --- a/src/lua/LuaTransaction.cpp +++ b/src/lua/LuaTransaction.cpp @@ -22,7 +22,7 @@ MYSQLOO_LUA_FUNCTION(addQuery) { LUA->Call(2, 0); LUA->Pop(4); - auto queryData = std::dynamic_pointer_cast(addedLuaQuery->buildQueryData(LUA, 2)); + auto queryData = std::dynamic_pointer_cast(addedLuaQuery->buildQueryData(LUA, 2, false)); luaTransaction->m_addedQueryData.push_back(queryData); return 0; @@ -30,6 +30,10 @@ MYSQLOO_LUA_FUNCTION(addQuery) { MYSQLOO_LUA_FUNCTION(getQueries) { LUA->GetField(1, "__queries"); + if (LUA->IsType(-1, GarrysMod::Lua::Type::Nil)) { + LUA->Pop(); + LUA->CreateTable(); + } return 1; } @@ -59,7 +63,7 @@ void LuaTransaction::createMetaTable(ILuaBase *LUA) { LUA->Pop(); } -std::shared_ptr LuaTransaction::buildQueryData(ILuaBase *LUA, int stackPosition) { +std::shared_ptr LuaTransaction::buildQueryData(ILuaBase *LUA, int stackPosition, bool shouldRef) { LUA->GetField(stackPosition, "__queries"); std::deque, std::shared_ptr>> queries; if (LUA->GetType(-1) != GarrysMod::Lua::Type::Nil) { @@ -79,8 +83,11 @@ std::shared_ptr LuaTransaction::buildQueryData(ILuaBase *LUA, int st } } LUA->Pop(); //Queries table + auto data = Transaction::buildQueryData(queries); - LuaIQuery::referenceCallbacks(LUA, stackPosition, *data); + if (shouldRef) { + LuaIQuery::referenceCallbacks(LUA, stackPosition, *data); + } return data; } diff --git a/src/lua/LuaTransaction.h b/src/lua/LuaTransaction.h index 4d42f0d..56be789 100644 --- a/src/lua/LuaTransaction.h +++ b/src/lua/LuaTransaction.h @@ -9,7 +9,7 @@ class LuaTransaction : public LuaIQuery { public: std::deque> m_addedQueryData = {}; - std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition) override; + std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition, bool shouldRef) override; static void createMetaTable(ILuaBase *LUA); diff --git a/src/mysql/Database.cpp b/src/mysql/Database.cpp index 47932ce..3c10d6f 100644 --- a/src/mysql/Database.cpp +++ b/src/mysql/Database.cpp @@ -142,7 +142,7 @@ bool Database::setCharacterSet(const std::string &characterSet) { //This mutex makes sure we can safely use the connection to run the query std::unique_lock lk2(m_queryMutex); if (mysql_set_character_set(m_sql, characterSet.c_str())) { - return false; //TODO: Also return error? + return false; } else { return true; } @@ -368,22 +368,22 @@ void Database::run() { }); while (true) { auto pair = this->queryQueue.take(); - //This detects the poison pill that is supposed to shutdown the database + //This detects the poison pill that is supposed to shut down the database if (pair.first == nullptr) { return; } - auto curquery = pair.first; + auto curQuery = pair.first; auto data = pair.second; { //New scope so mutex will be released as soon as possible std::unique_lock queryMutex(m_queryMutex); - curquery->executeStatement(*this, this->m_sql, data); + curQuery->executeStatement(*this, this->m_sql, data); } data->setFinished(true); finishedQueries.put(pair); { - std::unique_lock queryMutex(curquery->m_waitMutex); - curquery->m_waitWakeupVariable.notify_one(); + std::unique_lock queryMutex(curQuery->m_waitMutex); + curQuery->m_waitWakeupVariable.notify_one(); } //So that statements get eventually freed even if the queue is constantly full freeUnusedStatements();