Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- --[[
- MIT License
- Copyright (c) 2017 Mark Langen
- Permission is hereby granted, free of charge, to any person obtaining a copy
- of this software and associated documentation files (the "Software"), to deal
- in the Software without restriction, including without limitation the rights
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- copies of the Software, and to permit persons to whom the Software is
- furnished to do so, subject to the following conditions:
- The above copyright notice and this permission notice shall be included in all
- copies or substantial portions of the Software.
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- SOFTWARE.
- ]]
- function lookupify(tb)
- for _, v in pairs(tb) do
- tb[v] = true
- end
- return tb
- end
- function CountTable(tb)
- local c = 0
- for _ in pairs(tb) do c = c + 1 end
- return c
- end
- function FormatTableInt(tb, atIndent, ignoreFunc)
- if tb.Print then
- return tb.Print()
- end
- atIndent = atIndent or 0
- local useNewlines = (CountTable(tb) > 1)
- local baseIndent = string.rep(' ', atIndent+1)
- local out = "{"..(useNewlines and '\n' or '')
- for k, v in pairs(tb) do
- if type(v) ~= 'function' and not ignoreFunc(k) then
- out = out..(useNewlines and baseIndent or '')
- if type(k) == 'number' then
- --nothing to do
- elseif type(k) == 'string' and k:match("^[A-Za-z_][A-Za-z0-9_]*$") then
- out = out..k.." = "
- elseif type(k) == 'string' then
- out = out.."[\""..k.."\"] = "
- else
- out = out.."["..tostring(k).."] = "
- end
- if type(v) == 'string' then
- out = out.."\""..v.."\""
- elseif type(v) == 'number' then
- out = out..v
- elseif type(v) == 'table' then
- out = out..FormatTableInt(v, atIndent+(useNewlines and 1 or 0), ignoreFunc)
- else
- out = out..tostring(v)
- end
- if next(tb, k) then
- out = out..","
- end
- if useNewlines then
- out = out..'\n'
- end
- end
- end
- out = out..(useNewlines and string.rep(' ', atIndent) or '').."}"
- return out
- end
- function FormatTable(tb, ignoreFunc)
- ignoreFunc = ignoreFunc or function()
- return false
- end
- return FormatTableInt(tb, 0, ignoreFunc)
- end
- local WhiteChars = lookupify{' ', '\n', '\t', '\r'}
- local EscapeForCharacter = {['\r'] = '\\r', ['\n'] = '\\n', ['\t'] = '\\t', ['"'] = '\\"', ["'"] = "\\'", ['\\'] = '\\'}
- local CharacterForEscape = {['r'] = '\r', ['n'] = '\n', ['t'] = '\t', ['"'] = '"', ["'"] = "'", ['\\'] = '\\'}
- local AllIdentStartChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
- 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
- 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
- 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
- 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_'}
- local AllIdentChars = lookupify{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
- 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r',
- 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
- 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
- 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_',
- '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
- local Digits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
- local HexDigits = lookupify{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
- 'A', 'a', 'B', 'b', 'C', 'c', 'D', 'd', 'E', 'e', 'F', 'f'}
- local Symbols = lookupify{'+', '-', '*', '/', '^', '%', ',', '{', '}', '[', ']', '(', ')', ';', '#', '.', ':'}
- local EqualSymbols = lookupify{'~', '=', '>', '<'}
- local Keywords = lookupify{
- 'and', 'break', 'do', 'else', 'elseif',
- 'end', 'false', 'for', 'function', 'goto', 'if',
- 'in', 'local', 'nil', 'not', 'or', 'repeat',
- 'return', 'then', 'true', 'until', 'while',
- };
- local BlockFollowKeyword = lookupify{'else', 'elseif', 'until', 'end'}
- local UnopSet = lookupify{'-', 'not', '#'}
- local BinopSet = lookupify{
- '+', '-', '*', '/', '%', '^', '#',
- '..', '.', ':',
- '>', '<', '<=', '>=', '~=', '==',
- 'and', 'or'
- }
- local GlobalRenameIgnore = lookupify{
- }
- local BinaryPriority = {
- ['+'] = {6, 6};
- ['-'] = {6, 6};
- ['*'] = {7, 7};
- ['/'] = {7, 7};
- ['%'] = {7, 7};
- ['^'] = {10, 9};
- ['..'] = {5, 4};
- ['=='] = {3, 3};
- ['~='] = {3, 3};
- ['>'] = {3, 3};
- ['<'] = {3, 3};
- ['>='] = {3, 3};
- ['<='] = {3, 3};
- ['and'] = {2, 2};
- ['or'] = {1, 1};
- };
- local UnaryPriority = 8
- -- Eof, Ident, Keyword, Number, String, Symbol
- function CreateLuaTokenStream(text)
- -- Tracking for the current position in the buffer, and
- -- the current line / character we are on.
- local p = 1
- local length = #text
- -- Output buffer for tokens
- local tokenBuffer = {}
- -- Get a character, or '' if at eof
- local function look(n)
- n = p + (n or 0)
- if n <= length then
- return text:sub(n, n)
- else
- return ''
- end
- end
- local function get()
- if p <= length then
- local c = text:sub(p, p)
- p = p + 1
- return c
- else
- return ''
- end
- end
- -- Error
- local olderr = error
- local function error(str)
- local q = 1
- local line = 1
- local char = 1
- while q <= p do
- if text:sub(q, q) == '\n' then
- line = line + 1
- char = 1
- else
- char = char + 1
- end
- q = q + 1
- end
- for _, token in pairs(tokenBuffer) do
- print(token.Type.."<"..token.Source..">")
- end
- olderr("file<"..line..":"..char..">: "..str)
- end
- -- Consume a long data with equals count of `eqcount'
- local function longdata(eqcount)
- while true do
- local c = get()
- if c == '' then
- error("Unfinished long string.")
- elseif c == ']' then
- local done = true -- Until contested
- for i = 1, eqcount do
- if look() == '=' then
- p = p + 1
- else
- done = false
- break
- end
- end
- if done and get() == ']' then
- return
- end
- end
- end
- end
- -- Get the opening part for a long data `[` `=`* `[`
- -- Precondition: The first `[` has been consumed
- -- Return: nil or the equals count
- local function getopen()
- local startp = p
- while look() == '=' do
- p = p + 1
- end
- if look() == '[' then
- p = p + 1
- return p - startp - 1
- else
- p = startp
- return nil
- end
- end
- -- Add token
- local whiteStart = 1
- local tokenStart = 1
- local function token(type)
- local tk = {
- Type = type;
- LeadingWhite = text:sub(whiteStart, tokenStart-1);
- Source = text:sub(tokenStart, p-1);
- }
- table.insert(tokenBuffer, tk)
- whiteStart = p
- tokenStart = p
- return tk
- end
- -- Parse tokens loop
- while true do
- -- Mark the whitespace start
- whiteStart = p
- -- Get the leading whitespace + comments
- while true do
- local c = look()
- if c == '' then
- break
- elseif c == '-' then
- if look(1) == '-' then
- p = p + 2
- -- Consume comment body
- if look() == '[' then
- p = p + 1
- local eqcount = getopen()
- if eqcount then
- -- Long comment body
- longdata(eqcount)
- else
- -- Normal comment body
- while true do
- local c2 = get()
- if c2 == '' or c2 == '\n' then
- break
- end
- end
- end
- else
- -- Normal comment body
- while true do
- local c2 = get()
- if c2 == '' or c2 == '\n' then
- break
- end
- end
- end
- else
- break
- end
- elseif WhiteChars[c] then
- p = p + 1
- else
- break
- end
- end
- local leadingWhite = text:sub(whiteStart, p-1)
- -- Mark the token start
- tokenStart = p
- -- Switch on token type
- local c1 = get()
- if c1 == '' then
- -- End of file
- token('Eof')
- break
- elseif c1 == '\'' or c1 == '\"' then
- -- String constant
- while true do
- local c2 = get()
- if c2 == '\\' then
- local c3 = get()
- local esc = CharacterForEscape[c3]
- if not esc then
- error("Invalid Escape Sequence `"..c3.."`.")
- end
- elseif c2 == c1 then
- break
- end
- end
- token('String')
- elseif AllIdentStartChars[c1] then
- -- Ident or Keyword
- while AllIdentChars[look()] do
- p = p + 1
- end
- if Keywords[text:sub(tokenStart, p-1)] then
- token('Keyword')
- else
- token('Ident')
- end
- elseif Digits[c1] or (c1 == '.' and Digits[look()]) then
- -- Number
- if c1 == '0' and look() == 'x' then
- p = p + 1
- -- Hex number
- while HexDigits[look()] do
- p = p + 1
- end
- else
- -- Normal Number
- while Digits[look()] do
- p = p + 1
- end
- if look() == '.' then
- -- With decimal point
- p = p + 1
- while Digits[look()] do
- p = p + 1
- end
- end
- if look() == 'e' or look() == 'E' then
- -- With exponent
- p = p + 1
- if look() == '-' then
- p = p + 1
- end
- while Digits[look()] do
- p = p + 1
- end
- end
- end
- token('Number')
- elseif c1 == '[' then
- -- '[' Symbol or Long String
- local eqCount = getopen()
- if eqCount then
- -- Long string
- longdata(eqCount)
- token('String')
- else
- -- Symbol
- token('Symbol')
- end
- elseif c1 == '.' then
- -- Greedily consume up to 3 `.` for . / .. / ... tokens
- if look() == '.' then
- get()
- if look() == '.' then
- get()
- end
- end
- token('Symbol')
- elseif EqualSymbols[c1] then
- if look() == '=' then
- p = p + 1
- end
- token('Symbol')
- elseif Symbols[c1] then
- token('Symbol')
- else
- error("Bad symbol `"..c1.."` in source.")
- end
- end
- return tokenBuffer
- end
- function CreateLuaParser(text)
- -- Token stream and pointer into it
- local tokens = CreateLuaTokenStream(text)
- -- for _, tok in pairs(tokens) do
- -- print(tok.Type..": "..tok.Source)
- -- end
- local p = 1
- local function get()
- local tok = tokens[p]
- if p < #tokens then
- p = p + 1
- end
- return tok
- end
- local function peek(n)
- n = p + (n or 0)
- return tokens[n] or tokens[#tokens]
- end
- local function getTokenStartPosition(token)
- local line = 1
- local char = 0
- local tkNum = 1
- while true do
- local tk = tokens[tkNum]
- local text;
- if tk == token then
- text = tk.LeadingWhite
- else
- text = tk.LeadingWhite..tk.Source
- end
- for i = 1, #text do
- local c = text:sub(i, i)
- if c == '\n' then
- line = line + 1
- char = 0
- else
- char = char + 1
- end
- end
- if tk == token then
- break
- end
- tkNum = tkNum + 1
- end
- return line..":"..(char+1)
- end
- local function debugMark()
- local tk = peek()
- return "<"..tk.Type.." `"..tk.Source.."`> at: "..getTokenStartPosition(tk)
- end
- local function isBlockFollow()
- local tok = peek()
- return tok.Type == 'Eof' or (tok.Type == 'Keyword' and BlockFollowKeyword[tok.Source])
- end
- local function isUnop()
- return UnopSet[peek().Source] or false
- end
- local function isBinop()
- return BinopSet[peek().Source] or false
- end
- local function expect(type, source)
- local tk = peek()
- if tk.Type == type and (source == nil or tk.Source == source) then
- return get()
- else
- for i = -3, 3 do
- print("Tokens["..i.."] = `"..peek(i).Source.."`")
- end
- if source then
- error(getTokenStartPosition(tk)..": `"..source.."` expected.")
- else
- error(getTokenStartPosition(tk)..": "..type.." expected.")
- end
- end
- end
- local function MkNode(node)
- local getf = node.GetFirstToken
- local getl = node.GetLastToken
- function node:GetFirstToken()
- local t = getf(self)
- assert(t)
- return t
- end
- function node:GetLastToken()
- local t = getl(self)
- assert(t)
- return t
- end
- return node
- end
- -- Forward decls
- local block;
- local expr;
- -- Expression list
- local function exprlist()
- local exprList = {}
- local commaList = {}
- table.insert(exprList, expr())
- while peek().Source == ',' do
- table.insert(commaList, get())
- table.insert(exprList, expr())
- end
- return exprList, commaList
- end
- local function prefixexpr()
- local tk = peek()
- if tk.Source == '(' then
- local oparenTk = get()
- local inner = expr()
- local cparenTk = expect('Symbol', ')')
- return MkNode{
- Type = 'ParenExpr';
- Expression = inner;
- Token_OpenParen = oparenTk;
- Token_CloseParen = cparenTk;
- GetFirstToken = function(self)
- return self.Token_OpenParen
- end;
- GetLastToken = function(self)
- return self.Token_CloseParen
- end;
- }
- elseif tk.Type == 'Ident' then
- return MkNode{
- Type = 'VariableExpr';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- else
- print(debugMark())
- error(getTokenStartPosition(tk)..": Unexpected symbol")
- end
- end
- function tableexpr()
- local obrace = expect('Symbol', '{')
- local entries = {}
- local separators = {}
- while peek().Source ~= '}' do
- if peek().Source == '[' then
- -- Index
- local obrac = get()
- local index = expr()
- local cbrac = expect('Symbol', ']')
- local eq = expect('Symbol', '=')
- local value = expr()
- table.insert(entries, {
- EntryType = 'Index';
- Index = index;
- Value = value;
- Token_OpenBracket = obrac;
- Token_CloseBracket = cbrac;
- Token_Equals = eq;
- })
- elseif peek().Type == 'Ident' and peek(1).Source == '=' then
- -- Field
- local field = get()
- local eq = get()
- local value = expr()
- table.insert(entries, {
- EntryType = 'Field';
- Field = field;
- Value = value;
- Token_Equals = eq;
- })
- else
- -- Value
- local value = expr()
- table.insert(entries, {
- EntryType = 'Value';
- Value = value;
- })
- end
- -- Comma or Semicolon separator
- if peek().Source == ',' or peek().Source == ';' then
- table.insert(separators, get())
- else
- break
- end
- end
- local cbrace = expect('Symbol', '}')
- return MkNode{
- Type = 'TableLiteral';
- EntryList = entries;
- Token_SeparatorList = separators;
- Token_OpenBrace = obrace;
- Token_CloseBrace = cbrace;
- GetFirstToken = function(self)
- return self.Token_OpenBrace
- end;
- GetLastToken = function(self)
- return self.Token_CloseBrace
- end;
- }
- end
- -- List of identifiers
- local function varlist()
- local varList = {}
- local commaList = {}
- if peek().Type == 'Ident' then
- table.insert(varList, get())
- end
- while peek().Source == ',' do
- table.insert(commaList, get())
- local id = expect('Ident')
- table.insert(varList, id)
- end
- return varList, commaList
- end
- -- Body
- local function blockbody(terminator)
- local body = block()
- local after = peek()
- if after.Type == 'Keyword' and after.Source == terminator then
- get()
- return body, after
- else
- print(after.Type, after.Source)
- error(getTokenStartPosition(after)..": "..terminator.." expected.")
- end
- end
- -- Function declaration
- local function funcdecl(isAnonymous)
- local functionKw = get()
- --
- local nameChain;
- local nameChainSeparator;
- --
- if not isAnonymous then
- nameChain = {}
- nameChainSeparator = {}
- --
- table.insert(nameChain, expect('Ident'))
- --
- while peek().Source == '.' do
- table.insert(nameChainSeparator, get())
- table.insert(nameChain, expect('Ident'))
- end
- if peek().Source == ':' then
- table.insert(nameChainSeparator, get())
- table.insert(nameChain, expect('Ident'))
- end
- end
- --
- local oparenTk = expect('Symbol', '(')
- local argList, argCommaList = varlist()
- local cparenTk = expect('Symbol', ')')
- local fbody, enTk = blockbody('end')
- --
- return MkNode{
- Type = (isAnonymous and 'FunctionLiteral' or 'FunctionStat');
- NameChain = nameChain;
- ArgList = argList;
- Body = fbody;
- --
- Token_Function = functionKw;
- Token_NameChainSeparator = nameChainSeparator;
- Token_OpenParen = oparenTk;
- Token_ArgCommaList = argCommaList;
- Token_CloseParen = cparenTk;
- Token_End = enTk;
- GetFirstToken = function(self)
- return self.Token_Function
- end;
- GetLastToken = function(self)
- return self.Token_End;
- end;
- }
- end
- -- Argument list passed to a funciton
- local function functionargs()
- local tk = peek()
- if tk.Source == '(' then
- local oparenTk = get()
- local argList = {}
- local argCommaList = {}
- while peek().Source ~= ')' do
- table.insert(argList, expr())
- if peek().Source == ',' then
- table.insert(argCommaList, get())
- else
- break
- end
- end
- local cparenTk = expect('Symbol', ')')
- return MkNode{
- CallType = 'ArgCall';
- ArgList = argList;
- --
- Token_CommaList = argCommaList;
- Token_OpenParen = oparenTk;
- Token_CloseParen = cparenTk;
- GetFirstToken = function(self)
- return self.Token_OpenParen
- end;
- GetLastToken = function(self)
- return self.Token_CloseParen
- end;
- }
- elseif tk.Source == '{' then
- return MkNode{
- CallType = 'TableCall';
- TableExpr = expr();
- GetFirstToken = function(self)
- return self.TableExpr:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.TableExpr:GetLastToken()
- end;
- }
- elseif tk.Type == 'String' then
- return MkNode{
- CallType = 'StringCall';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- else
- error("Function arguments expected.")
- end
- end
- local function primaryexpr()
- local base = prefixexpr()
- assert(base, "nil prefixexpr")
- while true do
- local tk = peek()
- if tk.Source == '.' then
- local dotTk = get()
- local fieldName = expect('Ident')
- base = MkNode{
- Type = 'FieldExpr';
- Base = base;
- Field = fieldName;
- Token_Dot = dotTk;
- GetFirstToken = function(self)
- return self.Base:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.Field
- end;
- }
- elseif tk.Source == ':' then
- local colonTk = get()
- local methodName = expect('Ident')
- local fargs = functionargs()
- base = MkNode{
- Type = 'MethodExpr';
- Base = base;
- Method = methodName;
- FunctionArguments = fargs;
- Token_Colon = colonTk;
- GetFirstToken = function(self)
- return self.Base:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.FunctionArguments:GetLastToken()
- end;
- }
- elseif tk.Source == '[' then
- local obrac = get()
- local index = expr()
- local cbrac = expect('Symbol', ']')
- base = MkNode{
- Type = 'IndexExpr';
- Base = base;
- Index = index;
- Token_OpenBracket = obrac;
- Token_CloseBracket = cbrac;
- GetFirstToken = function(self)
- return self.Base:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.Token_CloseBracket
- end;
- }
- elseif tk.Source == '{' then
- base = MkNode{
- Type = 'CallExpr';
- Base = base;
- FunctionArguments = functionargs();
- GetFirstToken = function(self)
- return self.Base:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.FunctionArguments:GetLastToken()
- end;
- }
- elseif tk.Source == '(' then
- base = MkNode{
- Type = 'CallExpr';
- Base = base;
- FunctionArguments = functionargs();
- GetFirstToken = function(self)
- return self.Base:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.FunctionArguments:GetLastToken()
- end;
- }
- else
- return base
- end
- end
- end
- local function simpleexpr()
- local tk = peek()
- if tk.Type == 'Number' then
- return MkNode{
- Type = 'NumberLiteral';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- elseif tk.Type == 'String' then
- return MkNode{
- Type = 'StringLiteral';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- elseif tk.Source == 'nil' then
- return MkNode{
- Type = 'NilLiteral';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- elseif tk.Source == 'true' or tk.Source == 'false' then
- return MkNode{
- Type = 'BooleanLiteral';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- elseif tk.Source == '...' then
- return MkNode{
- Type = 'VargLiteral';
- Token = get();
- GetFirstToken = function(self)
- return self.Token
- end;
- GetLastToken = function(self)
- return self.Token
- end;
- }
- elseif tk.Source == '{' then
- return tableexpr()
- elseif tk.Source == 'function' then
- return funcdecl(true)
- else
- return primaryexpr()
- end
- end
- local function subexpr(limit)
- local curNode;
- -- Initial Base Expression
- if isUnop() then
- local opTk = get()
- local ex = subexpr(UnaryPriority)
- curNode = MkNode{
- Type = 'UnopExpr';
- Token_Op = opTk;
- Rhs = ex;
- GetFirstToken = function(self)
- return self.Token_Op
- end;
- GetLastToken = function(self)
- return self.Rhs:GetLastToken()
- end;
- }
- else
- curNode = simpleexpr()
- assert(curNode, "nil simpleexpr")
- end
- -- Apply Precedence Recursion Chain
- while isBinop() and BinaryPriority[peek().Source][1] > limit do
- local opTk = get()
- local rhs = subexpr(BinaryPriority[opTk.Source][2])
- assert(rhs, "RhsNeeded")
- curNode = MkNode{
- Type = 'BinopExpr';
- Lhs = curNode;
- Rhs = rhs;
- Token_Op = opTk;
- GetFirstToken = function(self)
- return self.Lhs:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.Rhs:GetLastToken()
- end;
- }
- end
- -- Return result
- return curNode
- end
- -- Expression
- expr = function()
- return subexpr(0)
- end
- -- Expression statement
- local function exprstat()
- local ex = primaryexpr()
- if ex.Type == 'MethodExpr' or ex.Type == 'CallExpr' then
- -- all good, calls can be statements
- return MkNode{
- Type = 'CallExprStat';
- Expression = ex;
- GetFirstToken = function(self)
- return self.Expression:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.Expression:GetLastToken()
- end;
- }
- else
- -- Assignment expr
- local lhs = {ex}
- local lhsSeparator = {}
- while peek().Source == ',' do
- table.insert(lhsSeparator, get())
- local lhsPart = primaryexpr()
- if lhsPart.Type == 'MethodExpr' or lhsPart.Type == 'CallExpr' then
- error("Bad left hand side of assignment")
- end
- table.insert(lhs, lhsPart)
- end
- local eq = expect('Symbol', '=')
- local rhs = {expr()}
- local rhsSeparator = {}
- while peek().Source == ',' do
- table.insert(rhsSeparator, get())
- table.insert(rhs, expr())
- end
- return MkNode{
- Type = 'AssignmentStat';
- Rhs = rhs;
- Lhs = lhs;
- Token_Equals = eq;
- Token_LhsSeparatorList = lhsSeparator;
- Token_RhsSeparatorList = rhsSeparator;
- GetFirstToken = function(self)
- return self.Lhs[1]:GetFirstToken()
- end;
- GetLastToken = function(self)
- return self.Rhs[#self.Rhs]:GetLastToken()
- end;
- }
- end
- end
- -- If statement
- local function ifstat()
- local ifKw = get()
- local condition = expr()
- local thenKw = expect('Keyword', 'then')
- local ifBody = block()
- local elseClauses = {}
- while peek().Source == 'elseif' or peek().Source == 'else' do
- local elseifKw = get()
- local elseifCondition, elseifThenKw;
- if elseifKw.Source == 'elseif' then
- elseifCondition = expr()
- elseifThenKw = expect('Keyword', 'then')
- end
- local elseifBody = block()
- table.insert(elseClauses, {
- Condition = elseifCondition;
- Body = elseifBody;
- --
- ClauseType = elseifKw.Source;
- Token = elseifKw;
- Token_Then = elseifThenKw;
- })
- if elseifKw.Source == 'else' then
- break
- end
- end
- local enKw = expect('Keyword', 'end')
- return MkNode{
- Type = 'IfStat';
- Condition = condition;
- Body = ifBody;
- ElseClauseList = elseClauses;
- --
- Token_If = ifKw;
- Token_Then = thenKw;
- Token_End = enKw;
- GetFirstToken = function(self)
- return self.Token_If
- end;
- GetLastToken = function(self)
- return self.Token_End
- end;
- }
- end
- -- Do statement
- local function dostat()
- local doKw = get()
- local body, enKw = blockbody('end')
- --
- return MkNode{
- Type = 'DoStat';
- Body = body;
- --
- Token_Do = doKw;
- Token_End = enKw;
- GetFirstToken = function(self)
- return self.Token_Do
- end;
- GetLastToken = function(self)
- return self.Token_End
- end;
- }
- end
- -- While statement
- local function whilestat()
- local whileKw = get()
- local condition = expr()
- local doKw = expect('Keyword', 'do')
- local body, enKw = blockbody('end')
- --
- return MkNode{
- Type = 'WhileStat';
- Condition = condition;
- Body = body;
- --
- Token_While = whileKw;
- Token_Do = doKw;
- Token_End = enKw;
- GetFirstToken = function(self)
- return self.Token_While
- end;
- GetLastToken = function(self)
- return self.Token_End
- end;
- }
- end
- -- For statement
- local function forstat()
- local forKw = get()
- local loopVars, loopVarCommas = varlist()
- local node = {}
- if peek().Source == '=' then
- local eqTk = get()
- local exprList, exprCommaList = exprlist()
- if #exprList < 2 or #exprList > 3 then
- error("expected 2 or 3 values for range bounds")
- end
- local doTk = expect('Keyword', 'do')
- local body, enTk = blockbody('end')
- return MkNode{
- Type = 'NumericForStat';
- VarList = loopVars;
- RangeList = exprList;
- Body = body;
- --
- Token_For = forKw;
- Token_VarCommaList = loopVarCommas;
- Token_Equals = eqTk;
- Token_RangeCommaList = exprCommaList;
- Token_Do = doTk;
- Token_End = enTk;
- GetFirstToken = function(self)
- return self.Token_For
- end;
- GetLastToken = function(self)
- return self.Token_End
- end;
- }
- elseif peek().Source == 'in' then
- local inTk = get()
- local exprList, exprCommaList = exprlist()
- local doTk = expect('Keyword', 'do')
- local body, enTk = blockbody('end')
- return MkNode{
- Type = 'GenericForStat';
- VarList = loopVars;
- GeneratorList = exprList;
- Body = body;
- --
- Token_For = forKw;
- Token_VarCommaList = loopVarCommas;
- Token_In = inTk;
- Token_GeneratorCommaList = exprCommaList;
- Token_Do = doTk;
- Token_End = enTk;
- GetFirstToken = function(self)
- return self.Token_For
- end;
- GetLastToken = function(self)
- return self.Token_End
- end;
- }
- else
- error("`=` or in expected")
- end
- end
- -- Repeat statement
- local function repeatstat()
- local repeatKw = get()
- local body, untilTk = blockbody('until')
- local condition = expr()
- return MkNode{
- Type = 'RepeatStat';
- Body = body;
- Condition = condition;
- --
- Token_Repeat = repeatKw;
- Token_Until = untilTk;
- GetFirstToken = function(self)
- return self.Token_Repeat
- end;
- GetLastToken = function(self)
- return self.Condition:GetLastToken()
- end;
- }
- end
- -- Local var declaration
- local function localdecl()
- local localKw = get()
- if peek().Source == 'function' then
- -- Local function def
- local funcStat = funcdecl(false)
- if #funcStat.NameChain > 1 then
- error(getTokenStartPosition(funcStat.Token_NameChainSeparator[1])..": `(` expected.")
- end
- return MkNode{
- Type = 'LocalFunctionStat';
- FunctionStat = funcStat;
- Token_Local = localKw;
- GetFirstToken = function(self)
- return self.Token_Local
- end;
- GetLastToken = function(self)
- return self.FunctionStat:GetLastToken()
- end;
- }
- elseif peek().Type == 'Ident' then
- -- Local variable declaration
- local varList, varCommaList = varlist()
- local exprList, exprCommaList = {}, {}
- local eqToken;
- if peek().Source == '=' then
- eqToken = get()
- exprList, exprCommaList = exprlist()
- end
- return MkNode{
- Type = 'LocalVarStat';
- VarList = varList;
- ExprList = exprList;
- Token_Local = localKw;
- Token_Equals = eqToken;
- Token_VarCommaList = varCommaList;
- Token_ExprCommaList = exprCommaList;
- GetFirstToken = function(self)
- return self.Token_Local
- end;
- GetLastToken = function(self)
- if #self.ExprList > 0 then
- return self.ExprList[#self.ExprList]:GetLastToken()
- else
- return self.VarList[#self.VarList]
- end
- end;
- }
- else
- error("`function` or ident expected")
- end
- end
- -- Return statement
- local function retstat()
- local returnKw = get()
- local exprList;
- local commaList;
- if isBlockFollow() or peek().Source == ';' then
- exprList = {}
- commaList = {}
- else
- exprList, commaList = exprlist()
- end
- return {
- Type = 'ReturnStat';
- ExprList = exprList;
- Token_Return = returnKw;
- Token_CommaList = commaList;
- GetFirstToken = function(self)
- return self.Token_Return
- end;
- GetLastToken = function(self)
- if #self.ExprList > 0 then
- return self.ExprList[#self.ExprList]:GetLastToken()
- else
- return self.Token_Return
- end
- end;
- }
- end
- -- Break statement
- local function breakstat()
- local breakKw = get()
- return {
- Type = 'BreakStat';
- Token_Break = breakKw;
- GetFirstToken = function(self)
- return self.Token_Break
- end;
- GetLastToken = function(self)
- return self.Token_Break
- end;
- }
- end
- -- Expression
- local function statement()
- local tok = peek()
- if tok.Source == 'if' then
- return false, ifstat()
- elseif tok.Source == 'while' then
- return false, whilestat()
- elseif tok.Source == 'do' then
- return false, dostat()
- elseif tok.Source == 'for' then
- return false, forstat()
- elseif tok.Source == 'repeat' then
- return false, repeatstat()
- elseif tok.Source == 'function' then
- return false, funcdecl(false)
- elseif tok.Source == 'local' then
- return false, localdecl()
- elseif tok.Source == 'return' then
- return true, retstat()
- elseif tok.Source == 'break' then
- return true, breakstat()
- else
- return false, exprstat()
- end
- end
- -- Chunk
- block = function()
- local statements = {}
- local semicolons = {}
- local isLast = false
- while not isLast and not isBlockFollow() do
- -- Parse statement
- local stat;
- isLast, stat = statement()
- table.insert(statements, stat)
- local next = peek()
- if next.Type == 'Symbol' and next.Source == ';' then
- semicolons[#statements] = get()
- end
- end
- return {
- Type = 'StatList';
- StatementList = statements;
- SemicolonList = semicolons;
- GetFirstToken = function(self)
- if #self.StatementList == 0 then
- return nil
- else
- return self.StatementList[1]:GetFirstToken()
- end
- end;
- GetLastToken = function(self)
- if #self.StatementList == 0 then
- return nil
- elseif self.SemicolonList[#self.StatementList] then
- -- Last token may be one of the semicolon separators
- return self.SemicolonList[#self.StatementList]
- else
- return self.StatementList[#self.StatementList]:GetLastToken()
- end
- end;
- }
- end
- return block()
- end
- function VisitAst(ast, visitors)
- local ExprType = lookupify{
- 'BinopExpr'; 'UnopExpr';
- 'NumberLiteral'; 'StringLiteral'; 'NilLiteral'; 'BooleanLiteral'; 'VargLiteral';
- 'FieldExpr'; 'IndexExpr';
- 'MethodExpr'; 'CallExpr';
- 'FunctionLiteral';
- 'VariableExpr';
- 'ParenExpr';
- 'TableLiteral';
- }
- local StatType = lookupify{
- 'StatList';
- 'BreakStat';
- 'ReturnStat';
- 'LocalVarStat';
- 'LocalFunctionStat';
- 'FunctionStat';
- 'RepeatStat';
- 'GenericForStat';
- 'NumericForStat';
- 'WhileStat';
- 'DoStat';
- 'IfStat';
- 'CallExprStat';
- 'AssignmentStat';
- }
- -- Check for typos in visitor construction
- for visitorSubject, visitor in pairs(visitors) do
- if not StatType[visitorSubject] and not ExprType[visitorSubject] then
- error("Invalid visitor target: `"..visitorSubject.."`")
- end
- end
- -- Helpers to call visitors on a node
- local function preVisit(exprOrStat)
- local visitor = visitors[exprOrStat.Type]
- if type(visitor) == 'function' then
- return visitor(exprOrStat)
- elseif visitor and visitor.Pre then
- return visitor.Pre(exprOrStat)
- end
- end
- local function postVisit(exprOrStat)
- local visitor = visitors[exprOrStat.Type]
- if visitor and type(visitor) == 'table' and visitor.Post then
- return visitor.Post(exprOrStat)
- end
- end
- local visitExpr, visitStat;
- visitExpr = function(expr)
- if preVisit(expr) then
- -- Handler did custom child iteration or blocked child iteration
- return
- end
- if expr.Type == 'BinopExpr' then
- visitExpr(expr.Lhs)
- visitExpr(expr.Rhs)
- elseif expr.Type == 'UnopExpr' then
- visitExpr(expr.Rhs)
- elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
- expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
- expr.Type == 'VargLiteral'
- then
- -- No children to visit, single token literals
- elseif expr.Type == 'FieldExpr' then
- visitExpr(expr.Base)
- elseif expr.Type == 'IndexExpr' then
- visitExpr(expr.Base)
- visitExpr(expr.Index)
- elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
- visitExpr(expr.Base)
- if expr.FunctionArguments.CallType == 'ArgCall' then
- for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
- visitExpr(argExpr)
- end
- elseif expr.FunctionArguments.CallType == 'TableCall' then
- visitExpr(expr.FunctionArguments.TableExpr)
- end
- elseif expr.Type == 'FunctionLiteral' then
- visitStat(expr.Body)
- elseif expr.Type == 'VariableExpr' then
- -- No children to visit
- elseif expr.Type == 'ParenExpr' then
- visitExpr(expr.Expression)
- elseif expr.Type == 'TableLiteral' then
- for index, entry in pairs(expr.EntryList) do
- if entry.EntryType == 'Field' then
- visitExpr(entry.Value)
- elseif entry.EntryType == 'Index' then
- visitExpr(entry.Index)
- visitExpr(entry.Value)
- elseif entry.EntryType == 'Value' then
- visitExpr(entry.Value)
- else
- assert(false, "unreachable")
- end
- end
- else
- assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
- end
- postVisit(expr)
- end
- visitStat = function(stat)
- if preVisit(stat) then
- -- Handler did custom child iteration or blocked child iteration
- return
- end
- if stat.Type == 'StatList' then
- for index, ch in pairs(stat.StatementList) do
- visitStat(ch)
- end
- elseif stat.Type == 'BreakStat' then
- -- No children to visit
- elseif stat.Type == 'ReturnStat' then
- for index, expr in pairs(stat.ExprList) do
- visitExpr(expr)
- end
- elseif stat.Type == 'LocalVarStat' then
- if stat.Token_Equals then
- for index, expr in pairs(stat.ExprList) do
- visitExpr(expr)
- end
- end
- elseif stat.Type == 'LocalFunctionStat' then
- visitStat(stat.FunctionStat.Body)
- elseif stat.Type == 'FunctionStat' then
- visitStat(stat.Body)
- elseif stat.Type == 'RepeatStat' then
- visitStat(stat.Body)
- visitExpr(stat.Condition)
- elseif stat.Type == 'GenericForStat' then
- for index, expr in pairs(stat.GeneratorList) do
- visitExpr(expr)
- end
- visitStat(stat.Body)
- elseif stat.Type == 'NumericForStat' then
- for index, expr in pairs(stat.RangeList) do
- visitExpr(expr)
- end
- visitStat(stat.Body)
- elseif stat.Type == 'WhileStat' then
- visitExpr(stat.Condition)
- visitStat(stat.Body)
- elseif stat.Type == 'DoStat' then
- visitStat(stat.Body)
- elseif stat.Type == 'IfStat' then
- visitExpr(stat.Condition)
- visitStat(stat.Body)
- for _, clause in pairs(stat.ElseClauseList) do
- if clause.Condition then
- visitExpr(clause.Condition)
- end
- visitStat(clause.Body)
- end
- elseif stat.Type == 'CallExprStat' then
- visitExpr(stat.Expression)
- elseif stat.Type == 'AssignmentStat' then
- for index, ex in pairs(stat.Lhs) do
- visitExpr(ex)
- end
- for index, ex in pairs(stat.Rhs) do
- visitExpr(ex)
- end
- else
- assert(false, "unreachable")
- end
- postVisit(stat)
- end
- if StatType[ast.Type] then
- visitStat(ast)
- else
- visitExpr(ast)
- end
- end
- function AddVariableInfo(ast)
- local globalVars = {}
- local currentScope = nil
- -- Numbering generator for variable lifetimes
- local locationGenerator = 0
- local function markLocation()
- locationGenerator = locationGenerator + 1
- return locationGenerator
- end
- -- Scope management
- local function pushScope()
- currentScope = {
- ParentScope = currentScope;
- ChildScopeList = {};
- VariableList = {};
- BeginLocation = markLocation();
- }
- if currentScope.ParentScope then
- currentScope.Depth = currentScope.ParentScope.Depth + 1
- table.insert(currentScope.ParentScope.ChildScopeList, currentScope)
- else
- currentScope.Depth = 1
- end
- function currentScope:GetVar(varName)
- for _, var in pairs(self.VariableList) do
- if var.Name == varName then
- return var
- end
- end
- if self.ParentScope then
- return self.ParentScope:GetVar(varName)
- else
- for _, var in pairs(globalVars) do
- if var.Name == varName then
- return var
- end
- end
- end
- end
- end
- local function popScope()
- local scope = currentScope
- -- Mark where this scope ends
- scope.EndLocation = markLocation()
- -- Mark all of the variables in the scope as ending there
- for _, var in pairs(scope.VariableList) do
- var.ScopeEndLocation = scope.EndLocation
- end
- -- Move to the parent scope
- currentScope = scope.ParentScope
- return scope
- end
- pushScope() -- push initial scope
- -- Add / reference variables
- local function addLocalVar(name, setNameFunc, localInfo)
- assert(localInfo, "Misisng localInfo")
- assert(name, "Missing local var name")
- local var = {
- Type = 'Local';
- Name = name;
- RenameList = {setNameFunc};
- AssignedTo = false;
- Info = localInfo;
- UseCount = 0;
- Scope = currentScope;
- BeginLocation = markLocation();
- EndLocation = markLocation();
- ReferenceLocationList = {markLocation()};
- }
- function var:Rename(newName)
- self.Name = newName
- for _, renameFunc in pairs(self.RenameList) do
- renameFunc(newName)
- end
- end
- function var:Reference()
- self.UseCount = self.UseCount + 1
- end
- table.insert(currentScope.VariableList, var)
- return var
- end
- local function getGlobalVar(name)
- for _, var in pairs(globalVars) do
- if var.Name == name then
- return var
- end
- end
- local var = {
- Type = 'Global';
- Name = name;
- RenameList = {};
- AssignedTo = false;
- UseCount = 0;
- Scope = nil; -- Globals have no scope
- BeginLocation = markLocation();
- EndLocation = markLocation();
- ReferenceLocationList = {};
- }
- function var:Rename(newName)
- self.Name = newName
- for _, renameFunc in pairs(self.RenameList) do
- renameFunc(newName)
- end
- end
- function var:Reference()
- self.UseCount = self.UseCount + 1
- end
- table.insert(globalVars, var)
- return var
- end
- local function addGlobalReference(name, setNameFunc)
- assert(name, "Missing var name")
- local var = getGlobalVar(name)
- table.insert(var.RenameList, setNameFunc)
- return var
- end
- local function getLocalVar(scope, name)
- -- First search this scope
- -- Note: Reverse iterate here because Lua does allow shadowing a local
- -- within the same scope, and the later defined variable should
- -- be the one referenced.
- for i = #scope.VariableList, 1, -1 do
- if scope.VariableList[i].Name == name then
- return scope.VariableList[i]
- end
- end
- -- Then search parent scope
- if scope.ParentScope then
- local var = getLocalVar(scope.ParentScope, name)
- if var then
- return var
- end
- end
- -- Then
- return nil
- end
- local function referenceVariable(name, setNameFunc)
- assert(name, "Missing var name")
- local var = getLocalVar(currentScope, name)
- if var then
- table.insert(var.RenameList, setNameFunc)
- else
- var = addGlobalReference(name, setNameFunc)
- end
- -- Update the end location of where this variable is used, and
- -- add this location to the list of references to this variable.
- local curLocation = markLocation()
- var.EndLocation = curLocation
- table.insert(var.ReferenceLocationList, var.EndLocation)
- return var
- end
- local visitor = {}
- visitor.FunctionLiteral = {
- -- Function literal adds a new scope and adds the function literal arguments
- -- as local variables in the scope.
- Pre = function(expr)
- pushScope()
- for index, ident in pairs(expr.ArgList) do
- local var = addLocalVar(ident.Source, function(name)
- ident.Source = name
- end, {
- Type = 'Argument';
- Index = index;
- })
- end
- end;
- Post = function(expr)
- popScope()
- end;
- }
- visitor.VariableExpr = function(expr)
- -- Variable expression references from existing local varibales
- -- in the current scope, annotating the variable usage with variable
- -- information.
- expr.Variable = referenceVariable(expr.Token.Source, function(newName)
- expr.Token.Source = newName
- end)
- end
- visitor.StatList = {
- -- StatList adds a new scope
- Pre = function(stat)
- pushScope()
- end;
- Post = function(stat)
- popScope()
- end;
- }
- visitor.LocalVarStat = {
- Post = function(stat)
- -- Local var stat adds the local variables to the current scope as locals
- -- We need to visit the subexpressions first, because these new locals
- -- will not be in scope for the initialization value expressions. That is:
- -- `local bar = bar + 1`
- -- Is valid code
- for varNum, ident in pairs(stat.VarList) do
- addLocalVar(ident.Source, function(name)
- stat.VarList[varNum].Source = name
- end, {
- Type = 'Local';
- })
- end
- end;
- }
- visitor.LocalFunctionStat = {
- Pre = function(stat)
- -- Local function stat adds the function itself to the current scope as
- -- a local variable, and creates a new scope with the function arguments
- -- as local variables.
- addLocalVar(stat.FunctionStat.NameChain[1].Source, function(name)
- stat.FunctionStat.NameChain[1].Source = name
- end, {
- Type = 'LocalFunction';
- })
- pushScope()
- for index, ident in pairs(stat.FunctionStat.ArgList) do
- addLocalVar(ident.Source, function(name)
- ident.Source = name
- end, {
- Type = 'Argument';
- Index = index;
- })
- end
- end;
- Post = function()
- popScope()
- end;
- }
- visitor.FunctionStat = {
- Pre = function(stat)
- -- Function stat adds a new scope containing the function arguments
- -- as local variables.
- -- A function stat may also assign to a global variable if it is in
- -- the form `function foo()` with no additional dots/colons in the
- -- name chain.
- local nameChain = stat.NameChain
- local var;
- if #nameChain == 1 then
- -- If there is only one item in the name chain, then the first item
- -- is a reference to a global variable.
- var = addGlobalReference(nameChain[1].Source, function(name)
- nameChain[1].Source = name
- end)
- else
- var = referenceVariable(nameChain[1].Source, function(name)
- nameChain[1].Source = name
- end)
- end
- var.AssignedTo = true
- pushScope()
- for index, ident in pairs(stat.ArgList) do
- addLocalVar(ident.Source, function(name)
- ident.Source = name
- end, {
- Type = 'Argument';
- Index = index;
- })
- end
- end;
- Post = function()
- popScope()
- end;
- }
- visitor.GenericForStat = {
- Pre = function(stat)
- -- Generic fors need an extra scope holding the range variables
- -- Need a custom visitor so that the generator expressions can be
- -- visited before we push a scope, but the body can be visited
- -- after we push a scope.
- for _, ex in pairs(stat.GeneratorList) do
- VisitAst(ex, visitor)
- end
- pushScope()
- for index, ident in pairs(stat.VarList) do
- addLocalVar(ident.Source, function(name)
- ident.Source = name
- end, {
- Type = 'ForRange';
- Index = index;
- })
- end
- VisitAst(stat.Body, visitor)
- popScope()
- return true -- Custom visit
- end;
- }
- visitor.NumericForStat = {
- Pre = function(stat)
- -- Numeric fors need an extra scope holding the range variables
- -- Need a custom visitor so that the generator expressions can be
- -- visited before we push a scope, but the body can be visited
- -- after we push a scope.
- for _, ex in pairs(stat.RangeList) do
- VisitAst(ex, visitor)
- end
- pushScope()
- for index, ident in pairs(stat.VarList) do
- addLocalVar(ident.Source, function(name)
- ident.Source = name
- end, {
- Type = 'ForRange';
- Index = index;
- })
- end
- VisitAst(stat.Body, visitor)
- popScope()
- return true -- Custom visit
- end;
- }
- visitor.AssignmentStat = {
- Post = function(stat)
- -- For an assignment statement we need to mark the
- -- "assigned to" flag on variables.
- for _, ex in pairs(stat.Lhs) do
- if ex.Variable then
- ex.Variable.AssignedTo = true
- end
- end
- end;
- }
- VisitAst(ast, visitor)
- return globalVars, popScope()
- end
- -- Prints out an AST to a string
- function PrintAst(ast)
- local printStat, printExpr;
- local function printt(tk)
- if not tk.LeadingWhite or not tk.Source then
- error("Bad token: "..FormatTable(tk))
- end
- io.write(tk.LeadingWhite)
- io.write(tk.Source)
- end
- printExpr = function(expr)
- if expr.Type == 'BinopExpr' then
- printExpr(expr.Lhs)
- printt(expr.Token_Op)
- printExpr(expr.Rhs)
- elseif expr.Type == 'UnopExpr' then
- printt(expr.Token_Op)
- printExpr(expr.Rhs)
- elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
- expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
- expr.Type == 'VargLiteral'
- then
- -- Just print the token
- printt(expr.Token)
- elseif expr.Type == 'FieldExpr' then
- printExpr(expr.Base)
- printt(expr.Token_Dot)
- printt(expr.Field)
- elseif expr.Type == 'IndexExpr' then
- printExpr(expr.Base)
- printt(expr.Token_OpenBracket)
- printExpr(expr.Index)
- printt(expr.Token_CloseBracket)
- elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
- printExpr(expr.Base)
- if expr.Type == 'MethodExpr' then
- printt(expr.Token_Colon)
- printt(expr.Method)
- end
- if expr.FunctionArguments.CallType == 'StringCall' then
- printt(expr.FunctionArguments.Token)
- elseif expr.FunctionArguments.CallType == 'ArgCall' then
- printt(expr.FunctionArguments.Token_OpenParen)
- for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
- printExpr(argExpr)
- local sep = expr.FunctionArguments.Token_CommaList[index]
- if sep then
- printt(sep)
- end
- end
- printt(expr.FunctionArguments.Token_CloseParen)
- elseif expr.FunctionArguments.CallType == 'TableCall' then
- printExpr(expr.FunctionArguments.TableExpr)
- end
- elseif expr.Type == 'FunctionLiteral' then
- printt(expr.Token_Function)
- printt(expr.Token_OpenParen)
- for index, arg in pairs(expr.ArgList) do
- printt(arg)
- local comma = expr.Token_ArgCommaList[index]
- if comma then
- printt(comma)
- end
- end
- printt(expr.Token_CloseParen)
- printStat(expr.Body)
- printt(expr.Token_End)
- elseif expr.Type == 'VariableExpr' then
- printt(expr.Token)
- elseif expr.Type == 'ParenExpr' then
- printt(expr.Token_OpenParen)
- printExpr(expr.Expression)
- printt(expr.Token_CloseParen)
- elseif expr.Type == 'TableLiteral' then
- printt(expr.Token_OpenBrace)
- for index, entry in pairs(expr.EntryList) do
- if entry.EntryType == 'Field' then
- printt(entry.Field)
- printt(entry.Token_Equals)
- printExpr(entry.Value)
- elseif entry.EntryType == 'Index' then
- printt(entry.Token_OpenBracket)
- printExpr(entry.Index)
- printt(entry.Token_CloseBracket)
- printt(entry.Token_Equals)
- printExpr(entry.Value)
- elseif entry.EntryType == 'Value' then
- printExpr(entry.Value)
- else
- assert(false, "unreachable")
- end
- local sep = expr.Token_SeparatorList[index]
- if sep then
- printt(sep)
- end
- end
- printt(expr.Token_CloseBrace)
- else
- assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
- end
- end
- printStat = function(stat)
- if stat.Type == 'StatList' then
- for index, ch in pairs(stat.StatementList) do
- printStat(ch)
- if stat.SemicolonList[index] then
- printt(stat.SemicolonList[index])
- end
- end
- elseif stat.Type == 'BreakStat' then
- printt(stat.Token_Break)
- elseif stat.Type == 'ReturnStat' then
- printt(stat.Token_Return)
- for index, expr in pairs(stat.ExprList) do
- printExpr(expr)
- if stat.Token_CommaList[index] then
- printt(stat.Token_CommaList[index])
- end
- end
- elseif stat.Type == 'LocalVarStat' then
- printt(stat.Token_Local)
- for index, var in pairs(stat.VarList) do
- printt(var)
- local comma = stat.Token_VarCommaList[index]
- if comma then
- printt(comma)
- end
- end
- if stat.Token_Equals then
- printt(stat.Token_Equals)
- for index, expr in pairs(stat.ExprList) do
- printExpr(expr)
- local comma = stat.Token_ExprCommaList[index]
- if comma then
- printt(comma)
- end
- end
- end
- elseif stat.Type == 'LocalFunctionStat' then
- printt(stat.Token_Local)
- printt(stat.FunctionStat.Token_Function)
- printt(stat.FunctionStat.NameChain[1])
- printt(stat.FunctionStat.Token_OpenParen)
- for index, arg in pairs(stat.FunctionStat.ArgList) do
- printt(arg)
- local comma = stat.FunctionStat.Token_ArgCommaList[index]
- if comma then
- printt(comma)
- end
- end
- printt(stat.FunctionStat.Token_CloseParen)
- printStat(stat.FunctionStat.Body)
- printt(stat.FunctionStat.Token_End)
- elseif stat.Type == 'FunctionStat' then
- printt(stat.Token_Function)
- for index, part in pairs(stat.NameChain) do
- printt(part)
- local sep = stat.Token_NameChainSeparator[index]
- if sep then
- printt(sep)
- end
- end
- printt(stat.Token_OpenParen)
- for index, arg in pairs(stat.ArgList) do
- printt(arg)
- local comma = stat.Token_ArgCommaList[index]
- if comma then
- printt(comma)
- end
- end
- printt(stat.Token_CloseParen)
- printStat(stat.Body)
- printt(stat.Token_End)
- elseif stat.Type == 'RepeatStat' then
- printt(stat.Token_Repeat)
- printStat(stat.Body)
- printt(stat.Token_Until)
- printExpr(stat.Condition)
- elseif stat.Type == 'GenericForStat' then
- printt(stat.Token_For)
- for index, var in pairs(stat.VarList) do
- printt(var)
- local sep = stat.Token_VarCommaList[index]
- if sep then
- printt(sep)
- end
- end
- printt(stat.Token_In)
- for index, expr in pairs(stat.GeneratorList) do
- printExpr(expr)
- local sep = stat.Token_GeneratorCommaList[index]
- if sep then
- printt(sep)
- end
- end
- printt(stat.Token_Do)
- printStat(stat.Body)
- printt(stat.Token_End)
- elseif stat.Type == 'NumericForStat' then
- printt(stat.Token_For)
- for index, var in pairs(stat.VarList) do
- printt(var)
- local sep = stat.Token_VarCommaList[index]
- if sep then
- printt(sep)
- end
- end
- printt(stat.Token_Equals)
- for index, expr in pairs(stat.RangeList) do
- printExpr(expr)
- local sep = stat.Token_RangeCommaList[index]
- if sep then
- printt(sep)
- end
- end
- printt(stat.Token_Do)
- printStat(stat.Body)
- printt(stat.Token_End)
- elseif stat.Type == 'WhileStat' then
- printt(stat.Token_While)
- printExpr(stat.Condition)
- printt(stat.Token_Do)
- printStat(stat.Body)
- printt(stat.Token_End)
- elseif stat.Type == 'DoStat' then
- printt(stat.Token_Do)
- printStat(stat.Body)
- printt(stat.Token_End)
- elseif stat.Type == 'IfStat' then
- printt(stat.Token_If)
- printExpr(stat.Condition)
- printt(stat.Token_Then)
- printStat(stat.Body)
- for _, clause in pairs(stat.ElseClauseList) do
- printt(clause.Token)
- if clause.Condition then
- printExpr(clause.Condition)
- printt(clause.Token_Then)
- end
- printStat(clause.Body)
- end
- printt(stat.Token_End)
- elseif stat.Type == 'CallExprStat' then
- printExpr(stat.Expression)
- elseif stat.Type == 'AssignmentStat' then
- for index, ex in pairs(stat.Lhs) do
- printExpr(ex)
- local sep = stat.Token_LhsSeparatorList[index]
- if sep then
- printt(sep)
- end
- end
- printt(stat.Token_Equals)
- for index, ex in pairs(stat.Rhs) do
- printExpr(ex)
- local sep = stat.Token_RhsSeparatorList[index]
- if sep then
- printt(sep)
- end
- end
- else
- assert(false, "unreachable")
- end
- end
- printStat(ast)
- end
- -- Adds / removes whitespace in an AST to put it into a "standard formatting"
- local function FormatAst(ast)
- local formatStat, formatExpr;
- local currentIndent = 0
- local function applyIndent(token)
- local indentString = '\n'..('\t'):rep(currentIndent)
- if token.LeadingWhite == '' or (token.LeadingWhite:sub(-#indentString, -1) ~= indentString) then
- -- Trim existing trailing whitespace on LeadingWhite
- -- Trim trailing tabs and spaces, and up to one newline
- token.LeadingWhite = token.LeadingWhite:gsub("\n?[\t ]*$", "")
- token.LeadingWhite = token.LeadingWhite..indentString
- end
- end
- local function indent()
- currentIndent = currentIndent + 1
- end
- local function undent()
- currentIndent = currentIndent - 1
- assert(currentIndent >= 0, "Undented too far")
- end
- local function leadingChar(tk)
- if #tk.LeadingWhite > 0 then
- return tk.LeadingWhite:sub(1,1)
- else
- return tk.Source:sub(1,1)
- end
- end
- local function padToken(tk)
- if not WhiteChars[leadingChar(tk)] then
- tk.LeadingWhite = ' '..tk.LeadingWhite
- end
- end
- local function padExpr(expr)
- padToken(expr:GetFirstToken())
- end
- local function formatBody(openToken, bodyStat, closeToken)
- indent()
- formatStat(bodyStat)
- undent()
- applyIndent(closeToken)
- end
- formatExpr = function(expr)
- if expr.Type == 'BinopExpr' then
- formatExpr(expr.Lhs)
- formatExpr(expr.Rhs)
- if expr.Token_Op.Source == '..' then
- -- No padding on ..
- else
- padExpr(expr.Rhs)
- padToken(expr.Token_Op)
- end
- elseif expr.Type == 'UnopExpr' then
- formatExpr(expr.Rhs)
- --(expr.Token_Op)
- elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
- expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
- expr.Type == 'VargLiteral'
- then
- -- Nothing to do
- --(expr.Token)
- elseif expr.Type == 'FieldExpr' then
- formatExpr(expr.Base)
- --(expr.Token_Dot)
- --(expr.Field)
- elseif expr.Type == 'IndexExpr' then
- formatExpr(expr.Base)
- formatExpr(expr.Index)
- --(expr.Token_OpenBracket)
- --(expr.Token_CloseBracket)
- elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
- formatExpr(expr.Base)
- if expr.Type == 'MethodExpr' then
- --(expr.Token_Colon)
- --(expr.Method)
- end
- if expr.FunctionArguments.CallType == 'StringCall' then
- --(expr.FunctionArguments.Token)
- elseif expr.FunctionArguments.CallType == 'ArgCall' then
- --(expr.FunctionArguments.Token_OpenParen)
- for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
- formatExpr(argExpr)
- if index > 1 then
- padExpr(argExpr)
- end
- local sep = expr.FunctionArguments.Token_CommaList[index]
- if sep then
- --(sep)
- end
- end
- --(expr.FunctionArguments.Token_CloseParen)
- elseif expr.FunctionArguments.CallType == 'TableCall' then
- formatExpr(expr.FunctionArguments.TableExpr)
- end
- elseif expr.Type == 'FunctionLiteral' then
- --(expr.Token_Function)
- --(expr.Token_OpenParen)
- for index, arg in pairs(expr.ArgList) do
- --(arg)
- if index > 1 then
- padToken(arg)
- end
- local comma = expr.Token_ArgCommaList[index]
- if comma then
- --(comma)
- end
- end
- --(expr.Token_CloseParen)
- formatBody(expr.Token_CloseParen, expr.Body, expr.Token_End)
- elseif expr.Type == 'VariableExpr' then
- --(expr.Token)
- elseif expr.Type == 'ParenExpr' then
- formatExpr(expr.Expression)
- --(expr.Token_OpenParen)
- --(expr.Token_CloseParen)
- elseif expr.Type == 'TableLiteral' then
- --(expr.Token_OpenBrace)
- if #expr.EntryList == 0 then
- -- Nothing to do
- else
- indent()
- for index, entry in pairs(expr.EntryList) do
- if entry.EntryType == 'Field' then
- applyIndent(entry.Field)
- padToken(entry.Token_Equals)
- formatExpr(entry.Value)
- padExpr(entry.Value)
- elseif entry.EntryType == 'Index' then
- applyIndent(entry.Token_OpenBracket)
- formatExpr(entry.Index)
- --(entry.Token_CloseBracket)
- padToken(entry.Token_Equals)
- formatExpr(entry.Value)
- padExpr(entry.Value)
- elseif entry.EntryType == 'Value' then
- formatExpr(entry.Value)
- applyIndent(entry.Value:GetFirstToken())
- else
- assert(false, "unreachable")
- end
- local sep = expr.Token_SeparatorList[index]
- if sep then
- --(sep)
- end
- end
- undent()
- applyIndent(expr.Token_CloseBrace)
- end
- --(expr.Token_CloseBrace)
- else
- assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
- end
- end
- formatStat = function(stat)
- if stat.Type == 'StatList' then
- for _, stat in pairs(stat.StatementList) do
- formatStat(stat)
- applyIndent(stat:GetFirstToken())
- end
- elseif stat.Type == 'BreakStat' then
- --(stat.Token_Break)
- elseif stat.Type == 'ReturnStat' then
- --(stat.Token_Return)
- for index, expr in pairs(stat.ExprList) do
- formatExpr(expr)
- padExpr(expr)
- if stat.Token_CommaList[index] then
- --(stat.Token_CommaList[index])
- end
- end
- elseif stat.Type == 'LocalVarStat' then
- --(stat.Token_Local)
- for index, var in pairs(stat.VarList) do
- padToken(var)
- local comma = stat.Token_VarCommaList[index]
- if comma then
- --(comma)
- end
- end
- if stat.Token_Equals then
- padToken(stat.Token_Equals)
- for index, expr in pairs(stat.ExprList) do
- formatExpr(expr)
- padExpr(expr)
- local comma = stat.Token_ExprCommaList[index]
- if comma then
- --(comma)
- end
- end
- end
- elseif stat.Type == 'LocalFunctionStat' then
- --(stat.Token_Local)
- padToken(stat.FunctionStat.Token_Function)
- padToken(stat.FunctionStat.NameChain[1])
- --(stat.FunctionStat.Token_OpenParen)
- for index, arg in pairs(stat.FunctionStat.ArgList) do
- if index > 1 then
- padToken(arg)
- end
- local comma = stat.FunctionStat.Token_ArgCommaList[index]
- if comma then
- --(comma)
- end
- end
- --(stat.FunctionStat.Token_CloseParen)
- formatBody(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End)
- elseif stat.Type == 'FunctionStat' then
- --(stat.Token_Function)
- for index, part in pairs(stat.NameChain) do
- if index == 1 then
- padToken(part)
- end
- local sep = stat.Token_NameChainSeparator[index]
- if sep then
- --(sep)
- end
- end
- --(stat.Token_OpenParen)
- for index, arg in pairs(stat.ArgList) do
- if index > 1 then
- padToken(arg)
- end
- local comma = stat.Token_ArgCommaList[index]
- if comma then
- --(comma)
- end
- end
- --(stat.Token_CloseParen)
- formatBody(stat.Token_CloseParen, stat.Body, stat.Token_End)
- elseif stat.Type == 'RepeatStat' then
- --(stat.Token_Repeat)
- formatBody(stat.Token_Repeat, stat.Body, stat.Token_Until)
- formatExpr(stat.Condition)
- padExpr(stat.Condition)
- elseif stat.Type == 'GenericForStat' then
- --(stat.Token_For)
- for index, var in pairs(stat.VarList) do
- padToken(var)
- local sep = stat.Token_VarCommaList[index]
- if sep then
- --(sep)
- end
- end
- padToken(stat.Token_In)
- for index, expr in pairs(stat.GeneratorList) do
- formatExpr(expr)
- padExpr(expr)
- local sep = stat.Token_GeneratorCommaList[index]
- if sep then
- --(sep)
- end
- end
- padToken(stat.Token_Do)
- formatBody(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'NumericForStat' then
- --(stat.Token_For)
- for index, var in pairs(stat.VarList) do
- padToken(var)
- local sep = stat.Token_VarCommaList[index]
- if sep then
- --(sep)
- end
- end
- padToken(stat.Token_Equals)
- for index, expr in pairs(stat.RangeList) do
- formatExpr(expr)
- padExpr(expr)
- local sep = stat.Token_RangeCommaList[index]
- if sep then
- --(sep)
- end
- end
- padToken(stat.Token_Do)
- formatBody(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'WhileStat' then
- --(stat.Token_While)
- formatExpr(stat.Condition)
- padExpr(stat.Condition)
- padToken(stat.Token_Do)
- formatBody(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'DoStat' then
- --(stat.Token_Do)
- formatBody(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'IfStat' then
- --(stat.Token_If)
- formatExpr(stat.Condition)
- padExpr(stat.Condition)
- padToken(stat.Token_Then)
- --
- local lastBodyOpen = stat.Token_Then
- local lastBody = stat.Body
- --
- for _, clause in pairs(stat.ElseClauseList) do
- formatBody(lastBodyOpen, lastBody, clause.Token)
- lastBodyOpen = clause.Token
- --
- if clause.Condition then
- formatExpr(clause.Condition)
- padExpr(clause.Condition)
- padToken(clause.Token_Then)
- lastBodyOpen = clause.Token_Then
- end
- lastBody = clause.Body
- end
- --
- formatBody(lastBodyOpen, lastBody, stat.Token_End)
- elseif stat.Type == 'CallExprStat' then
- formatExpr(stat.Expression)
- elseif stat.Type == 'AssignmentStat' then
- for index, ex in pairs(stat.Lhs) do
- formatExpr(ex)
- if index > 1 then
- padExpr(ex)
- end
- local sep = stat.Token_LhsSeparatorList[index]
- if sep then
- --(sep)
- end
- end
- padToken(stat.Token_Equals)
- for index, ex in pairs(stat.Rhs) do
- formatExpr(ex)
- padExpr(ex)
- local sep = stat.Token_RhsSeparatorList[index]
- if sep then
- --(sep)
- end
- end
- else
- assert(false, "unreachable")
- end
- end
- formatStat(ast)
- end
- -- Strips as much whitespace off of tokens in an AST as possible without causing problems
- local function StripAst(ast)
- local stripStat, stripExpr;
- local function stript(token)
- token.LeadingWhite = ''
- end
- -- Make to adjacent tokens as close as possible
- local function joint(tokenA, tokenB)
- -- Strip the second token's whitespace
- stript(tokenB)
- -- Get the trailing A <-> leading B character pair
- local lastCh = tokenA.Source:sub(-1, -1)
- local firstCh = tokenB.Source:sub(1, 1)
- -- Cases to consider:
- -- Touching minus signs -> comment: `- -42` -> `--42' is invalid
- -- Touching words: `a b` -> `ab` is invalid
- -- Touching digits: `2 3`, can't occurr in the Lua syntax as number literals aren't a primary expression
- -- Abiguous syntax: `f(x)\n(x)()` is already disallowed, we can't cause a problem by removing newlines
- -- Figure out what separation is needed
- if
- (lastCh == '-' and firstCh == '-') or
- (AllIdentChars[lastCh] and AllIdentChars[firstCh])
- then
- tokenB.LeadingWhite = ' ' -- Use a separator
- else
- tokenB.LeadingWhite = '' -- Don't use a separator
- end
- end
- -- Join up a statement body and it's opening / closing tokens
- local function bodyjoint(open, body, close)
- stripStat(body)
- stript(close)
- local bodyFirst = body:GetFirstToken()
- local bodyLast = body:GetLastToken()
- if bodyFirst then
- -- Body is non-empty, join body to open / close
- joint(open, bodyFirst)
- joint(bodyLast, close)
- else
- -- Body is empty, just join open and close token together
- joint(open, close)
- end
- end
- stripExpr = function(expr)
- if expr.Type == 'BinopExpr' then
- stripExpr(expr.Lhs)
- stript(expr.Token_Op)
- stripExpr(expr.Rhs)
- -- Handle the `a - -b` -/-> `a--b` case which would otherwise incorrectly generate a comment
- -- Also handles operators "or" / "and" which definitely need joining logic in a bunch of cases
- joint(expr.Token_Op, expr.Rhs:GetFirstToken())
- joint(expr.Lhs:GetLastToken(), expr.Token_Op)
- elseif expr.Type == 'UnopExpr' then
- stript(expr.Token_Op)
- stripExpr(expr.Rhs)
- -- Handle the `- -b` -/-> `--b` case which would otherwise incorrectly generate a comment
- joint(expr.Token_Op, expr.Rhs:GetFirstToken())
- elseif expr.Type == 'NumberLiteral' or expr.Type == 'StringLiteral' or
- expr.Type == 'NilLiteral' or expr.Type == 'BooleanLiteral' or
- expr.Type == 'VargLiteral'
- then
- -- Just print the token
- stript(expr.Token)
- elseif expr.Type == 'FieldExpr' then
- stripExpr(expr.Base)
- stript(expr.Token_Dot)
- stript(expr.Field)
- elseif expr.Type == 'IndexExpr' then
- stripExpr(expr.Base)
- stript(expr.Token_OpenBracket)
- stripExpr(expr.Index)
- stript(expr.Token_CloseBracket)
- elseif expr.Type == 'MethodExpr' or expr.Type == 'CallExpr' then
- stripExpr(expr.Base)
- if expr.Type == 'MethodExpr' then
- stript(expr.Token_Colon)
- stript(expr.Method)
- end
- if expr.FunctionArguments.CallType == 'StringCall' then
- stript(expr.FunctionArguments.Token)
- elseif expr.FunctionArguments.CallType == 'ArgCall' then
- stript(expr.FunctionArguments.Token_OpenParen)
- for index, argExpr in pairs(expr.FunctionArguments.ArgList) do
- stripExpr(argExpr)
- local sep = expr.FunctionArguments.Token_CommaList[index]
- if sep then
- stript(sep)
- end
- end
- stript(expr.FunctionArguments.Token_CloseParen)
- elseif expr.FunctionArguments.CallType == 'TableCall' then
- stripExpr(expr.FunctionArguments.TableExpr)
- end
- elseif expr.Type == 'FunctionLiteral' then
- stript(expr.Token_Function)
- stript(expr.Token_OpenParen)
- for index, arg in pairs(expr.ArgList) do
- stript(arg)
- local comma = expr.Token_ArgCommaList[index]
- if comma then
- stript(comma)
- end
- end
- stript(expr.Token_CloseParen)
- bodyjoint(expr.Token_CloseParen, expr.Body, expr.Token_End)
- elseif expr.Type == 'VariableExpr' then
- stript(expr.Token)
- elseif expr.Type == 'ParenExpr' then
- stript(expr.Token_OpenParen)
- stripExpr(expr.Expression)
- stript(expr.Token_CloseParen)
- elseif expr.Type == 'TableLiteral' then
- stript(expr.Token_OpenBrace)
- for index, entry in pairs(expr.EntryList) do
- if entry.EntryType == 'Field' then
- stript(entry.Field)
- stript(entry.Token_Equals)
- stripExpr(entry.Value)
- elseif entry.EntryType == 'Index' then
- stript(entry.Token_OpenBracket)
- stripExpr(entry.Index)
- stript(entry.Token_CloseBracket)
- stript(entry.Token_Equals)
- stripExpr(entry.Value)
- elseif entry.EntryType == 'Value' then
- stripExpr(entry.Value)
- else
- assert(false, "unreachable")
- end
- local sep = expr.Token_SeparatorList[index]
- if sep then
- stript(sep)
- end
- end
- stript(expr.Token_CloseBrace)
- else
- assert(false, "unreachable, type: "..expr.Type..":"..FormatTable(expr))
- end
- end
- stripStat = function(stat)
- if stat.Type == 'StatList' then
- -- Strip all surrounding whitespace on statement lists along with separating whitespace
- for i = 1, #stat.StatementList do
- local chStat = stat.StatementList[i]
- -- Strip the statement and it's whitespace
- stripStat(chStat)
- stript(chStat:GetFirstToken())
- -- If there was a last statement, join them appropriately
- local lastChStat = stat.StatementList[i-1]
- if lastChStat then
- -- See if we can remove a semi-colon, the only case where we can't is if
- -- this and the last statement have a `);(` pair, where removing the semi-colon
- -- would introduce ambiguous syntax.
- if stat.SemicolonList[i-1] and
- (lastChStat:GetLastToken().Source ~= ')' or chStat:GetFirstToken().Source ~= ')')
- then
- stat.SemicolonList[i-1] = nil
- end
- -- If there isn't a semi-colon, we should safely join the two statements
- -- (If there is one, then no whitespace leading chStat is always okay)
- if not stat.SemicolonList[i-1] then
- joint(lastChStat:GetLastToken(), chStat:GetFirstToken())
- end
- end
- end
- -- A semi-colon is never needed on the last stat in a statlist:
- stat.SemicolonList[#stat.StatementList] = nil
- -- The leading whitespace on the statlist should be stripped
- if #stat.StatementList > 0 then
- stript(stat.StatementList[1]:GetFirstToken())
- end
- elseif stat.Type == 'BreakStat' then
- stript(stat.Token_Break)
- elseif stat.Type == 'ReturnStat' then
- stript(stat.Token_Return)
- for index, expr in pairs(stat.ExprList) do
- stripExpr(expr)
- if stat.Token_CommaList[index] then
- stript(stat.Token_CommaList[index])
- end
- end
- if #stat.ExprList > 0 then
- joint(stat.Token_Return, stat.ExprList[1]:GetFirstToken())
- end
- elseif stat.Type == 'LocalVarStat' then
- stript(stat.Token_Local)
- for index, var in pairs(stat.VarList) do
- if index == 1 then
- joint(stat.Token_Local, var)
- else
- stript(var)
- end
- local comma = stat.Token_VarCommaList[index]
- if comma then
- stript(comma)
- end
- end
- if stat.Token_Equals then
- stript(stat.Token_Equals)
- for index, expr in pairs(stat.ExprList) do
- stripExpr(expr)
- local comma = stat.Token_ExprCommaList[index]
- if comma then
- stript(comma)
- end
- end
- end
- elseif stat.Type == 'LocalFunctionStat' then
- stript(stat.Token_Local)
- joint(stat.Token_Local, stat.FunctionStat.Token_Function)
- joint(stat.FunctionStat.Token_Function, stat.FunctionStat.NameChain[1])
- joint(stat.FunctionStat.NameChain[1], stat.FunctionStat.Token_OpenParen)
- for index, arg in pairs(stat.FunctionStat.ArgList) do
- stript(arg)
- local comma = stat.FunctionStat.Token_ArgCommaList[index]
- if comma then
- stript(comma)
- end
- end
- stript(stat.FunctionStat.Token_CloseParen)
- bodyjoint(stat.FunctionStat.Token_CloseParen, stat.FunctionStat.Body, stat.FunctionStat.Token_End)
- elseif stat.Type == 'FunctionStat' then
- stript(stat.Token_Function)
- for index, part in pairs(stat.NameChain) do
- if index == 1 then
- joint(stat.Token_Function, part)
- else
- stript(part)
- end
- local sep = stat.Token_NameChainSeparator[index]
- if sep then
- stript(sep)
- end
- end
- stript(stat.Token_OpenParen)
- for index, arg in pairs(stat.ArgList) do
- stript(arg)
- local comma = stat.Token_ArgCommaList[index]
- if comma then
- stript(comma)
- end
- end
- stript(stat.Token_CloseParen)
- bodyjoint(stat.Token_CloseParen, stat.Body, stat.Token_End)
- elseif stat.Type == 'RepeatStat' then
- stript(stat.Token_Repeat)
- bodyjoint(stat.Token_Repeat, stat.Body, stat.Token_Until)
- stripExpr(stat.Condition)
- joint(stat.Token_Until, stat.Condition:GetFirstToken())
- elseif stat.Type == 'GenericForStat' then
- stript(stat.Token_For)
- for index, var in pairs(stat.VarList) do
- if index == 1 then
- joint(stat.Token_For, var)
- else
- stript(var)
- end
- local sep = stat.Token_VarCommaList[index]
- if sep then
- stript(sep)
- end
- end
- joint(stat.VarList[#stat.VarList], stat.Token_In)
- for index, expr in pairs(stat.GeneratorList) do
- stripExpr(expr)
- if index == 1 then
- joint(stat.Token_In, expr:GetFirstToken())
- end
- local sep = stat.Token_GeneratorCommaList[index]
- if sep then
- stript(sep)
- end
- end
- joint(stat.GeneratorList[#stat.GeneratorList]:GetLastToken(), stat.Token_Do)
- bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'NumericForStat' then
- stript(stat.Token_For)
- for index, var in pairs(stat.VarList) do
- if index == 1 then
- joint(stat.Token_For, var)
- else
- stript(var)
- end
- local sep = stat.Token_VarCommaList[index]
- if sep then
- stript(sep)
- end
- end
- joint(stat.VarList[#stat.VarList], stat.Token_Equals)
- for index, expr in pairs(stat.RangeList) do
- stripExpr(expr)
- if index == 1 then
- joint(stat.Token_Equals, expr:GetFirstToken())
- end
- local sep = stat.Token_RangeCommaList[index]
- if sep then
- stript(sep)
- end
- end
- joint(stat.RangeList[#stat.RangeList]:GetLastToken(), stat.Token_Do)
- bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'WhileStat' then
- stript(stat.Token_While)
- stripExpr(stat.Condition)
- stript(stat.Token_Do)
- joint(stat.Token_While, stat.Condition:GetFirstToken())
- joint(stat.Condition:GetLastToken(), stat.Token_Do)
- bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'DoStat' then
- stript(stat.Token_Do)
- stript(stat.Token_End)
- bodyjoint(stat.Token_Do, stat.Body, stat.Token_End)
- elseif stat.Type == 'IfStat' then
- stript(stat.Token_If)
- stripExpr(stat.Condition)
- joint(stat.Token_If, stat.Condition:GetFirstToken())
- joint(stat.Condition:GetLastToken(), stat.Token_Then)
- --
- local lastBodyOpen = stat.Token_Then
- local lastBody = stat.Body
- --
- for _, clause in pairs(stat.ElseClauseList) do
- bodyjoint(lastBodyOpen, lastBody, clause.Token)
- lastBodyOpen = clause.Token
- --
- if clause.Condition then
- stripExpr(clause.Condition)
- joint(clause.Token, clause.Condition:GetFirstToken())
- joint(clause.Condition:GetLastToken(), clause.Token_Then)
- lastBodyOpen = clause.Token_Then
- end
- stripStat(clause.Body)
- lastBody = clause.Body
- end
- --
- bodyjoint(lastBodyOpen, lastBody, stat.Token_End)
- elseif stat.Type == 'CallExprStat' then
- stripExpr(stat.Expression)
- elseif stat.Type == 'AssignmentStat' then
- for index, ex in pairs(stat.Lhs) do
- stripExpr(ex)
- local sep = stat.Token_LhsSeparatorList[index]
- if sep then
- stript(sep)
- end
- end
- stript(stat.Token_Equals)
- for index, ex in pairs(stat.Rhs) do
- stripExpr(ex)
- local sep = stat.Token_RhsSeparatorList[index]
- if sep then
- stript(sep)
- end
- end
- else
- assert(false, "unreachable")
- end
- end
- stripStat(ast)
- end
- local idGen = 0
- local VarDigits = {}
- for i = ('a'):byte(), ('z'):byte() do table.insert(VarDigits, string.char(i)) end
- for i = ('A'):byte(), ('Z'):byte() do table.insert(VarDigits, string.char(i)) end
- for i = ('0'):byte(), ('9'):byte() do table.insert(VarDigits, string.char(i)) end
- table.insert(VarDigits, '_')
- local VarStartDigits = {}
- for i = ('a'):byte(), ('z'):byte() do table.insert(VarStartDigits, string.char(i)) end
- for i = ('A'):byte(), ('Z'):byte() do table.insert(VarStartDigits, string.char(i)) end
- local function indexToVarName(index)
- local id = ''
- local d = index % #VarStartDigits
- index = (index - d) / #VarStartDigits
- id = id..VarStartDigits[d+1]
- while index > 0 do
- local d = index % #VarDigits
- index = (index - d) / #VarDigits
- id = id..VarDigits[d+1]
- end
- return id
- end
- local function genNextVarName()
- local varToUse = idGen
- idGen = idGen + 1
- return indexToVarName(varToUse)
- end
- local function genVarName()
- local varName = ''
- repeat
- varName = genNextVarName()
- until not Keywords[varName]
- return varName
- end
- local function MinifyVariables(globalScope, rootScope)
- -- externalGlobals is a set of global variables that have not been assigned to, that is
- -- global variables defined "externally to the script". We are not going to be renaming
- -- those, and we have to make sure that we don't collide with them when renaming
- -- things so we keep track of them in this set.
- local externalGlobals = {}
- -- First we want to rename all of the variables to unique temoraries, so that we can
- -- easily use the scope::GetVar function to check whether renames are valid.
- local temporaryIndex = 0
- for _, var in pairs(globalScope) do
- if var.AssignedTo then
- var:Rename('_TMP_'..temporaryIndex..'_')
- temporaryIndex = temporaryIndex + 1
- else
- -- Not assigned to, external global
- externalGlobals[var.Name] = true
- end
- end
- local function temporaryRename(scope)
- for _, var in pairs(scope.VariableList) do
- var:Rename('_TMP_'..temporaryIndex..'_')
- temporaryIndex = temporaryIndex + 1
- end
- for _, childScope in pairs(scope.ChildScopeList) do
- temporaryRename(childScope)
- end
- end
- -- Now we go through renaming, first do globals, we probably want them
- -- to have shorter names in general.
- -- TODO: Rename all vars based on frequency patterns, giving variables
- -- used more shorter names.
- local nextFreeNameIndex = 0
- for _, var in pairs(globalScope) do
- if var.AssignedTo then
- local varName = ''
- repeat
- varName = indexToVarName(nextFreeNameIndex)
- nextFreeNameIndex = nextFreeNameIndex + 1
- until not Keywords[varName] and not externalGlobals[varName]
- var:Rename(varName)
- end
- end
- -- Now rename all local vars
- rootScope.FirstFreeName = nextFreeNameIndex
- local function doRenameScope(scope)
- for _, var in pairs(scope.VariableList) do
- local varName = ''
- repeat
- varName = indexToVarName(scope.FirstFreeName)
- scope.FirstFreeName = scope.FirstFreeName + 1
- until not Keywords[varName] and not externalGlobals[varName]
- var:Rename(varName)
- end
- for _, childScope in pairs(scope.ChildScopeList) do
- childScope.FirstFreeName = scope.FirstFreeName
- doRenameScope(childScope)
- end
- end
- doRenameScope(rootScope)
- end
- local function MinifyVariables_2(globalScope, rootScope)
- -- Variable names and other names that are fixed, that we cannot use
- -- Either these are Lua keywords, or globals that are not assigned to,
- -- that is environmental globals that are assigned elsewhere beyond our
- -- control.
- local globalUsedNames = {}
- for kw, _ in pairs(Keywords) do
- globalUsedNames[kw] = true
- end
- -- Gather a list of all of the variables that we will rename
- local allVariables = {}
- local allLocalVariables = {}
- do
- -- Add applicable globals
- for _, var in pairs(globalScope) do
- if var.AssignedTo then
- -- We can try to rename this global since it was assigned to
- -- (and thus presumably initialized) in the script we are
- -- minifying.
- table.insert(allVariables, var)
- else
- -- We can't rename this global, mark it as an unusable name
- -- and don't add it to the nename list
- globalUsedNames[var.Name] = true
- end
- end
- -- Recursively add locals, we can rename all of those
- local function addFrom(scope)
- for _, var in pairs(scope.VariableList) do
- table.insert(allVariables, var)
- table.insert(allLocalVariables, var)
- end
- for _, childScope in pairs(scope.ChildScopeList) do
- addFrom(childScope)
- end
- end
- addFrom(rootScope)
- end
- -- Add used name arrays to variables
- for _, var in pairs(allVariables) do
- var.UsedNameArray = {}
- end
- -- Sort the least used variables first
- table.sort(allVariables, function(a, b)
- return #a.RenameList < #b.RenameList
- end)
- -- Lazy generator for valid names to rename to
- local nextValidNameIndex = 0
- local varNamesLazy = {}
- local function varIndexToValidVarName(i)
- local name = varNamesLazy[i]
- if not name then
- repeat
- name = indexToVarName(nextValidNameIndex)
- nextValidNameIndex = nextValidNameIndex + 1
- until not globalUsedNames[name]
- varNamesLazy[i] = name
- end
- return name
- end
- -- For each variable, go to rename it
- for _, var in pairs(allVariables) do
- -- Lazy... todo: Make theis pair a proper for-each-pair-like set of loops
- -- rather than using a renamed flag.
- var.Renamed = true
- -- Find the first unused name
- local i = 1
- while var.UsedNameArray[i] do
- i = i + 1
- end
- -- Rename the variable to that name
- var:Rename(varIndexToValidVarName(i))
- if var.Scope then
- -- Now we need to mark the name as unusable by any variables:
- -- 1) At the same depth that overlap lifetime with this one
- -- 2) At a deeper level, which have a reference to this variable in their lifetimes
- -- 3) At a shallower level, which are referenced during this variable's lifetime
- for _, otherVar in pairs(allVariables) do
- if not otherVar.Renamed then
- if not otherVar.Scope or otherVar.Scope.Depth < var.Scope.Depth then
- -- Check Global variable (Which is always at a shallower level)
- -- or
- -- Check case 3
- -- The other var is at a shallower depth, is there a reference to it
- -- durring this variable's lifetime?
- for _, refAt in pairs(otherVar.ReferenceLocationList) do
- if refAt >= var.BeginLocation and refAt <= var.ScopeEndLocation then
- -- Collide
- otherVar.UsedNameArray[i] = true
- break
- end
- end
- elseif otherVar.Scope.Depth > var.Scope.Depth then
- -- Check Case 2
- -- The other var is at a greater depth, see if any of the references
- -- to this variable are in the other var's lifetime.
- for _, refAt in pairs(var.ReferenceLocationList) do
- if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then
- -- Collide
- otherVar.UsedNameArray[i] = true
- break
- end
- end
- else --otherVar.Scope.Depth must be equal to var.Scope.Depth
- -- Check case 1
- -- The two locals are in the same scope
- -- Just check if the usage lifetimes overlap within that scope. That is, we
- -- can shadow a local variable within the same scope as long as the usages
- -- of the two locals do not overlap.
- if var.BeginLocation < otherVar.EndLocation and
- var.EndLocation > otherVar.BeginLocation
- then
- otherVar.UsedNameArray[i] = true
- end
- end
- end
- end
- else
- -- This is a global var, all other globals can't collide with it, and
- -- any local variable with a reference to this global in it's lifetime
- -- can't collide with it.
- for _, otherVar in pairs(allVariables) do
- if not otherVar.Renamed then
- if otherVar.Type == 'Global' then
- otherVar.UsedNameArray[i] = true
- elseif otherVar.Type == 'Local' then
- -- Other var is a local, see if there is a reference to this global within
- -- that local's lifetime.
- for _, refAt in pairs(var.ReferenceLocationList) do
- if refAt >= otherVar.BeginLocation and refAt <= otherVar.ScopeEndLocation then
- -- Collide
- otherVar.UsedNameArray[i] = true
- break
- end
- end
- else
- assert(false, "unreachable")
- end
- end
- end
- end
- end
- -- --
- -- print("Total Variables: "..#allVariables)
- -- print("Total Range: "..rootScope.BeginLocation.."-"..rootScope.EndLocation)
- -- print("")
- -- for _, var in pairs(allVariables) do
- -- io.write("`"..var.Name.."':\n\t#symbols: "..#var.RenameList..
- -- "\n\tassigned to: "..tostring(var.AssignedTo))
- -- if var.Type == 'Local' then
- -- io.write("\n\trange: "..var.BeginLocation.."-"..var.EndLocation)
- -- io.write("\n\tlocal type: "..var.Info.Type)
- -- end
- -- io.write("\n\n")
- -- end
- -- -- First we want to rename all of the variables to unique temoraries, so that we can
- -- -- easily use the scope::GetVar function to check whether renames are valid.
- -- local temporaryIndex = 0
- -- for _, var in pairs(allVariables) do
- -- var:Rename('_TMP_'..temporaryIndex..'_')
- -- temporaryIndex = temporaryIndex + 1
- -- end
- -- For each variable, we need to build a list of names that collide with it
- --
- --error()
- end
- local function BeautifyVariables(globalScope, rootScope)
- local externalGlobals = {}
- for _, var in pairs(globalScope) do
- if not var.AssignedTo then
- externalGlobals[var.Name] = true
- end
- end
- local localNumber = 1
- local globalNumber = 1
- local function setVarName(var, name)
- var.Name = name
- for _, setter in pairs(var.RenameList) do
- setter(name)
- end
- end
- for _, var in pairs(globalScope) do
- if var.AssignedTo then
- setVarName(var, 'G_'..globalNumber)
- globalNumber = globalNumber + 1
- end
- end
- local function modify(scope)
- for _, var in pairs(scope.VariableList) do
- local name = 'L_'..localNumber..'_'
- if var.Info.Type == 'Argument' then
- name = name..'arg'..var.Info.Index
- elseif var.Info.Type == 'LocalFunction' then
- name = name..'func'
- elseif var.Info.Type == 'ForRange' then
- name = name..'forvar'..var.Info.Index
- end
- setVarName(var, name)
- localNumber = localNumber + 1
- end
- for _, scope in pairs(scope.ChildScopeList) do
- modify(scope)
- end
- end
- modify(rootScope)
- end
- local function usageError()
- error(
- "\nusage: minify <file> or unminify <file>\n" ..
- " The modified code will be printed to the stdout, pipe it to a file, the\n" ..
- " lua interpreter, or something else as desired EG:\n\n" ..
- " lua minify.lua minify input.lua > output.lua\n\n" ..
- " * minify will minify the code in the file.\n" ..
- " * unminify will beautify the code and replace the variable names with easily\n" ..
- " find-replacable ones to aide in reverse engineering minified code.\n", 0)
- end
- local args = {...}
- if #args ~= 2 then
- usageError()
- end
- local sourceFile = io.open(args[2], 'r')
- if not sourceFile then
- error("Could not open the input file `" .. args[2] .. "`", 0)
- end
- local data = sourceFile:read('*all')
- local ast = CreateLuaParser(data)
- local global_scope, root_scope = AddVariableInfo(ast)
- local function minify(ast, global_scope, root_scope)
- MinifyVariables(global_scope, root_scope)
- StripAst(ast)
- PrintAst(ast)
- end
- local function beautify(ast, global_scope, root_scope)
- BeautifyVariables(global_scope, root_scope)
- FormatAst(ast)
- PrintAst(ast)
- end
- if args[1] == 'minify' then
- minify(ast, global_scope, root_scope)
- elseif args[1] == 'unminify' then
- beautify(ast, global_scope, root_scope)
- else
- usageError()
- end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement