Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pragma once
- #include "xxxxxxxxxProtocol.h"
- #include "xxxxxxxxxTools.h"
- #define ZSTD_STATIC_LINKING_ONLY /* Enable advanced API */
- #include "thirdparty/zstd/zstd.h" // Zstd
- #include "thirdparty/zstd/zstd_errors.h"
- namespace xxx {
- //------------------------------------------------------------------------------
- // Compression Constants
- /// Zstd compression level
- static const int kCompressionLevel = 1;
- /// Compression history buffer size
- static const unsigned kCompressionDictBytes = 24 * 1000;
- /// Bytes allocated per packet
- static const unsigned kCompressionAllocateBytes = \
- protocol::kMaxPossibleDatagramByteLimit - protocol::kMaxOverheadBytes;
- //------------------------------------------------------------------------------
- // Ring Buffer
- template<size_t kBufferBytes>
- class RingBuffer
- {
- public:
- /// Get a contiguous region that is at least `bytes` in size
- XXXX_FORCE_INLINE void* Allocate(unsigned bytes)
- {
- if (NextWriteOffset + bytes > kBufferBytes) {
- NextWriteOffset = 0;
- }
- return Buffer + NextWriteOffset;
- }
- /// Commit some number of bytes up to allocated bytes
- XXXX_FORCE_INLINE void Commit(unsigned bytes)
- {
- XXXX_DEBUG_ASSERT(NextWriteOffset + bytes <= kBufferBytes);
- NextWriteOffset += bytes;
- }
- protected:
- /// Ring buffer that eats its own tail
- uint8_t Buffer[kBufferBytes];
- /// Next offset to write to
- unsigned NextWriteOffset = 0;
- };
- //------------------------------------------------------------------------------
- // MessageCompressor
- class MessageCompressor
- {
- public:
- Result Initialize();
- ~MessageCompressor();
- /// Compress data to the destination buffer `destBuffer`.
- /// Returns the number of bytes written in `writtenBytes`.
- /// Returns writtenBytes = 0 if data should not be compressed
- Result Compress(
- const uint8_t* data,
- unsigned bytes,
- uint8_t* dest,
- unsigned& writtenBytes);
- protected:
- /// Dictionary history used by decompressor
- RingBuffer<kCompressionDictBytes> History;
- /// Zstd context object used to compress packets
- ZSTD_CCtx* CCtx = nullptr;
- };
- //------------------------------------------------------------------------------
- // MessageDecompressor
- struct Decompressed
- {
- const uint8_t* Data;
- unsigned Bytes;
- };
- class MessageDecompressor
- {
- public:
- Result Initialize();
- ~MessageDecompressor();
- /// Decompress and handle a block of messages
- Result Decompress(
- const void* data,
- unsigned bytes,
- Decompressed& decompressed);
- /// Insert uncompressed reliable datagram
- void InsertUncompressed(
- const uint8_t* data,
- unsigned bytes);
- protected:
- /// Dictionary history used by decompressor
- RingBuffer<kCompressionDictBytes> History;
- /// Zstd context object used to decompress packets
- ZSTD_DCtx* DCtx = nullptr;
- };
- //------------------------------------------------------------------------------
- // MessageCompressor
- Result MessageCompressor::Initialize()
- {
- CCtx = ZSTD_createCCtx();
- if (!CCtx) {
- return Result("SessionOutgoing::Initialize", "ZSTD_createCCtx failed", ErrorType::Zstd);
- }
- const size_t estimatedPacketSize = kCompressionAllocateBytes;
- ZSTD_parameters zParams;
- zParams.cParams = ZSTD_getCParams(
- kCompressionLevel,
- estimatedPacketSize,
- kCompressionDictBytes);
- zParams.fParams.checksumFlag = 0;
- zParams.fParams.contentSizeFlag = 0;
- zParams.fParams.noDictIDFlag = 1;
- const size_t icsResult = ZSTD_compressBegin_advanced(
- CCtx,
- nullptr,
- 0,
- zParams,
- ZSTD_CONTENTSIZE_UNKNOWN);
- if (0 != ZSTD_isError(icsResult)) {
- XXXX_DEBUG_BREAK();
- return Result("SessionOutgoing::Initialize", "ZSTD_compressBegin_advanced failed", ErrorType::Zstd, icsResult);
- }
- const size_t blockSizeBytes = ZSTD_getBlockSize(CCtx);
- if (blockSizeBytes < kCompressionAllocateBytes) {
- return Result("SessionOutgoing::Initialize", "Zstd block size is too small", ErrorType::Zstd);
- }
- return Result::Success();
- }
- MessageCompressor::~MessageCompressor()
- {
- if (CCtx) {
- ZSTD_freeCCtx(CCtx);
- }
- }
- Result MessageCompressor::Compress(
- const uint8_t* data,
- unsigned bytes,
- uint8_t* destBuffer,
- unsigned& writtenBytes)
- {
- XXXX_DEBUG_ASSERT(bytes >= protocol::kMessageFrameBytes);
- writtenBytes = 0;
- // Insert data into history ring buffer
- XXXX_DEBUG_ASSERT(kCompressionAllocateBytes >= bytes);
- void* history = History.Allocate(kCompressionAllocateBytes);
- memcpy(history, data, bytes);
- History.Commit(bytes);
- // Compress into scratch buffer, leaving room for a frame header
- const size_t result = ZSTD_compressBlock(
- CCtx,
- destBuffer + protocol::kMessageFrameBytes,
- kCompressionAllocateBytes,
- history,
- bytes);
- // If no data to compress, or would require too much space,
- // or did not produce a small enough result:
- if (0 == result ||
- (size_t)-ZSTD_error_dstSize_tooSmall == result ||
- protocol::kMessageFrameBytes + result >= bytes)
- {
- // Note: Input data was accumulated into history ring buffer
- return Result::Success();
- }
- // If compression failed:
- if (0 != ZSTD_isError(result))
- {
- std::string reason = "ZSTD_compressBlock failed: ";
- reason += ZSTD_getErrorName(result);
- XXXX_DEBUG_BREAK();
- return Result("SessionOutgoing::compress", reason, ErrorType::Zstd, result);
- }
- const unsigned compressedBytes = static_cast<unsigned>(result);
- // Write Compressed frame header
- protocol::WriteMessageFrameHeader(
- destBuffer,
- protocol::MessageType_Compressed,
- compressedBytes);
- // Compressed bytes includes the frame header
- writtenBytes = protocol::kMessageFrameBytes + compressedBytes;
- XXXX_DEBUG_ASSERT(writtenBytes <= kCompressionAllocateBytes);
- return Result::Success();
- }
- //------------------------------------------------------------------------------
- // MessageDecompressor
- Result MessageDecompressor::Initialize()
- {
- DCtx = ZSTD_createDCtx();
- if (!DCtx) {
- return Result("SessionIncoming::Initialize", "ZSTD_createDCtx failed", ErrorType::Zstd);
- }
- const size_t beginResult = ZSTD_decompressBegin(DCtx);
- if (0 != ZSTD_isError(beginResult)) {
- return Result("SessionIncoming::Initialize", "ZSTD_decompressBegin failed", ErrorType::Zstd, beginResult);
- }
- return Result::Success();
- }
- MessageDecompressor::~MessageDecompressor()
- {
- if (DCtx) {
- ZSTD_freeDCtx(DCtx);
- }
- }
- void MessageDecompressor::InsertUncompressed(
- const uint8_t* data,
- unsigned bytes)
- {
- if (bytes > kCompressionAllocateBytes) {
- XXXX_DEBUG_BREAK(); // Invalid input
- return;
- }
- void* history = History.Allocate(kCompressionAllocateBytes);
- memcpy(history, data, bytes);
- ZSTD_insertBlock(DCtx, history, bytes);
- History.Commit(bytes);
- }
- Result MessageDecompressor::Decompress(
- const void* data,
- unsigned bytes,
- Decompressed& decompressed)
- {
- // Decompress data into history ring buffer
- void* history = History.Allocate(kCompressionAllocateBytes);
- const size_t result = ZSTD_decompressBlock(
- DCtx,
- history,
- kCompressionAllocateBytes,
- data,
- bytes);
- // If decompression failed:
- if (0 == result || 0 != ZSTD_isError(result))
- {
- std::string reason = "ZSTD_decompressBlock failed: ";
- reason += ZSTD_getErrorName(result);
- XXXX_DEBUG_BREAK();
- return Result("SessionOutgoing::decompress", reason, ErrorType::Zstd, result);
- }
- const uint8_t* datagramData = reinterpret_cast<uint8_t*>(history);
- const unsigned datagramBytes = static_cast<unsigned>(result);
- History.Commit(datagramBytes);
- decompressed.Data = datagramData;
- decompressed.Bytes = datagramBytes;
- return Result::Success();
- }
Add Comment
Please, Sign In to add comment