Advertisement
Guest User

Untitled

a guest
May 28th, 2015
267
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 14.15 KB | None | 0 0
  1.  
  2. #include <stdio.h>
  3. #include <assert.h>
  4. #include <memory.h>
  5. #include <algorithm>
  6. #include "compressor_internal.h"
  7. #include "RangeCoderBitTree.h"
  8.  
  9. using namespace NCompress::NRangeCoder;
  10.  
  11. template<int NBITS>
  12. class Context
  13. {
  14. public:
  15.     uint context;
  16.  
  17. public:
  18.     Context()
  19.     {
  20.         context = 0;
  21.     }
  22.  
  23.     void Reset()
  24.     {
  25.         context = (0xffffffff & ((1 << NBITS) - 1)) - 1;        // least likely
  26.     }
  27.  
  28.     operator uint () const { return context; }
  29.  
  30.     void operator = (uint x)
  31.     {
  32.         assert(x < (1 << NBITS));
  33.         context = x;
  34.     }
  35.  
  36. };
  37.  
  38. template<int NBITS>
  39. class BitContext
  40. {
  41.     uint context;
  42.  
  43. public:
  44.     BitContext()
  45.     {
  46.         context = 0;
  47.     }
  48.  
  49.     void Reset()
  50.     {
  51.         context = 0;
  52.     }
  53.  
  54.     void Zero()
  55.     {
  56.         context = (context << 1) & ((1 << NBITS) - 1);
  57.     }
  58.  
  59.     void One()
  60.     {
  61.         context = ((context << 1) | 1) & ((1 << NBITS) - 1);
  62.     }
  63.  
  64.     void Update(uint x)
  65.     {
  66.         context = ((context << 1) | (x & 1)) & ((1 << NBITS) - 1);
  67.     }
  68.  
  69.     uint GetLowBits(int n) const
  70.     {
  71.         assert(n <= NBITS);
  72.         return context & ((1 << n) - 1);
  73.     }
  74.  
  75.     operator uint () const { return context; }
  76. };
  77.  
  78. #pragma pack(2)
  79. struct OptimalTreeNode
  80. {
  81.     short SecondChild;
  82.     short Middle;
  83. };
  84. #pragma pack()
  85.  
  86. static void CreateOptimalTree(OptimalTreeNode * __restrict entries, int &count, int64_t const * __restrict distribution, int64_t sum, int start, int end, int uniformStart)
  87. {
  88.     if (end - start > 2)
  89.     {
  90.         int middle = 0;
  91.         int64_t halfSum = 0;
  92.  
  93.         if (sum > 0 && start < uniformStart)
  94.         {
  95.             for (middle = start; middle < std::min(end - 1, uniformStart); middle++)
  96.             {
  97.                 if (halfSum >= (sum + 1) / 2) break;                // lower half bigger for odd-length ranges
  98.                 //if (halfSum >= std::max(1LL, sum / 2)) break;     // upper half bigger for odd-length ranges
  99.                 halfSum += distribution[middle];
  100.             }
  101.         }
  102.         else
  103.         {
  104.             // balanced subtree for 0-prob ranges and ranges where start >= uniformStart
  105.             middle = start + (end - start) / 2;
  106.         }
  107.  
  108.         assert(start < middle);
  109.         assert(middle < end);
  110.  
  111.         int current = count++;
  112.  
  113.         CreateOptimalTree(entries, count, distribution, halfSum, start, middle, uniformStart);
  114.         entries[current].Middle = middle;
  115.         entries[current].SecondChild = count;
  116.         CreateOptimalTree(entries, count, distribution, sum - halfSum, middle, end, uniformStart);
  117.  
  118.     }
  119.     else if (end - start == 2)
  120.     {
  121.         entries[count].Middle = start + 1;
  122.         entries[count].SecondChild = 0;
  123.         count++;
  124.     }
  125.  
  126. }
  127.  
  128. static void CreateOptimalTree(OptimalTreeNode * __restrict entries, int entryCount, int64_t const * __restrict distribution, int distributionSize)
  129. {
  130.     assert(entryCount >= distributionSize);
  131.  
  132.     // last value in distribution is the total probability of [distributionSize-1..entryCount-1]
  133.     int64_t sum = 0;
  134.     for (int i = 0; i < distributionSize; i++) sum += distribution[i];
  135.     int count = 0;
  136.     CreateOptimalTree(entries, count, distribution, sum, 0, entryCount, distributionSize - 1);
  137.     if (count != entryCount - 1) throw "CreateOptimalTree fail";
  138. }
  139.  
  140.  
  141. static const int OTCMoveBits = 4;
  142.  
  143. template <class T, int SIZE>
  144. class OptimalTreeEncoder
  145. {
  146.     CBitEncoder<CBitModel<OTCMoveBits>, T> encoders[SIZE];
  147.  
  148.     static_assert(SIZE < 32768, "");
  149.  
  150. public:
  151.     void Encode(CEncoder<T> *encoder, uint symbol, OptimalTreeNode const * __restrict tree)
  152.     {
  153.         assert(symbol < SIZE);
  154.         int i = 0;
  155.         int lowerBound = 0;
  156.         int upperBound = SIZE;
  157.  
  158.         do
  159.         {
  160.             if (symbol < (uint)tree[i].Middle)
  161.             {
  162.                 encoders[i].Encode(encoder, 0);
  163.                 upperBound = tree[i].Middle;
  164.                 i++;
  165.             }
  166.             else
  167.             {
  168.                 encoders[i].Encode(encoder, 1);
  169.                 lowerBound = tree[i].Middle;
  170.                 i = tree[i].SecondChild;
  171.             }
  172.  
  173.         } while (upperBound - lowerBound > 1);
  174.     }
  175.  
  176. };
  177.  
  178. template <class T, int SIZE>
  179. class OptimalTreeDecoder
  180. {
  181.     CBitDecoder<CBitModel<OTCMoveBits>, T> decoders[SIZE];
  182.  
  183.  
  184. public:
  185.  
  186.     uint Decode(CDecoder<T> *decoder, OptimalTreeNode const * __restrict tree)
  187.     {
  188.         int i = 0;
  189.         int lowerBound = 0;
  190.         int upperBound = SIZE;
  191.  
  192.         do
  193.         {
  194.             assert(i < SIZE);
  195.  
  196.             if (!decoders[i].Decode(decoder))
  197.             {
  198.                 upperBound = tree[i].Middle;
  199.                 i++;
  200.             }
  201.             else
  202.             {
  203.                 lowerBound = tree[i].Middle;
  204.                 i = tree[i].SecondChild;
  205.             }
  206.  
  207.         } while (upperBound - lowerBound > 1);
  208.  
  209.         return lowerBound;
  210.     }
  211.  
  212. };
  213.  
  214. char const CompressorID[] = "Cmpr";
  215.  
  216. class EntropyCoder
  217. {
  218. public:
  219.     static const int BlockSize = 16;
  220.        
  221.     static const int MoveBits = 4;
  222.     static const int ZeroContextBits = 16;
  223.     static const int BlockContextBits = 12;
  224.  
  225.     static const int OptimalBits = 10;
  226.     static const int OptimalSize = 1 << OptimalBits;
  227.     static const int OptimalContextBits = 10;
  228.     static const int OptimalContextSize = 1 << OptimalContextBits;
  229.    
  230.     static_assert(COMPRESSOR_DISTRIBUTION_SIZE <= OptimalSize, "");
  231.  
  232. protected:
  233.     inline uint GetOptimalContext(Context<OptimalContextBits> const &prev, uint symbol, BitContext<ZeroContextBits> const &zeroContext)
  234.     {
  235.         assert(symbol > 0);
  236.         return std::min(symbol - 1, (uint)OptimalContextSize - 1);
  237.     }
  238. };
  239.  
  240. class LargeEncoder
  241. {
  242.     CBitTreeEncoder<EntropyCoder::MoveBits, StreamOut> lowEncoder;
  243.     CBitTreeEncoder<EntropyCoder::MoveBits, StreamOut> highEncoder;
  244.  
  245. public:
  246.     LargeEncoder()
  247.         : lowEncoder(17), highEncoder(16)
  248.     {
  249.     }
  250.  
  251.     void Encode(CEncoder<StreamOut> *encoder, uint symbol)
  252.     {
  253.         uint low = symbol & 0xffff;
  254.         if (symbol > 0xffff) low |= 0x10000;
  255.         lowEncoder.Encode(encoder, low);
  256.  
  257.         if (symbol > 0xffff)
  258.             highEncoder.Encode(encoder, symbol >> 16);
  259.  
  260.     }
  261. };
  262.  
  263. class LargeDecoder
  264. {
  265.     CBitTreeDecoder<EntropyCoder::MoveBits, StreamIn> lowDecoder;
  266.     CBitTreeDecoder<EntropyCoder::MoveBits, StreamIn> highDecoder;
  267.  
  268. public:
  269.     LargeDecoder()
  270.         : lowDecoder(17), highDecoder(16)
  271.     {
  272.     }
  273.  
  274.     uint Decode(CDecoder<StreamIn> *decoder)
  275.     {
  276.         uint symbol = lowDecoder.Decode(decoder);
  277.  
  278.         if (symbol > 0xffff)
  279.         {
  280.             symbol &= 0xffff;
  281.             symbol |= highDecoder.Decode(decoder) << 16;
  282.         }
  283.  
  284.         return symbol;
  285.     }
  286. };
  287.  
  288.  
  289. class Compressor : public CompressorInterface, public EntropyCoder
  290. {
  291.     CEncoder<StreamOut> encoder;
  292.     CBitEncoder<CBitModel<MoveBits>, StreamOut> zeroEncoder[2][1 << (ZeroContextBits)];
  293.     CBitEncoder<CBitModel<MoveBits>, StreamOut> blockEncoder[2][1 << BlockContextBits];
  294.     LargeEncoder largeEncoder;
  295.     CBitEncoder<CBitModel<MoveBits>, StreamOut> predictorEncoder;
  296.  
  297.     BitContext<ZeroContextBits> zeroContext;
  298.     BitContext<BlockContextBits> blockContext;
  299.  
  300.     int64_t *newDistribution;
  301.  
  302.     StreamOut *stream;
  303.     OptimalTreeNode *optimalTree;
  304.  
  305.     Context<OptimalContextBits> optimalContext;
  306.     OptimalTreeEncoder<StreamOut, OptimalSize> optimalEncoder[2][OptimalContextSize];
  307.  
  308.     CompressorStats *stats;
  309.  
  310. public:
  311.     Compressor(StreamOut *stream, const int64_t *oldDistribution, int64_t *newDistribution, CompressorStats *stats)
  312.     {
  313.         this->stream = stream;
  314.         this->stats = stats;
  315.              
  316.         stream->WriteBytes((byte *)CompressorID, 4);
  317.  
  318.         optimalTree = (OptimalTreeNode *)malloc(OptimalSize * sizeof(OptimalTreeNode));
  319.         ::CreateOptimalTree(optimalTree, OptimalSize, oldDistribution, COMPRESSOR_DISTRIBUTION_SIZE);
  320.  
  321.         stream->WriteBytes((byte *)oldDistribution, COMPRESSOR_DISTRIBUTION_SIZE * sizeof(int64_t));
  322.  
  323.         this->newDistribution = newDistribution;
  324.         memset(newDistribution, 0, COMPRESSOR_DISTRIBUTION_SIZE * sizeof(int64_t));
  325.  
  326.         encoder.Init(stream);
  327.  
  328.     }
  329.  
  330.     virtual ~Compressor()
  331.     {
  332.         encoder.FlushData();
  333.         free(optimalTree);
  334.         delete stream;
  335.     }
  336.  
  337.     void Compress(uint const *data, int dataLength, int predictor)
  338.     {
  339.         zeroContext.Reset();
  340.         blockContext.Reset();
  341.         optimalContext.Reset();
  342.  
  343.         predictorEncoder.Encode(&encoder, predictor);
  344.  
  345.         int i;
  346.         for (i = 0; i <= dataLength - BlockSize; i += BlockSize)
  347.             DoBlock(data, i, BlockSize, predictor);
  348.  
  349.         if (i < dataLength)
  350.             DoBlock(data, i, dataLength - i, predictor);
  351.  
  352.         stats->TotalInputBytes += dataLength * sizeof(int);
  353.         stats->CallCount++;
  354.         if (predictor) stats->ComplexPredictorCount++;
  355.        
  356.     }
  357.  
  358.  
  359. private:
  360.     inline void DoBlock(uint const *data, int i, int count, int predictor)
  361.     {
  362.         if (IsBlockZero(data, i, count))
  363.         {
  364.             blockEncoder[predictor][blockContext].Encode(&encoder, 0);
  365.             blockContext.Zero();
  366.             stats->EncoderZeroBlocks++;
  367.         }
  368.         else
  369.         {
  370.             blockEncoder[predictor][blockContext].Encode(&encoder, 1);
  371.             blockContext.One();
  372.             CompressBlock(data, i, count, predictor);
  373.         }
  374.  
  375.         stats->EncoderTotalBlocks++;
  376.     }
  377.  
  378.     inline bool IsBlockZero(uint const *data, int offset, int count)
  379.     {
  380.         int x = 0;
  381.  
  382.         for (int i = offset; i < offset + count; i++)
  383.             x |= data[i];       // could also return false if nonzero but this vectorizes nicely
  384.  
  385.         return x == 0;
  386.     }
  387.  
  388.     inline void CompressNonzeroSymbol(uint symbol, int i, int const predictor)
  389.     {
  390.         assert(symbol != 0);
  391.  
  392.         newDistribution[std::min(symbol - 1, (uint)COMPRESSOR_DISTRIBUTION_SIZE - 1)]++;
  393.         stats->MaxEncoderSymbol = std::max(stats->MaxEncoderSymbol, (int)(symbol - 1));
  394.         stats->SymbolSum += symbol;
  395.  
  396.         if (symbol < (uint)OptimalSize)
  397.         {
  398.             optimalEncoder[predictor][optimalContext].Encode(&encoder, symbol - 1, optimalTree);
  399.             optimalContext = GetOptimalContext(optimalContext, symbol, zeroContext);
  400.             stats->EncoderSizes[0]++;
  401.         }
  402.         else
  403.         {
  404.             optimalEncoder[predictor][optimalContext].Encode(&encoder, OptimalSize - 1, optimalTree);
  405.             optimalContext = OptimalContextSize - 1;
  406.  
  407.             largeEncoder.Encode(&encoder, symbol - OptimalSize);
  408.             stats->EncoderSizes[2]++;
  409.  
  410.         }
  411.  
  412.     }
  413.  
  414.     void CompressBlock(uint const *data, int offset, int count, int const predictor)
  415.     {
  416.         for (int i = offset; i < offset + count; i++)
  417.         {
  418.             auto symbol = data[i];
  419.  
  420.             if (symbol == 0)
  421.             {
  422.                 zeroEncoder[predictor][zeroContext].Encode(&encoder, 0);
  423.                 zeroContext.Zero();
  424.             }
  425.             else
  426.             {
  427.                 zeroEncoder[predictor][zeroContext].Encode(&encoder, 1);
  428.                 zeroContext.One();
  429.  
  430.                 CompressNonzeroSymbol(symbol, i, predictor);
  431.             }
  432.  
  433.         }
  434.  
  435.     }
  436.  
  437. };
  438.  
  439. class Decompressor : public DecompressorInterface, public EntropyCoder
  440. {
  441.     CDecoder<StreamIn> decoder;
  442.     CBitDecoder<CBitModel<MoveBits>, StreamIn> zeroDecoder[2][1 << (ZeroContextBits)];
  443.     CBitDecoder<CBitModel<MoveBits>, StreamIn> blockDecoder[2][1 << BlockContextBits];
  444.    
  445.     LargeDecoder largeDecoder;
  446.  
  447.     BitContext<ZeroContextBits> zeroContext;
  448.     BitContext<BlockContextBits> blockContext;
  449.  
  450.     CBitDecoder<CBitModel<MoveBits>, StreamIn> predictorDecoder;
  451.  
  452.     StreamIn *stream;
  453.  
  454.     OptimalTreeNode *optimalTree;
  455.     Context<OptimalContextBits> optimalContext;
  456.     OptimalTreeDecoder<StreamIn, OptimalSize> optimalDecoder[2][1 << OptimalContextBits];
  457.  
  458. public:
  459.    
  460.     Decompressor(StreamIn *stream)
  461.     {
  462.         this->stream = stream;
  463.  
  464.         char vid[5] = { 0 };
  465.         stream->ReadBytes((byte *)vid, sizeof(int));
  466.         if (strcmp(vid, CompressorID))
  467.             throw "compressor wrong id";
  468.  
  469.         auto distribution = (int64_t *)malloc(COMPRESSOR_DISTRIBUTION_SIZE * sizeof(int64_t));
  470.         stream->ReadBytes((byte *)distribution, COMPRESSOR_DISTRIBUTION_SIZE * sizeof(int64_t));
  471.  
  472.         optimalTree = (OptimalTreeNode *)malloc(OptimalSize * sizeof(OptimalTreeNode));
  473.         ::CreateOptimalTree(optimalTree, OptimalSize, distribution, COMPRESSOR_DISTRIBUTION_SIZE);
  474.  
  475.         free(distribution);
  476.  
  477.         decoder.Init(stream);
  478.  
  479.     }
  480.  
  481.     virtual ~Decompressor()
  482.     {
  483.         free(optimalTree);
  484.         delete stream;
  485.     }
  486.  
  487.     virtual bool DecodesRegret() { return true; }
  488.  
  489.     int Decompress(uint *data, int const *northData, int dataLength)
  490.     {
  491.         zeroContext.Reset();
  492.         blockContext.Reset();
  493.         optimalContext.Reset();
  494.  
  495.         int const predictor = predictorDecoder.Decode(&decoder);
  496.  
  497.         int i;
  498.         if (dataLength >= BlockSize)
  499.         {
  500.             DoBlockSlow(data, northData, 0, BlockSize, predictor);
  501.             if (northData != NULL && predictor != 0)
  502.             {
  503.                 for (i = BlockSize; i <= dataLength - BlockSize; i += BlockSize)
  504.                     DoBlockFast(data + i, northData + i, 1);
  505.             }
  506.             else
  507.             {
  508.                 for (i = BlockSize; i <= dataLength - BlockSize; i += BlockSize)
  509.                     DoSimpleBlockFast(data + i, 0);
  510.             }
  511.  
  512.             if (i < dataLength)
  513.                 DoBlockSlow(data, northData, i, dataLength - i, predictor);
  514.         }
  515.         else
  516.         {
  517.             DoBlockSlow(data, northData, 0, dataLength, predictor);
  518.         }
  519.  
  520.         return predictor;
  521.     }
  522.  
  523. private:
  524.     inline void DoBlockSlow(uint * __restrict data, int const * __restrict northData, int i, int count, int predictor)
  525.     {
  526.         if (GotBlock(predictor))
  527.         {
  528.             for (int j = i; j < i + count; j++)
  529.             {
  530.                 DecodeRegret((int *)data, northData, j, predictor, DecompressSymbol(predictor));
  531.             }
  532.         }
  533.         else
  534.         {
  535.             for (int j = i; j < i + count; j++)
  536.                 DecodeRegret((int *)data, northData, j, predictor, 0);
  537.         }
  538.     }
  539.  
  540.     inline void DoBlockFast(uint * __restrict data, int const * __restrict northData, int predictor)
  541.     {
  542.         if (GotBlock(predictor))
  543.         {
  544.             for (int i = 0; i < BlockSize; i++)
  545.             {
  546.                 int p = Predict(northData[i], data[i - 1], northData[i - 1]);
  547.                 data[i] = DecompressSymbol(predictor) + p;
  548.             }
  549.         }
  550.         else
  551.         {
  552.             for (int i = 0; i < BlockSize; i++)
  553.             {
  554.                 int p = Predict(northData[i], data[i - 1], northData[i - 1]);
  555.                 data[i] = p;
  556.             }
  557.         }
  558.     }
  559.  
  560.     inline void DoSimpleBlockFast(uint * __restrict data, int predictor)
  561.     {
  562.         if (GotBlock(predictor))
  563.         {
  564.             for (int i = 0; i < BlockSize; i++)
  565.                 data[i] = DecompressSymbol(predictor) + data[i - 1];
  566.         }
  567.         else
  568.         {
  569.             auto x = data[-1];
  570.  
  571.             for (int i = 0; i < BlockSize; i++)
  572.                 data[i] = x;
  573.         }
  574.     }
  575.  
  576.     inline bool GotBlock(int predictor)
  577.     {
  578.         bool gotBlock = blockDecoder[predictor][blockContext].Decode(&decoder) != 0;
  579.         blockContext.Update(gotBlock ? 1 : 0);
  580.         return gotBlock;
  581.     }
  582.  
  583.     inline uint DecompressNonzeroSymbol(int predictor)
  584.     {
  585.         uint symbol = optimalDecoder[predictor][optimalContext].Decode(&decoder, optimalTree) + 1;
  586.         optimalContext = GetOptimalContext(optimalContext, symbol, zeroContext);
  587.  
  588.         if (symbol == (uint)OptimalSize)
  589.         {
  590.             symbol = largeDecoder.Decode(&decoder) + OptimalSize;
  591.         }
  592.  
  593.         return ZigZagDecode32(symbol);
  594.     }
  595.  
  596.     inline uint DecompressSymbol(int predictor)
  597.     {
  598.         auto notZero = zeroDecoder[predictor][zeroContext].Decode(&decoder);
  599.         zeroContext.Update(notZero);
  600.         return notZero ? DecompressNonzeroSymbol(predictor) : 0;
  601.     }
  602.  
  603. };
  604.  
  605. CompressorInterface *CreateNormalCompressor(StreamOut *stream, const int64_t *oldDistribution, int64_t *newDistribution, CompressorStats *stats)
  606. {
  607.     return new Compressor(stream, oldDistribution, newDistribution, stats);
  608. }
  609.  
  610. DecompressorInterface *CreateNormalDecompressor(StreamIn *stream)
  611. {
  612.     return new Decompressor(stream);
  613. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement