immibis

compress

Feb 25th, 2013
391
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 8.07 KB | None | 0 0
  1. local inFN, outFN
  2.  
  3. local args = {...}
  4. if #args ~= 2 then
  5.     error("Usage: "..shell.getRunningProgram().." <input file> <output file>", 0)
  6. end
  7. inFN = args[1]
  8. outFN = args[2]
  9.  
  10. local function countTabs(l)
  11.     for k = 1, #l do
  12.         if l:sub(k, k) ~= "\t" then
  13.             return k - 1
  14.         end
  15.     end
  16.     return #l
  17. end
  18.  
  19. local compressNonIdentGroups = false
  20. local huffmanEncode = true
  21.  
  22. -- read input file
  23. local f = fs.open(inFN, "r")
  24. local lines = {}
  25. for line in f.readLine do
  26.     lines[#lines+1] = line
  27. end
  28. f.close()
  29.  
  30. -- convert indentation
  31. local inText = ""
  32. local lastIndent = 0
  33. for lineNo, line in ipairs(lines) do
  34.     local thisIndent = countTabs(line)
  35.     local nextIndent = lines[lineNo+1] and countTabs(lines[lineNo+1]) or thisIndent
  36.     local prevIndent = lines[lineNo-1] and countTabs(lines[lineNo-1]) or 0
  37.    
  38.     if thisIndent > nextIndent and thisIndent > prevIndent then
  39.         thisIndent = math.min(nextIndent, prevIndent)
  40.     end
  41.    
  42.     while lastIndent < thisIndent do
  43.         inText = inText .. "&+"
  44.         lastIndent = lastIndent + 1
  45.     end
  46.     while lastIndent > thisIndent do
  47.         inText = inText .. "&-"
  48.         lastIndent = lastIndent - 1
  49.     end
  50.    
  51.     if line:sub(1,1) == "&" then
  52.         line = "&" .. line
  53.     end
  54.    
  55.     inText = inText .. line:sub(lastIndent+1) .. "\n"
  56. end
  57.  
  58.  
  59. -- parse into alternating strings of alphanumerics and non-alphanumerics
  60. local parsed = {}
  61. local idents = "abcdefghijklmnopqrstvuwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
  62. local lastIdent = nil
  63.  
  64. local function isIdentString(s)
  65.     return idents:find(s:sub(1,1), 1, true) ~= nil
  66. end
  67.  
  68. local groupCounts = {}
  69.  
  70. local function onFinishSegment(isIdent, segment)
  71.     if isIdent or compressNonIdentGroups then
  72.         groupCounts[segment] = (groupCounts[segment] or 0) + 1
  73.     end
  74. end
  75.  
  76. for k = 1, #inText do
  77.     local ch = inText:sub(k, k)
  78.     local isIdent = idents:find(ch, 1, true) ~= nil
  79.     if isIdent ~= lastIdent then
  80.         if #parsed > 0 then
  81.             onFinishSegment(lastIdent, parsed[#parsed])
  82.         end
  83.         parsed[#parsed+1] = ""
  84.     end
  85.     lastIdent = isIdent
  86.     parsed[#parsed] = parsed[#parsed]..ch
  87. end
  88. if #parsed > 0 then
  89.     onFinishSegment(isIdent, parsed[#parsed])
  90. end
  91.  
  92. --print(table.concat(parsed," "))
  93.  
  94. local id_literal_escape = "_"
  95. local pc_literal_escape = "$"
  96.  
  97. local nextCompressed
  98. do
  99.     local validchars_id = idents:gsub(id_literal_escape,"")
  100.     local validchars_pc = ""
  101.    
  102.     for n=32,126 do
  103.         local ch = string.char(n)
  104.         if not idents:find(ch,1,true) and ch ~= pc_literal_escape then
  105.             validchars_pc = validchars_pc .. ch
  106.         end
  107.     end
  108.    
  109.     local function encode(n, isIdent)
  110.         local s = ""
  111.         local validchars = isIdent and validchars_id or validchars_pc
  112.         while n > 0 do
  113.             local digit = (n % #validchars) + 1
  114.             s = s .. validchars:sub(digit, digit)
  115.             n = math.floor(n / #validchars)
  116.         end
  117.         return s
  118.     end
  119.    
  120.     local next = {[true]=0,[false]=0}
  121.     function nextCompressed(isIdent)
  122.         next[isIdent] = next[isIdent] + 1
  123.         return encode(next[isIdent], isIdent)
  124.     end
  125. end
  126.  
  127.  
  128.  
  129.  
  130.  
  131.  
  132. local groupsSorted = {}
  133. local groups = {}
  134. for k, v in pairs(groupCounts) do
  135.     if (#k > 1 and v > 1) or k:find(isIdentString(k) and id_literal_escape or pc_literal_escape) then
  136.         local t = {k, v}
  137.         groups[k] = t
  138.         table.insert(groupsSorted, t)
  139.     end
  140. end
  141.  
  142. local avgCompressedLength = 2
  143.  
  144. local function estSavings(a)
  145.     local str = a[1]
  146.     local count = a[2]
  147.     local compressedLength = a[3] and #a[3] or avgCompressedLength
  148.    
  149.     -- estimates the number of chars saved by compressing this group
  150.    
  151.     -- it costs #str+1 chars to encode the group literally, or about compressedLength chars to compress it
  152.     -- so by compressing it, each time the group occurs we save (#str + 1 - compressedLength) chars
  153.     local saved = (#str + 1 - compressedLength) * count
  154.    
  155.     -- but we also use about #str + 1 chars in the name table if we compress it.
  156.     saved = saved - (#str + 1)
  157.    
  158.     return saved
  159. end
  160.  
  161. table.sort(groupsSorted, function(a, b)
  162.     return estSavings(a) > estSavings(b)
  163. end)
  164.  
  165. --local total = 0
  166. for _, v in ipairs(groupsSorted) do
  167.     v[3] = nextCompressed(isIdentString(v[1]))
  168.     --total = total + estSavings(v)
  169.     --[[if estSavings(v) > 0 then
  170.         print(v[1]:gsub("\n","\\n")," ",v[2]," ",v[3]," ",estSavings(v))
  171.     end]]
  172. end
  173.  
  174. --print(total," ",#inText," ",#inText-total)
  175.  
  176.  
  177.  
  178.  
  179.  
  180. local out = #groupsSorted .. "^"
  181. for _, v in ipairs(groupsSorted) do
  182.     local encoded = v[1]:gsub("&", "&a"):gsub("%^","&b")
  183.     out = out .. encoded .. "^"
  184. end
  185. for _, v in pairs(parsed) do
  186.     if groups[v] then
  187.         out = out .. groups[v][3]
  188.     elseif isIdentString(v) then
  189.         out = out .. id_literal_escape .. v
  190.     elseif compressNonIdentGroups then
  191.         out = out .. pc_literal_escape .. v
  192.     else
  193.         out = out .. v
  194.     end
  195. end
  196.  
  197. if huffmanEncode then
  198.     -- generate a huffman tree - first we need to count the number of times each symbol occurs
  199.     local symbolCounts = {}
  200.     local numSymbols = 0
  201.     for k = 1, #out do
  202.         local sym = out:sub(k,k)
  203.         if not symbolCounts[sym] then
  204.             numSymbols = numSymbols + 1
  205.             symbolCounts[sym] = 1
  206.         else
  207.             symbolCounts[sym] = symbolCounts[sym] + 1
  208.         end
  209.     end
  210.    
  211.     -- convert them to tree nodes and sort them by count, ascending order
  212.     -- a tree node is either {symbol, count} or {{subtree_0, subtree_1}, count}
  213.     local treeFragments = {}
  214.     for sym, count in pairs(symbolCounts) do
  215.         treeFragments[#treeFragments + 1] = {sym, count}
  216.     end
  217.     table.sort(treeFragments, function(a, b)
  218.         return a[2] < b[2]
  219.     end)
  220.    
  221.     while #treeFragments > 1 do
  222.         -- take the two lowest-count fragments and combine them
  223.         local a = table.remove(treeFragments, 1)
  224.         local b = table.remove(treeFragments, 1)
  225.        
  226.         local newCount = a[2] + b[2]
  227.         local new = {{a, b}, newCount}
  228.        
  229.         -- insert the new fragment in the right place
  230.         if #treeFragments == 0 or newCount > treeFragments[#treeFragments][2] then
  231.             table.insert(treeFragments, new)
  232.         else
  233.             local ok = false
  234.             for k=1,#treeFragments do
  235.                 if treeFragments[k][2] >= newCount then
  236.                     table.insert(treeFragments, k, new)
  237.                     ok = true
  238.                     break
  239.                 end
  240.             end
  241.             assert(ok, "internal error: couldn't find place for tree fragment")
  242.         end
  243.     end
  244.    
  245.     local symbolCodes = {}
  246.    
  247.     local function shallowCopyTable(t)
  248.         local rv = {}
  249.         for k,v in pairs(t) do
  250.             rv[k] = v
  251.         end
  252.         return rv
  253.     end
  254.    
  255.     -- now we have a huffman tree (codes -> symbols) but we need a map of symbols -> codes, so do that
  256.     local function iterate(root, path)
  257.         if type(root[1]) == "table" then
  258.             local t = shallowCopyTable(path)
  259.             t[#t+1] = false
  260.             iterate(root[1][1], t)
  261.             path[#path+1] = true
  262.             iterate(root[1][2], path)
  263.         else
  264.             symbolCodes[root[1]] = path
  265.         end
  266.     end
  267.     iterate(treeFragments[1], {})
  268.    
  269.     local rv = {}
  270.    
  271.     local symbolBitWidth = 8
  272.    
  273.     --print("#syms: ",numSymbols)
  274.     --print("sbw: ",symbolBitWidth)
  275.    
  276.     local function writeTree(tree)
  277.         if type(tree[1]) == "table" then
  278.             rv[#rv+1] = false
  279.             writeTree(tree[1][1])
  280.             writeTree(tree[1][2])
  281.         else
  282.             rv[#rv+1] = true
  283.             local symbol = tree[1]:byte()
  284.             for k = 0, symbolBitWidth - 1 do
  285.                
  286.                 local testBit = 2 ^ k
  287.                
  288.                 -- local bit = ((symbol & testBit) != 0)
  289.                 local bit = (symbol % (2 * testBit)) >= testBit
  290.                 rv[#rv+1] = bit
  291.             end
  292.         end
  293.     end
  294.    
  295.     writeTree(treeFragments[1])
  296.    
  297.     --print("tree size: ",#rv)
  298.    
  299.     for k = 1, #out do
  300.         local symbol = out:sub(k,k)
  301.         --print(symbol," ",#symbolCodes[symbol])
  302.         for _, bit in ipairs(symbolCodes[symbol] or error("internal error: symbol "..symbol.." has no code")) do
  303.             rv[#rv+1] = bit
  304.         end
  305.     end
  306.    
  307.     --print("total size: ",#rv)
  308.    
  309.    
  310.     -- convert the array of bits (rv) back to characters
  311.    
  312.    
  313.     local s = ""
  314.    
  315.     -- write 6 bits per byte because LuaJ (and/or CC)  
  316.     local bitsPerByte = 6
  317.     local firstCharacter = 32
  318.    
  319.     -- pad to an integral number of bytes
  320.     local padbit = not rv[#rv]
  321.     repeat
  322.         rv[#rv+1] = padbit
  323.     until (#rv % bitsPerByte) == 0
  324.    
  325.    
  326.     for k = 1, #rv, bitsPerByte do
  327.         local byte = firstCharacter
  328.         for i = 0, bitsPerByte-1 do
  329.             if rv[k+i] then    
  330.                 byte = byte + 2 ^ i
  331.             end
  332.         end
  333.         s = s .. string.char(byte)
  334.     end
  335.    
  336.     out = s
  337. end
  338.  
  339. local f = fs.open(outFN, "w")
  340. f.write(out)
  341. f.close()
  342.  
  343. print("uncompressed size: ",inText:len())
  344. print("compressed size: ",out:len())
  345. print("compression ratio: ",out:len()/inText:len())
Advertisement
Add Comment
Please, Sign In to add comment