Advertisement
Guest User

Untitled

a guest
Jul 19th, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 86.61 KB | None | 0 0
  1. --[[
  2. MIT License
  3.  
  4. Copyright (c) 2017 Mark Langen
  5.  
  6. Permission is hereby granted, free of charge, to any person obtaining a copy
  7. of this software and associated documentation files (the "Software"), to deal
  8. in the Software without restriction, including without limitation the rights
  9. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  10. copies of the Software, and to permit persons to whom the Software is
  11. furnished to do so, subject to the following conditions:
  12.  
  13. The above copyright notice and this permission notice shall be included in all
  14. copies or substantial portions of the Software.
  15.  
  16. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  17. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  18. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  19. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  20. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  21. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. SOFTWARE.
  23. ]]
  24.  
  25. function lookupify(tb)
  26. for _, v in pairs(tb) do
  27. tb[v] = true
  28. end
  29. return tb
  30. end
  31.  
  32. function CountTable(tb)
  33. local c = 0
  34. for _ in pairs(tb) do c = c + 1 end
  35. return c
  36. end
  37.  
  38. function FormatTableInt(tb, atIndent, ignoreFunc)
  39. if tb.Print then
  40. return tb.Print()
  41. end
  42. atIndent = atIndent or 0
  43. local useNewlines = (CountTable(tb) > 1)
  44. local baseIndent = string.rep(' ', atIndent+1)
  45. local out = "{"..(useNewlines and '\n' or '')
  46. for k, v in pairs(tb) do
  47. if type(v) ~= 'function' and not ignoreFunc(k) then
  48. out = out..(useNewlines and baseIndent or '')
  49. if type(k) == 'number' then
  50. --nothing to do
  51. elseif type(k) == 'string' and k:match("^[A-Za-z_][A-Za-z0-9_]*$") then
  52. out = out..k.." = "
  53. elseif type(k) == 'string' then
  54. out = out.."[\""..k.."\"] = "
  55. else
  56. out = out.."["..tostring(k).."] = "
  57. end
  58. if type(v) == 'string' then
  59. out = out.."\""..v.."\""
  60. elseif type(v) == 'number' then
  61. out = out..v
  62. elseif type(v) == 'table' then
  63. out = out..FormatTableInt(v, atIndent+(useNewlines and 1 or 0), ignoreFunc)
  64. else
  65. out = out..tostring(v)
  66. end
  67. if next(tb, k) then
  68. out = out..","
  69. end
  70. if useNewlines then
  71. out = out..'\n'
  72. end
  73. end
  74. end
  75. out = out..(useNewlines and string.rep(' ', atIndent) or '').."}"
  76. return out
  77. end
  78.  
  79. function FormatTable(tb, ignoreFunc)
  80. ignoreFunc = ignoreFunc or function()
  81. return false
  82. end
  83. return FormatTableInt(tb, 0, ignoreFunc)
  84. end
  85.  
  86. local WhiteChars = lookupify{' ', '\n', '\t', '\r'}
  87.  
  88. local EscapeForCharacter = {['\r'] = '\\r', ['\n'] = '\\n', ['\t'] = '\\t', ['"'] = '\\"', ["'"] = "\\'", ['\\'] = '\\'}
  89.  
  90. local CharacterForEscape = {['r'] = '\r', ['n'] = '\n', ['t'] = '\t', ['"'] = '"', ["'"] = "'", ['\\'] = '\\'}
  91.  
  92. local AllIdentStartChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
  93. 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
  94. 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
  95. 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
  96. 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
  97. 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_'}
  98.  
  99. local AllIdentChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
  100. 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
  101. 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
  102. 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
  103. 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
  104. 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_',
  105. '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
  106.  
  107. local Digits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
  108.  
  109. local HexDigits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
  110. 'A', 'a', 'B', 'b', 'C', 'c', 'D', 'd', 'E', 'e', 'F', 'f'}
  111.  
  112. local Symbols = lookupify{'+', '-', '*', '/', '^', '%', ',', '{', '}', '[', ']', '(', ')', ';', '#', '.', ':'}
  113.  
  114. local EqualSymbols = lookupify{'~', '=', '>', '<'}
  115.  
  116. local Keywords = lookupify{
  117. 'and', 'break', 'do', 'else', 'elseif',
  118. 'end', 'false', 'for', 'function', 'goto', 'if',
  119. 'in', 'local', 'nil', 'not', 'or', 'repeat',
  120. 'return', 'then', 'true', 'until', 'while',
  121. };
  122.  
  123. local BlockFollowKeyword = lookupify{'else', 'elseif', 'until', 'end'}
  124.  
  125. local UnopSet = lookupify{'-', 'not', '#'}
  126.  
  127. local BinopSet = lookupify{
  128. '+', '-', '*', '/', '%', '^', '#',
  129. '..', '.', ':',
  130. '>', '<', '<=', '>=', '~=', '==',
  131. 'and', 'or'
  132. }
  133.  
  134. local GlobalRenameIgnore = lookupify{
  135.  
  136. }
  137.  
  138. local BinaryPriority = {
  139. ['+'] = {6, 6};
  140. ['-'] = {6, 6};
  141. ['*'] = {7, 7};
  142. ['/'] = {7, 7};
  143. ['%'] = {7, 7};
  144. ['^'] = {10, 9};
  145. ['..'] = {5, 4};
  146. ['=='] = {3, 3};
  147. ['~='] = {3, 3};
  148. ['>'] = {3, 3};
  149. ['<'] = {3, 3};
  150. ['>='] = {3, 3};
  151. ['<='] = {3, 3};
  152. ['and'] = {2, 2};
  153. ['or'] = {1, 1};
  154. };
  155. local UnaryPriority = 8
  156.  
  157. -- Eof, Ident, Keyword, Number, String, Symbol
  158.  
  159. function CreateLuaTokenStream(text)
  160. -- Tracking for the current position in the buffer, and
  161. -- the current line / character we are on.
  162. local p = 1
  163. local length = #text
  164.  
  165. -- Output buffer for tokens
  166. local tokenBuffer = {}
  167.  
  168. -- Get a character, or '' if at eof
  169. local function look(n)
  170. n = p + (n or 0)
  171. if n <= length then
  172. return text:sub(n, n)
  173. else
  174. return ''
  175. end
  176. end
  177. local function get()
  178. if p <= length then
  179. local c = text:sub(p, p)
  180. p = p + 1
  181. return c
  182. else
  183. return ''
  184. end
  185. end
  186.  
  187. -- Error
  188. local olderr = error
  189. local function error(str)
  190. local q = 1
  191. local line = 1
  192. local char = 1
  193. while q <= p do
  194. if text:sub(q, q) == '\n' then
  195. line = line + 1
  196. char = 1
  197. else
  198. char = char + 1
  199. end
  200. q = q + 1
  201. end
  202. for _, token in pairs(tokenBuffer) do
  203. print(token.Type.."<"..token.Source..">")
  204. end
  205. olderr("file<"..line..":"..char..">: "..str)
  206. end
  207.  
  208. -- Consume a long data with equals count of `eqcount'
  209. local function longdata(eqcount)
  210. while true do
  211. local c = get()
  212. if c == '' then
  213. error("Unfinished long string.")
  214. elseif c == ']' then
  215. local done = true -- Until contested
  216. for i = 1, eqcount do
  217. if look() == '=' then
  218. p = p + 1
  219. else
  220. done = false
  221. break
  222. end
  223. end
  224. if done and get() == ']' then
  225. return
  226. end
  227. end
  228. end
  229. end
  230.  
  231. -- Get the opening part for a long data `[` `=`* `[`
  232. -- Precondition: The first `[` has been consumed
  233. -- Return: nil or the equals count
  234. local function getopen()
  235. local startp = p
  236. while look() == '=' do
  237. p = p + 1
  238. end
  239. if look() == '[' then
  240. p = p + 1
  241. return p - startp - 1
  242. else
  243. p = startp
  244. return nil
  245. end
  246. end
  247.  
  248. -- Add token
  249. local whiteStart = 1
  250. local tokenStart = 1
  251. local function token(type)
  252. local tk = {
  253. Type = type;
  254. LeadingWhite = text:sub(whiteStart, tokenStart-1);
  255. Source = text:sub(tokenStart, p-1);
  256. }
  257. table.insert(tokenBuffer, tk)
  258. whiteStart = p
  259. tokenStart = p
  260. return tk
  261. end
  262.  
  263. -- Parse tokens loop
  264. while true do
  265. -- Mark the whitespace start
  266. whiteStart = p
  267.  
  268. -- Get the leading whitespace + comments
  269. while true do
  270. local c = look()
  271. if c == '' then
  272. break
  273. elseif c == '-' then
  274. if look(1) == '-' then
  275. p = p + 2
  276. -- Consume comment body
  277. if look() == '[' then
  278. p = p + 1
  279. local eqcount = getopen()
  280. if eqcount then
  281. -- Long comment body
  282. longdata(eqcount)
  283. else
  284. -- Normal comment body
  285. while true do
  286. local c2 = get()
  287. if c2 == '' or c2 == '\n' then
  288. break
  289. end
  290. end
  291. end
  292. else
  293. -- Normal comment body
  294. while true do
  295. local c2 = get()
  296. if c2 == '' or c2 == '\n' then
  297. break
  298. end
  299. end
  300. end
  301. else
  302. break
  303. end
  304. elseif WhiteChars[c] then
  305. p = p + 1
  306. else
  307. break
  308. end
  309. end
  310. local leadingWhite = text:sub(whiteStart, p-1)
  311.  
  312. -- Mark the token start
  313. tokenStart = p
  314.  
  315. -- Switch on token type
  316. local c1 = get()
  317. if c1 == '' then
  318. -- End of file
  319. token('Eof')
  320. break
  321. elseif c1 == '\'' or c1 == '\"' then
  322. -- String constant
  323. while true do
  324. local c2 = get()
  325. if c2 == '\\' then
  326. local c3 = get()
  327. local esc = CharacterForEscape[c3]
  328. if not esc then
  329. error("Invalid Escape Sequence `"..c3.."`.")
  330. end
  331. elseif c2 == c1 then
  332. break
  333. end
  334. end
  335. token('String')
  336. elseif AllIdentStartChars[c1] then
  337. -- Ident or Keyword
  338. while AllIdentChars[look()] do
  339. p = p + 1
  340. end
  341. if Keywords[text:sub(tokenStart, p-1)] then
  342. token('Keyword')
  343. else
  344. token('Ident')
  345. end
  346. elseif Digits[c1] or (c1 == '.' and Digits[look()]) then
  347. -- Number
  348. if c1 == '0' and look() == 'x' then
  349. p = p + 1
  350. -- Hex number
  351. while HexDigits[look()] do
  352. p = p + 1
  353. end
  354. else
  355. -- Normal Number
  356. while Digits[look()] do
  357. p = p + 1
  358. end
  359. if look() == '.' then
  360. -- With decimal point
  361. p = p + 1
  362. while Digits[look()] do
  363. p = p + 1
  364. end
  365. end
  366. if look() == 'e' or look() == 'E' then
  367. -- With exponent
  368. p = p + 1
  369. if look() == '-' then
  370. p = p + 1
  371. end
  372. while Digits[look()] do
  373. p = p + 1
  374. end
  375. end
  376. end
  377. token('Number')
  378. elseif c1 == '[' then
  379. -- '[' Symbol or Long String
  380. local eqCount = getopen()
  381. if eqCount then
  382. -- Long string
  383. longdata(eqCount)
  384. token('String')
  385. else
  386. -- Symbol
  387. token('Symbol')
  388. end
  389. elseif c1 == '.' then
  390. -- Greedily consume up to 3 `.` for . / .. / ... tokens
  391. if look() == '.' then
  392. get()
  393. if look() == '.' then
  394. get()
  395. end
  396. end
  397. token('Symbol')
  398. elseif EqualSymbols[c1] then
  399. if look() == '=' then
  400. p = p + 1
  401. end
  402. token('Symbol')
  403. elseif Symbols[c1] then
  404. token('Symbol')
  405. else
  406. error("Bad symbol `"..c1.."` in source.")
  407. end
  408. end
  409. return tokenBuffer
  410. end
  411.  
  412. function CreateLuaParser(text)
  413. -- Token stream and pointer into it
  414. local tokens = CreateLuaTokenStream(text)
  415. -- for _, tok in pairs(tokens) do
  416. -- print(tok.Type..": "..tok.Source)
  417. -- end
  418. local p = 1
  419.  
  420. local function get()
  421. local tok = tokens[p]
  422. if p < #tokens then
  423. p = p + 1
  424. end
  425. return tok
  426. end
  427. local function peek(n)
  428. n = p + (n or 0)
  429. return tokens[n] or tokens[#tokens]
  430. end
  431.  
  432. local function getTokenStartPosition(token)
  433. local line = 1
  434. local char = 0
  435. local tkNum = 1
  436. while true do
  437. local tk = tokens[tkNum]
  438. local text;
  439. if tk == token then
  440. text = tk.LeadingWhite
  441. else
  442. text = tk.LeadingWhite..tk.Source
  443. end
  444. for i = 1, #text do
  445. local c = text:sub(i, i)
  446. if c == '\n' then
  447. line = line + 1
  448. char = 0
  449. else
  450. char = char + 1
  451. end
  452. end
  453. if tk == token then
  454. break
  455. end
  456. tkNum = tkNum + 1
  457. end
  458. return line..":"..(char+1)
  459. end
  460. local function debugMark()
  461. local tk = peek()
  462. return "<"..tk.Type.." `"..tk.Source.."`> at: "..getTokenStartPosition(tk)
  463. end
  464.  
  465. local function isBlockFollow()
  466. local tok = peek()
  467. return tok.Type == 'Eof' or (tok.Type == 'Keyword' and BlockFollowKeyword[tok.Source])
  468. end
  469. local function isUnop()
  470. return UnopSet[peek().Source] or false
  471. end
  472. local function isBinop()
  473. return BinopSet[peek().Source] or false
  474. end
  475. local function expect(type, source)
  476. local tk = peek()
  477. if tk.Type == type and (source == nil or tk.Source == source) then
  478. return get()
  479. else
  480. for i = -3, 3 do
  481. print("Tokens["..i.."] = `"..peek(i).Source.."`")
  482. end
  483. if source then
  484. error(getTokenStartPosition(tk)..": `"..source.."` expected.")
  485. else
  486. error(getTokenStartPosition(tk)..": "..type.." expected.")
  487. end
  488. end
  489. end
  490.  
  491. local function MkNode(node)
  492. local getf = node.GetFirstToken
  493. local getl = node.GetLastToken
  494. function node:GetFirstToken()
  495. local t = getf(self)
  496. assert(t)
  497. return t
  498. end
  499. function node:GetLastToken()
  500. local t = getl(self)
  501. assert(t)
  502. return t
  503. end
  504. return node
  505. end
  506.  
  507. -- Forward decls
  508. local block;
  509. local expr;
  510.  
  511. -- Expression list
  512. local function exprlist()
  513. local exprList = {}
  514. local commaList = {}
  515. table.insert(exprList, expr())
  516. while peek().Source == ',' do
  517. table.insert(commaList, get())
  518. table.insert(exprList, expr())
  519. end
  520. return exprList, commaList
  521. end
  522.  
  523. local function prefixexpr()
  524. local tk = peek()
  525. if tk.Source == '(' then
  526. local oparenTk = get()
  527. local inner = expr()
  528. local cparenTk = expect('Symbol', ')')
  529. return MkNode{
  530. Type = 'ParenExpr';
  531. Expression = inner;
  532. Token_OpenParen = oparenTk;
  533. Token_CloseParen = cparenTk;
  534. GetFirstToken = function(self)
  535. return self.Token_OpenParen
  536. end;
  537. GetLastToken = function(self)
  538. return self.Token_CloseParen
  539. end;
  540. }
  541. elseif tk.Type == 'Ident' then
  542. return MkNode{
  543. Type = 'VariableExpr';
  544. Token = get();
  545. GetFirstToken = function(self)
  546. return self.Token
  547. end;
  548. GetLastToken = function(self)
  549. return self.Token
  550. end;
  551. }
  552. else
  553. print(debugMark())
  554. error(getTokenStartPosition(tk)..": Unexpected symbol")
  555. end
  556. end
  557.  
  558. function tableexpr()
  559. local obrace = expect('Symbol', '{')
  560. local entries = {}
  561. local separators = {}
  562. while peek().Source ~= '}' do
  563. if peek().Source == '[' then
  564. -- Index
  565. local obrac = get()
  566. local index = expr()
  567. local cbrac = expect('Symbol', ']')
  568. local eq = expect('Symbol', '=')
  569. local value = expr()
  570. table.insert(entries, {
  571. EntryType = 'Index';
  572. Index = index;
  573. Value = value;
  574. Token_OpenBracket = obrac;
  575. Token_CloseBracket = cbrac;
  576. Token_Equals = eq;
  577. })
  578. elseif peek().Type == 'Ident' and peek(1).Source == '=' then
  579. -- Field
  580. local field = get()
  581. local eq = get()
  582. local value = expr()
  583. table.insert(entries, {
  584. EntryType = 'Field';
  585. Field = field;
  586. Value = value;
  587. Token_Equals = eq;
  588. })
  589. else
  590. -- Value
  591. local value = expr()
  592. table.insert(entries, {
  593. EntryType = 'Value';
  594. Value = value;
  595. })
  596. end
  597.  
  598. -- Comma or Semicolon separator
  599. if peek().Source == ',' or peek().Source == ';' then
  600. table.insert(separators, get())
  601. else
  602. break
  603. end
  604. end
  605. local cbrace = expect('Symbol', '}')
  606. return MkNode{
  607. Type = 'TableLiteral';
  608. EntryList = entries;
  609. Token_SeparatorList = separators;
  610. Token_OpenBrace = obrace;
  611. Token_CloseBrace = cbrace;
  612. GetFirstToken = function(self)
  613. return self.Token_OpenBrace
  614. end;
  615. GetLastToken = function(self)
  616. return self.Token_CloseBrace
  617. end;
  618. }
  619. end
  620.  
  621. -- List of identifiers
  622. local function varlist()
  623. local varList = {}
  624. local commaList = {}
  625. if peek().Type == 'Ident' then
  626. table.insert(varList, get())
  627. end
  628. while peek().Source == ',' do
  629. table.insert(commaList, get())
  630. local id = expect('Ident')
  631. table.insert(varList, id)
  632. end
  633. return varList, commaList
  634. end
  635.  
  636. -- Body
  637. local function blockbody(terminator)
  638. local body = block()
  639. local after = peek()
  640. if after.Type == 'Keyword' and after.Source == terminator then
  641. get()
  642. return body, after
  643. else
  644. print(after.Type, after.Source)
  645. error(getTokenStartPosition(after)..": "..terminator.." expected.")
  646. end
  647. end
  648.  
  649. -- Function declaration
  650. local function funcdecl(isAnonymous)
  651. local functionKw = get()
  652. --
  653. local nameChain;
  654. local nameChainSeparator;
  655. --
  656. if not isAnonymous then
  657. nameChain = {}
  658. nameChainSeparator = {}
  659. --
  660. table.insert(nameChain, expect('Ident'))
  661. --
  662. while peek().Source == '.' do
  663. table.insert(nameChainSeparator, get())
  664. table.insert(nameChain, expect('Ident'))
  665. end
  666. if peek().Source == ':' then
  667. table.insert(nameChainSeparator, get())
  668. table.insert(nameChain, expect('Ident'))
  669. end
  670. end
  671. --
  672. local oparenTk = expect('Symbol', '(')
  673. local argList, argCommaList = varlist()
  674. local cparenTk = expect('Symbol', ')')
  675. local fbody, enTk = blockbody('end')
  676. --
  677. return MkNode{
  678. Type = (isAnonymous and 'FunctionLiteral' or 'FunctionStat');
  679. NameChain = nameChain;
  680. ArgList = argList;
  681. Body = fbody;
  682. --
  683. Token_Function = functionKw;
  684. Token_NameChainSeparator = nameChainSeparator;
  685. Token_OpenParen = oparenTk;
  686. Token_ArgCommaList = argCommaList;
  687. Token_CloseParen = cparenTk;
  688. Token_End = enTk;
  689. GetFirstToken = function(self)
  690. return self.Token_Function
  691. end;
  692. GetLastToken = function(self)
  693. return self.Token_End;
  694. end;
  695. }
  696. end
  697.  
  698. -- Argument list passed to a funciton
  699. local function functionargs()
  700. local tk = peek()
  701. if tk.Source == '(' then
  702. local oparenTk = get()
  703. local argList = {}
  704. local argCommaList = {}
  705. while peek().Source ~= ')' do
  706. table.insert(argList, expr())
  707. if peek().Source == ',' then
  708. table.insert(argCommaList, get())
  709. else
  710. break
  711. end
  712. end
  713. local cparenTk = expect('Symbol', ')')
  714. return MkNode{
  715. CallType = 'ArgCall';
  716. ArgList = argList;
  717. --
  718. Token_CommaList = argCommaList;
  719. Token_OpenParen = oparenTk;
  720. Token_CloseParen = cparenTk;
  721. GetFirstToken = function(self)
  722. return self.Token_OpenParen
  723. end;
  724. GetLastToken = function(self)
  725. return self.Token_CloseParen
  726. end;
  727. }
  728. elseif tk.Source == '{' then
  729. return MkNode{
  730. CallType = 'TableCall';
  731. TableExpr = expr();
  732. GetFirstToken = function(self)
  733. return self.TableExpr:GetFirstToken()
  734. end;
  735. GetLastToken = function(self)
  736. return self.TableExpr:GetLastToken()
  737. end;
  738. }
  739. elseif tk.Type == 'String' then
  740. return MkNode{
  741. CallType = 'StringCall';
  742. Token = get();
  743. GetFirstToken = function(self)
  744. return self.Token
  745. end;
  746. GetLastToken = function(self)
  747. return self.Token
  748. end;
  749. }
  750. else
  751. error("Function arguments expected.")
  752. end
  753. end
  754.  
  755. local function primaryexpr()
  756. local base = prefixexpr()
  757. assert(base, "nil prefixexpr")
  758. while true do
  759. local tk = peek()
  760. if tk.Source == '.' then
  761. local dotTk = get()
  762. local fieldName = expect('Ident')
  763. base = MkNode{
  764. Type = 'FieldExpr';
  765. Base = base;
  766. Field = fieldName;
  767. Token_Dot = dotTk;
  768. GetFirstToken = function(self)
  769. return self.Base:GetFirstToken()
  770. end;
  771. GetLastToken = function(self)
  772. return self.Field
  773. end;
  774. }
  775. elseif tk.Source == ':' then
  776. local colonTk = get()
  777. local methodName = expect('Ident')
  778. local fargs = functionargs()
  779. base = MkNode{
  780. Type = 'MethodExpr';
  781. Base = base;
  782. Method = methodName;
  783. FunctionArguments = fargs;
  784. Token_Colon = colonTk;
  785. GetFirstToken = function(self)
  786. return self.Base:GetFirstToken()
  787. end;
  788. GetLastToken = function(self)
  789. return self.FunctionArguments:GetLastToken()
  790. end;
  791. }
  792. elseif tk.Source == '[' then
  793. local obrac = get()
  794. local index = expr()
  795. local cbrac = expect('Symbol', ']')
  796. base = MkNode{
  797. Type = 'IndexExpr';
  798. Base = base;
  799. Index = index;
  800. Token_OpenBracket = obrac;
  801. Token_CloseBracket = cbrac;
  802. GetFirstToken = function(self)
  803. return self.Base:GetFirstToken()
  804. end;
  805. GetLastToken = function(self)
  806. return self.Token_CloseBracket
  807. end;
  808. }
  809. elseif tk.Source == '{' then
  810. base = MkNode{
  811. Type = 'CallExpr';
  812. Base = base;
  813. FunctionArguments = functionargs();
  814. GetFirstToken = function(self)
  815. return self.Base:GetFirstToken()
  816. end;
  817. GetLastToken = function(self)
  818. return self.FunctionArguments:GetLastToken()
  819. end;
  820. }
  821. elseif tk.Source == '(' then
  822. base = MkNode{
  823. Type = 'CallExpr';
  824. Base = base;
  825. FunctionArguments = functionargs();
  826. GetFirstToken = function(self)
  827. return self.Base:GetFirstToken()
  828. end;
  829. GetLastToken = function(self)
  830. return self.FunctionArguments:GetLastToken()
  831. end;
  832. }
  833. else
  834. return base
  835. end
  836. end
  837. end
  838.  
  839. local function simpleexpr()
  840. local tk = peek()
  841. if tk.Type == 'Number' then
  842. return MkNode{
  843. Type = 'NumberLiteral';
  844. Token = get();
  845. GetFirstToken = function(self)
  846. return self.Token
  847. end;
  848. GetLastToken = function(self)
  849. return self.Token
  850. end;
  851. }
  852. elseif tk.Type == 'String' then
  853. return MkNode{
  854. Type = 'StringLiteral';
  855. Token = get();
  856. GetFirstToken = function(self)
  857. return self.Token
  858. end;
  859. GetLastToken = function(self)
  860. return self.Token
  861. end;
  862. }
  863. elseif tk.Source == 'nil' then
  864. return MkNode{
  865. Type = 'NilLiteral';
  866. Token = get();
  867. GetFirstToken = function(self)
  868. return self.Token
  869. end;
  870. GetLastToken = function(self)
  871. return self.Token
  872. end;
  873. }
  874. elseif tk.Source == 'true' or tk.Source == 'false' then
  875. return MkNode{
  876. Type = 'BooleanLiteral';
  877. Token = get();
  878. GetFirstToken = function(self)
  879. return self.Token
  880. end;
  881. GetLastToken = function(self)
  882. return self.Token
  883. end;
  884. }
  885. elseif tk.Source == '...' then
  886. return MkNode{
  887. Type = 'VargLiteral';
  888. Token = get();
  889. GetFirstToken = function(self)
  890. return self.Token
  891. end;
  892. GetLastToken = function(self)
  893. return self.Token
  894. end;
  895. }
  896. elseif tk.Source == '{' then
  897. return tableexpr()
  898. elseif tk.Source == 'function' then
  899. return funcdecl(true)
  900. else
  901. return primaryexpr()
  902. end
  903. end
  904.  
  905. local function subexpr(limit)
  906. local curNode;
  907.  
  908. -- Initial Base Expression
  909. if isUnop() then
  910. local opTk = get()
  911. local ex = subexpr(UnaryPriority)
  912. curNode = MkNode{
  913. Type = 'UnopExpr';
  914. Token_Op = opTk;
  915. Rhs = ex;
  916. GetFirstToken = function(self)
  917. return self.Token_Op
  918. end;
  919. GetLastToken = function(self)
  920. return self.Rhs:GetLastToken()
  921. end;
  922. }
  923. else
  924. curNode = simpleexpr()
  925. assert(curNode, "nil simpleexpr")
  926. end
  927.  
  928. -- Apply Precedence Recursion Chain
  929. while isBinop() and BinaryPriority[peek().Source][1] > limit do
  930. local opTk = get()
  931. local rhs = subexpr(BinaryPriority[opTk.Source][2])
  932. assert(rhs, "RhsNeeded")
  933. curNode = MkNode{
  934. Type = 'BinopExpr';
  935. Lhs = curNode;
  936. Rhs = rhs;
  937. Token_Op = opTk;
  938. GetFirstToken = function(self)
  939. return self.Lhs:GetFirstToken()
  940. end;
  941. GetLastToken = function(self)
  942. return self.Rhs:GetLastToken()
  943. end;
  944. }
  945. end
  946.  
  947. -- Return result
  948. return curNode
  949. end
  950.  
  951. -- Expression
  952. expr = function()
  953. return subexpr(0)
  954. end
  955.  
  956. -- Expression statement
  957. local function exprstat()
  958. local ex = primaryexpr()
  959. if ex.Type == 'MethodExpr' or ex.Type == 'CallExpr' then
  960. -- all good, calls can be statements
  961. return MkNode{
  962. Type = 'CallExprStat';
  963. Expression = ex;
  964. GetFirstToken = function(self)
  965. return self.Expression:GetFirstToken()
  966. end;
  967. GetLastToken = function(self)
  968. return self.Expression:GetLastToken()
  969. end;
  970. }
  971. else
  972. -- Assignment expr
  973. local lhs = {ex}
  974. local lhsSeparator = {}
  975. while peek().Source == ',' do
  976. table.insert(lhsSeparator, get())
  977. local lhsPart = primaryexpr()
  978. if lhsPart.Type == 'MethodExpr' or lhsPart.Type == 'CallExpr' then
  979. error("Bad left hand side of assignment")
  980. end
  981. table.insert(lhs, lhsPart)
  982. end
  983. local eq = expect('Symbol', '=')
  984. local rhs = {expr()}
  985. local rhsSeparator = {}
  986. while peek().Source == ',' do
  987. table.insert(rhsSeparator, get())
  988. table.insert(rhs, expr())
  989. end
  990. return MkNode{
  991. Type = 'AssignmentStat';
  992. Rhs = rhs;
  993. Lhs = lhs;
  994. Token_Equals = eq;
  995. Token_LhsSeparatorList = lhsSeparator;
  996. Token_RhsSeparatorList = rhsSeparator;
  997. GetFirstToken = function(self)
  998. return self.Lhs[1]:GetFirstToken()
  999. end;
  1000. GetLastToken = function(self)
  1001. return self.Rhs[#self.Rhs]:GetLastToken()
  1002. end;
  1003. }
  1004. end
  1005. end
  1006.  
  1007. -- If statement
  1008. local function ifstat()
  1009. local ifKw = get()
  1010. local condition = expr()
  1011. local thenKw = expect('Keyword', 'then')
  1012. local ifBody = block()
  1013. local elseClauses = {}
  1014. while peek().Source == 'elseif' or peek().Source == 'else' do
  1015. local elseifKw = get()
  1016. local elseifCondition, elseifThenKw;
  1017. if elseifKw.Source == 'elseif' then
  1018. elseifCondition = expr()
  1019. elseifThenKw = expect('Keyword', 'then')
  1020. end
  1021. local elseifBody = block()
  1022. table.insert(elseClauses, {
  1023. Condition = elseifCondition;
  1024. Body = elseifBody;
  1025. --
  1026. ClauseType = elseifKw.Source;
  1027. Token = elseifKw;
  1028. Token_Then = elseifThenKw;
  1029. })
  1030. if elseifKw.Source == 'else' then
  1031. break
  1032. end
  1033. end
  1034. local enKw = expect('Keyword', 'end')
  1035. return MkNode{
  1036. Type = 'IfStat';
  1037. Condition = condition;
  1038. Body = ifBody;
  1039. ElseClauseList = elseClauses;
  1040. --
  1041. Token_If = ifKw;
  1042. Token_Then = thenKw;
  1043. Token_End = enKw;
  1044. GetFirstToken = function(self)
  1045. return self.Token_If
  1046. end;
  1047. GetLastToken = function(self)
  1048. return self.Token_End
  1049. end;
  1050. }
  1051. end
  1052.  
  1053. -- Do statement
  1054. local function dostat()
  1055. local doKw = get()
  1056. local body, enKw = blockbody('end')
  1057. --
  1058. return MkNode{
  1059. Type = 'DoStat';
  1060. Body = body;
  1061. --
  1062. Token_Do = doKw;
  1063. Token_End = enKw;
  1064. GetFirstToken = function(self)
  1065. return self.Token_Do
  1066. end;
  1067. GetLastToken = function(self)
  1068. return self.Token_End
  1069. end;
  1070. }
  1071. end
  1072.  
  1073. -- While statement
  1074. local function whilestat()
  1075. local whileKw = get()
  1076. local condition = expr()
  1077. local doKw = expect('Keyword', 'do')
  1078. local body, enKw = blockbody('end')
  1079. --
  1080. return MkNode{
  1081. Type = 'WhileStat';
  1082. Condition = condition;
  1083. Body = body;
  1084. --
  1085. Token_While = whileKw;
  1086. Token_Do = doKw;
  1087. Token_End = enKw;
  1088. GetFirstToken = function(self)
  1089. return self.Token_While
  1090. end;
  1091. GetLastToken = function(self)
  1092. return self.Token_End
  1093. end;
  1094. }
  1095. end
  1096.  
  1097. -- For statement
  1098. local function forstat()
  1099. local forKw = get()
  1100. local loopVars, loopVarCommas = varlist()
  1101. local node = {}
  1102. if peek().Source == '=' then
  1103. local eqTk = get()
  1104. local exprList, exprCommaList = exprlist()
  1105. if #exprList < 2 or #exprList > 3 then
  1106. error("expected 2 or 3 values for range bounds")
  1107. end
  1108. local doTk = expect('Keyword', 'do')
  1109. local body, enTk = blockbody('end')
  1110. return MkNode{
  1111. Type = 'NumericForStat';
  1112. VarList = loopVars;
  1113. RangeList = exprList;
  1114. Body = body;
  1115. --
  1116. Token_For = forKw;
  1117. Token_VarCommaList = loopVarCommas;
  1118. Token_Equals = eqTk;
  1119. Token_RangeCommaList = exprCommaList;
  1120. Token_Do = doTk;
  1121. Token_End = enTk;
  1122. GetFirstToken = function(self)
  1123. return self.Token_For
  1124. end;
  1125. GetLastToken = function(self)
  1126. return self.Token_End
  1127. end;
  1128. }
  1129. elseif peek().Source == 'in' then
  1130. local inTk = get()
  1131. local exprList, exprCommaList = exprlist()
  1132. local doTk = expect('Keyword', 'do')
  1133. local body, enTk = blockbody('end')
  1134. return MkNode{
  1135. Type = 'GenericForStat';
  1136. VarList = loopVars;
  1137. GeneratorList = exprList;
  1138. Body = body;
  1139. --
  1140. Token_For = forKw;
  1141. Token_VarCommaList = loopVarCommas;
  1142. Token_In = inTk;
  1143. Token_GeneratorCommaList = exprCommaList;
  1144. Token_Do = doTk;
  1145. Token_End = enTk;
  1146. GetFirstToken = function(self)
  1147. return self.Token_For
  1148. end;
  1149. GetLastToken = function(self)
  1150. return self.Token_End
  1151. end;
  1152. }
  1153. else
  1154. error("`=` or in expected")
  1155. end
  1156. end
  1157.  
  1158. -- Repeat statement
  1159. local function repeatstat()
  1160. local repeatKw = get()
  1161. local body, untilTk = blockbody('until')
  1162. local condition = expr()
  1163. return MkNode{
  1164. Type = 'RepeatStat';
  1165. Body = body;
  1166. Condition = condition;
  1167. --
  1168. Token_Repeat = repeatKw;
  1169. Token_Until = untilTk;
  1170. GetFirstToken = function(self)
  1171. return self.Token_Repeat
  1172. end;
  1173. GetLastToken = function(self)
  1174. return self.Condition:GetLastToken()
  1175. end;
  1176. }
  1177. end
  1178.  
  1179. -- Local var declaration
  1180. local function localdecl()
  1181. local localKw = get()
  1182. if peek().Source == 'function' then
  1183. -- Local function def
  1184. local funcStat = funcdecl(false)
  1185. if #funcStat.NameChain > 1 then
  1186. error(getTokenStartPosition(funcStat.Token_NameChainSeparator[1])..": `(` expected.")
  1187. end
  1188. return MkNode{
  1189. Type = 'LocalFunctionStat';
  1190. FunctionStat = funcStat;
  1191. Token_Local = localKw;
  1192. GetFirstToken = function(self)
  1193. return self.Token_Local
  1194. end;
  1195. GetLastToken = function(self)
  1196. return self.FunctionStat:GetLastToken()
  1197. end;
  1198. }
  1199. elseif peek().Type == 'Ident' then
  1200. -- Local variable declaration
  1201. local varList, varCommaList = varlist()
  1202. local exprList, exprCommaList = {}, {}
  1203. local eqToken;
  1204. if peek().Source == '=' then
  1205. eqToken = get()
  1206. exprList, exprCommaList = exprlist()
  1207. end
  1208. return MkNode{
  1209. Type = 'LocalVarStat';
  1210. VarList = varList;
  1211. ExprList = exprList;
  1212. Token_Local = localKw;
  1213. Token_Equals = eqToken;
  1214. Token_VarCommaList = varCommaList;
  1215. Token_ExprCommaList = exprCommaList;
  1216. GetFirstToken = function(self)
  1217. return self.Token_Local
  1218. end;
  1219. GetLastToken = function(self)
  1220. if #self.ExprList > 0 then
  1221. return self.ExprList[#self.ExprList]:GetLastToken()
  1222. else
  1223. return self.VarList[#self.VarList]
  1224. end
  1225. end;
  1226. }
  1227. else
  1228. error("`function` or ident expected")
  1229. end
  1230. end
  1231.  
  1232. -- Return statement
  1233. local function retstat()
  1234. local returnKw = get()
  1235. local exprList;
  1236. local commaList;
  1237. if isBlockFollow() or peek().Source == ';' then
  1238. exprList = {}
  1239. commaList = {}
  1240. else
  1241. exprList, commaList = exprlist()
  1242. end
  1243. return {
  1244. Type = 'ReturnStat';
  1245. ExprList = exprList;
  1246. Token_Return = returnKw;
  1247. Token_CommaList = commaList;
  1248. GetFirstToken = function(self)
  1249. return self.Token_Return
  1250. end;
  1251. GetLastToken = function(self)
  1252. if #self.ExprList > 0 then
  1253. return self.ExprList[#self.ExprList]:GetLastToken()
  1254. else
  1255. return self.Token_Return
  1256. end
  1257. end;
  1258. }
  1259. end
  1260.  
  1261. -- Break statement
  1262. local function breakstat()
  1263. local breakKw = get()
  1264. return {
  1265. Type = 'BreakStat';
  1266. Token_Break = breakKw;
  1267. GetFirstToken = function(self)
  1268. return self.Token_Break
  1269. end;
  1270. GetLastToken = function(self)
  1271. return self.Token_Break
  1272. end;
  1273. }
  1274. end
  1275.  
  1276. -- Expression
  1277. local function statement()
  1278. local tok = peek()
  1279. if tok.Source == 'if' then
  1280. return false, ifstat()
  1281. elseif tok.Source == 'while' then
  1282. return false, whilestat()
  1283. elseif tok.Source == 'do' then
  1284. return false, dostat()
  1285. elseif tok.Source == 'for' then
  1286. return false, forstat()
  1287. elseif tok.Source == 'repeat' then
  1288. return false, repeatstat()
  1289. elseif tok.Source == 'function' then
  1290. return false, funcdecl(false)
  1291. elseif tok.Source == 'local' then
  1292. return false, localdecl()
  1293. elseif tok.Source == 'return' then
  1294. return true, retstat()
  1295. elseif tok.Source == 'break' then
  1296. return true, breakstat()
  1297. else
  1298. return false, exprstat()
  1299. end
  1300. end
  1301.  
  1302. -- Chunk
  1303. block = function()
  1304. local statements = {}
  1305. local semicolons = {}
  1306. local isLast = false
  1307. while not isLast and not isBlockFollow() do
  1308. -- Parse statement
  1309. local stat;
  1310. isLast, stat = statement()
  1311. table.insert(statements, stat)
  1312. local next = peek()
  1313. if next.Type == 'Symbol' and next.Source == ';' then
  1314. semicolons[#statements] = get()
  1315. end
  1316. end
  1317. return {
  1318. Type = 'StatList';
  1319. StatementList = statements;
  1320. SemicolonList = semicolons;
  1321. GetFirstToken = function(self)
  1322. if #self.StatementList == 0 then
  1323. return nil
  1324. else
  1325. return self.StatementList[1]:GetFirstToken()
  1326. end
  1327. end;
  1328. GetLastToken = function(self)
  1329. if #self.StatementList == 0 then
  1330. return nil
  1331. elseif self.SemicolonList[#self.StatementList] then
  1332. -- Last token may be one of the semicolon separators
  1333. return self.SemicolonList[#self.StatementList]
  1334. else
  1335. return self.StatementList[#self.StatementList]:GetLastToken()
  1336. end
  1337. end;
  1338. }
  1339. end
  1340.  
  1341. return block()
  1342. end
  1343.  
  1344. function VisitAst(ast, visitors)
  1345. local ExprType = lookupify{
  1346. 'BinopExpr'; 'UnopExpr';
  1347. 'NumberLiteral'; 'StringLiteral'; 'NilLiteral'; 'BooleanLiteral'; 'VargLiteral';
  1348. 'FieldExpr'; 'IndexExpr';
  1349. 'MethodExpr'; 'CallExpr';
  1350. 'FunctionLiteral';
  1351. 'VariableExpr';
  1352. 'ParenExpr';
  1353. 'TableLiteral';
  1354. }
  1355.  
  1356. local StatType = lookupify{
  1357. 'StatList';
  1358. 'BreakStat';
  1359. 'ReturnStat';
  1360. 'LocalVarStat';
  1361. 'LocalFunctionStat';
  1362. 'FunctionStat';
  1363. 'RepeatStat';
  1364. 'GenericForStat';
  1365. 'NumericForStat';
  1366. 'WhileStat';
  1367. 'DoStat';
  1368. 'IfStat';
  1369. 'CallExprStat';
  1370. 'AssignmentStat';
  1371. }
  1372.  
  1373. -- Check for typos in visitor construction
  1374. for visitorSubject, visitor in pairs(visitors) do
  1375. if not StatType[visitorSubject] and not ExprType[visitorSubject] then
  1376. error("Invalid visitor target: `"..visitorSubject.."`")
  1377. end
  1378. end
  1379.  
  1380. -- Helpers to call visitors on a node
  1381. local function preVisit(exprOrStat)
  1382. local visitor = visitors[exprOrStat.Type]
  1383. if type(visitor) == 'function' then
  1384. return visitor(exprOrStat)
  1385. elseif visitor and visitor.Pre then
  1386. return visitor.Pre(exprOrStat)
  1387. end
  1388. end
  1389. local function postVisit(exprOrStat)
  1390. local visitor = visitors[exprOrStat.Type]
  1391. if visitor and type(visitor) == 'table' and visitor.Post then
  1392. return visitor.Post(exprOrStat)
  1393. end
  1394. end
  1395.  
  1396. local visitExpr, visitStat;
  1397.  
  1398. visitExpr = function(expr)
  1399. if preVisit(expr) then
  1400. -- Handler did custom child iteration or blocked child iteration
  1401. return
  1402. end
  1403. if expr.Type == 'BinopExpr' then
  1404. visitExpr(expr.Lhs)
  1405. visitExpr(expr.Rhs)
  1406. elseif expr.Type == 'UnopExpr' then
  1407. visitExpr(expr.Rhs)
  1408. elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
  1409. expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
  1410. expr.Type == 'VargLiteral'
  1411. then
  1412. -- No children to visit, single token literals
  1413. elseif expr.Type == 'FieldExpr' then
  1414. visitExpr(expr.Base)
  1415. elseif expr.Type == 'IndexExpr' then
  1416. visitExpr(expr.Base)
  1417. visitExpr(expr.Index)
  1418. elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
  1419. visitExpr(expr.Base)
  1420. if expr.FunctionArguments.CallType == 'ArgCall' then
  1421. for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
  1422. visitExpr(argExpr)
  1423. end
  1424. elseif expr.FunctionArguments.CallType == 'TableCall' then
  1425. visitExpr(expr.FunctionArguments.TableExpr)
  1426. end
  1427. elseif expr.Type == 'FunctionLiteral' then
  1428. visitStat(expr.Body)
  1429. elseif expr.Type == 'VariableExpr' then
  1430. -- No children to visit
  1431. elseif expr.Type == 'ParenExpr' then
  1432. visitExpr(expr.Expression)
  1433. elseif expr.Type == 'TableLiteral' then
  1434. for index, entry in pairs(expr.EntryList) do
  1435. if entry.EntryType == 'Field' then
  1436. visitExpr(entry.Value)
  1437. elseif entry.EntryType == 'Index' then
  1438. visitExpr(entry.Index)
  1439. visitExpr(entry.Value)
  1440. elseif entry.EntryType == 'Value' then
  1441. visitExpr(entry.Value)
  1442. else
  1443. assert(false, "unreachable")
  1444. end
  1445. end
  1446. else
  1447. assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
  1448. end
  1449. postVisit(expr)
  1450. end
  1451.  
  1452. visitStat = function(stat)
  1453. if preVisit(stat) then
  1454. -- Handler did custom child iteration or blocked child iteration
  1455. return
  1456. end
  1457. if stat.Type == 'StatList' then
  1458. for index, ch in pairs(stat.StatementList) do
  1459. visitStat(ch)
  1460. end
  1461. elseif stat.Type == 'BreakStat' then
  1462. -- No children to visit
  1463. elseif stat.Type == 'ReturnStat' then
  1464. for index, expr in pairs(stat.ExprList) do
  1465. visitExpr(expr)
  1466. end
  1467. elseif stat.Type == 'LocalVarStat' then
  1468. if stat.Token_Equals then
  1469. for index, expr in pairs(stat.ExprList) do
  1470. visitExpr(expr)
  1471. end
  1472. end
  1473. elseif stat.Type == 'LocalFunctionStat' then
  1474. visitStat(stat.FunctionStat.Body)
  1475. elseif stat.Type == 'FunctionStat' then
  1476. visitStat(stat.Body)
  1477. elseif stat.Type == 'RepeatStat' then
  1478. visitStat(stat.Body)
  1479. visitExpr(stat.Condition)
  1480. elseif stat.Type == 'GenericForStat' then
  1481. for index, expr in pairs(stat.GeneratorList) do
  1482. visitExpr(expr)
  1483. end
  1484. visitStat(stat.Body)
  1485. elseif stat.Type == 'NumericForStat' then
  1486. for index, expr in pairs(stat.RangeList) do
  1487. visitExpr(expr)
  1488. end
  1489. visitStat(stat.Body)
  1490. elseif stat.Type == 'WhileStat' then
  1491. visitExpr(stat.Condition)
  1492. visitStat(stat.Body)
  1493. elseif stat.Type == 'DoStat' then
  1494. visitStat(stat.Body)
  1495. elseif stat.Type == 'IfStat' then
  1496. visitExpr(stat.Condition)
  1497. visitStat(stat.Body)
  1498. for _, clause in pairs(stat.ElseClauseList) do
  1499. if clause.Condition then
  1500. visitExpr(clause.Condition)
  1501. end
  1502. visitStat(clause.Body)
  1503. end
  1504. elseif stat.Type == 'CallExprStat' then
  1505. visitExpr(stat.Expression)
  1506. elseif stat.Type == 'AssignmentStat' then
  1507. for index, ex in pairs(stat.Lhs) do
  1508. visitExpr(ex)
  1509. end
  1510. for index, ex in pairs(stat.Rhs) do
  1511. visitExpr(ex)
  1512. end
  1513. else
  1514. assert(false, "unreachable")
  1515. end
  1516. postVisit(stat)
  1517. end
  1518.  
  1519. if StatType[ast.Type] then
  1520. visitStat(ast)
  1521. else
  1522. visitExpr(ast)
  1523. end
  1524. end
  1525.  
  1526. function AddVariableInfo(ast)
  1527. local globalVars = {}
  1528. local currentScope = nil
  1529.  
  1530. -- Numbering generator for variable lifetimes
  1531. local locationGenerator = 0
  1532. local function markLocation()
  1533. locationGenerator = locationGenerator + 1
  1534. return locationGenerator
  1535. end
  1536.  
  1537. -- Scope management
  1538. local function pushScope()
  1539. currentScope = {
  1540. ParentScope = currentScope;
  1541. ChildScopeList = {};
  1542. VariableList = {};
  1543. BeginLocation = markLocation();
  1544. }
  1545. if currentScope.ParentScope then
  1546. currentScope.Depth = currentScope.ParentScope.Depth + 1
  1547. table.insert(currentScope.ParentScope.ChildScopeList, currentScope)
  1548. else
  1549. currentScope.Depth = 1
  1550. end
  1551. function currentScope:GetVar(varName)
  1552. for _, var in pairs(self.VariableList) do
  1553. if var.Name == varName then
  1554. return var
  1555. end
  1556. end
  1557. if self.ParentScope then
  1558. return self.ParentScope:GetVar(varName)
  1559. else
  1560. for _, var in pairs(globalVars) do
  1561. if var.Name == varName then
  1562. return var
  1563. end
  1564. end
  1565. end
  1566. end
  1567. end
  1568. local function popScope()
  1569. local scope = currentScope
  1570.  
  1571. -- Mark where this scope ends
  1572. scope.EndLocation = markLocation()
  1573.  
  1574. -- Mark all of the variables in the scope as ending there
  1575. for _, var in pairs(scope.VariableList) do
  1576. var.ScopeEndLocation = scope.EndLocation
  1577. end
  1578.  
  1579. -- Move to the parent scope
  1580. currentScope = scope.ParentScope
  1581.  
  1582. return scope
  1583. end
  1584. pushScope() -- push initial scope
  1585.  
  1586. -- Add / reference variables
  1587. local function addLocalVar(name, setNameFunc, localInfo)
  1588. assert(localInfo, "Misisng localInfo")
  1589. assert(name, "Missing local var name")
  1590. local var = {
  1591. Type = 'Local';
  1592. Name = name;
  1593. RenameList = {setNameFunc};
  1594. AssignedTo = false;
  1595. Info = localInfo;
  1596. UseCount = 0;
  1597. Scope = currentScope;
  1598. BeginLocation = markLocation();
  1599. EndLocation = markLocation();
  1600. ReferenceLocationList = {markLocation()};
  1601. }
  1602. function var:Rename(newName)
  1603. self.Name = newName
  1604. for _, renameFunc in pairs(self.RenameList) do
  1605. renameFunc(newName)
  1606. end
  1607. end
  1608. function var:Reference()
  1609. self.UseCount = self.UseCount + 1
  1610. end
  1611. table.insert(currentScope.VariableList, var)
  1612. return var
  1613. end
  1614. local function getGlobalVar(name)
  1615. for _, var in pairs(globalVars) do
  1616. if var.Name == name then
  1617. return var
  1618. end
  1619. end
  1620. local var = {
  1621. Type = 'Global';
  1622. Name = name;
  1623. RenameList = {};
  1624. AssignedTo = false;
  1625. UseCount = 0;
  1626. Scope = nil; -- Globals have no scope
  1627. BeginLocation = markLocation();
  1628. EndLocation = markLocation();
  1629. ReferenceLocationList = {};
  1630. }
  1631. function var:Rename(newName)
  1632. self.Name = newName
  1633. for _, renameFunc in pairs(self.RenameList) do
  1634. renameFunc(newName)
  1635. end
  1636. end
  1637. function var:Reference()
  1638. self.UseCount = self.UseCount + 1
  1639. end
  1640. table.insert(globalVars, var)
  1641. return var
  1642. end
  1643. local function addGlobalReference(name, setNameFunc)
  1644. assert(name, "Missing var name")
  1645. local var = getGlobalVar(name)
  1646. table.insert(var.RenameList, setNameFunc)
  1647. return var
  1648. end
  1649. local function getLocalVar(scope, name)
  1650. -- First search this scope
  1651. -- Note: Reverse iterate here because Lua does allow shadowing a local
  1652. -- within the same scope, and the later defined variable should
  1653. -- be the one referenced.
  1654. for i = #scope.VariableList, 1, -1 do
  1655. if scope.VariableList[i].Name == name then
  1656. return scope.VariableList[i]
  1657. end
  1658. end
  1659.  
  1660. -- Then search parent scope
  1661. if scope.ParentScope then
  1662. local var = getLocalVar(scope.ParentScope, name)
  1663. if var then
  1664. return var
  1665. end
  1666. end
  1667.  
  1668. -- Then
  1669. return nil
  1670. end
  1671. local function referenceVariable(name, setNameFunc)
  1672. assert(name, "Missing var name")
  1673. local var = getLocalVar(currentScope, name)
  1674. if var then
  1675. table.insert(var.RenameList, setNameFunc)
  1676. else
  1677. var = addGlobalReference(name, setNameFunc)
  1678. end
  1679. -- Update the end location of where this variable is used, and
  1680. -- add this location to the list of references to this variable.
  1681. local curLocation = markLocation()
  1682. var.EndLocation = curLocation
  1683. table.insert(var.ReferenceLocationList, var.EndLocation)
  1684. return var
  1685. end
  1686.  
  1687. local visitor = {}
  1688. visitor.FunctionLiteral = {
  1689. -- Function literal adds a new scope and adds the function literal arguments
  1690. -- as local variables in the scope.
  1691. Pre = function(expr)
  1692. pushScope()
  1693. for index, ident in pairs(expr.ArgList) do
  1694. local var = addLocalVar(ident.Source, function(name)
  1695. ident.Source = name
  1696. end, {
  1697. Type = 'Argument';
  1698. Index = index;
  1699. })
  1700. end
  1701. end;
  1702. Post = function(expr)
  1703. popScope()
  1704. end;
  1705. }
  1706. visitor.VariableExpr = function(expr)
  1707. -- Variable expression references from existing local varibales
  1708. -- in the current scope, annotating the variable usage with variable
  1709. -- information.
  1710. expr.Variable = referenceVariable(expr.Token.Source, function(newName)
  1711. expr.Token.Source = newName
  1712. end)
  1713. end
  1714. visitor.StatList = {
  1715. -- StatList adds a new scope
  1716. Pre = function(stat)
  1717. pushScope()
  1718. end;
  1719. Post = function(stat)
  1720. popScope()
  1721. end;
  1722. }
  1723. visitor.LocalVarStat = {
  1724. Post = function(stat)
  1725. -- Local var stat adds the local variables to the current scope as locals
  1726. -- We need to visit the subexpressions first, because these new locals
  1727. -- will not be in scope for the initialization value expressions. That is:
  1728. -- `local bar = bar + 1`
  1729. -- Is valid code
  1730. for varNum, ident in pairs(stat.VarList) do
  1731. addLocalVar(ident.Source, function(name)
  1732. stat.VarList[varNum].Source = name
  1733. end, {
  1734. Type = 'Local';
  1735. })
  1736. end
  1737. end;
  1738. }
  1739. visitor.LocalFunctionStat = {
  1740. Pre = function(stat)
  1741. -- Local function stat adds the function itself to the current scope as
  1742. -- a local variable, and creates a new scope with the function arguments
  1743. -- as local variables.
  1744. addLocalVar(stat.FunctionStat.NameChain[1].Source, function(name)
  1745. stat.FunctionStat.NameChain[1].Source = name
  1746. end, {
  1747. Type = 'LocalFunction';
  1748. })
  1749. pushScope()
  1750. for index, ident in pairs(stat.FunctionStat.ArgList) do
  1751. addLocalVar(ident.Source, function(name)
  1752. ident.Source = name
  1753. end, {
  1754. Type = 'Argument';
  1755. Index = index;
  1756. })
  1757. end
  1758. end;
  1759. Post = function()
  1760. popScope()
  1761. end;
  1762. }
  1763. visitor.FunctionStat = {
  1764. Pre = function(stat)
  1765. -- Function stat adds a new scope containing the function arguments
  1766. -- as local variables.
  1767. -- A function stat may also assign to a global variable if it is in
  1768. -- the form `function foo()` with no additional dots/colons in the
  1769. -- name chain.
  1770. local nameChain = stat.NameChain
  1771. local var;
  1772. if #nameChain == 1 then
  1773. -- If there is only one item in the name chain, then the first item
  1774. -- is a reference to a global variable.
  1775. var = addGlobalReference(nameChain[1].Source, function(name)
  1776. nameChain[1].Source = name
  1777. end)
  1778. else
  1779. var = referenceVariable(nameChain[1].Source, function(name)
  1780. nameChain[1].Source = name
  1781. end)
  1782. end
  1783. var.AssignedTo = true
  1784. pushScope()
  1785. for index, ident in pairs(stat.ArgList) do
  1786. addLocalVar(ident.Source, function(name)
  1787. ident.Source = name
  1788. end, {
  1789. Type = 'Argument';
  1790. Index = index;
  1791. })
  1792. end
  1793. end;
  1794. Post = function()
  1795. popScope()
  1796. end;
  1797. }
  1798. visitor.GenericForStat = {
  1799. Pre = function(stat)
  1800. -- Generic fors need an extra scope holding the range variables
  1801. -- Need a custom visitor so that the generator expressions can be
  1802. -- visited before we push a scope, but the body can be visited
  1803. -- after we push a scope.
  1804. for _, ex in pairs(stat.GeneratorList) do
  1805. VisitAst(ex, visitor)
  1806. end
  1807. pushScope()
  1808. for index, ident in pairs(stat.VarList) do
  1809. addLocalVar(ident.Source, function(name)
  1810. ident.Source = name
  1811. end, {
  1812. Type = 'ForRange';
  1813. Index = index;
  1814. })
  1815. end
  1816. VisitAst(stat.Body, visitor)
  1817. popScope()
  1818. return true -- Custom visit
  1819. end;
  1820. }
  1821. visitor.NumericForStat = {
  1822. Pre = function(stat)
  1823. -- Numeric fors need an extra scope holding the range variables
  1824. -- Need a custom visitor so that the generator expressions can be
  1825. -- visited before we push a scope, but the body can be visited
  1826. -- after we push a scope.
  1827. for _, ex in pairs(stat.RangeList) do
  1828. VisitAst(ex, visitor)
  1829. end
  1830. pushScope()
  1831. for index, ident in pairs(stat.VarList) do
  1832. addLocalVar(ident.Source, function(name)
  1833. ident.Source = name
  1834. end, {
  1835. Type = 'ForRange';
  1836. Index = index;
  1837. })
  1838. end
  1839. VisitAst(stat.Body, visitor)
  1840. popScope()
  1841. return true -- Custom visit
  1842. end;
  1843. }
  1844. visitor.AssignmentStat = {
  1845. Post = function(stat)
  1846. -- For an assignment statement we need to mark the
  1847. -- "assigned to" flag on variables.
  1848. for _, ex in pairs(stat.Lhs) do
  1849. if ex.Variable then
  1850. ex.Variable.AssignedTo = true
  1851. end
  1852. end
  1853. end;
  1854. }
  1855.  
  1856. VisitAst(ast, visitor)
  1857.  
  1858. return globalVars, popScope()
  1859. end
  1860.  
  1861. -- Prints out an AST to a string
  1862. function PrintAst(ast)
  1863.  
  1864. local printStat, printExpr;
  1865.  
  1866. local function printt(tk)
  1867. if not tk.LeadingWhite or not tk.Source then
  1868. error("Bad token: "..FormatTable(tk))
  1869. end
  1870. io.write(tk.LeadingWhite)
  1871. io.write(tk.Source)
  1872. end
  1873.  
  1874. printExpr = function(expr)
  1875. if expr.Type == 'BinopExpr' then
  1876. printExpr(expr.Lhs)
  1877. printt(expr.Token_Op)
  1878. printExpr(expr.Rhs)
  1879. elseif expr.Type == 'UnopExpr' then
  1880. printt(expr.Token_Op)
  1881. printExpr(expr.Rhs)
  1882. elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
  1883. expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
  1884. expr.Type == 'VargLiteral'
  1885. then
  1886. -- Just print the token
  1887. printt(expr.Token)
  1888. elseif expr.Type == 'FieldExpr' then
  1889. printExpr(expr.Base)
  1890. printt(expr.Token_Dot)
  1891. printt(expr.Field)
  1892. elseif expr.Type == 'IndexExpr' then
  1893. printExpr(expr.Base)
  1894. printt(expr.Token_OpenBracket)
  1895. printExpr(expr.Index)
  1896. printt(expr.Token_CloseBracket)
  1897. elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
  1898. printExpr(expr.Base)
  1899. if expr.Type == 'MethodExpr' then
  1900. printt(expr.Token_Colon)
  1901. printt(expr.Method)
  1902. end
  1903. if expr.FunctionArguments.CallType == 'StringCall' then
  1904. printt(expr.FunctionArguments.Token)
  1905. elseif expr.FunctionArguments.CallType == 'ArgCall' then
  1906. printt(expr.FunctionArguments.Token_OpenParen)
  1907. for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
  1908. printExpr(argExpr)
  1909. local sep = expr.FunctionArguments.Token_CommaList[index]
  1910. if sep then
  1911. printt(sep)
  1912. end
  1913. end
  1914. printt(expr.FunctionArguments.Token_CloseParen)
  1915. elseif expr.FunctionArguments.CallType == 'TableCall' then
  1916. printExpr(expr.FunctionArguments.TableExpr)
  1917. end
  1918. elseif expr.Type == 'FunctionLiteral' then
  1919. printt(expr.Token_Function)
  1920. printt(expr.Token_OpenParen)
  1921. for index, arg in pairs(expr.ArgList) do
  1922. printt(arg)
  1923. local comma = expr.Token_ArgCommaList[index]
  1924. if comma then
  1925. printt(comma)
  1926. end
  1927. end
  1928. printt(expr.Token_CloseParen)
  1929. printStat(expr.Body)
  1930. printt(expr.Token_End)
  1931. elseif expr.Type == 'VariableExpr' then
  1932. printt(expr.Token)
  1933. elseif expr.Type == 'ParenExpr' then
  1934. printt(expr.Token_OpenParen)
  1935. printExpr(expr.Expression)
  1936. printt(expr.Token_CloseParen)
  1937. elseif expr.Type == 'TableLiteral' then
  1938. printt(expr.Token_OpenBrace)
  1939. for index, entry in pairs(expr.EntryList) do
  1940. if entry.EntryType == 'Field' then
  1941. printt(entry.Field)
  1942. printt(entry.Token_Equals)
  1943. printExpr(entry.Value)
  1944. elseif entry.EntryType == 'Index' then
  1945. printt(entry.Token_OpenBracket)
  1946. printExpr(entry.Index)
  1947. printt(entry.Token_CloseBracket)
  1948. printt(entry.Token_Equals)
  1949. printExpr(entry.Value)
  1950. elseif entry.EntryType == 'Value' then
  1951. printExpr(entry.Value)
  1952. else
  1953. assert(false, "unreachable")
  1954. end
  1955. local sep = expr.Token_SeparatorList[index]
  1956. if sep then
  1957. printt(sep)
  1958. end
  1959. end
  1960. printt(expr.Token_CloseBrace)
  1961. else
  1962. assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
  1963. end
  1964. end
  1965.  
  1966. printStat = function(stat)
  1967. if stat.Type == 'StatList' then
  1968. for index, ch in pairs(stat.StatementList) do
  1969. printStat(ch)
  1970. if stat.SemicolonList[index] then
  1971. printt(stat.SemicolonList[index])
  1972. end
  1973. end
  1974. elseif stat.Type == 'BreakStat' then
  1975. printt(stat.Token_Break)
  1976. elseif stat.Type == 'ReturnStat' then
  1977. printt(stat.Token_Return)
  1978. for index, expr in pairs(stat.ExprList) do
  1979. printExpr(expr)
  1980. if stat.Token_CommaList[index] then
  1981. printt(stat.Token_CommaList[index])
  1982. end
  1983. end
  1984. elseif stat.Type == 'LocalVarStat' then
  1985. printt(stat.Token_Local)
  1986. for index, var in pairs(stat.VarList) do
  1987. printt(var)
  1988. local comma = stat.Token_VarCommaList[index]
  1989. if comma then
  1990. printt(comma)
  1991. end
  1992. end
  1993. if stat.Token_Equals then
  1994. printt(stat.Token_Equals)
  1995. for index, expr in pairs(stat.ExprList) do
  1996. printExpr(expr)
  1997. local comma = stat.Token_ExprCommaList[index]
  1998. if comma then
  1999. printt(comma)
  2000. end
  2001. end
  2002. end
  2003. elseif stat.Type == 'LocalFunctionStat' then
  2004. printt(stat.Token_Local)
  2005. printt(stat.FunctionStat.Token_Function)
  2006. printt(stat.FunctionStat.NameChain[1])
  2007. printt(stat.FunctionStat.Token_OpenParen)
  2008. for index, arg in pairs(stat.FunctionStat.ArgList) do
  2009. printt(arg)
  2010. local comma = stat.FunctionStat.Token_ArgCommaList[index]
  2011. if comma then
  2012. printt(comma)
  2013. end
  2014. end
  2015. printt(stat.FunctionStat.Token_CloseParen)
  2016. printStat(stat.FunctionStat.Body)
  2017. printt(stat.FunctionStat.Token_End)
  2018. elseif stat.Type == 'FunctionStat' then
  2019. printt(stat.Token_Function)
  2020. for index, part in pairs(stat.NameChain) do
  2021. printt(part)
  2022. local sep = stat.Token_NameChainSeparator[index]
  2023. if sep then
  2024. printt(sep)
  2025. end
  2026. end
  2027. printt(stat.Token_OpenParen)
  2028. for index, arg in pairs(stat.ArgList) do
  2029. printt(arg)
  2030. local comma = stat.Token_ArgCommaList[index]
  2031. if comma then
  2032. printt(comma)
  2033. end
  2034. end
  2035. printt(stat.Token_CloseParen)
  2036. printStat(stat.Body)
  2037. printt(stat.Token_End)
  2038. elseif stat.Type == 'RepeatStat' then
  2039. printt(stat.Token_Repeat)
  2040. printStat(stat.Body)
  2041. printt(stat.Token_Until)
  2042. printExpr(stat.Condition)
  2043. elseif stat.Type == 'GenericForStat' then
  2044. printt(stat.Token_For)
  2045. for index, var in pairs(stat.VarList) do
  2046. printt(var)
  2047. local sep = stat.Token_VarCommaList[index]
  2048. if sep then
  2049. printt(sep)
  2050. end
  2051. end
  2052. printt(stat.Token_In)
  2053. for index, expr in pairs(stat.GeneratorList) do
  2054. printExpr(expr)
  2055. local sep = stat.Token_GeneratorCommaList[index]
  2056. if sep then
  2057. printt(sep)
  2058. end
  2059. end
  2060. printt(stat.Token_Do)
  2061. printStat(stat.Body)
  2062. printt(stat.Token_End)
  2063. elseif stat.Type == 'NumericForStat' then
  2064. printt(stat.Token_For)
  2065. for index, var in pairs(stat.VarList) do
  2066. printt(var)
  2067. local sep = stat.Token_VarCommaList[index]
  2068. if sep then
  2069. printt(sep)
  2070. end
  2071. end
  2072. printt(stat.Token_Equals)
  2073. for index, expr in pairs(stat.RangeList) do
  2074. printExpr(expr)
  2075. local sep = stat.Token_RangeCommaList[index]
  2076. if sep then
  2077. printt(sep)
  2078. end
  2079. end
  2080. printt(stat.Token_Do)
  2081. printStat(stat.Body)
  2082. printt(stat.Token_End)
  2083. elseif stat.Type == 'WhileStat' then
  2084. printt(stat.Token_While)
  2085. printExpr(stat.Condition)
  2086. printt(stat.Token_Do)
  2087. printStat(stat.Body)
  2088. printt(stat.Token_End)
  2089. elseif stat.Type == 'DoStat' then
  2090. printt(stat.Token_Do)
  2091. printStat(stat.Body)
  2092. printt(stat.Token_End)
  2093. elseif stat.Type == 'IfStat' then
  2094. printt(stat.Token_If)
  2095. printExpr(stat.Condition)
  2096. printt(stat.Token_Then)
  2097. printStat(stat.Body)
  2098. for _, clause in pairs(stat.ElseClauseList) do
  2099. printt(clause.Token)
  2100. if clause.Condition then
  2101. printExpr(clause.Condition)
  2102. printt(clause.Token_Then)
  2103. end
  2104. printStat(clause.Body)
  2105. end
  2106. printt(stat.Token_End)
  2107. elseif stat.Type == 'CallExprStat' then
  2108. printExpr(stat.Expression)
  2109. elseif stat.Type == 'AssignmentStat' then
  2110. for index, ex in pairs(stat.Lhs) do
  2111. printExpr(ex)
  2112. local sep = stat.Token_LhsSeparatorList[index]
  2113. if sep then
  2114. printt(sep)
  2115. end
  2116. end
  2117. printt(stat.Token_Equals)
  2118. for index, ex in pairs(stat.Rhs) do
  2119. printExpr(ex)
  2120. local sep = stat.Token_RhsSeparatorList[index]
  2121. if sep then
  2122. printt(sep)
  2123. end
  2124. end
  2125. else
  2126. assert(false, "unreachable")
  2127. end
  2128. end
  2129.  
  2130. printStat(ast)
  2131. end
  2132.  
  2133. -- Adds / removes whitespace in an AST to put it into a "standard formatting"
  2134. local function FormatAst(ast)
  2135. local formatStat, formatExpr;
  2136.  
  2137. local currentIndent = 0
  2138.  
  2139. local function applyIndent(token)
  2140. local indentString = '\n'..('\t'):rep(currentIndent)
  2141. if token.LeadingWhite == '' or (token.LeadingWhite:sub(-#indentString, -1) ~= indentString) then
  2142. -- Trim existing trailing whitespace on LeadingWhite
  2143. -- Trim trailing tabs and spaces, and up to one newline
  2144. token.LeadingWhite = token.LeadingWhite:gsub("\n?[\t ]*$", "")
  2145. token.LeadingWhite = token.LeadingWhite..indentString
  2146. end
  2147. end
  2148.  
  2149. local function indent()
  2150. currentIndent = currentIndent + 1
  2151. end
  2152.  
  2153. local function undent()
  2154. currentIndent = currentIndent - 1
  2155. assert(currentIndent >= 0, "Undented too far")
  2156. end
  2157.  
  2158. local function leadingChar(tk)
  2159. if #tk.LeadingWhite > 0 then
  2160. return tk.LeadingWhite:sub(1,1)
  2161. else
  2162. return tk.Source:sub(1,1)
  2163. end
  2164. end
  2165.  
  2166. local function padToken(tk)
  2167. if not WhiteChars[leadingChar(tk)] then
  2168. tk.LeadingWhite = ' '..tk.LeadingWhite
  2169. end
  2170. end
  2171.  
  2172. local function padExpr(expr)
  2173. padToken(expr:GetFirstToken())
  2174. end
  2175.  
  2176. local function formatBody(openToken, bodyStat, closeToken)
  2177. indent()
  2178. formatStat(bodyStat)
  2179. undent()
  2180. applyIndent(closeToken)
  2181. end
  2182.  
  2183. formatExpr = function(expr)
  2184. if expr.Type == 'BinopExpr' then
  2185. formatExpr(expr.Lhs)
  2186. formatExpr(expr.Rhs)
  2187. if expr.Token_Op.Source == '..' then
  2188. -- No padding on ..
  2189. else
  2190. padExpr(expr.Rhs)
  2191. padToken(expr.Token_Op)
  2192. end
  2193. elseif expr.Type == 'UnopExpr' then
  2194. formatExpr(expr.Rhs)
  2195. --(expr.Token_Op)
  2196. elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
  2197. expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
  2198. expr.Type == 'VargLiteral'
  2199. then
  2200. -- Nothing to do
  2201. --(expr.Token)
  2202. elseif expr.Type == 'FieldExpr' then
  2203. formatExpr(expr.Base)
  2204. --(expr.Token_Dot)
  2205. --(expr.Field)
  2206. elseif expr.Type == 'IndexExpr' then
  2207. formatExpr(expr.Base)
  2208. formatExpr(expr.Index)
  2209. --(expr.Token_OpenBracket)
  2210. --(expr.Token_CloseBracket)
  2211. elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
  2212. formatExpr(expr.Base)
  2213. if expr.Type == 'MethodExpr' then
  2214. --(expr.Token_Colon)
  2215. --(expr.Method)
  2216. end
  2217. if expr.FunctionArguments.CallType == 'StringCall' then
  2218. --(expr.FunctionArguments.Token)
  2219. elseif expr.FunctionArguments.CallType == 'ArgCall' then
  2220. --(expr.FunctionArguments.Token_OpenParen)
  2221. for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
  2222. formatExpr(argExpr)
  2223. if index > 1 then
  2224. padExpr(argExpr)
  2225. end
  2226. local sep = expr.FunctionArguments.Token_CommaList[index]
  2227. if sep then
  2228. --(sep)
  2229. end
  2230. end
  2231. --(expr.FunctionArguments.Token_CloseParen)
  2232. elseif expr.FunctionArguments.CallType == 'TableCall' then
  2233. formatExpr(expr.FunctionArguments.TableExpr)
  2234. end
  2235. elseif expr.Type == 'FunctionLiteral' then
  2236. --(expr.Token_Function)
  2237. --(expr.Token_OpenParen)
  2238. for index, arg in pairs(expr.ArgList) do
  2239. --(arg)
  2240. if index > 1 then
  2241. padToken(arg)
  2242. end
  2243. local comma = expr.Token_ArgCommaList[index]
  2244. if comma then
  2245. --(comma)
  2246. end
  2247. end
  2248. --(expr.Token_CloseParen)
  2249. formatBody(expr.Token_CloseParen, expr.Body, expr.Token_End)
  2250. elseif expr.Type == 'VariableExpr' then
  2251. --(expr.Token)
  2252. elseif expr.Type == 'ParenExpr' then
  2253. formatExpr(expr.Expression)
  2254. --(expr.Token_OpenParen)
  2255. --(expr.Token_CloseParen)
  2256. elseif expr.Type == 'TableLiteral' then
  2257. --(expr.Token_OpenBrace)
  2258. if #expr.EntryList == 0 then
  2259. -- Nothing to do
  2260. else
  2261. indent()
  2262. for index, entry in pairs(expr.EntryList) do
  2263. if entry.EntryType == 'Field' then
  2264. applyIndent(entry.Field)
  2265. padToken(entry.Token_Equals)
  2266. formatExpr(entry.Value)
  2267. padExpr(entry.Value)
  2268. elseif entry.EntryType == 'Index' then
  2269. applyIndent(entry.Token_OpenBracket)
  2270. formatExpr(entry.Index)
  2271. --(entry.Token_CloseBracket)
  2272. padToken(entry.Token_Equals)
  2273. formatExpr(entry.Value)
  2274. padExpr(entry.Value)
  2275. elseif entry.EntryType == 'Value' then
  2276. formatExpr(entry.Value)
  2277. applyIndent(entry.Value:GetFirstToken())
  2278. else
  2279. assert(false, "unreachable")
  2280. end
  2281. local sep = expr.Token_SeparatorList[index]
  2282. if sep then
  2283. --(sep)
  2284. end
  2285. end
  2286. undent()
  2287. applyIndent(expr.Token_CloseBrace)
  2288. end
  2289. --(expr.Token_CloseBrace)
  2290. else
  2291. assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
  2292. end
  2293. end
  2294.  
  2295. formatStat = function(stat)
  2296. if stat.Type == 'StatList' then
  2297. for _, stat in pairs(stat.StatementList) do
  2298. formatStat(stat)
  2299. applyIndent(stat:GetFirstToken())
  2300. end
  2301.  
  2302. elseif stat.Type == 'BreakStat' then
  2303. --(stat.Token_Break)
  2304.  
  2305. elseif stat.Type == 'ReturnStat' then
  2306. --(stat.Token_Return)
  2307. for index, expr in pairs(stat.ExprList) do
  2308. formatExpr(expr)
  2309. padExpr(expr)
  2310. if stat.Token_CommaList[index] then
  2311. --(stat.Token_CommaList[index])
  2312. end
  2313. end
  2314. elseif stat.Type == 'LocalVarStat' then
  2315. --(stat.Token_Local)
  2316. for index, var in pairs(stat.VarList) do
  2317. padToken(var)
  2318. local comma = stat.Token_VarCommaList[index]
  2319. if comma then
  2320. --(comma)
  2321. end
  2322. end
  2323. if stat.Token_Equals then
  2324. padToken(stat.Token_Equals)
  2325. for index, expr in pairs(stat.ExprList) do
  2326. formatExpr(expr)
  2327. padExpr(expr)
  2328. local comma = stat.Token_ExprCommaList[index]
  2329. if comma then
  2330. --(comma)
  2331. end
  2332. end
  2333. end
  2334. elseif stat.Type == 'LocalFunctionStat' then
  2335. --(stat.Token_Local)
  2336. padToken(stat.FunctionStat.Token_Function)
  2337. padToken(stat.FunctionStat.NameChain[1])
  2338. --(stat.FunctionStat.Token_OpenParen)
  2339. for index, arg in pairs(stat.FunctionStat.ArgList) do
  2340. if index > 1 then
  2341. padToken(arg)
  2342. end
  2343. local comma = stat.FunctionStat.Token_ArgCommaList[index]
  2344. if comma then
  2345. --(comma)
  2346. end
  2347. end
  2348. --(stat.FunctionStat.Token_CloseParen)
  2349. formatBody(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End)
  2350. elseif stat.Type == 'FunctionStat' then
  2351. --(stat.Token_Function)
  2352. for index, part in pairs(stat.NameChain) do
  2353. if index == 1 then
  2354. padToken(part)
  2355. end
  2356. local sep = stat.Token_NameChainSeparator[index]
  2357. if sep then
  2358. --(sep)
  2359. end
  2360. end
  2361. --(stat.Token_OpenParen)
  2362. for index, arg in pairs(stat.ArgList) do
  2363. if index > 1 then
  2364. padToken(arg)
  2365. end
  2366. local comma = stat.Token_ArgCommaList[index]
  2367. if comma then
  2368. --(comma)
  2369. end
  2370. end
  2371. --(stat.Token_CloseParen)
  2372. formatBody(stat.Token_CloseParen, stat.Body, stat.Token_End)
  2373. elseif stat.Type == 'RepeatStat' then
  2374. --(stat.Token_Repeat)
  2375. formatBody(stat.Token_Repeat, stat.Body, stat.Token_Until)
  2376. formatExpr(stat.Condition)
  2377. padExpr(stat.Condition)
  2378. elseif stat.Type == 'GenericForStat' then
  2379. --(stat.Token_For)
  2380. for index, var in pairs(stat.VarList) do
  2381. padToken(var)
  2382. local sep = stat.Token_VarCommaList[index]
  2383. if sep then
  2384. --(sep)
  2385. end
  2386. end
  2387. padToken(stat.Token_In)
  2388. for index, expr in pairs(stat.GeneratorList) do
  2389. formatExpr(expr)
  2390. padExpr(expr)
  2391. local sep = stat.Token_GeneratorCommaList[index]
  2392. if sep then
  2393. --(sep)
  2394. end
  2395. end
  2396. padToken(stat.Token_Do)
  2397. formatBody(stat.Token_Do, stat.Body, stat.Token_End)
  2398. elseif stat.Type == 'NumericForStat' then
  2399. --(stat.Token_For)
  2400. for index, var in pairs(stat.VarList) do
  2401. padToken(var)
  2402. local sep = stat.Token_VarCommaList[index]
  2403. if sep then
  2404. --(sep)
  2405. end
  2406. end
  2407. padToken(stat.Token_Equals)
  2408. for index, expr in pairs(stat.RangeList) do
  2409. formatExpr(expr)
  2410. padExpr(expr)
  2411. local sep = stat.Token_RangeCommaList[index]
  2412. if sep then
  2413. --(sep)
  2414. end
  2415. end
  2416. padToken(stat.Token_Do)
  2417. formatBody(stat.Token_Do, stat.Body, stat.Token_End)
  2418. elseif stat.Type == 'WhileStat' then
  2419. --(stat.Token_While)
  2420. formatExpr(stat.Condition)
  2421. padExpr(stat.Condition)
  2422. padToken(stat.Token_Do)
  2423. formatBody(stat.Token_Do, stat.Body, stat.Token_End)
  2424. elseif stat.Type == 'DoStat' then
  2425. --(stat.Token_Do)
  2426. formatBody(stat.Token_Do, stat.Body, stat.Token_End)
  2427. elseif stat.Type == 'IfStat' then
  2428. --(stat.Token_If)
  2429. formatExpr(stat.Condition)
  2430. padExpr(stat.Condition)
  2431. padToken(stat.Token_Then)
  2432. --
  2433. local lastBodyOpen = stat.Token_Then
  2434. local lastBody = stat.Body
  2435. --
  2436. for _, clause in pairs(stat.ElseClauseList) do
  2437. formatBody(lastBodyOpen, lastBody, clause.Token)
  2438. lastBodyOpen = clause.Token
  2439. --
  2440. if clause.Condition then
  2441. formatExpr(clause.Condition)
  2442. padExpr(clause.Condition)
  2443. padToken(clause.Token_Then)
  2444. lastBodyOpen = clause.Token_Then
  2445. end
  2446. lastBody = clause.Body
  2447. end
  2448. --
  2449. formatBody(lastBodyOpen, lastBody, stat.Token_End)
  2450.  
  2451. elseif stat.Type == 'CallExprStat' then
  2452. formatExpr(stat.Expression)
  2453. elseif stat.Type == 'AssignmentStat' then
  2454. for index, ex in pairs(stat.Lhs) do
  2455. formatExpr(ex)
  2456. if index > 1 then
  2457. padExpr(ex)
  2458. end
  2459. local sep = stat.Token_LhsSeparatorList[index]
  2460. if sep then
  2461. --(sep)
  2462. end
  2463. end
  2464. padToken(stat.Token_Equals)
  2465. for index, ex in pairs(stat.Rhs) do
  2466. formatExpr(ex)
  2467. padExpr(ex)
  2468. local sep = stat.Token_RhsSeparatorList[index]
  2469. if sep then
  2470. --(sep)
  2471. end
  2472. end
  2473. else
  2474. assert(false, "unreachable")
  2475. end
  2476. end
  2477.  
  2478. formatStat(ast)
  2479. end
  2480.  
  2481. -- Strips as much whitespace off of tokens in an AST as possible without causing problems
  2482. local function StripAst(ast)
  2483. local stripStat, stripExpr;
  2484.  
  2485. local function stript(token)
  2486. token.LeadingWhite = ''
  2487. end
  2488.  
  2489. -- Make to adjacent tokens as close as possible
  2490. local function joint(tokenA, tokenB)
  2491. -- Strip the second token's whitespace
  2492. stript(tokenB)
  2493.  
  2494. -- Get the trailing A <-> leading B character pair
  2495. local lastCh = tokenA.Source:sub(-1, -1)
  2496. local firstCh = tokenB.Source:sub(1, 1)
  2497.  
  2498. -- Cases to consider:
  2499. -- Touching minus signs -> comment: `- -42` -> `--42' is invalid
  2500. -- Touching words: `a b` -> `ab` is invalid
  2501. -- Touching digits: `2 3`, can't occurr in the Lua syntax as number literals aren't a primary expression
  2502. -- Abiguous syntax: `f(x)\n(x)()` is already disallowed, we can't cause a problem by removing newlines
  2503.  
  2504. -- Figure out what separation is needed
  2505. if
  2506. (lastCh == '-' and firstCh == '-') or
  2507. (AllIdentChars[lastCh] and AllIdentChars[firstCh])
  2508. then
  2509. tokenB.LeadingWhite = ' ' -- Use a separator
  2510. else
  2511. tokenB.LeadingWhite = '' -- Don't use a separator
  2512. end
  2513. end
  2514.  
  2515. -- Join up a statement body and it's opening / closing tokens
  2516. local function bodyjoint(open, body, close)
  2517. stripStat(body)
  2518. stript(close)
  2519. local bodyFirst = body:GetFirstToken()
  2520. local bodyLast = body:GetLastToken()
  2521. if bodyFirst then
  2522. -- Body is non-empty, join body to open / close
  2523. joint(open, bodyFirst)
  2524. joint(bodyLast, close)
  2525. else
  2526. -- Body is empty, just join open and close token together
  2527. joint(open, close)
  2528. end
  2529. end
  2530.  
  2531. stripExpr = function(expr)
  2532. if expr.Type == 'BinopExpr' then
  2533. stripExpr(expr.Lhs)
  2534. stript(expr.Token_Op)
  2535. stripExpr(expr.Rhs)
  2536. -- Handle the `a - -b` -/-> `a--b` case which would otherwise incorrectly generate a comment
  2537. -- Also handles operators "or" / "and" which definitely need joining logic in a bunch of cases
  2538. joint(expr.Token_Op, expr.Rhs:GetFirstToken())
  2539. joint(expr.Lhs:GetLastToken(), expr.Token_Op)
  2540. elseif expr.Type == 'UnopExpr' then
  2541. stript(expr.Token_Op)
  2542. stripExpr(expr.Rhs)
  2543. -- Handle the `- -b` -/-> `--b` case which would otherwise incorrectly generate a comment
  2544. joint(expr.Token_Op, expr.Rhs:GetFirstToken())
  2545. elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
  2546. expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
  2547. expr.Type == 'VargLiteral'
  2548. then
  2549. -- Just print the token
  2550. stript(expr.Token)
  2551. elseif expr.Type == 'FieldExpr' then
  2552. stripExpr(expr.Base)
  2553. stript(expr.Token_Dot)
  2554. stript(expr.Field)
  2555. elseif expr.Type == 'IndexExpr' then
  2556. stripExpr(expr.Base)
  2557. stript(expr.Token_OpenBracket)
  2558. stripExpr(expr.Index)
  2559. stript(expr.Token_CloseBracket)
  2560. elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
  2561. stripExpr(expr.Base)
  2562. if expr.Type == 'MethodExpr' then
  2563. stript(expr.Token_Colon)
  2564. stript(expr.Method)
  2565. end
  2566. if expr.FunctionArguments.CallType == 'StringCall' then
  2567. stript(expr.FunctionArguments.Token)
  2568. elseif expr.FunctionArguments.CallType == 'ArgCall' then
  2569. stript(expr.FunctionArguments.Token_OpenParen)
  2570. for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
  2571. stripExpr(argExpr)
  2572. local sep = expr.FunctionArguments.Token_CommaList[index]
  2573. if sep then
  2574. stript(sep)
  2575. end
  2576. end
  2577. stript(expr.FunctionArguments.Token_CloseParen)
  2578. elseif expr.FunctionArguments.CallType == 'TableCall' then
  2579. stripExpr(expr.FunctionArguments.TableExpr)
  2580. end
  2581. elseif expr.Type == 'FunctionLiteral' then
  2582. stript(expr.Token_Function)
  2583. stript(expr.Token_OpenParen)
  2584. for index, arg in pairs(expr.ArgList) do
  2585. stript(arg)
  2586. local comma = expr.Token_ArgCommaList[index]
  2587. if comma then
  2588. stript(comma)
  2589. end
  2590. end
  2591. stript(expr.Token_CloseParen)
  2592. bodyjoint(expr.Token_CloseParen, expr.Body, expr.Token_End)
  2593. elseif expr.Type == 'VariableExpr' then
  2594. stript(expr.Token)
  2595. elseif expr.Type == 'ParenExpr' then
  2596. stript(expr.Token_OpenParen)
  2597. stripExpr(expr.Expression)
  2598. stript(expr.Token_CloseParen)
  2599. elseif expr.Type == 'TableLiteral' then
  2600. stript(expr.Token_OpenBrace)
  2601. for index, entry in pairs(expr.EntryList) do
  2602. if entry.EntryType == 'Field' then
  2603. stript(entry.Field)
  2604. stript(entry.Token_Equals)
  2605. stripExpr(entry.Value)
  2606. elseif entry.EntryType == 'Index' then
  2607. stript(entry.Token_OpenBracket)
  2608. stripExpr(entry.Index)
  2609. stript(entry.Token_CloseBracket)
  2610. stript(entry.Token_Equals)
  2611. stripExpr(entry.Value)
  2612. elseif entry.EntryType == 'Value' then
  2613. stripExpr(entry.Value)
  2614. else
  2615. assert(false, "unreachable")
  2616. end
  2617. local sep = expr.Token_SeparatorList[index]
  2618. if sep then
  2619. stript(sep)
  2620. end
  2621. end
  2622. stript(expr.Token_CloseBrace)
  2623. else
  2624. assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
  2625. end
  2626. end
  2627.  
  2628. stripStat = function(stat)
  2629. if stat.Type == 'StatList' then
  2630. -- Strip all surrounding whitespace on statement lists along with separating whitespace
  2631. for i = 1, #stat.StatementList do
  2632. local chStat = stat.StatementList[i]
  2633.  
  2634. -- Strip the statement and it's whitespace
  2635. stripStat(chStat)
  2636. stript(chStat:GetFirstToken())
  2637.  
  2638. -- If there was a last statement, join them appropriately
  2639. local lastChStat = stat.StatementList[i-1]
  2640. if lastChStat then
  2641. -- See if we can remove a semi-colon, the only case where we can't is if
  2642. -- this and the last statement have a `);(` pair, where removing the semi-colon
  2643. -- would introduce ambiguous syntax.
  2644. if stat.SemicolonList[i-1] and
  2645. (lastChStat:GetLastToken().Source ~= ')' or chStat:GetFirstToken().Source ~= ')')
  2646. then
  2647. stat.SemicolonList[i-1] = nil
  2648. end
  2649.  
  2650. -- If there isn't a semi-colon, we should safely join the two statements
  2651. -- (If there is one, then no whitespace leading chStat is always okay)
  2652. if not stat.SemicolonList[i-1] then
  2653. joint(lastChStat:GetLastToken(), chStat:GetFirstToken())
  2654. end
  2655. end
  2656. end
  2657.  
  2658. -- A semi-colon is never needed on the last stat in a statlist:
  2659. stat.SemicolonList[#stat.StatementList] = nil
  2660.  
  2661. -- The leading whitespace on the statlist should be stripped
  2662. if #stat.StatementList > 0 then
  2663. stript(stat.StatementList[1]:GetFirstToken())
  2664. end
  2665.  
  2666. elseif stat.Type == 'BreakStat' then
  2667. stript(stat.Token_Break)
  2668.  
  2669. elseif stat.Type == 'ReturnStat' then
  2670. stript(stat.Token_Return)
  2671. for index, expr in pairs(stat.ExprList) do
  2672. stripExpr(expr)
  2673. if stat.Token_CommaList[index] then
  2674. stript(stat.Token_CommaList[index])
  2675. end
  2676. end
  2677. if #stat.ExprList > 0 then
  2678. joint(stat.Token_Return, stat.ExprList[1]:GetFirstToken())
  2679. end
  2680. elseif stat.Type == 'LocalVarStat' then
  2681. stript(stat.Token_Local)
  2682. for index, var in pairs(stat.VarList) do
  2683. if index == 1 then
  2684. joint(stat.Token_Local, var)
  2685. else
  2686. stript(var)
  2687. end
  2688. local comma = stat.Token_VarCommaList[index]
  2689. if comma then
  2690. stript(comma)
  2691. end
  2692. end
  2693. if stat.Token_Equals then
  2694. stript(stat.Token_Equals)
  2695. for index, expr in pairs(stat.ExprList) do
  2696. stripExpr(expr)
  2697. local comma = stat.Token_ExprCommaList[index]
  2698. if comma then
  2699. stript(comma)
  2700. end
  2701. end
  2702. end
  2703. elseif stat.Type == 'LocalFunctionStat' then
  2704. stript(stat.Token_Local)
  2705. joint(stat.Token_Local, stat.FunctionStat.Token_Function)
  2706. joint(stat.FunctionStat.Token_Function, stat.FunctionStat.NameChain[1])
  2707. joint(stat.FunctionStat.NameChain[1], stat.FunctionStat.Token_OpenParen)
  2708. for index, arg in pairs(stat.FunctionStat.ArgList) do
  2709. stript(arg)
  2710. local comma = stat.FunctionStat.Token_ArgCommaList[index]
  2711. if comma then
  2712. stript(comma)
  2713. end
  2714. end
  2715. stript(stat.FunctionStat.Token_CloseParen)
  2716. bodyjoint(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End)
  2717. elseif stat.Type == 'FunctionStat' then
  2718. stript(stat.Token_Function)
  2719. for index, part in pairs(stat.NameChain) do
  2720. if index == 1 then
  2721. joint(stat.Token_Function, part)
  2722. else
  2723. stript(part)
  2724. end
  2725. local sep = stat.Token_NameChainSeparator[index]
  2726. if sep then
  2727. stript(sep)
  2728. end
  2729. end
  2730. stript(stat.Token_OpenParen)
  2731. for index, arg in pairs(stat.ArgList) do
  2732. stript(arg)
  2733. local comma = stat.Token_ArgCommaList[index]
  2734. if comma then
  2735. stript(comma)
  2736. end
  2737. end
  2738. stript(stat.Token_CloseParen)
  2739. bodyjoint(stat.Token_CloseParen, stat.Body, stat.Token_End)
  2740. elseif stat.Type == 'RepeatStat' then
  2741. stript(stat.Token_Repeat)
  2742. bodyjoint(stat.Token_Repeat, stat.Body, stat.Token_Until)
  2743. stripExpr(stat.Condition)
  2744. joint(stat.Token_Until, stat.Condition:GetFirstToken())
  2745. elseif stat.Type == 'GenericForStat' then
  2746. stript(stat.Token_For)
  2747. for index, var in pairs(stat.VarList) do
  2748. if index == 1 then
  2749. joint(stat.Token_For, var)
  2750. else
  2751. stript(var)
  2752. end
  2753. local sep = stat.Token_VarCommaList[index]
  2754. if sep then
  2755. stript(sep)
  2756. end
  2757. end
  2758. joint(stat.VarList[#stat.VarList], stat.Token_In)
  2759. for index, expr in pairs(stat.GeneratorList) do
  2760. stripExpr(expr)
  2761. if index == 1 then
  2762. joint(stat.Token_In, expr:GetFirstToken())
  2763. end
  2764. local sep = stat.Token_GeneratorCommaList[index]
  2765. if sep then
  2766. stript(sep)
  2767. end
  2768. end
  2769. joint(stat.GeneratorList[#stat.GeneratorList]:GetLastToken(), stat.Token_Do)
  2770. bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
  2771. elseif stat.Type == 'NumericForStat' then
  2772. stript(stat.Token_For)
  2773. for index, var in pairs(stat.VarList) do
  2774. if index == 1 then
  2775. joint(stat.Token_For, var)
  2776. else
  2777. stript(var)
  2778. end
  2779. local sep = stat.Token_VarCommaList[index]
  2780. if sep then
  2781. stript(sep)
  2782. end
  2783. end
  2784. joint(stat.VarList[#stat.VarList], stat.Token_Equals)
  2785. for index, expr in pairs(stat.RangeList) do
  2786. stripExpr(expr)
  2787. if index == 1 then
  2788. joint(stat.Token_Equals, expr:GetFirstToken())
  2789. end
  2790. local sep = stat.Token_RangeCommaList[index]
  2791. if sep then
  2792. stript(sep)
  2793. end
  2794. end
  2795. joint(stat.RangeList[#stat.RangeList]:GetLastToken(), stat.Token_Do)
  2796. bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
  2797. elseif stat.Type == 'WhileStat' then
  2798. stript(stat.Token_While)
  2799. stripExpr(stat.Condition)
  2800. stript(stat.Token_Do)
  2801. joint(stat.Token_While, stat.Condition:GetFirstToken())
  2802. joint(stat.Condition:GetLastToken(), stat.Token_Do)
  2803. bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
  2804. elseif stat.Type == 'DoStat' then
  2805. stript(stat.Token_Do)
  2806. stript(stat.Token_End)
  2807. bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
  2808. elseif stat.Type == 'IfStat' then
  2809. stript(stat.Token_If)
  2810. stripExpr(stat.Condition)
  2811. joint(stat.Token_If, stat.Condition:GetFirstToken())
  2812. joint(stat.Condition:GetLastToken(), stat.Token_Then)
  2813. --
  2814. local lastBodyOpen = stat.Token_Then
  2815. local lastBody = stat.Body
  2816. --
  2817. for _, clause in pairs(stat.ElseClauseList) do
  2818. bodyjoint(lastBodyOpen, lastBody, clause.Token)
  2819. lastBodyOpen = clause.Token
  2820. --
  2821. if clause.Condition then
  2822. stripExpr(clause.Condition)
  2823. joint(clause.Token, clause.Condition:GetFirstToken())
  2824. joint(clause.Condition:GetLastToken(), clause.Token_Then)
  2825. lastBodyOpen = clause.Token_Then
  2826. end
  2827. stripStat(clause.Body)
  2828. lastBody = clause.Body
  2829. end
  2830. --
  2831. bodyjoint(lastBodyOpen, lastBody, stat.Token_End)
  2832.  
  2833. elseif stat.Type == 'CallExprStat' then
  2834. stripExpr(stat.Expression)
  2835. elseif stat.Type == 'AssignmentStat' then
  2836. for index, ex in pairs(stat.Lhs) do
  2837. stripExpr(ex)
  2838. local sep = stat.Token_LhsSeparatorList[index]
  2839. if sep then
  2840. stript(sep)
  2841. end
  2842. end
  2843. stript(stat.Token_Equals)
  2844. for index, ex in pairs(stat.Rhs) do
  2845. stripExpr(ex)
  2846. local sep = stat.Token_RhsSeparatorList[index]
  2847. if sep then
  2848. stript(sep)
  2849. end
  2850. end
  2851. else
  2852. assert(false, "unreachable")
  2853. end
  2854. end
  2855.  
  2856. stripStat(ast)
  2857. end
  2858.  
  2859. local idGen = 0
  2860. local VarDigits = {}
  2861. for i = ('a'):byte(), ('z'):byte() do table.insert(VarDigits, string.char(i)) end
  2862. for i = ('A'):byte(), ('Z'):byte() do table.insert(VarDigits, string.char(i)) end
  2863. for i = ('0'):byte(), ('9'):byte() do table.insert(VarDigits, string.char(i)) end
  2864. table.insert(VarDigits, '_')
  2865. local VarStartDigits = {}
  2866. for i = ('a'):byte(), ('z'):byte() do table.insert(VarStartDigits, string.char(i)) end
  2867. for i = ('A'):byte(), ('Z'):byte() do table.insert(VarStartDigits, string.char(i)) end
  2868. local function indexToVarName(index)
  2869. local id = ''
  2870. local d = index % #VarStartDigits
  2871. index = (index - d) / #VarStartDigits
  2872. id = id..VarStartDigits[d+1]
  2873. while index > 0 do
  2874. local d = index % #VarDigits
  2875. index = (index - d) / #VarDigits
  2876. id = id..VarDigits[d+1]
  2877. end
  2878. return id
  2879. end
  2880. local function genNextVarName()
  2881. local varToUse = idGen
  2882. idGen = idGen + 1
  2883. return indexToVarName(varToUse)
  2884. end
  2885. local function genVarName()
  2886. local varName = ''
  2887. repeat
  2888. varName = genNextVarName()
  2889. until not Keywords[varName]
  2890. return varName
  2891. end
  2892. local function MinifyVariables(globalScope, rootScope)
  2893. -- externalGlobals is a set of global variables that have not been assigned to, that is
  2894. -- global variables defined "externally to the script". We are not going to be renaming
  2895. -- those, and we have to make sure that we don't collide with them when renaming
  2896. -- things so we keep track of them in this set.
  2897. local externalGlobals = {}
  2898.  
  2899. -- First we want to rename all of the variables to unique temoraries, so that we can
  2900. -- easily use the scope::GetVar function to check whether renames are valid.
  2901. local temporaryIndex = 0
  2902. for _, var in pairs(globalScope) do
  2903. if var.AssignedTo then
  2904. var:Rename('_TMP_'..temporaryIndex..'_')
  2905. temporaryIndex = temporaryIndex + 1
  2906. else
  2907. -- Not assigned to, external global
  2908. externalGlobals[var.Name] = true
  2909. end
  2910. end
  2911. local function temporaryRename(scope)
  2912. for _, var in pairs(scope.VariableList) do
  2913. var:Rename('_TMP_'..temporaryIndex..'_')
  2914. temporaryIndex = temporaryIndex + 1
  2915. end
  2916. for _, childScope in pairs(scope.ChildScopeList) do
  2917. temporaryRename(childScope)
  2918. end
  2919. end
  2920.  
  2921. -- Now we go through renaming, first do globals, we probably want them
  2922. -- to have shorter names in general.
  2923. -- TODO: Rename all vars based on frequency patterns, giving variables
  2924. -- used more shorter names.
  2925. local nextFreeNameIndex = 0
  2926. for _, var in pairs(globalScope) do
  2927. if var.AssignedTo then
  2928. local varName = ''
  2929. repeat
  2930. varName = indexToVarName(nextFreeNameIndex)
  2931. nextFreeNameIndex = nextFreeNameIndex + 1
  2932. until not Keywords[varName] and not externalGlobals[varName]
  2933. var:Rename(varName)
  2934. end
  2935. end
  2936.  
  2937. -- Now rename all local vars
  2938. rootScope.FirstFreeName = nextFreeNameIndex
  2939. local function doRenameScope(scope)
  2940. for _, var in pairs(scope.VariableList) do
  2941. local varName = ''
  2942. repeat
  2943. varName = indexToVarName(scope.FirstFreeName)
  2944. scope.FirstFreeName = scope.FirstFreeName + 1
  2945. until not Keywords[varName] and not externalGlobals[varName]
  2946. var:Rename(varName)
  2947. end
  2948. for _, childScope in pairs(scope.ChildScopeList) do
  2949. childScope.FirstFreeName = scope.FirstFreeName
  2950. doRenameScope(childScope)
  2951. end
  2952. end
  2953. doRenameScope(rootScope)
  2954. end
  2955.  
  2956. local function MinifyVariables_2(globalScope, rootScope)
  2957. -- Variable names and other names that are fixed, that we cannot use
  2958. -- Either these are Lua keywords, or globals that are not assigned to,
  2959. -- that is environmental globals that are assigned elsewhere beyond our
  2960. -- control.
  2961. local globalUsedNames = {}
  2962. for kw, _ in pairs(Keywords) do
  2963. globalUsedNames[kw] = true
  2964. end
  2965.  
  2966. -- Gather a list of all of the variables that we will rename
  2967. local allVariables = {}
  2968. local allLocalVariables = {}
  2969. do
  2970. -- Add applicable globals
  2971. for _, var in pairs(globalScope) do
  2972. if var.AssignedTo then
  2973. -- We can try to rename this global since it was assigned to
  2974. -- (and thus presumably initialized) in the script we are
  2975. -- minifying.
  2976. table.insert(allVariables, var)
  2977. else
  2978. -- We can't rename this global, mark it as an unusable name
  2979. -- and don't add it to the nename list
  2980. globalUsedNames[var.Name] = true
  2981. end
  2982. end
  2983.  
  2984. -- Recursively add locals, we can rename all of those
  2985. local function addFrom(scope)
  2986. for _, var in pairs(scope.VariableList) do
  2987. table.insert(allVariables, var)
  2988. table.insert(allLocalVariables, var)
  2989. end
  2990. for _, childScope in pairs(scope.ChildScopeList) do
  2991. addFrom(childScope)
  2992. end
  2993. end
  2994. addFrom(rootScope)
  2995. end
  2996.  
  2997. -- Add used name arrays to variables
  2998. for _, var in pairs(allVariables) do
  2999. var.UsedNameArray = {}
  3000. end
  3001.  
  3002. -- Sort the least used variables first
  3003. table.sort(allVariables, function(a, b)
  3004. return #a.RenameList < #b.RenameList
  3005. end)
  3006.  
  3007. -- Lazy generator for valid names to rename to
  3008. local nextValidNameIndex = 0
  3009. local varNamesLazy = {}
  3010. local function varIndexToValidVarName(i)
  3011. local name = varNamesLazy[i]
  3012. if not name then
  3013. repeat
  3014. name = indexToVarName(nextValidNameIndex)
  3015. nextValidNameIndex = nextValidNameIndex + 1
  3016. until not globalUsedNames[name]
  3017. varNamesLazy[i] = name
  3018. end
  3019. return name
  3020. end
  3021.  
  3022. -- For each variable, go to rename it
  3023. for _, var in pairs(allVariables) do
  3024. -- Lazy... todo: Make theis pair a proper for-each-pair-like set of loops
  3025. -- rather than using a renamed flag.
  3026. var.Renamed = true
  3027.  
  3028. -- Find the first unused name
  3029. local i = 1
  3030. while var.UsedNameArray[i] do
  3031. i = i + 1
  3032. end
  3033.  
  3034. -- Rename the variable to that name
  3035. var:Rename(varIndexToValidVarName(i))
  3036.  
  3037. if var.Scope then
  3038. -- Now we need to mark the name as unusable by any variables:
  3039. -- 1) At the same depth that overlap lifetime with this one
  3040. -- 2) At a deeper level, which have a reference to this variable in their lifetimes
  3041. -- 3) At a shallower level, which are referenced during this variable's lifetime
  3042. for _, otherVar in pairs(allVariables) do
  3043. if not otherVar.Renamed then
  3044. if not otherVar.Scope or otherVar.Scope.Depth < var.Scope.Depth then
  3045. -- Check Global variable (Which is always at a shallower level)
  3046. -- or
  3047. -- Check case 3
  3048. -- The other var is at a shallower depth, is there a reference to it
  3049. -- durring this variable's lifetime?
  3050. for _, refAt in pairs(otherVar.ReferenceLocationList) do
  3051. if refAt >= var.BeginLocation and refAt <= var.ScopeEndLocation then
  3052. -- Collide
  3053. otherVar.UsedNameArray[i] = true
  3054. break
  3055. end
  3056. end
  3057.  
  3058. elseif otherVar.Scope.Depth > var.Scope.Depth then
  3059. -- Check Case 2
  3060. -- The other var is at a greater depth, see if any of the references
  3061. -- to this variable are in the other var's lifetime.
  3062. for _, refAt in pairs(var.ReferenceLocationList) do
  3063. if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then
  3064. -- Collide
  3065. otherVar.UsedNameArray[i] = true
  3066. break
  3067. end
  3068. end
  3069.  
  3070. else --otherVar.Scope.Depth must be equal to var.Scope.Depth
  3071. -- Check case 1
  3072. -- The two locals are in the same scope
  3073. -- Just check if the usage lifetimes overlap within that scope. That is, we
  3074. -- can shadow a local variable within the same scope as long as the usages
  3075. -- of the two locals do not overlap.
  3076. if var.BeginLocation < otherVar.EndLocation and
  3077. var.EndLocation > otherVar.BeginLocation
  3078. then
  3079. otherVar.UsedNameArray[i] = true
  3080. end
  3081. end
  3082. end
  3083. end
  3084. else
  3085. -- This is a global var, all other globals can't collide with it, and
  3086. -- any local variable with a reference to this global in it's lifetime
  3087. -- can't collide with it.
  3088. for _, otherVar in pairs(allVariables) do
  3089. if not otherVar.Renamed then
  3090. if otherVar.Type == 'Global' then
  3091. otherVar.UsedNameArray[i] = true
  3092. elseif otherVar.Type == 'Local' then
  3093. -- Other var is a local, see if there is a reference to this global within
  3094. -- that local's lifetime.
  3095. for _, refAt in pairs(var.ReferenceLocationList) do
  3096. if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then
  3097. -- Collide
  3098. otherVar.UsedNameArray[i] = true
  3099. break
  3100. end
  3101. end
  3102. else
  3103. assert(false, "unreachable")
  3104. end
  3105. end
  3106. end
  3107. end
  3108. end
  3109.  
  3110.  
  3111. -- --
  3112. -- print("Total Variables: "..#allVariables)
  3113. -- print("Total Range: "..rootScope.BeginLocation.."-"..rootScope.EndLocation)
  3114. -- print("")
  3115. -- for _, var in pairs(allVariables) do
  3116. -- io.write("`"..var.Name.."':\n\t#symbols: "..#var.RenameList..
  3117. -- "\n\tassigned to: "..tostring(var.AssignedTo))
  3118. -- if var.Type == 'Local' then
  3119. -- io.write("\n\trange: "..var.BeginLocation.."-"..var.EndLocation)
  3120. -- io.write("\n\tlocal type: "..var.Info.Type)
  3121. -- end
  3122. -- io.write("\n\n")
  3123. -- end
  3124.  
  3125. -- -- First we want to rename all of the variables to unique temoraries, so that we can
  3126. -- -- easily use the scope::GetVar function to check whether renames are valid.
  3127. -- local temporaryIndex = 0
  3128. -- for _, var in pairs(allVariables) do
  3129. -- var:Rename('_TMP_'..temporaryIndex..'_')
  3130. -- temporaryIndex = temporaryIndex + 1
  3131. -- end
  3132.  
  3133. -- For each variable, we need to build a list of names that collide with it
  3134.  
  3135. --
  3136. --error()
  3137. end
  3138.  
  3139. local function BeautifyVariables(globalScope, rootScope)
  3140. local externalGlobals = {}
  3141. for _, var in pairs(globalScope) do
  3142. if not var.AssignedTo then
  3143. externalGlobals[var.Name] = true
  3144. end
  3145. end
  3146.  
  3147. local localNumber = 1
  3148. local globalNumber = 1
  3149.  
  3150. local function setVarName(var, name)
  3151. var.Name = name
  3152. for _, setter in pairs(var.RenameList) do
  3153. setter(name)
  3154. end
  3155. end
  3156.  
  3157. for _, var in pairs(globalScope) do
  3158. if var.AssignedTo then
  3159. setVarName(var, 'G_'..globalNumber)
  3160. globalNumber = globalNumber + 1
  3161. end
  3162. end
  3163.  
  3164. local function modify(scope)
  3165. for _, var in pairs(scope.VariableList) do
  3166. local name = 'L_'..localNumber..'_'
  3167. if var.Info.Type == 'Argument' then
  3168. name = name..'arg'..var.Info.Index
  3169. elseif var.Info.Type == 'LocalFunction' then
  3170. name = name..'func'
  3171. elseif var.Info.Type == 'ForRange' then
  3172. name = name..'forvar'..var.Info.Index
  3173. end
  3174. setVarName(var, name)
  3175. localNumber = localNumber + 1
  3176. end
  3177. for _, scope in pairs(scope.ChildScopeList) do
  3178. modify(scope)
  3179. end
  3180. end
  3181. modify(rootScope)
  3182. end
  3183.  
  3184. local function usageError()
  3185. error(
  3186. "\nusage: minify <file> or unminify <file>\n" ..
  3187. " The modified code will be printed to the stdout, pipe it to a file, the\n" ..
  3188. " lua interpreter, or something else as desired EG:\n\n" ..
  3189. " lua minify.lua minify input.lua > output.lua\n\n" ..
  3190. " * minify will minify the code in the file.\n" ..
  3191. " * unminify will beautify the code and replace the variable names with easily\n" ..
  3192. " find-replacable ones to aide in reverse engineering minified code.\n", 0)
  3193. end
  3194.  
  3195. local args = {...}
  3196. if #args ~= 2 then
  3197. usageError()
  3198. end
  3199.  
  3200. local sourceFile = io.open(args[2], 'r')
  3201. if not sourceFile then
  3202. error("Could not open the input file `" .. args[2] .. "`", 0)
  3203. end
  3204.  
  3205. local data = sourceFile:read('*all')
  3206. local ast = CreateLuaParser(data)
  3207. local global_scope, root_scope = AddVariableInfo(ast)
  3208.  
  3209. local function minify(ast, global_scope, root_scope)
  3210. MinifyVariables(global_scope, root_scope)
  3211. StripAst(ast)
  3212. PrintAst(ast)
  3213. end
  3214.  
  3215. local function beautify(ast, global_scope, root_scope)
  3216. BeautifyVariables(global_scope, root_scope)
  3217. FormatAst(ast)
  3218. PrintAst(ast)
  3219. end
  3220.  
  3221. if args[1] == 'minify' then
  3222. minify(ast, global_scope, root_scope)
  3223. elseif args[1] == 'unminify' then
  3224. beautify(ast, global_scope, root_scope)
  3225. else
  3226. usageError()
  3227. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement