Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include "Huffman.h"
- #include <queue>
- #include <unordered_map>
- #include <stack>
- #include <bitset>
- struct Node
- {
- Node(byte value, long long freq) : value(value), freq(freq), left(nullptr), right(nullptr) {}
- byte value;
- long long freq;
- Node *left;
- Node *right;
- };
- class NodePtrComparator
- {
- public:
- bool operator()(Node *l, Node *r)
- {
- return l->freq > r->freq;
- }
- };
- void buildCodeTable(Node *root, std::unordered_map<byte, std::vector<unsigned char>> &table, std::vector<unsigned char> path)
- {
- if (!root->left && !root->right)
- {
- table[root->value] = path;
- }
- else
- {
- if (root->left)
- {
- std::vector<unsigned char> left(path);
- left.push_back(0);
- buildCodeTable(root->left, table, left);
- }
- if (root->right)
- {
- std::vector<unsigned char> right(path);
- right.push_back(1);
- buildCodeTable(root->right, table, right);
- }
- }
- }
- std::unordered_map<byte, std::vector<unsigned char>> buildCodeTable(Node *root)
- {
- std::unordered_map<byte, std::vector<unsigned char>> table;
- buildCodeTable(root, table, {});
- return table;
- }
- class BitWriter
- {
- public:
- BitWriter() : bitCount(0) {}
- void WriteBit(unsigned char bit)
- {
- if (bitCount == buffer.size() * 8)
- buffer.push_back(0);
- if (bit)
- buffer[bitCount/8] |= 1 << (7 - bitCount % 8);
- bitCount++;
- }
- void WriteByte(unsigned char byte)
- {
- if (bitCount % 8 == 0)
- buffer.push_back(byte);
- else
- {
- int offset = bitCount % 8;
- buffer[bitCount/8] |= byte >> offset;
- buffer.push_back(byte << (8 - offset));
- }
- bitCount += 8;
- }
- const std::vector<unsigned char> &getBuffer() const
- {
- return buffer;
- }
- size_t getBitCount() const
- {
- return bitCount;
- }
- void visualize() const
- {
- for (auto &b: buffer)
- {
- std::cout << std::bitset<8>(b) << "|";
- }
- std::cout << std::endl;
- }
- private:
- std::vector<unsigned char> buffer;
- size_t bitCount;
- };
- class BitReader
- {
- public:
- BitReader(std::vector<unsigned char> &data, size_t bitCount) : bitPos(0), bitCount(bitCount), data(data) {}
- unsigned char readBit()
- {
- auto val = (data[bitPos/8] >> (7 - bitPos % 8)) & 1;
- bitPos++;
- return val;
- }
- unsigned char readByte()
- {
- unsigned char result = 0;
- if (bitPos % 8 == 0)
- {
- result = data[bitPos/8];
- }
- else
- {
- result = data[bitPos/8] << (bitPos % 8);
- result |= data[bitPos/8 + 1] >> (8 - bitPos % 8);
- }
- bitPos += 8;
- return result;
- }
- bool hasData() const
- {
- return bitPos < bitCount;
- }
- private:
- size_t bitPos;
- size_t bitCount;
- std::vector<unsigned char> &data;
- };
- void serializeTree(Node *root, BitWriter &bw)
- {
- if (!root->left && !root->right)
- {
- bw.WriteBit(1);
- bw.WriteByte(root->value);
- }
- else
- {
- if (root->left)
- serializeTree(root->left, bw);
- if (root->right)
- serializeTree(root->right, bw);
- bw.WriteBit(0);
- }
- }
- void printCodeTable(std::unordered_map<byte, std::vector<unsigned char>> &t)
- {
- for (auto &pair: t)
- {
- std::cout << pair.first << " = ";
- for (auto &c: pair.second)
- std::cout << (int)c;
- std::cout << std::endl;
- }
- }
- void Encode(IInputStream& original, IOutputStream& compressed)
- {
- // byte frequencies
- std::vector<unsigned long long> byteValueCounter(256, 0);
- // copy of input data
- std::vector<byte> data;
- // copy input and count byte frequencies
- byte value;
- while (original.Read(value))
- {
- byteValueCounter[value]++;
- data.push_back(value);
- }
- // initialize the priority queue of Huffman tree nodes
- size_t alphabetSize = 0;
- std::priority_queue<Node*, std::vector<Node*>, NodePtrComparator> nodePriorityQueue;
- for (int i = 0; i < 256; i++)
- {
- // byte value seen at least once
- if (byteValueCounter[i])
- {
- alphabetSize++;
- nodePriorityQueue.push(new Node(i, byteValueCounter[i]));
- }
- }
- // Huffman tree construction
- while (nodePriorityQueue.size() > 1)
- {
- Node *r = nodePriorityQueue.top();
- nodePriorityQueue.pop();
- Node *l = nodePriorityQueue.top();
- nodePriorityQueue.pop();
- Node *root = new Node(0, l->freq + r->freq);
- root->left = l;
- root->right = r;
- nodePriorityQueue.push(root);
- }
- // Huffman tree root
- Node *root = nodePriorityQueue.top();
- // Hashtable <byte value> --> <Huffman code>
- std::unordered_map<byte, std::vector<unsigned char>> codeTable = buildCodeTable(root);
- BitWriter bw;
- bw.WriteByte(alphabetSize);
- serializeTree(root, bw);
- // encoding the input stream copy
- for (auto &byteValue: data)
- {
- for (auto &bit: codeTable[byteValue])
- bw.WriteBit(bit);
- }
- // output the encoded data
- for (auto &byteValue: bw.getBuffer())
- {
- compressed.Write(byteValue);
- }
- // how many significant bits there are in the last byte of the encoded sequence
- unsigned char significantBitsInLastByte = bw.getBitCount() % 8;
- if (!significantBitsInLastByte)
- significantBitsInLastByte = 8;
- compressed.Write(significantBitsInLastByte);
- }
- byte decodeByte(Node *root, BitReader &bitReader)
- {
- while (root->left || root->right)
- {
- switch (bitReader.readBit())
- {
- case 0:
- root = root->left;
- break;
- case 1:
- root = root->right;
- break;
- }
- }
- return root->value;
- }
- void Decode(IInputStream& compressed, IOutputStream& original)
- {
- std::vector<byte> data;
- byte value;
- while (compressed.Read(value))
- {
- data.push_back(value);
- }
- unsigned char significantBitsInLastByte = data[data.size() - 1];
- data.pop_back();
- size_t bitCount = (data.size() - 1) * 8 + significantBitsInLastByte;
- BitReader bitReader(data, bitCount);
- auto alphabetSize = bitReader.readByte();
- std::stack<Node*> stack;
- size_t lettersRead = 0;
- while (1)
- {
- auto bit = bitReader.readBit();
- if (bit)
- {
- auto letter = bitReader.readByte();
- Node *node = new Node(letter, 0);
- stack.push(node);
- lettersRead++;
- }
- else
- {
- Node *right = stack.top();
- stack.pop();
- Node *left = stack.top();
- stack.pop();
- Node *node = new Node(0, 0);
- node->left = left;
- node->right = right;
- stack.push(node);
- }
- if (lettersRead == alphabetSize && stack.size() == 1)
- break;
- }
- Node *root = stack.top();
- stack.pop();
- std::unordered_map<byte, std::vector<unsigned char>> codeTable = buildCodeTable(root);
- while (bitReader.hasData())
- {
- byte decodedByte = decodeByte(root, bitReader);
- original.Write(decodedByte);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement