Guest User

Intra-Residual Weight Grid DCT for BC7/ASTC

a guest
Jun 9th, 2026
35
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 30.51 KB | Source Code | 0 0
  1. namespace xbc7
  2. {
  3.     static const int g_baseline_jpeg_y[8][8] =
  4.     {
  5.         // DC element modified (from 16) so bilinear fetches near (0,0) grab a smaller quant table value, protecting most important LF coefficients
  6.         {  4, 11, 10, 16, 24, 40, 51, 61 },
  7.         { 12, 12, 14, 19, 26, 58, 60, 55 },
  8.         { 14, 13, 16, 24, 40, 57, 69, 56 },
  9.         { 14, 17, 22, 29, 51, 87, 80, 62 },
  10.         { 18, 22, 37, 56, 68,109,103, 77 },
  11.         { 24, 35, 55, 64, 81,104,113, 92 },
  12.         { 49, 64, 78, 87,103,121,120,101 },
  13.         { 72, 92, 95, 98,112,100,103, 99 }
  14.     };
  15.  
  16.     // centers at (0,0)
  17.     static inline float sample_jpeg_quant(const int Q8[8][8], float i, float j)
  18.     {
  19.         i = basisu::minimum(basisu::maximum(i, 0.0f), 7.0f);
  20.         j = basisu::minimum(basisu::maximum(j, 0.0f), 7.0f);
  21.         int i0 = (int)floorf(i), j0 = (int)floorf(j);
  22.         int i1 = basisu::minimum(i0 + 1, 7), j1 = basisu::minimum(j0 + 1, 7);
  23.         float ti = i - i0, tj = j - j0;
  24.         float a = (1 - ti) * Q8[j0][i0] + ti * Q8[j0][i1];
  25.         float b = (1 - ti) * Q8[j1][i0] + ti * Q8[j1][i1];
  26.         return (1 - tj) * a + tj * b;
  27.     }
  28.  
  29.     static const uint8_t g_zigzag4x4_xy[16][2] = // [index][X,Y]
  30.     {
  31.         { 0, 0 },
  32.         { 1, 0 },
  33.         { 0, 1 },
  34.         { 0, 2 },
  35.         { 1, 1 },
  36.         { 2, 0 },
  37.         { 3, 0 },
  38.         { 2, 1 },
  39.         { 1, 2 },
  40.         { 0, 3 },
  41.         { 1, 3 },
  42.         { 2, 2 },
  43.         { 3, 1 },
  44.         { 3, 2 },
  45.         { 2, 3 },
  46.         { 3, 3 }
  47.     };
  48.  
  49.     void compute_quant_table(float q,
  50.         uint32_t grid_width, uint32_t grid_height,
  51.         float level_scale, int* dct_quant_tab,
  52.         int block_width = 4, int block_height = 4)
  53.     {
  54.         assert(q > 0.0f);
  55.  
  56.         dct_quant_tab[0] = 1;
  57.  
  58.         if (q >= 100.0f)
  59.         {
  60.             for (uint32_t y = 0; y < grid_height; y++)
  61.                 for (uint32_t x = 0; x < grid_width; x++)
  62.                     if (x || y)
  63.                         dct_quant_tab[x + y * grid_width] = 1;
  64.             return;
  65.         }
  66.                
  67.         const int Bx = block_width, By = block_height;
  68.  
  69.         const float sx = (float)8.0f / (float)Bx;
  70.         const float sy = (float)8.0f / (float)By;
  71.  
  72.         for (uint32_t y = 0; y < grid_height; y++)
  73.         {
  74.             float ny = float(y);
  75.             float ry = ny * sy;
  76.  
  77.             for (uint32_t x = y ? 0 : 1; x < grid_width; x++)
  78.             {
  79.                 int quant_scale = 0;
  80.  
  81.                 assert(x || y);
  82.  
  83.                 float nx = float(x);
  84.                 float rx = nx * sx;
  85.  
  86.                 // sample from the JPEG baseline luma 8x8 DCT quant matrix
  87.                 // this is an approximation (we could do an actual desired radians per spatial sample search vs. each of the 8x8 basis vectors to find the best, most conservative mapping),
  88.                 // but for 4x4 and 6x6 block sizes it's reasonable enough and simple/fast
  89.                 // at 4x4, the lowest frequencies are slightly more heavily quantized than we would want (but the quant table entries near DC are so similar it's doubtful it matters much if at all)
  90.                 float base = sample_jpeg_quant(g_baseline_jpeg_y, rx, ry);
  91.                                
  92. #if 1
  93.                 if ((x + y) == 1)
  94.                     base *= .25f;
  95.                 else if ((x == 1) && (y == 1))
  96.                     base *= .75f;
  97. #endif
  98.  
  99.                 //quant_scale = (int)std::floor(base * level_scale + 0.5f);
  100.                 quant_scale = (int)(base * level_scale + 0.5f);
  101.                 assert(quant_scale == (int)std::floor(base * level_scale + 0.5f));
  102.  
  103.                 quant_scale = basisu::maximum<int>(1, quant_scale);
  104.  
  105. #if 1
  106.                 if ((x + y) == 1)
  107.                 {
  108.                     const int MAX_QUANT_SCALE_AC_1_1 = 73; // 73
  109.                     quant_scale = minimum(quant_scale, MAX_QUANT_SCALE_AC_1_1);
  110.                 }
  111. #endif
  112.  
  113.                 dct_quant_tab[x + y * grid_width] = quant_scale;
  114.             } // x
  115.         } // y
  116.  
  117.         for (uint32_t y = 0; y < grid_height; y++)
  118.         {
  119.             for (uint32_t x = y + 1; x < grid_width; x++)
  120.             {
  121.                 assert(x != y);
  122.  
  123.                 const int a = dct_quant_tab[x + y * grid_width];
  124.                 const int b = dct_quant_tab[y + x * grid_width];
  125.  
  126.                 const int c = maximum(a, b);
  127.  
  128.                 dct_quant_tab[x + y * grid_width] = c;
  129.                 dct_quant_tab[y + x * grid_width] = c;
  130.             }
  131.         }
  132.     }
  133.  
  134.     struct coeff
  135.     {
  136.         int16_t m_num_zeros; // number of zero AC coefficients before this one
  137.         int16_t m_coeff; // both sign and mag, [-256,256], or INT16_MAX if last
  138.  
  139.         void clear()
  140.         {
  141.             m_num_zeros = 0;
  142.             m_coeff = 0;
  143.         }
  144.     };
  145.  
  146.     typedef basisu::vector<coeff> coeff_vec;
  147.  
  148.     struct dct_syms
  149.     {
  150.         int16_t m_dc;       // [-256,256]
  151.                
  152.         coeff_vec m_ac_vals;
  153.  
  154.         void clear()
  155.         {
  156.             m_dc = 0;
  157.             m_ac_vals.resize(0);
  158.         }
  159.     };
  160.  
  161.     static const float DEADZONE_ALPHA = 0.5f;
  162.  
  163.     static const float g_scale_quant_steps[3] =
  164.     {
  165.         1.35588217f, // 4 (2-bits)
  166.         1.24573100f, // 8 (3-bits)
  167.         1.15431654f, // 16 (4-bits)
  168.     };
  169.  
  170.     static inline uint32_t get_weight_size_index_from_bits(uint32_t num_weight_bits)
  171.     {
  172.         switch (num_weight_bits)
  173.         {
  174.         case 2: return 0;
  175.         case 3: return 1;
  176.         case 4: return 2;
  177.         default:
  178.             assert(0);
  179.             return 0;
  180.         }
  181.     }
  182.  
  183.     class xbc7_weight_grid_dct
  184.     {
  185.     public:
  186.         xbc7_weight_grid_dct()
  187.         {
  188.         }
  189.  
  190.         void init()
  191.         {
  192.             m_dct.init(BLOCK_HEIGHT, BLOCK_WIDTH);
  193.         }
  194.                
  195.         void forward(
  196.             float global_q, uint32_t plane_index,
  197.             const int *pWeight_predictions, // may be nullptr
  198.             const basist::bc7u::log_bc7_block& log_blk,
  199.             dct_syms &syms,
  200.             basist::astc_ldr_t::fvec &dct_work)
  201.         {
  202.             syms.clear();
  203.  
  204.             float orig_weights[16];
  205.             for (uint32_t i = 0; i < 16; i++)
  206.             {
  207.                 const int predicted_weight = pWeight_predictions ? pWeight_predictions[i] : 0;
  208.                 assert((predicted_weight >= 0) && (predicted_weight <= 64));
  209.  
  210.                 orig_weights[i] = (float)(basist::bc7u::dequant_weight(log_blk.m_weights[plane_index][i], log_blk.m_weight_bits[plane_index]) - predicted_weight);
  211.             }
  212.                        
  213.             float dct_weights[16];
  214.  
  215.             m_dct.forward(orig_weights, dct_weights, dct_work);
  216.  
  217.             const float span_len = get_max_span_len(log_blk, plane_index);
  218.             const float level_scale = compute_level_scale(global_q, span_len, log_blk.m_weight_bits[plane_index]);
  219.  
  220.             int dct_quant_tab[16];
  221.             compute_quant_table(global_q, 4, 4, level_scale, dct_quant_tab, 4, 4);
  222.  
  223.             int dct_coeffs[16];
  224.  
  225.             for (uint32_t y = 0; y < 4; y++)
  226.             {
  227.                 for (uint32_t x = 0; x < 4; x++)
  228.                 {
  229.                     if (!x && !y)
  230.                     {
  231.                         dct_coeffs[0] = clamp<int>((int)std::round(dct_weights[0]), -256, 256);
  232.                         continue;
  233.                     }
  234.  
  235.                     const int levels = dct_quant_tab[x + y * 4];
  236.  
  237.                     float d = dct_weights[x + y * 4];
  238.  
  239.                     int id = quantize_deadzone(d, levels, DEADZONE_ALPHA, x, y);
  240.  
  241.                     dct_coeffs[x + y * 4] = clamp<int>(id, -256, 256);
  242.  
  243.                 } // x
  244.  
  245.             }  // y
  246.  
  247.             syms.m_dc = safe_cast_int16(dct_coeffs[0]);
  248.  
  249.             syms.m_ac_vals.reserve(17);
  250.  
  251.             int total_zeros = 0;
  252.             for (uint32_t i = 1; i < 16; i++)
  253.             {
  254.                 const uint32_t dct_idx = g_zigzag4x4_xy[i][0] + (g_zigzag4x4_xy[i][1] * 4);
  255.                 assert(dct_idx);
  256.  
  257.                 int ac_coeff = dct_coeffs[dct_idx];
  258.                 if (!ac_coeff)
  259.                 {
  260.                     total_zeros++;
  261.                     continue;
  262.                 }
  263.  
  264.                 coeff cf;
  265.                 cf.m_num_zeros = basisu::safe_cast_int16(total_zeros);
  266.                 cf.m_coeff = basisu::safe_cast_int16(ac_coeff);
  267.  
  268.                 syms.m_ac_vals.push_back(cf);
  269.  
  270.                 total_zeros = 0;
  271.             }
  272.  
  273.             if (total_zeros)
  274.             {
  275.                 coeff cf;
  276.                 cf.m_num_zeros = basisu::safe_cast_int16(total_zeros);
  277.                 cf.m_coeff = INT16_MAX;
  278.                 syms.m_ac_vals.push_back(cf);
  279.             }
  280.         }
  281.  
  282.         bool inverse(
  283.             float global_q, uint32_t plane_index,
  284.             const int* pWeight_predictions, // may be nullptr
  285.             const dct_syms& syms,
  286.             basist::bc7u::log_bc7_block& log_blk,
  287.             basist::astc_ldr_t::fvec& dct_work)
  288.         {
  289.             const float span_len = get_max_span_len(log_blk, plane_index);
  290.  
  291.             const float level_scale = compute_level_scale(global_q, span_len, log_blk.m_weight_bits[plane_index]);
  292.  
  293.             int dct_quant_tab[16];
  294.             compute_quant_table(global_q, 4, 4, level_scale, dct_quant_tab, 4, 4);
  295.  
  296.             float dct_weights[16];
  297.             basisu::clear_obj(dct_weights);
  298.  
  299.             dct_weights[0] = (float)syms.m_dc;
  300.  
  301.             uint32_t zig_idx = 1;
  302.             uint32_t coeff_ofs = 0;
  303.             while (coeff_ofs < syms.m_ac_vals.size())
  304.             {
  305.                 const uint32_t run_len = syms.m_ac_vals[coeff_ofs].m_num_zeros;
  306.                 const int coeff = syms.m_ac_vals[coeff_ofs].m_coeff;
  307.                 coeff_ofs++;
  308.  
  309.                 if ((run_len + zig_idx) > 16)
  310.                     return false;
  311.  
  312.                 zig_idx += run_len;
  313.  
  314.                 if (zig_idx >= 16)
  315.                     break;
  316.  
  317.                 assert(coeff != INT16_MAX);
  318.  
  319.                 const int x = g_zigzag4x4_xy[zig_idx][0];
  320.                 const int y = g_zigzag4x4_xy[zig_idx][1];
  321.                 const int dct_idx = x + (y * 4);
  322.  
  323.                 const int quant = dct_quant_tab[dct_idx];
  324.  
  325.                 dct_weights[dct_idx] = dequant_deadzone(coeff, quant, DEADZONE_ALPHA, x, y);
  326.  
  327.                 zig_idx++;
  328.             }
  329.  
  330.             float idct_weights[astc_helpers::MAX_BLOCK_PIXELS];
  331.  
  332.             m_dct.inverse(dct_weights, idct_weights, dct_work);
  333.  
  334.             for (uint32_t i = 0; i < 16; i++)
  335.             {
  336.                 log_blk.m_weights[plane_index][i] = basist::bc7u::quant_weight(
  337.                     basisu::clamp<int>(fast_roundf_int(idct_weights[i] + (pWeight_predictions ? pWeight_predictions[i] : 0)), 0, 64),
  338.                     log_blk.m_weight_bits[plane_index]);
  339.             }
  340.  
  341.             return true;
  342.         }
  343.                
  344.     private:
  345.         static const uint32_t BLOCK_WIDTH = 4;
  346.         static const uint32_t BLOCK_HEIGHT = 4;
  347.        
  348.         basist::astc_ldr_t::dct2f m_dct;
  349.                
  350.         // Adaptive quantization
  351.         float compute_level_scale(float q, float span_len, uint32_t num_weight_bits)
  352.         {
  353.             const uint32_t weight_size_index = get_weight_size_index_from_bits(num_weight_bits);
  354.            
  355.             // Standard JPEG quality factor calcs
  356.             // TODO: Precompute this once
  357.             float level_scale;
  358.             q = basisu::clamp(q, 1.0f, 100.0f);
  359.             if (q < 50.0f)
  360.                 level_scale = 5000.0f / q;
  361.             else
  362.                 level_scale = 200.0f - 2.0f * q;
  363.  
  364.             level_scale *= (1.0f / 100.0f); // because JPEG's quant table is scaled by 100
  365.  
  366.             //const float span_floor = 28.0f;
  367.             const float span_floor = 14.0f;
  368.             //const float adaptive_factor = 255.0f / maximum<float>(span_len, span_floor);
  369.             // 64.0 = dynamic range adjustment (JPEG uses 255)
  370.             // divide by span len to adjust adaptive low/high values per-block (JPEG always uses effective span=0-255)
  371.             // actually (64/255) * 255/max(span_len, span_floor)
  372.             float adaptive_factor = 64.0f / basisu::maximum<float>(span_len, span_floor);
  373.  
  374.             // input signal scalar quantization noise will be distributed between multiple AC coefficients - compensate by adaptively adjusting the quant step size
  375.             float weight_quant_adaptive_factor = g_scale_quant_steps[weight_size_index];
  376.             adaptive_factor *= weight_quant_adaptive_factor;
  377.  
  378.             // (Adaptive quant)
  379.             level_scale *= adaptive_factor;
  380.  
  381.             // The higher the level_scale, the more quantized DCT coefficients will be and vice versa.
  382.  
  383.             return level_scale;
  384.         }
  385.  
  386.         // Needed by AQ
  387.         float get_max_span_len(const basist::bc7u::log_bc7_block& log_blk, uint32_t plane_index) const
  388.         {
  389.             float span_len = 0.0f;
  390.  
  391.             if (log_blk.is_dual_plane())
  392.             {
  393.                 basist::color_rgba ep[2];
  394.  
  395.                 basist::bc7u::unpack_endpoints(log_blk, ep, 0);
  396.  
  397.                 const basist::color_rgba& l = ep[0];
  398.                 const basist::color_rgba& h = ep[1];
  399.  
  400.                 for (uint32_t c = 0; c < 4; c++)
  401.                 {
  402.                     // get the weight plane used by this endpoint channel (NOT the decoded pixel channel, which is after any mode 4/5 channel swapping/rotation)
  403.                     const uint32_t endpoint_chan_plane = log_blk.get_endpoint_channel_weight_plane(c);
  404.  
  405.                     if (endpoint_chan_plane == plane_index)
  406.                     {
  407.                         span_len += basisu::squaref((float)h[c] - (float)l[c]);
  408.                     }
  409.                 }
  410.  
  411.                 span_len = sqrtf(span_len);
  412.             }
  413.             else
  414.             {
  415.                 assert(!plane_index);
  416.  
  417.                 for (uint32_t i = 0; i < log_blk.m_num_partitions; i++)
  418.                 {
  419.                     basist::color_rgba ep[2];
  420.  
  421.                     basist::bc7u::unpack_endpoints(log_blk, ep, i);
  422.  
  423.                     const basist::color_rgba& l = ep[0];
  424.                     const basist::color_rgba& h = ep[1];
  425.  
  426.                     float part_span_len = sqrtf(
  427.                         basisu::squaref((float)h.r - (float)l.r) + basisu::squaref((float)h.g - (float)l.g) + basisu::squaref((float)h.b - (float)l.b) + basisu::squaref((float)h.a - (float)l.a)
  428.                     );
  429.  
  430.                     span_len = basisu::maximum(part_span_len, span_len);
  431.                 }
  432.             }
  433.  
  434.             return span_len;
  435.         }
  436.  
  437.         inline int quantize_deadzone(float d, int L, float alpha, uint32_t x, uint32_t y) const
  438.         {
  439.             assert((x < BLOCK_WIDTH) && (y < BLOCK_HEIGHT));
  440.  
  441.             if (((x == 1) && (y == 0)) ||
  442.                 ((x == 0) && (y == 1)))
  443.             {
  444.                 return (int)std::round(d / (float)L);
  445.             }
  446.  
  447.             // L = quant step, alpha in [0,1.2] (typical 0.7–0.85)
  448.             if (L <= 0)
  449.                 return 0;
  450.  
  451.             float s = fabsf(d);
  452.             float tau = alpha * float(L);                 // half-width of the zero band
  453.  
  454.             if (s <= tau)
  455.                 return 0;                       // inside dead-zone towards zero
  456.  
  457.             // Quantize the residual outside the dead-zone with mid-tread rounding
  458.             float qf = (s - tau) / float(L);
  459.             int   q = (int)floorf(qf + 0.5f);            // ties-nearest
  460.             return (d < 0.0f) ? -q : q;
  461.         }
  462.  
  463.         inline float dequant_deadzone(int q, int L, float alpha, uint32_t x, uint32_t y) const
  464.         {
  465.             assert((x < BLOCK_WIDTH) && (y < BLOCK_HEIGHT));
  466.  
  467.             if (((x == 1) && (y == 0)) ||
  468.                 ((x == 0) && (y == 1)))
  469.             {
  470.                 return (float)q * (float)L;
  471.             }
  472.  
  473.             if (q == 0 || L <= 0)
  474.                 return 0.0f;
  475.  
  476.             float tau = alpha * float(L);
  477.             float mag = tau + float(abs(q)) * float(L);   // center of the (nonzero) bin
  478.             return (q < 0) ? -mag : mag;
  479.         }
  480.     };
  481.        
  482. } // namespace xbc7
  483.  
  484. static inline uint32_t index_from_xy(uint32_t x, uint32_t y) { assert((x < 4) && (y < 4));  return x + y * 4; }
  485.  
  486. static bool bc7_test(int argc, const char *argv[])
  487. {
  488.     basisu::rand rnd;
  489.     rnd.seed(1000);
  490.  
  491.     enable_debug_printf(true);
  492.  
  493.     if (argc != 2)
  494.         return false;
  495.  
  496.     const char* pFilename = argv[1];
  497.  
  498.     image orig_img;
  499.     if (!load_png(pFilename, orig_img))
  500.         return false;
  501.  
  502.     const bool srgb_flag = true;
  503.  
  504.     const uint32_t block_width = 4;
  505.     const uint32_t block_height = 4;
  506.     const uint32_t total_block_pixels = block_width * block_height;
  507.  
  508.     const uint32_t width = orig_img.get_width();
  509.     const uint32_t height = orig_img.get_height();
  510.     const uint32_t num_blocks_x = (width + block_width - 1) / block_width;
  511.     const uint32_t num_blocks_y = (height + block_height - 1) / block_height;
  512.     const uint32_t total_blocks = num_blocks_x * num_blocks_y;
  513.        
  514.     vector2D<basist::bc7u::log_bc7_block> log_blks(num_blocks_x, num_blocks_y);
  515.  
  516.     for (uint32_t by = 0; by < num_blocks_y; by++)
  517.     {
  518.         for (uint32_t bx = 0; bx < num_blocks_x; bx++)
  519.         {
  520.             color_rgba orig_block[16];
  521.             orig_img.extract_block_clamped(orig_block, bx * 4, by * 4, 4, 4);
  522.  
  523.             basist::bc7u::phys_bc7_block phys_blk;
  524.             basist::bc7f::fast_pack_bc7_auto_rgba(phys_blk.m_bytes, (basist::color_rgba*)orig_block, basist::bc7f::cPackBC7FlagDefaultPartiallyAnalytical);
  525.  
  526.             basist::bc7u::log_bc7_block& log_blk = log_blks(bx, by);
  527.  
  528.             bool unpack_status = basist::bc7u::unpack_bc7(&phys_blk, log_blk);
  529.             if (!unpack_status)
  530.             {
  531.                 assert(0);
  532.                 return false;
  533.             }
  534.         }
  535.     }
  536.  
  537.     image out_img(width, height);
  538.  
  539.     xbc7::xbc7_weight_grid_dct weight_grid_dct;
  540.     weight_grid_dct.init();
  541.  
  542.     basist::astc_ldr_t::fvec dct_work;
  543.  
  544.     const float global_q = 75;// 9.0f;
  545.  
  546.     uint32_t total_ac_syms = 0;
  547.  
  548.     uint_vec cand_hist(17);
  549.  
  550.     vector2D<basist::bc7u::phys_bc7_block> phys_blocks(num_blocks_x, num_blocks_y);
  551.  
  552.     for (uint32_t by = 0; by < num_blocks_y; by++)
  553.     {
  554.         fmt_printf(".");
  555.  
  556.         for (uint32_t bx = 0; bx < num_blocks_x; bx++)
  557.         {
  558.             color_rgba orig_block[16];
  559.             orig_img.extract_block_clamped(orig_block, bx * 4, by * 4, 4, 4);
  560.  
  561.             const basist::bc7u::log_bc7_block* pLeft_diag_log_blk = (bx && by) ? &log_blks(bx - 1, by - 1) : nullptr;
  562.             const basist::bc7u::log_bc7_block* pRight_diag_log_blk = (((bx + 1) < num_blocks_x) && by) ? &log_blks(bx + 1, by - 1) : nullptr;
  563.             const basist::bc7u::log_bc7_block* pUp_log_blk = by ? &log_blks(bx, by - 1) : nullptr;
  564.             const basist::bc7u::log_bc7_block* pLeft_log_blk = bx ? &log_blks(bx - 1, by) : nullptr;
  565.            
  566.             basist::bc7u::log_bc7_block& log_blk = log_blks(bx, by);
  567.             const basist::bc7u::log_bc7_block orig_log_blk(log_blk);
  568.  
  569.             basist::bc7u::log_bc7_block best_cand_log_blk(orig_log_blk);
  570.             uint64_t best_err = UINT64_MAX;
  571.             uint32_t best_num_ac_syms = UINT32_MAX;
  572.            
  573.             if (!basist::bc7u::is_solid_blk(log_blk))
  574.             {
  575.                 uint32_t best_cand_index = 0;
  576.  
  577.                 const uint32_t TOTAL_CANDIDATES = 17;
  578.  
  579.                 for (uint32_t cand_index = 0; cand_index < TOTAL_CANDIDATES; cand_index++)
  580.                 {
  581.                     const basist::bc7u::log_bc7_block* pCand_log_blk = nullptr;
  582.  
  583.                     if (cand_index == 0)
  584.                     {
  585.  
  586.                     }
  587.                     else
  588.                     {
  589.                         if (cand_index == 1)
  590.                             pCand_log_blk = pLeft_log_blk;
  591.                         else if (cand_index == 2)
  592.                             pCand_log_blk = pUp_log_blk;
  593.                         else if (cand_index == 3)
  594.                             pCand_log_blk = pLeft_diag_log_blk;
  595.                         else if (cand_index == 4)
  596.                             pCand_log_blk = pRight_diag_log_blk;
  597.                         else if (cand_index == 5)
  598.                             pCand_log_blk = pLeft_log_blk; // left edge
  599.                         else if (cand_index == 6)
  600.                             pCand_log_blk = pUp_log_blk; // upper edge
  601.                         else if (cand_index == 7)
  602.                             pCand_log_blk = (pLeft_log_blk && pUp_log_blk) ? pLeft_log_blk : nullptr; // left+upper edge blend
  603.                         else if (cand_index == 8)
  604.                             pCand_log_blk = pLeft_log_blk; // reflect left
  605.                         else if (cand_index == 9)
  606.                             pCand_log_blk = pUp_log_blk; // reflect upper
  607.                         else if (cand_index == 10)
  608.                             pCand_log_blk = (pLeft_log_blk && pUp_log_blk) ? pLeft_log_blk : nullptr; // left+upper edge avg
  609.                         else if (cand_index == 11)
  610.                             pCand_log_blk = (pLeft_log_blk && pUp_log_blk) ? pLeft_log_blk : nullptr; // left+upper edge stronger distance blend
  611.                         else if (cand_index == 12)
  612.                             pCand_log_blk = (pLeft_log_blk && pUp_log_blk && pLeft_diag_log_blk) ? pLeft_log_blk : nullptr; // gradient
  613.                         else if (cand_index == 13)
  614.                             pCand_log_blk = (pLeft_log_blk && pUp_log_blk && pLeft_diag_log_blk) ? pLeft_log_blk : nullptr; // damped gradient
  615.                         else if (cand_index == 14)
  616.                             pCand_log_blk = (pLeft_diag_log_blk && pRight_diag_log_blk) ? pLeft_diag_log_blk : nullptr; // left/right diagonal avg
  617.                         else if (cand_index == 15)
  618.                             pCand_log_blk = (pLeft_diag_log_blk && pRight_diag_log_blk) ? pLeft_diag_log_blk : nullptr; // diagonal edge blend
  619.                         else if (cand_index == 16)
  620.                             pCand_log_blk = (pUp_log_blk && pLeft_diag_log_blk && pRight_diag_log_blk) ? pLeft_diag_log_blk : nullptr; // upper + diagonal edge blend
  621.  
  622.                         if (!pCand_log_blk)
  623.                             continue;
  624.                     }
  625.  
  626.                     basist::bc7u::log_bc7_block cand_log_blk(log_blk);
  627.  
  628.                     uint32_t cand_total_ac_syms = 0;
  629.  
  630.                     for (uint32_t p = 0; p < cand_log_blk.m_num_planes; p++)
  631.                     {
  632.                         const int* pWeight_predictions = nullptr;
  633.  
  634.                         int weight_preds[16];
  635.                         if (pCand_log_blk)
  636.                         {
  637.                             for (uint32_t w = 0; w < 16; w++)
  638.                             {
  639.                                 if (pCand_log_blk->is_dual_plane())
  640.                                     weight_preds[w] = basist::bc7u::dequant_weight(pCand_log_blk->m_weights[p][w], pCand_log_blk->m_weight_bits[p]);
  641.                                 else
  642.                                     weight_preds[w] = basist::bc7u::dequant_weight(pCand_log_blk->m_weights[0][w], pCand_log_blk->m_weight_bits[0]);
  643.                             }
  644.  
  645.                             int orig_weight_preds[16];
  646.                             memcpy(orig_weight_preds, weight_preds, sizeof(orig_weight_preds));
  647.  
  648.                             if (cand_index == 5)
  649.                             {
  650.                                 // left edge
  651.                                 for (uint32_t y = 0; y < 4; y++)
  652.                                     for (uint32_t x = 0; x < 4; x++)
  653.                                         weight_preds[index_from_xy(x, y)] = orig_weight_preds[index_from_xy(3, y)];
  654.                             }
  655.                             else if (cand_index == 6)
  656.                             {
  657.                                 // upper edge
  658.                                 for (uint32_t y = 0; y < 4; y++)
  659.                                     for (uint32_t x = 0; x < 4; x++)
  660.                                         weight_preds[index_from_xy(x, y)] = orig_weight_preds[index_from_xy(x, 3)];
  661.                             }
  662.                             else if ((cand_index == 7) || (cand_index == 10) || (cand_index == 11))
  663.                             {
  664.                                 // left+upper edge blend variants.
  665.                                 // pCand_log_blk is pLeft_log_blk here, so orig_weight_preds contains the left block.
  666.                                 // Pull upper edge directly from pUp_log_blk.
  667.  
  668.                                 int upper_edge[4];
  669.  
  670.                                 for (uint32_t x = 0; x < 4; x++)
  671.                                 {
  672.                                     const uint32_t w = index_from_xy(x, 3); // upper block's bottom edge
  673.  
  674.                                     if (pUp_log_blk->is_dual_plane())
  675.                                         upper_edge[x] = basist::bc7u::dequant_weight(pUp_log_blk->m_weights[p][w], pUp_log_blk->m_weight_bits[p]);
  676.                                     else
  677.                                         upper_edge[x] = basist::bc7u::dequant_weight(pUp_log_blk->m_weights[0][w], pUp_log_blk->m_weight_bits[0]);
  678.                                 }
  679.  
  680.                                 for (uint32_t y = 0; y < 4; y++)
  681.                                 {
  682.                                     const int left_val = orig_weight_preds[index_from_xy(3, y)]; // left block's right edge
  683.  
  684.                                     for (uint32_t x = 0; x < 4; x++)
  685.                                     {
  686.                                         const int upper_val = upper_edge[x];
  687.                                         int pred;
  688.  
  689.                                         if (cand_index == 7)
  690.                                         {
  691.                                             // Existing distance-weighted blend.
  692.                                             const int wl = 4 - static_cast<int>(x); // 4,3,2,1
  693.                                             const int wu = 4 - static_cast<int>(y); // 4,3,2,1
  694.                                             const int den = wl + wu;
  695.  
  696.                                             pred = (wl * left_val + wu * upper_val + (den >> 1)) / den;
  697.                                         }
  698.                                         else if (cand_index == 10)
  699.                                         {
  700.                                             // Simple average.
  701.                                             pred = (left_val + upper_val + 1) >> 1;
  702.                                         }
  703.                                         else // cand_index == 11
  704.                                         {
  705.                                             // Stronger distance weighting: trust the nearest edge more.
  706.                                             const int dx = 4 - static_cast<int>(x); // 4,3,2,1
  707.                                             const int dy = 4 - static_cast<int>(y); // 4,3,2,1
  708.                                             const int wl = dx * dx; // 16,9,4,1
  709.                                             const int wu = dy * dy; // 16,9,4,1
  710.                                             const int den = wl + wu;
  711.  
  712.                                             pred = (wl * left_val + wu * upper_val + (den >> 1)) / den;
  713.                                         }
  714.  
  715.                                         weight_preds[index_from_xy(x, y)] = pred;
  716.                                     }
  717.                                 }
  718.                             }
  719.                             else if (cand_index == 8)
  720.                             {
  721.                                 // reflect left
  722.                                 for (uint32_t y = 0; y < 4; y++)
  723.                                     for (uint32_t x = 0; x < 4; x++)
  724.                                         weight_preds[index_from_xy(x, y)] = orig_weight_preds[index_from_xy(3 - x, y)];
  725.                             }
  726.                             else if (cand_index == 9)
  727.                             {
  728.                                 // reflect upper
  729.                                 for (uint32_t y = 0; y < 4; y++)
  730.                                     for (uint32_t x = 0; x < 4; x++)
  731.                                         weight_preds[index_from_xy(x, y)] = orig_weight_preds[index_from_xy(x, 3 - y)];
  732.                             }
  733.                             else if ((cand_index == 12) || (cand_index == 13))
  734.                             {
  735.                                 int upper_edge[4];
  736.  
  737.                                 for (uint32_t x = 0; x < 4; x++)
  738.                                 {
  739.                                     const uint32_t w = index_from_xy(x, 3); // upper block's bottom edge
  740.  
  741.                                     if (pUp_log_blk->is_dual_plane())
  742.                                         upper_edge[x] = basist::bc7u::dequant_weight(pUp_log_blk->m_weights[p][w], pUp_log_blk->m_weight_bits[p]);
  743.                                     else
  744.                                         upper_edge[x] = basist::bc7u::dequant_weight(pUp_log_blk->m_weights[0][w], pUp_log_blk->m_weight_bits[0]);
  745.                                 }
  746.  
  747.                                 const uint32_t corner_w = index_from_xy(3, 3); // upper-left block's bottom-right
  748.  
  749.                                 int corner_val;
  750.                                 if (pLeft_diag_log_blk->is_dual_plane())
  751.                                     corner_val = basist::bc7u::dequant_weight(pLeft_diag_log_blk->m_weights[p][corner_w], pLeft_diag_log_blk->m_weight_bits[p]);
  752.                                 else
  753.                                     corner_val = basist::bc7u::dequant_weight(pLeft_diag_log_blk->m_weights[0][corner_w], pLeft_diag_log_blk->m_weight_bits[0]);
  754.  
  755.                                 for (uint32_t y = 0; y < 4; y++)
  756.                                 {
  757.                                     const int left_val = orig_weight_preds[index_from_xy(3, y)]; // left block's right edge
  758.  
  759.                                     for (uint32_t x = 0; x < 4; x++)
  760.                                     {
  761.                                         const int upper_val = upper_edge[x];
  762.  
  763.                                         int grad = left_val + upper_val - corner_val;
  764.                                         grad = basisu::clamp<int>(grad, 0, 64); // or your clamp helper
  765.  
  766.                                         if (cand_index == 12)
  767.                                         {
  768.                                             weight_preds[index_from_xy(x, y)] = grad;
  769.                                         }
  770.                                         else
  771.                                         {
  772.                                             // Damped gradient: blend gradient with your proven #7 predictor.
  773.                                             const int wl = 4 - static_cast<int>(x);
  774.                                             const int wu = 4 - static_cast<int>(y);
  775.                                             const int den = wl + wu;
  776.                                             const int blend7 = (wl * left_val + wu * upper_val + (den >> 1)) / den;
  777.  
  778.                                             weight_preds[index_from_xy(x, y)] = (grad + blend7 + 1) >> 1;
  779.                                         }
  780.                                     }
  781.                                 }
  782.                             }
  783.                             else if (cand_index == 14)
  784.                             {
  785.                                 // Average upper-left and upper-right diagonal blocks.
  786.                                 // pCand_log_blk is pLeft_diag_log_blk here, so orig_weight_preds contains upper-left.
  787.                                 // Pull upper-right directly from pRight_diag_log_blk.
  788.  
  789.                                 for (uint32_t w = 0; w < 16; w++)
  790.                                 {
  791.                                     int right_diag_val;
  792.  
  793.                                     if (pRight_diag_log_blk->is_dual_plane())
  794.                                         right_diag_val = basist::bc7u::dequant_weight(
  795.                                             pRight_diag_log_blk->m_weights[p][w],
  796.                                             pRight_diag_log_blk->m_weight_bits[p]);
  797.                                     else
  798.                                         right_diag_val = basist::bc7u::dequant_weight(
  799.                                             pRight_diag_log_blk->m_weights[0][w],
  800.                                             pRight_diag_log_blk->m_weight_bits[0]);
  801.  
  802.                                     weight_preds[w] = (orig_weight_preds[w] + right_diag_val + 1) >> 1;
  803.                                 }
  804.                             }
  805.                             else if (cand_index == 15)
  806.                             {
  807.                                 // Blend upper-left block's right edge with upper-right block's left edge.
  808.                                 // pCand_log_blk is pLeft_diag_log_blk, so orig_weight_preds contains upper-left.
  809.                                 // Pull upper-right left edge directly from pRight_diag_log_blk.
  810.                                 //
  811.                                 // For each row y:
  812.                                 //   L = upper-left[3,y]
  813.                                 //   R = upper-right[0,y]
  814.                                 // Then interpolate across x.
  815.  
  816.                                 int right_diag_left_edge[4];
  817.  
  818.                                 for (uint32_t y = 0; y < 4; y++)
  819.                                 {
  820.                                     const uint32_t w = index_from_xy(0, y); // upper-right block's left edge
  821.  
  822.                                     if (pRight_diag_log_blk->is_dual_plane())
  823.                                         right_diag_left_edge[y] = basist::bc7u::dequant_weight(
  824.                                             pRight_diag_log_blk->m_weights[p][w],
  825.                                             pRight_diag_log_blk->m_weight_bits[p]);
  826.                                     else
  827.                                         right_diag_left_edge[y] = basist::bc7u::dequant_weight(
  828.                                             pRight_diag_log_blk->m_weights[0][w],
  829.                                             pRight_diag_log_blk->m_weight_bits[0]);
  830.                                 }
  831.  
  832.                                 for (uint32_t y = 0; y < 4; y++)
  833.                                 {
  834.                                     const int left_val = orig_weight_preds[index_from_xy(3, y)]; // upper-left right edge
  835.                                     const int right_val = right_diag_left_edge[y];                // upper-right left edge
  836.  
  837.                                     for (uint32_t x = 0; x < 4; x++)
  838.                                     {
  839.                                         // x=0 mostly left_val, x=3 mostly right_val.
  840.                                         // Use 4-sample interpolation: 3/0, 2/1, 1/2, 0/3.
  841.                                         const int pred = ((3 - static_cast<int>(x)) * left_val +
  842.                                             static_cast<int>(x) * right_val + 1) / 3;
  843.  
  844.                                         weight_preds[index_from_xy(x, y)] = pred;
  845.                                     }
  846.                                 }
  847.                             }
  848.                             else if (cand_index == 16)
  849.                             {
  850.                                 // Blend upper edge predictor with diagonal edge blend.
  851.                                 //
  852.                                 // upper_edge[x] = upper block's bottom edge
  853.                                 // diag_blend[x,y] = horizontal interpolation between:
  854.                                 //   upper-left block's right edge and upper-right block's left edge
  855.                                 //
  856.                                 // This combines direct top continuation with previous-row lateral structure.
  857.  
  858.                                 int upper_edge[4];
  859.                                 int right_diag_left_edge[4];
  860.  
  861.                                 for (uint32_t x = 0; x < 4; x++)
  862.                                 {
  863.                                     const uint32_t up_w = index_from_xy(x, 3); // upper block's bottom edge
  864.  
  865.                                     if (pUp_log_blk->is_dual_plane())
  866.                                         upper_edge[x] = basist::bc7u::dequant_weight(
  867.                                             pUp_log_blk->m_weights[p][up_w],
  868.                                             pUp_log_blk->m_weight_bits[p]);
  869.                                     else
  870.                                         upper_edge[x] = basist::bc7u::dequant_weight(
  871.                                             pUp_log_blk->m_weights[0][up_w],
  872.                                             pUp_log_blk->m_weight_bits[0]);
  873.                                 }
  874.  
  875.                                 for (uint32_t y = 0; y < 4; y++)
  876.                                 {
  877.                                     const uint32_t rd_w = index_from_xy(0, y); // upper-right block's left edge
  878.  
  879.                                     if (pRight_diag_log_blk->is_dual_plane())
  880.                                         right_diag_left_edge[y] = basist::bc7u::dequant_weight(
  881.                                             pRight_diag_log_blk->m_weights[p][rd_w],
  882.                                             pRight_diag_log_blk->m_weight_bits[p]);
  883.                                     else
  884.                                         right_diag_left_edge[y] = basist::bc7u::dequant_weight(
  885.                                             pRight_diag_log_blk->m_weights[0][rd_w],
  886.                                             pRight_diag_log_blk->m_weight_bits[0]);
  887.                                 }
  888.  
  889.                                 for (uint32_t y = 0; y < 4; y++)
  890.                                 {
  891.                                     const int left_diag_right_val = orig_weight_preds[index_from_xy(3, y)]; // upper-left right edge
  892.                                     const int right_diag_left_val = right_diag_left_edge[y];
  893.  
  894.                                     for (uint32_t x = 0; x < 4; x++)
  895.                                     {
  896.                                         // Same as #15: lateral predictor from upper-left/right diagonal edges.
  897.                                         const int diag_blend =
  898.                                             ((3 - static_cast<int>(x)) * left_diag_right_val +
  899.                                                 static_cast<int>(x) * right_diag_left_val + 1) / 3;
  900.  
  901.                                         // Same as #6: direct upper edge replicated downward.
  902.                                         const int up_val = upper_edge[x];
  903.  
  904.                                         // Trust upper edge more near y=0, trust diagonal lateral structure more lower in the block.
  905.                                         const int wu = 4 - static_cast<int>(y); // 4,3,2,1
  906.                                         const int wd = 1 + static_cast<int>(y); // 1,2,3,4
  907.                                         const int den = wu + wd;                // always 5
  908.  
  909.                                         weight_preds[index_from_xy(x, y)] =
  910.                                             (wu * up_val + wd * diag_blend + (den >> 1)) / den;
  911.                                     }
  912.                                 }
  913.                             }
  914.                            
  915.                             pWeight_predictions = weight_preds;
  916.                         }
  917.  
  918.                         xbc7::dct_syms syms;
  919.  
  920.                         weight_grid_dct.forward(global_q, p, pWeight_predictions, cand_log_blk, syms, dct_work);
  921.                                                
  922.                         memset(cand_log_blk.m_weights[p], 0, 16);
  923.  
  924.                         bool status = weight_grid_dct.inverse(global_q, p, pWeight_predictions, syms, cand_log_blk, dct_work);
  925.                         if (!status)
  926.                         {
  927.                             assert(0);
  928.                             return false;
  929.                         }
  930.  
  931.                         cand_total_ac_syms += syms.m_ac_vals.size_u32();
  932.  
  933.                     } // p
  934.  
  935.                     if (cand_total_ac_syms < best_num_ac_syms)
  936.                     {
  937.                         best_cand_log_blk = cand_log_blk;
  938.                         best_cand_index = cand_index;
  939.                         best_num_ac_syms = cand_total_ac_syms;
  940.                     }
  941.  
  942. #if 0
  943.                     color_rgba cand_block_pixels[16];
  944.                     basist::bc7u::unpack_bc7(cand_log_blk, (basist::color_rgba*)cand_block_pixels);
  945.  
  946.                     uint64_t cand_err = 0;
  947.                     for (uint32_t i = 0; i < 16; i++)
  948.                         cand_err += cand_block_pixels[i].get_dist2(orig_block[i]);
  949.  
  950.                     if (cand_err < best_err)
  951.                     {
  952.                         best_err = cand_err;
  953.                         best_cand_log_blk = cand_log_blk;
  954.                         best_cand_index = cand_index;
  955.                     }
  956. #endif
  957.                 }
  958.  
  959.                 cand_hist[best_cand_index]++;
  960.             }
  961.  
  962.             log_blk = best_cand_log_blk;
  963.             total_ac_syms += best_num_ac_syms;
  964.  
  965.             color_rgba new_block_pixels[16];
  966.             basist::bc7u::unpack_bc7(log_blk, (basist::color_rgba *)new_block_pixels);
  967.             out_img.set_block_clipped(new_block_pixels, bx * 4, by * 4, 4, 4);
  968.  
  969.             bool pack_status = basist::bc7u::pack_bc7(log_blk, &phys_blocks(bx, by));
  970.             assert(pack_status);
  971.             if (!pack_status)
  972.                 return false;
  973.                        
  974.         } // bx
  975.  
  976.     } // by
  977.  
  978.     save_png("out_img.png", out_img);
  979.  
  980.     fmt_printf("\nOK\n");
  981.  
  982.     fmt_printf("Total AC syms: {}, avg per block: {}\n", total_ac_syms, (float)total_ac_syms / (float)total_blocks);
  983.  
  984.     fmt_printf("Candidate histogram:\n");
  985.     for (uint32_t i = 0; i < cand_hist.size(); i++)
  986.         fmt_printf("{}: {}\n", i, cand_hist[i]);
  987.  
  988.     create_bc7_debug_images(
  989.         width, height,
  990.         phys_blocks.get_ptr(),
  991.         "out_img");
  992.  
  993.     return false;
  994. }
  995.  
Tags: graphics
Advertisement
Add Comment
Please, Sign In to add comment