ou1z

Untitled

Jul 25th, 2021 (edited)
251
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 17.11 KB | None | 0 0
  1. local M = {_TYPE='module', _NAME='compress.deflatelua', _VERSION='0.3.20111128'}
  2.  
  3. local assert = assert
  4. local error = error
  5. local ipairs = ipairs
  6. local pairs = pairs
  7. local print = print
  8. local tostring = tostring
  9. local type = type
  10. local setmetatable = setmetatable
  11. local io = io
  12. local math = math
  13. local table_sort = table.sort
  14. local math_max = math.max
  15. local string_char = string.char
  16.  
  17.  
  18.  
  19. local DEBUG = false
  20.  
  21. local NATIVE_BITOPS = (bit ~= nil)
  22.  
  23. local band, lshift, rshift
  24. if NATIVE_BITOPS then
  25.   band = bit.band
  26.   lshift = bit.lshift
  27.   rshift = bit.rshift
  28. end
  29.  
  30.  
  31. local function warn(s)
  32. print(s)
  33. end
  34.  
  35.  
  36. local function debug(...)
  37.   print('DEBUG', ...)
  38. end
  39.  
  40.  
  41. local function runtime_error(s, level)
  42.   level = level or 1
  43.   error({s}, level+1)
  44. end
  45.  
  46.  
  47. local function make_outstate(outbs)
  48.   local outstate = {}
  49.   outstate.outbs = outbs
  50.   outstate.window = {}
  51.   outstate.window_pos = 1
  52.   return outstate
  53. end
  54.  
  55.  
  56. local function output(outstate, byte)
  57.   -- debug('OUTPUT:', s)
  58.   local window_pos = outstate.window_pos
  59.   outstate.outbs(byte)
  60.   outstate.window[window_pos] = byte
  61.   outstate.window_pos = window_pos % 32768 + 1 -- 32K
  62. end
  63.  
  64.  
  65. local function noeof(val)
  66.   return assert(val, 'unexpected end of file')
  67. end
  68.  
  69.  
  70. local function hasbit(bits, bit)
  71.   return bits % (bit + bit) >= bit
  72. end
  73.  
  74.  
  75. local function memoize(f)
  76.   local mt = {}
  77.   local t = setmetatable({}, mt)
  78.   function mt:__index(k)
  79.     local v = f(k)
  80.     t[k] = v
  81.     return v
  82.   end
  83.   return t
  84. end
  85.  
  86.  
  87. -- small optimization (lookup table for powers of 2)
  88. local pow2 = memoize(function(n) return 2^n end)
  89.  
  90. --local tbits = memoize(
  91. -- function(bits)
  92. -- return memoize( function(bit) return getbit(bits, bit) end )
  93. -- end )
  94.  
  95.  
  96. -- weak metatable marking objects as bitstream type
  97. local is_bitstream = setmetatable({}, {__mode='k'})
  98.  
  99.  
  100. -- DEBUG
  101. -- prints LSB first
  102. --[[
  103. local function bits_tostring(bits, nbits)
  104. local s = ''
  105. local tmp = bits
  106. local function f()
  107. local b = tmp % 2 == 1 and 1 or 0
  108. s = s .. b
  109. tmp = (tmp - b) / 2
  110. end
  111. if nbits then
  112. for i=1,nbits do f() end
  113. else
  114. while tmp ~= 0 do f() end
  115. end
  116.  
  117. return s
  118. end
  119. --]]
  120.  
  121. local function bytestream_from_file(fh)
  122.   local o = {}
  123.   function o:read()
  124.     local sb = fh:read(1)
  125.     if sb then return sb:byte() end
  126.   end
  127.   return o
  128. end
  129.  
  130.  
  131. local function bytestream_from_string(s)
  132.   local i = 1
  133.   local o = {}
  134.   function o:read()
  135.     local by
  136.     if i <= #s then
  137.       by = s:byte(i)
  138.       i = i + 1
  139.     end
  140.     return by
  141.   end
  142.   return o
  143. end
  144.  
  145.  
  146. local function bytestream_from_function(f)
  147.   local i = 0
  148.   local buffer = ''
  149.   local o = {}
  150.   function o:read()
  151.     i = i + 1
  152.     if i > #buffer then
  153.       buffer = f()
  154.       if not buffer then return end
  155.       i = 1
  156.     end
  157.     return buffer:byte(i,i)
  158.   end
  159.   return o
  160. end
  161.  
  162.  
  163. local function bitstream_from_bytestream(bys)
  164.   local buf_byte = 0
  165.   local buf_nbit = 0
  166.   local o = {}
  167.  
  168.   function o:nbits_left_in_byte()
  169.     return buf_nbit
  170.   end
  171.  
  172.   if NATIVE_BITOPS then
  173.     function o:read(nbits)
  174.       nbits = nbits or 1
  175.       while buf_nbit < nbits do
  176.         local byte = bys:read()
  177.         if not byte then return end -- note: more calls also return nil
  178.         buf_byte = buf_byte + lshift(byte, buf_nbit)
  179.         buf_nbit = buf_nbit + 8
  180.       end
  181.       local bits
  182.       if nbits == 0 then
  183.         bits = 0
  184.       elseif nbits == 32 then
  185.         bits = buf_byte
  186.         buf_byte = 0
  187.       else
  188.         bits = band(buf_byte, rshift(0xffffffff, 32 - nbits))
  189.         buf_byte = rshift(buf_byte, nbits)
  190.       end
  191.       buf_nbit = buf_nbit - nbits
  192.       return bits
  193.     end
  194.   else
  195.     function o:read(nbits)
  196.       nbits = nbits or 1
  197.       while buf_nbit < nbits do
  198.         local byte = bys:read()
  199.         if not byte then return end -- note: more calls also return nil
  200.         buf_byte = buf_byte + pow2[buf_nbit] * byte
  201.         buf_nbit = buf_nbit + 8
  202.       end
  203.       local m = pow2[nbits]
  204.       local bits = buf_byte % m
  205.       buf_byte = (buf_byte - bits) / m
  206.       buf_nbit = buf_nbit - nbits
  207.       return bits
  208.     end
  209.   end
  210.  
  211.   is_bitstream[o] = true
  212.  
  213.   return o
  214. end
  215.  
  216.  
  217. local function get_bitstream(o)
  218.   local bs
  219.   if is_bitstream[o] then
  220.     return o
  221.   elseif isfile(o) then
  222.     bs = bitstream_from_bytestream(bytestream_from_file(o))
  223.   elseif type(o) == 'string' then
  224.     bs = bitstream_from_bytestream(bytestream_from_string(o))
  225.   elseif type(o) == 'function' then
  226.     bs = bitstream_from_bytestream(bytestream_from_function(o))
  227.   else
  228.     runtime_error 'unrecognized type'
  229.   end
  230.   return bs
  231. end
  232.  
  233.  
  234. local function get_obytestream(o)
  235.   local bs
  236.   if isfile(o) then
  237.     bs = function(sbyte) o:write(string_char(sbyte)) end
  238.   elseif type(o) == 'function' then
  239.     bs = o
  240.   else
  241.     runtime_error('unrecognized type: ' .. tostring(o))
  242.   end
  243.   return bs
  244. end
  245.  
  246.  
  247. local function HuffmanTable(init, is_full)
  248.   local t = {}
  249.   if is_full then
  250.     for val,nbits in pairs(init) do
  251.       if nbits ~= 0 then
  252.         t[#t+1] = {val=val, nbits=nbits}
  253.         --debug('*',val,nbits)
  254.       end
  255.     end
  256.   else
  257.     for i=1,#init-2,2 do
  258.       local firstval, nbits, nextval = init[i], init[i+1], init[i+2]
  259.       --debug(val, nextval, nbits)
  260.       if nbits ~= 0 then
  261.         for val=firstval,nextval-1 do
  262.           t[#t+1] = {val=val, nbits=nbits}
  263.         end
  264.       end
  265.     end
  266.   end
  267.   table_sort(t, function(a,b)
  268.     return a.nbits == b.nbits and a.val < b.val or a.nbits < b.nbits
  269.   end)
  270.  
  271.   -- assign codes
  272.   local code = 1 -- leading 1 marker
  273.   local nbits = 0
  274.   for i,s in ipairs(t) do
  275.     if s.nbits ~= nbits then
  276.       code = code * pow2[s.nbits - nbits]
  277.       nbits = s.nbits
  278.     end
  279.     s.code = code
  280.     --debug('huffman code:', i, s.nbits, s.val, code, bits_tostring(code))
  281.     code = code + 1
  282.   end
  283.  
  284.   local minbits = math.huge
  285.   local look = {}
  286.   for i,s in ipairs(t) do
  287.     minbits = math.min(minbits, s.nbits)
  288.     look[s.code] = s.val
  289.   end
  290.  
  291.   --for _,o in ipairs(t) do
  292.   -- debug(':', o.nbits, o.val)
  293.   --end
  294.  
  295.   -- function t:lookup(bits) return look[bits] end
  296.  
  297.   local msb = NATIVE_BITOPS and function(bits, nbits)
  298.     local res = 0
  299.     for i=1,nbits do
  300.       res = lshift(res, 1) + band(bits, 1)
  301.       bits = rshift(bits, 1)
  302.     end
  303.     return res
  304.   end or function(bits, nbits)
  305.     local res = 0
  306.     for i=1,nbits do
  307.       local b = bits % 2
  308.       bits = (bits - b) / 2
  309.       res = res * 2 + b
  310.     end
  311.     return res
  312.   end
  313.  
  314.   local tfirstcode = memoize(
  315.     function(bits) return pow2[minbits] + msb(bits, minbits) end)
  316.  
  317.   function t:read(bs)
  318.     local code = 1 -- leading 1 marker
  319.     local nbits = 0
  320.     while 1 do
  321.       if nbits == 0 then -- small optimization (optional)
  322.         code = tfirstcode[noeof(bs:read(minbits))]
  323.         nbits = nbits + minbits
  324.       else
  325.         local b = noeof(bs:read())
  326.         nbits = nbits + 1
  327.         code = code * 2 + b -- MSB first
  328.         --[[NATIVE_BITOPS
  329. code = lshift(code, 1) + b -- MSB first
  330. --]]
  331.       end
  332.       --debug('code?', code, bits_tostring(code))
  333.       local val = look[code]
  334.       if val then
  335.         --debug('FOUND', val)
  336.         return val
  337.       end
  338.     end
  339.   end
  340.  
  341.   return t
  342. end
  343.  
  344.  
  345. local function parse_gzip_header(bs)
  346.   -- local FLG_FTEXT = 2^0
  347.   local FLG_FHCRC = 2^1
  348.   local FLG_FEXTRA = 2^2
  349.   local FLG_FNAME = 2^3
  350.   local FLG_FCOMMENT = 2^4
  351.  
  352.   local id1 = bs:read(8)
  353.   local id2 = bs:read(8)
  354.   if id1 ~= 31 or id2 ~= 139 then
  355.     runtime_error 'not in gzip format'
  356.   end
  357.   local cm = bs:read(8) -- compression method
  358.   local flg = bs:read(8) -- FLaGs
  359.   local mtime = bs:read(32) -- Modification TIME
  360.   local xfl = bs:read(8) -- eXtra FLags
  361.   local os = bs:read(8) -- Operating System
  362.  
  363.   if DEBUG then
  364.     debug("CM=", cm)
  365.     debug("FLG=", flg)
  366.     debug("MTIME=", mtime)
  367.     -- debug("MTIME_str=",os.date("%Y-%m-%d %H:%M:%S",mtime)) -- non-portable
  368.     debug("XFL=", xfl)
  369.     debug("OS=", os)
  370.   end
  371.  
  372.   if not os then runtime_error 'invalid header' end
  373.  
  374.   if hasbit(flg, FLG_FEXTRA) then
  375.     local xlen = bs:read(16)
  376.     local extra = 0
  377.     for i=1,xlen do
  378.       extra = bs:read(8)
  379.     end
  380.     if not extra then runtime_error 'invalid header' end
  381.   end
  382.  
  383.   local function parse_zstring(bs)
  384.     repeat
  385.       local by = bs:read(8)
  386.       if not by then runtime_error 'invalid header' end
  387.     until by == 0
  388.   end
  389.  
  390.   if hasbit(flg, FLG_FNAME) then
  391.     parse_zstring(bs)
  392.   end
  393.  
  394.   if hasbit(flg, FLG_FCOMMENT) then
  395.     parse_zstring(bs)
  396.   end
  397.  
  398.   if hasbit(flg, FLG_FHCRC) then
  399.     local crc16 = bs:read(16)
  400.     if not crc16 then runtime_error 'invalid header' end
  401.     -- IMPROVE: check CRC. where is an example .gz file that
  402.     -- has this set?
  403.     if DEBUG then
  404.       debug("CRC16=", crc16)
  405.     end
  406.   end
  407. end
  408.  
  409. local function parse_zlib_header(bs)
  410.   local cm = bs:read(4) -- Compression Method
  411.   local cinfo = bs:read(4) -- Compression info
  412.   local fcheck = bs:read(5) -- FLaGs: FCHECK (check bits for CMF and FLG)
  413.   local fdict = bs:read(1) -- FLaGs: FDICT (present dictionary)
  414.   local flevel = bs:read(2) -- FLaGs: FLEVEL (compression level)
  415.   local cmf = cinfo * 16 + cm -- CMF (Compresion Method and flags)
  416.   local flg = fcheck + fdict * 32 + flevel * 64 -- FLaGs
  417.  
  418.   if cm ~= 8 then -- not "deflate"
  419.     runtime_error("unrecognized zlib compression method: " .. cm)
  420.   end
  421.   if cinfo > 7 then
  422.     runtime_error("invalid zlib window size: cinfo=" .. cinfo)
  423.   end
  424.   local window_size = 2^(cinfo + 8)
  425.  
  426.   if (cmf*256 + flg) % 31 ~= 0 then
  427.     runtime_error("invalid zlib header (bad fcheck sum)")
  428.   end
  429.  
  430.   if fdict == 1 then
  431.     runtime_error("FIX:TODO - FDICT not currently implemented")
  432.     local dictid_ = bs:read(32)
  433.   end
  434.  
  435.   return window_size
  436. end
  437.  
  438. local function parse_huffmantables(bs)
  439.     local hlit = bs:read(5) -- # of literal/length codes - 257
  440.     local hdist = bs:read(5) -- # of distance codes - 1
  441.     local hclen = noeof(bs:read(4)) -- # of code length codes - 4
  442.  
  443.     local ncodelen_codes = hclen + 4
  444.     local codelen_init = {}
  445.     local codelen_vals = {
  446.       16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
  447.     for i=1,ncodelen_codes do
  448.       local nbits = bs:read(3)
  449.       local val = codelen_vals[i]
  450.       codelen_init[val] = nbits
  451.     end
  452.     local codelentable = HuffmanTable(codelen_init, true)
  453.  
  454.     local function decode(ncodes)
  455.       local init = {}
  456.       local nbits
  457.       local val = 0
  458.       while val < ncodes do
  459.         local codelen = codelentable:read(bs)
  460.         --FIX:check nil?
  461.         local nrepeat
  462.         if codelen <= 15 then
  463.           nrepeat = 1
  464.           nbits = codelen
  465.           --debug('w', nbits)
  466.         elseif codelen == 16 then
  467.           nrepeat = 3 + noeof(bs:read(2))
  468.           -- nbits unchanged
  469.         elseif codelen == 17 then
  470.           nrepeat = 3 + noeof(bs:read(3))
  471.           nbits = 0
  472.         elseif codelen == 18 then
  473.           nrepeat = 11 + noeof(bs:read(7))
  474.           nbits = 0
  475.         else
  476.           error 'ASSERT'
  477.         end
  478.         for i=1,nrepeat do
  479.           init[val] = nbits
  480.           val = val + 1
  481.         end
  482.       end
  483.       local huffmantable = HuffmanTable(init, true)
  484.       return huffmantable
  485.     end
  486.  
  487.     local nlit_codes = hlit + 257
  488.     local ndist_codes = hdist + 1
  489.  
  490.     local littable = decode(nlit_codes)
  491.     local disttable = decode(ndist_codes)
  492.  
  493.     return littable, disttable
  494. end
  495.  
  496.  
  497. local tdecode_len_base
  498. local tdecode_len_nextrabits
  499. local tdecode_dist_base
  500. local tdecode_dist_nextrabits
  501. local function parse_compressed_item(bs, outstate, littable, disttable)
  502.   local val = littable:read(bs)
  503.   --debug(val, val < 256 and string_char(val))
  504.   if val < 256 then -- literal
  505.     output(outstate, val)
  506.   elseif val == 256 then -- end of block
  507.     return true
  508.   else
  509.     if not tdecode_len_base then
  510.       local t = {[257]=3}
  511.       local skip = 1
  512.       for i=258,285,4 do
  513.         for j=i,i+3 do t[j] = t[j-1] + skip end
  514.         if i ~= 258 then skip = skip * 2 end
  515.       end
  516.       t[285] = 258
  517.       tdecode_len_base = t
  518.       --for i=257,285 do debug('T1',i,t[i]) end
  519.     end
  520.     if not tdecode_len_nextrabits then
  521.       local t = {}
  522.       if NATIVE_BITOPS then
  523.         for i=257,285 do
  524.           local j = math_max(i - 261, 0)
  525.           t[i] = rshift(j, 2)
  526.         end
  527.       else
  528.         for i=257,285 do
  529.           local j = math_max(i - 261, 0)
  530.           t[i] = (j - (j % 4)) / 4
  531.         end
  532.       end
  533.       t[285] = 0
  534.       tdecode_len_nextrabits = t
  535.       --for i=257,285 do debug('T2',i,t[i]) end
  536.     end
  537.     local len_base = tdecode_len_base[val]
  538.     local nextrabits = tdecode_len_nextrabits[val]
  539.     local extrabits = bs:read(nextrabits)
  540.     local len = len_base + extrabits
  541.  
  542.     if not tdecode_dist_base then
  543.       local t = {[0]=1}
  544.       local skip = 1
  545.       for i=1,29,2 do
  546.         for j=i,i+1 do t[j] = t[j-1] + skip end
  547.         if i ~= 1 then skip = skip * 2 end
  548.       end
  549.       tdecode_dist_base = t
  550.       --for i=0,29 do debug('T3',i,t[i]) end
  551.     end
  552.     if not tdecode_dist_nextrabits then
  553.       local t = {}
  554.       if NATIVE_BITOPS then
  555.         for i=0,29 do
  556.           local j = math_max(i - 2, 0)
  557.           t[i] = rshift(j, 1)
  558.         end
  559.       else
  560.         for i=0,29 do
  561.           local j = math_max(i - 2, 0)
  562.           t[i] = (j - (j % 2)) / 2
  563.         end
  564.       end
  565.       tdecode_dist_nextrabits = t
  566.       --for i=0,29 do debug('T4',i,t[i]) end
  567.     end
  568.     local dist_val = disttable:read(bs)
  569.     local dist_base = tdecode_dist_base[dist_val]
  570.     local dist_nextrabits = tdecode_dist_nextrabits[dist_val]
  571.     local dist_extrabits = bs:read(dist_nextrabits)
  572.     local dist = dist_base + dist_extrabits
  573.  
  574.     --debug('BACK', len, dist)
  575.     for i=1,len do
  576.       local pos = (outstate.window_pos - 1 - dist) % 32768 + 1 -- 32K
  577.       output(outstate, assert(outstate.window[pos], 'invalid distance'))
  578.     end
  579.   end
  580.   return false
  581. end
  582.  
  583.  
  584. local function parse_block(bs, outstate)
  585.   local bfinal = bs:read(1)
  586.   local btype = bs:read(2)
  587.  
  588.   local BTYPE_NO_COMPRESSION = 0
  589.   local BTYPE_FIXED_HUFFMAN = 1
  590.   local BTYPE_DYNAMIC_HUFFMAN = 2
  591.   local BTYPE_RESERVED_ = 3
  592.  
  593.   if DEBUG then
  594.     debug('bfinal=', bfinal)
  595.     debug('btype=', btype)
  596.   end
  597.  
  598.   if btype == BTYPE_NO_COMPRESSION then
  599.     bs:read(bs:nbits_left_in_byte())
  600.     local len = bs:read(16)
  601.     local nlen_ = noeof(bs:read(16))
  602.  
  603.     for i=1,len do
  604.       local by = noeof(bs:read(8))
  605.       output(outstate, by)
  606.     end
  607.   elseif btype == BTYPE_FIXED_HUFFMAN or btype == BTYPE_DYNAMIC_HUFFMAN then
  608.     local littable, disttable
  609.     if btype == BTYPE_DYNAMIC_HUFFMAN then
  610.       littable, disttable = parse_huffmantables(bs)
  611.     else
  612.       littable = HuffmanTable {0,8, 144,9, 256,7, 280,8, 288,nil}
  613.       disttable = HuffmanTable {0,5, 32,nil}
  614.     end
  615.  
  616.     repeat
  617.       local is_done = parse_compressed_item(
  618.         bs, outstate, littable, disttable)
  619.     until is_done
  620.   else
  621.     runtime_error 'unrecognized compression type'
  622.   end
  623.  
  624.   return bfinal ~= 0
  625. end
  626.  
  627.  
  628. function M.inflate(t)
  629.   local bs = get_bitstream(t.input)
  630.   local outbs = get_obytestream(t.output)
  631.   local outstate = make_outstate(outbs)
  632.  
  633.   repeat
  634.     local is_final = parse_block(bs, outstate)
  635.   until is_final
  636. end
  637. local inflate = M.inflate
  638.  
  639.  
  640. function M.gunzip(t)
  641.   local bs = get_bitstream(t.input)
  642.   local outbs = get_obytestream(t.output)
  643.   local disable_crc = t.disable_crc
  644.   if disable_crc == nil then disable_crc = false end
  645.  
  646.   parse_gzip_header(bs)
  647.  
  648.   local data_crc32 = 0
  649.  
  650.   inflate{input=bs, output=
  651.     disable_crc and outbs or
  652.       function(byte)
  653.         data_crc32 = crc32(byte, data_crc32)
  654.         outbs(byte)
  655.       end
  656.   }
  657.  
  658.   bs:read(bs:nbits_left_in_byte())
  659.  
  660.   local expected_crc32 = bs:read(32)
  661.   local isize = bs:read(32) -- ignored
  662.   if DEBUG then
  663.     debug('crc32=', expected_crc32)
  664.     debug('isize=', isize)
  665.   end
  666.   if not disable_crc and data_crc32 then
  667.     if data_crc32 ~= expected_crc32 then
  668.       runtime_error('invalid compressed data--crc error')
  669.     end
  670.   end
  671.   if bs:read() then
  672.     warn 'trailing garbage ignored'
  673.   end
  674. end
  675.  
  676.  
  677. function M.adler32(byte, crc)
  678.   local s1 = crc % 65536
  679.   local s2 = (crc - s1) / 65536
  680.   s1 = (s1 + byte) % 65521
  681.   s2 = (s2 + s1) % 65521
  682.   return s2*65536 + s1
  683. end -- 65521 is the largest prime smaller than 2^16
  684.  
  685.  
  686. function M.inflate_zlib(t)
  687.   local bs = get_bitstream(t.input)
  688.   local outbs = get_obytestream(t.output)
  689.   local disable_crc = t.disable_crc
  690.   if disable_crc == nil then disable_crc = false end
  691.  
  692.   local window_size_ = parse_zlib_header(bs)
  693.  
  694.   local data_adler32 = 1
  695.  
  696.   inflate{input=bs, output=
  697.     disable_crc and outbs or
  698.       function(byte)
  699.         data_adler32 = M.adler32(byte, data_adler32)
  700.         outbs(byte)
  701.       end
  702.   }
  703.  
  704.   bs:read(bs:nbits_left_in_byte())
  705.  
  706.   local b3 = bs:read(8)
  707.   local b2 = bs:read(8)
  708.   local b1 = bs:read(8)
  709.   local b0 = bs:read(8)
  710.   local expected_adler32 = ((b3*256 + b2)*256 + b1)*256 + b0
  711.   if DEBUG then
  712.     debug('alder32=', expected_adler32)
  713.   end
  714.   if not disable_crc then
  715.     if data_adler32 ~= expected_adler32 then
  716.       runtime_error('invalid compressed data--crc error')
  717.     end
  718.   end
  719.   if bs:read() then
  720.     warn 'trailing garbage ignored'
  721.   end
  722. end
  723.  
  724.  
  725. return M
Add Comment
Please, Sign In to add comment