Advertisement
Guest User

Untitled

a guest
Jun 25th, 2017
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.57 KB | None | 0 0
  1. local function stub() end
  2. local function untab(str)
  3. local pref = "\n" .. string.rep("\t", math.max(#str:match("^[\r\n]*(\t*)"), 1))
  4. return string.gsub("\n" .. str, pref, "\n"):sub(2):gsub("[\t\r\n]*$", "")
  5. end
  6. local function ms(i, j)
  7. return (math.random(i or 25, j or 100) + math.random(i or 25, j or 100))/2*0.001
  8. end
  9. local COL = SERVER and Color(137, 222, 255) or Color(231, 219, 116)
  10.  
  11. dbal = {ASYNC = stub}
  12.  
  13. --[[
  14. dbal.new(dbtype, host, user, pass, dbname, port = 3306, aliases = {})
  15. Creates a new database object.
  16. - dbtype: "sqlite" or "mysqloo"
  17. - aliases: key-value table of SQL aliases, ex.
  18. "select * from $key" -> "select * from value"
  19. ]]
  20. function dbal.new(dbtype, host, user, pass, name, port, aliases)
  21. assert(dbtype == "sqlite" or dbtype == "mysqloo", "unsupported dbtype: " .. dbtype)
  22. -- because mysqloo has horrible error reporting
  23. assert(type(host) == "string")
  24. assert(type(user) == "string")
  25. assert(type(pass) == "string")
  26. assert(type(name) == "string")
  27. assert(port == nil or type(port) == "number")
  28. local db = {
  29. Type = dbtype,
  30. Host = host,
  31. User = user,
  32. Pass = pass,
  33. Name = name,
  34. Port = port or 3306,
  35. Aliases = table.Copy(aliases or {}),
  36. }
  37.  
  38. for k, v in pairs(dbal) do
  39. if dbal[k .. "_" .. dbtype] then
  40. db[k] = dbal[k .. "_" .. dbtype]
  41. end
  42. end
  43.  
  44. -- hack?
  45. db.Aliases.affected_rows = db.ChangesFn
  46. db.Aliases.last_insert_id = db.LastInsertFn
  47.  
  48. setmetatable(db, {
  49. __index = dbal,
  50. __call = dbal.Query,
  51. __tostring = dbal.__tostring
  52. })
  53.  
  54. print(("SQL: Connecting to %s on %s@%s:%i. (%s)"):format(name, user, host, port or 3306, dbtype))
  55. db:Connect()
  56. print("SQL: Connected.")
  57.  
  58. return db
  59. end
  60.  
  61. -- Called automatically to connect to the database.
  62. -- Don't call this yourself.
  63. function dbal:Connect()
  64. require("mysqloo")
  65. self.DB = mysqloo.connect(self.Host, self.User, self.Pass, self.Name, self.Port)
  66.  
  67. -- kind of ghetto
  68. local err
  69. function self.DB:onConnectionFailed(e) err = e end
  70. self.DB:connect()
  71. self.DB:wait()
  72.  
  73. assert(not err, "Connection failed!\n" .. tostring(err))
  74. end
  75.  
  76. -- Disconnects from the server.
  77. function dbal:Disconnect()
  78. print(("SQL: Disconnected from %s on %s@%s:%i. (%s)"):format(self.Name, self.User, self.Host, self.Port, self.Type))
  79. self.DB = nil
  80. end
  81.  
  82. dbal.Connect_sqlite = stub
  83. dbal.Disconnect_sqlite = stub
  84.  
  85. -- Escapes a value. Used by db:Query() to escape its arguments.
  86. function dbal:Escape(v)
  87. if v == true or v == false then
  88. return v and 1 or 0
  89. elseif type(v) == "number" then
  90. return v
  91. elseif v == nil or v == NULL then
  92. return "NULL"
  93. elseif type(v) == "function" then
  94. local fi = debug.getinfo(v)
  95. error(("tried to escape function: %s:%i"):format(fi.source, fi.linedefined))
  96. end
  97. return self:EscapeStr(tostring(v))
  98. end
  99. function dbal:EscapeStr(v) return "'" .. self.DB:escape(v) .. "'" end
  100. function dbal:EscapeStr_sqlite(v) return sql.SQLStr(v) end
  101.  
  102. -- Substitutes table aliases, replaces ?'s with varargs.
  103. -- Returns a ready-to-use query and how many tokens were replaced.
  104. function dbal:Format(query, ...)
  105. query = query:gsub("$([a-z_]+)", self.Aliases)
  106. local args = {...}
  107. local i = 0
  108. local function repl()
  109. i = i + 1
  110. return self:Escape(args[i])
  111. end
  112. return query:gsub("%?", repl)
  113. end
  114.  
  115. --[[
  116. Safely formats the given query and runs it. If a callback function is
  117. given after the format arguments, the query is run NON-BLOCKING:
  118.  
  119. 1. db:Query() returns nil
  120. 2. when query finishes, callback is called as function(results, ...)
  121.  
  122. Usage:
  123.  
  124. local function cb(result, start)
  125. printf("Query got %i rows and took %i seconds.",
  126. #result, os.time() - start
  127. )
  128. end
  129. db:Query("SELECT * FROM $players WHERE steamid = ?",
  130. ply:SteamID(), cb, os.time()
  131. )
  132. ]]
  133. function dbal:Query(fmt, ...)
  134. if type(fmt) == "table" then
  135. return self:DoTransaction(fmt, ...)
  136. end
  137. local query, n = self:Format(fmt, ...)
  138. if self.DEBUG then
  139. MsgC(Color(160, 0, 160), untab(query), "\n")
  140. end
  141.  
  142. -- Format() removes values it uses from the table,
  143. -- so anything left is callback + cb args
  144. local args = {...}
  145. local cb = args[n + 1]
  146. assert(cb == nil or type(cb) == "function", "callback must be a function (arg #" .. n + 1 .. ")")
  147.  
  148. return (self:RawQuery(query, unpack(args, n + 1)))
  149. end
  150.  
  151. local function getResults(self, q)
  152. if (tonumber(mysqloo.MINOR_VERSION) or 0) >= 1 then
  153. local res = {}
  154. while q:hasMoreResults() do
  155. res[#res + 1] = q:getData()
  156. q:getNextResults()
  157. end
  158. if self.MULTIRES then
  159. return res
  160. end
  161. return res[#res]
  162. else
  163. local res = {q:getData()}
  164. while q:hasMoreResults() do
  165. res[#res + 1] = q:getNextResults()
  166. end
  167. if self.MULTIRES then
  168. return res
  169. end
  170. return res[#res]
  171. end
  172. end
  173.  
  174. -- Like db:Query(), but doesn't do query formatting.
  175. -- You probably shouldn't touch this.
  176. function dbal:RawQuery(query, cb, ...)
  177. local q = self.DB:query(query)
  178.  
  179. local tb = debug.traceback()
  180. local err
  181. function q:onError(e)
  182. err = e
  183. MsgC(COL, "--- SQL ERROR ---\n")
  184. MsgC(COL, "Query:\n", untab(query), "\n")
  185. MsgC(COL, "Error: ", err, "\n", tb, "\n")
  186. end
  187. function q:onAborted()
  188. MsgC(COL, "--- SQL QUERY ABORTED ---\n")
  189. MsgC(COL, "Query:\n", untab(query), "\n")
  190. end
  191.  
  192. if cb then
  193. -- callback was provided - run nonblocking
  194. local args = {...}
  195.  
  196. function q.onSuccess(this)
  197. assert(xpcall(cb, debug.traceback, getResults(self, this), unpack(args)))
  198. end
  199. q:start()
  200.  
  201. return nil
  202. else
  203. -- bad idea? of course!
  204. local cr = coroutine.running()
  205. if cr then
  206. -- running in a coroutine
  207. -- yield, and resume when the query completes
  208. function q.onSuccess(this)
  209. local ok, res = coroutine.resume(cr, getResults(self, this))
  210. if not ok then
  211. error("] " .. debug.traceback(cr, res))
  212. end
  213. end
  214. q:start()
  215.  
  216. return coroutine.yield()
  217. else
  218. -- running normally, no callback provided
  219. -- block until the query finishes
  220. q:start()
  221. q:wait()
  222. assert(not err, err)
  223. return getResults(self, q)
  224. end
  225. end
  226. end
  227.  
  228. function dbal:RawQuery_sqlite(query, cb, ...)
  229. local data = sql.Query(query)
  230. local err = sql.LastError()
  231.  
  232. local tb = debug.traceback()
  233. if cb then
  234. local args = {...}
  235.  
  236. timer.Simple(ms(25, 1000), function()
  237. -- check for errors here so we get the error, but not in this context
  238. if data == false then
  239. MsgC(COL, "--- SQL ERROR ---\n")
  240. MsgC(COL, "Query:\n", query, "\n")
  241. MsgC(COL, "Error: ", err, "\n", tb, "\n")
  242. end
  243. local ok, err = xpcall(cb, debug.traceback, data or {}, unpack(args))
  244. assert(ok, "] " .. tostring(err))
  245. end)
  246.  
  247. return nil
  248. else
  249. if data == false then
  250. MsgC(COL, "--- SQL ERROR ---\n")
  251. MsgC(COL, "Query:\n", query, "\n")
  252. MsgC(COL, "Error: ", err, "\n", tb, "\n")
  253. end
  254.  
  255. -- bad idea? of course!
  256. local cr = coroutine.running()
  257. if cr then
  258. timer.Simple(ms(), function()
  259. local ok, res = coroutine.resume(cr, data or {})
  260. if not ok then
  261. error("] " .. debug.traceback(cr, res))
  262. end
  263. end)
  264.  
  265. return coroutine.yield() or {}
  266. else
  267. return data or {}
  268. end
  269. end
  270. end
  271.  
  272. -- Like sql.QueryRow, returns the first row of the result set.
  273. function dbal:QueryRow(query, ...)
  274. return self:Query(query, ...)[1]
  275. end
  276.  
  277. -- Like sql.QueryValue, returns the first value in the result set.
  278. -- You should probably only select a single column for this.
  279. function dbal:QueryValue(query, ...)
  280. local key, val = next(self:Query(query, ...)[1] or {})
  281. return val
  282. end
  283.  
  284. -- Transactions. Completely untested.
  285. -- eg. db:DoTransaction({"select ?, ?", 1, 2}, {"update $tbl set col=?", "foo"}, cb, cbargs...)
  286. function dbal:DoTransaction(...)
  287. local args = {...}
  288. local tr = self.DB:createTransaction()
  289.  
  290. -- first few arguments should be {fmt, fmtargs...}
  291. local n = 0
  292. for i, v in ipairs(args) do
  293. if type(v) ~= "table" then break end
  294. local query = self:Format(table.remove(v, 1), v)
  295. tr:addQuery(self.DB:query(query))
  296. n = n + 1
  297. end
  298.  
  299. local tb = debug.traceback()
  300. function tr:onError(err)
  301. error("SQL transaction failed!\nError: %s\n%s\n", err, tb)
  302. end
  303.  
  304. local cb = args[n]
  305. assert(cb == nil or type(cb) == "function", "callback must be a function")
  306. if cb then
  307. function tr.onSuccess(this)
  308. for i, v in ipairs(this:getQueries()) do
  309. table.insert(args, n + i + 1, getResults(self, v))
  310. end
  311. assert(xpcall(cb, debug.traceback, unpack(args, n + 2)))
  312. end
  313. end
  314.  
  315. tr:start()
  316. end
  317.  
  318. function dbal:DoTransaction_sqlite(...)
  319. local args = {...}
  320. local queries = {}
  321.  
  322. -- first few arguments should be {fmt, fmtargs...}
  323. for i, v in ipairs(args) do
  324. if type(v) ~= "table" then break end
  325. local query = self:Format(table.remove(v, 1), v)
  326. queries[#queries + 1] = query
  327. end
  328.  
  329. local tb = debug.traceback()
  330. local cb = args[#queries + 1]
  331.  
  332. sql.Query("BEGIN")
  333. for i, v in ipairs(queries) do
  334. local res = sql.Query(v)
  335. if res == false then
  336. sql.Query("ROLLBACK")
  337. end
  338. end
  339. assert(xpcall(cb, debug.traceback, unpack(args, n + 1)))
  340. sql.Query("COMMIT")
  341. end
  342.  
  343. --[[
  344. Builds an INSERT query from the given data table, which may be
  345.  
  346. 1. A key-value table, where keys are SQL columns
  347. 2. An array of key-value tables
  348.  
  349. All values are escaped. If cb is given it's used as a callback similar
  350. to db:Query().
  351. ]]
  352. function dbal:Insert(name, data, cb, ...)
  353. name = string.gsub(name, "$([a-z]+)", self.Aliases)
  354. assert(next(data), "data table is empty")
  355. local cols, rows = {}, {}
  356. for k, v in SortedPairs(data[1] or data) do
  357. cols[#cols + 1] = k
  358. end
  359. for k, row in pairs(data[1] and data or {data}) do
  360. local vals = {}
  361. for i, col in ipairs(cols) do
  362. vals[#vals + 1] = self:Escape(row[col])
  363. end
  364. rows[#rows + 1] = "(" .. table.concat(vals, ", ") .. ")"
  365. end
  366.  
  367. local query = string.format("INSERT INTO %s (%s) VALUES\n%s;\nSELECT %s() AS id;",
  368. name, table.concat(cols, ", "), table.concat(rows, ",\n"), self.LastInsertFn
  369. )
  370. local res = self:RawQuery(query, cb, ...)
  371. if res then
  372. return tonumber((res[1] or {}).id)
  373. end
  374. end
  375.  
  376. --[[
  377. Builds an UPDATE query from the given data table, which may be a key-value
  378. table where keys are SQL columns.
  379. ]]
  380. function dbal:Update(name, data, cond, ...)
  381. if not next(data) then return end
  382. name = name:gsub("$([a-z]+)", self.Aliases)
  383. local args, n = {...}
  384.  
  385. local cols = {}
  386. for k, v in pairs(data) do
  387. cols[#cols + 1] = ("`%s` = %s"):format(k, self:Escape(v))
  388. end
  389.  
  390. if cond then
  391. cond, n = self:Format(cond, ...)
  392. cond = " WHERE " .. cond
  393. end
  394.  
  395. self:RawQuery(("UPDATE %s SET %s%s"):format(name, table.concat(cols, ","), cond or ""), unpack(args, n + 1))
  396. end
  397.  
  398. dbal.LastInsertFn = "last_insert_id"
  399. dbal.LastInsertFn_sqlite = "last_insert_rowid"
  400.  
  401. dbal.ChangesFn = "affected_rows"
  402. dbal.ChangesFn_sqlite = "changes"
  403.  
  404. function dbal:LastInsertID()
  405. return tonumber(self:RawQuery("SELECT " .. self.LastInsertFn .. "() AS id")[1].id)
  406. end
  407.  
  408. function dbal:__tostring()
  409. return ("SQL: %s (%s@%s:%i)"):format(self.Name, self.User, self.Host, self.Port)
  410. end
  411.  
  412. function dbal:__tostring_sqlite()
  413. return ("SQL: %s (%s@%s:%i)"):format("sv.db", "sqlite", "localhost", 0)
  414. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement