Advertisement
Baidicoot

adt-lua

Jul 2nd, 2020
579
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 6.34 KB | None | 0 0
  1. local function trim(str, chars)
  2.     if chars == nil then
  3.         chars = "%s*"
  4.     end
  5.     return string.match(str, "^"..chars.."(.-)"..chars.."$")
  6. end
  7.  
  8. local function split(str, delim)
  9.     if delim == nil then
  10.         delim = "\n"
  11.     end
  12.     local t = {}
  13.     for s in string.gmatch(str, "([^"..delim.."]+)") do
  14.         table.insert(t, trim(s))
  15.     end
  16.     return t
  17. end
  18.  
  19. local function find_multiple(str, patterns, offset)
  20.     local ws = string.len(str)
  21.     local we = we
  22.     local wpattern = nil
  23.     for i, pattern in pairs(patterns) do
  24.         s, e = string.find(str, pattern, offset)
  25.         if s ~= nil then
  26.             if s < ws then
  27.                 ws = s
  28.                 we = e
  29.                 wpattern = pattern
  30.             end
  31.         end
  32.     end
  33.     return ws, we, wpattern
  34. end
  35.  
  36. local function balanced_end(str, word, offset)
  37.     if offset == nil then
  38.         offset = 1
  39.     end
  40.     local i = offset
  41.     while true do
  42.         local s, e, p
  43.         if word == "then" then
  44.             s, e, p = find_multiple(str, {"%smatch%s", "%sfunction%s", "%sthen%s", "%sdo%s", "%send%s", "%selseif%s"}, i)
  45.         else
  46.             s, e, p = find_multiple(str, {"%smatch%s", "%sfunction%s", "%sthen%s", "%sdo%s", "%send%s"}, i)
  47.         end
  48.         if p == "%send%s" then
  49.             return e
  50.         elseif p == "%selseif%s" then
  51.             return e
  52.         elseif p == nil then
  53.             return "UNBAL"
  54.         end
  55.         i = balanced_end(str, string.sub(p, 3, -3), e)
  56.         if i == "UNBAL" then
  57.             return i
  58.         end
  59.     end
  60. end
  61.  
  62. local function get_decls(str)
  63.     -- gather data declarations from source & remove
  64.     local datas = {}
  65.     local i = 0
  66.     local strout = ""
  67.     while true do
  68.         local n, a = string.find(str, "%sdata [%w]+", i+1)
  69.         if n == nil then
  70.             strout = strout .. string.sub(str, i+1)
  71.             break
  72.         end
  73.         strout = strout .. string.sub(str, i+1, n)
  74.         local e, d = string.find(str, "end", i+1)
  75.         local cont = string.sub(str, a+1, e-1)
  76.         local data = {}
  77.         for i, case in ipairs(split(cont)) do
  78.             local c = {}
  79.             local b, p = string.find(case, "%(")
  80.             if b == nil then
  81.                 c.name = case
  82.                 c.args = 0  
  83.             else
  84.                 c.name = string.sub(case, 1, p-1)
  85.                 c.args = #split(case, ",")
  86.             end
  87.             table.insert(data, c)
  88.         end
  89.         i = d
  90.         table.insert(datas, data)
  91.     end
  92.     return datas, strout
  93. end
  94.  
  95. function parseexpr(str)
  96.     local n, o, name, body = string.find(str, "(%w+)(%b())")
  97.     if n == nil then
  98.         local b = string.find(str, ",")
  99.         if b == nil then
  100.             return {type="var",name=str}, ""
  101.         else
  102.             return {type="var", name=string.sub(str, 0, b-1)}, string.sub(str,b+1)
  103.         end
  104.     end
  105.     body = string.sub(body, 2, -2)
  106.     local obj = {type="data", name=name, body=parseexprs(body)}
  107.     local rem = string.sub(str, o+1)
  108.     local b = string.find(rem, ",")
  109.     if b == nil then
  110.         return obj, ""
  111.     else
  112.         return obj, string.sub(rem,b+1)
  113.     end
  114. end
  115.  
  116. function parseexprs(str)
  117.     local t = {}
  118.     while str ~= "" do
  119.         local obj
  120.         obj, str = parseexpr(str)
  121.         table.insert(t, obj)
  122.     end
  123.     return t
  124. end
  125.  
  126. local function getCase(datas, data)
  127.     for i, x in ipairs(datas) do
  128.         for i, case in ipairs(x) do
  129.             if case.name == data then
  130.                 return i
  131.             end
  132.         end
  133.     end
  134. end
  135.  
  136. local function comparison(datas, var, pattern)
  137.     if pattern.type == "data" then
  138.         local out = var .. ".case == " .. getCase(datas, pattern.name)
  139.         for i, x in ipairs(pattern.body) do
  140.             if x.type == "data" then
  141.                 out = out .. " and " .. comparison(datas, var .. "[" .. i .. "]", x)
  142.             end
  143.         end
  144.         return out
  145.     else
  146.         return "true"
  147.     end
  148. end
  149.  
  150. local function destructure(datas, var, pattern)
  151.     if pattern.type == "var" then
  152.         return "local " .. pattern.name .. " = " .. var
  153.     else
  154.         local out = ""
  155.         for i, x in ipairs(pattern.body) do
  156.             out = out .. "\n" .. destructure(datas, var .. "[" .. i .. "]", x)
  157.         end
  158.         return out
  159.     end
  160. end
  161.  
  162. function replace_match(datas, str)
  163.     while true do
  164.         local m, b, var = string.find(str, "%smatch (%w+)%s")
  165.         if m == nil then
  166.             return str
  167.         end
  168.         local e = balanced_end(str, "match", b)
  169.         local cont = string.sub(str, b, e-4)
  170.         str = string.sub(str, 1, m) .. replace_case(datas, cont, var) .. string.sub(str, e, -1)
  171.     end
  172. end
  173.  
  174. function replace_case(datas, out, var)
  175.     local str = out
  176.     local count = 0
  177.     while true do
  178.         local m
  179.         local b
  180.         local p
  181.         m, b, p = string.find(str, "%scase ([^%s]+) do%s")
  182.         if m == nil then
  183.             for i=1,count do
  184.                 str = str .. "\nend"
  185.             end
  186.             return str
  187.         end
  188.         count = count + 1
  189.         local e = balanced_end(str, "do", b)
  190.         local cont = string.sub(str, b, e-5)
  191.         local pattern = parseexpr(p)
  192.         local bool = comparison(datas, var, pattern)
  193.         local body = destructure(datas, var, pattern)
  194.         str = string.sub(str, 1, m) .. "\nif " .. bool .. " then\n" .. body .. cont .. "\nelse\n" .. string.sub(str, e, -1)
  195.     end
  196. end
  197.  
  198. function printpattern(pattern)
  199.     if pattern.type == "data" then
  200.         io.write(pattern.name)
  201.         io.write("(")
  202.         for i, v in ipairs(pattern.body) do
  203.             printpattern(v)
  204.             io.write(",")
  205.         end
  206.         io.write(")")
  207.     else
  208.         io.write(pattern.name)
  209.     end
  210. end
  211.  
  212. local function writeheaders(datas)
  213.     local out = ""
  214.     for i, x in ipairs(datas) do
  215.         for i, case in ipairs(x) do
  216.             out = out .. "\nlocal function " .. case.name .. "(...) return {case=" .. i .. ",...} end"
  217.         end
  218.     end
  219.     return out
  220. end
  221.  
  222. local function preprocess(str)
  223.     local decls, str0 = get_decls(str)
  224.     local b = replace_match(decls, str0)
  225.     local h = writeheaders(decls)
  226.     return h .. b
  227. end
  228.  
  229. args = {...}
  230.  
  231. inp = io.open(args[1], "r")
  232. str = inp:read("*a")
  233. io.close(inp)
  234.  
  235. out = io.open(args[2], "w")
  236. out:write(preprocess(str))
  237. io.close(out)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement