Kaelygon

palettizeOklab.py

Dec 2nd, 2025 (edited)
24
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.25 KB | None | 0 0
  1. ##CC0 Kaelygon 2025
  2. """
  3. Palettize and dither using arbitrary palette
  4. """
  5.  
  6. import math
  7. import random
  8. import numpy as np
  9. from PIL import Image
  10. from dataclasses import dataclass, field
  11. from typing import List, Optional
  12. from scipy.spatial import cKDTree as KDTree
  13.  
  14.  
  15. ### Constants ###
  16.  
  17. OKLAB_GAMUT_VOLUME = 0.054197416
  18.  
  19. def approxOkGap(point_count: int):
  20.     return (OKLAB_GAMUT_VOLUME/max(1,point_count))**(1.0/3.0)
  21.  
  22.  
  23. ### Color Conversion ###
  24.  
  25. def srgbToLinear(srgb: np.ndarray):
  26.     cutoff = srgb <= 0.04045
  27.     higher = ((srgb + 0.055) / 1.055) ** 2.4
  28.     lower = srgb / 12.92
  29.     return np.where(cutoff, lower, higher)
  30.  
  31. def linearToSrgb(lin: np.ndarray):
  32.     lin = np.maximum(lin, 0.0)
  33.     cutoff = lin <= 0.0031308
  34.     higher = 1.055 * np.power(lin, 1/2.4) - 0.055
  35.     lower = lin * 12.92
  36.     return np.where(cutoff, lower, higher)
  37.  
  38. def linearToOklab(lin: np.ndarray):
  39.     r, g, b = lin[:,0], lin[:,1], lin[:,2]
  40.     l = 0.4122214708*r + 0.5363325363*g + 0.0514459929*b
  41.     m = 0.2119034982*r + 0.6806995451*g + 0.1073969566*b
  42.     s = 0.0883024619*r + 0.2817188376*g + 0.6299787005*b
  43.    
  44.     l_ = np.sign(l) * np.abs(l) ** (1/3)
  45.     m_ = np.sign(m) * np.abs(m) ** (1/3)
  46.     s_ = np.sign(s) * np.abs(s) ** (1/3)
  47.    
  48.     L = 0.2104542553*l_ + 0.7936177850*m_ - 0.0040720468*s_
  49.     a = 1.9779984951*l_ - 2.4285922050*m_ + 0.4505937099*s_
  50.     b = 0.0259040371*l_ + 0.7827717662*m_ - 0.8086757660*s_
  51.    
  52.     return np.stack([L,a,b], axis=1)
  53.  
  54. def oklabToLinear(lab: np.ndarray):
  55.     L, a, b = lab[:,0], lab[:,1], lab[:,2]
  56.     l_ = L + 0.3963377774*a + 0.2158037573*b
  57.     m_ = L - 0.1055613458*a - 0.0638541728*b
  58.     s_ = L - 0.0894841775*a - 1.2914855480*b
  59.    
  60.     l = l_**3
  61.     m = m_**3
  62.     s = s_**3
  63.    
  64.     r = +4.0767416621*l - 3.3077115913*m + 0.2309699292*s
  65.     g = -1.2684380046*l + 2.6097574011*m - 0.3413193965*s
  66.     b = -0.0041960863*l - 0.7034186147*m + 1.7076147010*s
  67.    
  68.     return np.stack([r,g,b], axis=1)
  69.  
  70. def srgbToOklab(col: np.ndarray):
  71.     linRGB = srgbToLinear(col)
  72.     oklab = linearToOklab(linRGB)
  73.     return oklab
  74.  
  75. def oklabToSrgb(col: np.ndarray):
  76.     linRGB = oklabToLinear(col)
  77.     sRGB = linearToSrgb(linRGB)
  78.     return sRGB
  79.  
  80. #### Image conversion ###
  81.  
  82. #striped list. item = (color[i],alpha[i],area[i])
  83. @dataclass
  84. class UniqueList:
  85.     color: np.ndarray #uniques only
  86.     alpha: np.ndarray
  87.     area: np.ndarray
  88.  
  89.     unique_idxs: np.ndarray
  90.     original_idxs: np.ndarray #colors_with_dupes = color[original_idxs]
  91.  
  92. class OkImage:
  93.     #const
  94.     BAYER_2 = np.array([
  95.         [ 0, 2],
  96.         [ 3, 1],
  97.     ])
  98.     BAYER_4 = np.array([
  99.         [ 0, 8, 2,10],
  100.         [12, 4,14, 6],
  101.         [ 3,11, 1, 9],
  102.         [15, 7,13, 5],
  103.     ])
  104.  
  105.     BAYER_8 = np.array([
  106.         [ 0, 32,  8, 40,  2, 34, 10, 42],
  107.         [48, 16, 56, 24, 50, 18, 58, 26],
  108.         [12, 44,  4, 36, 14, 46,  6, 38],
  109.         [60, 28, 52, 20, 62, 30, 54, 22],
  110.         [ 3, 35, 11, 43,  1, 33,  9, 41],
  111.         [51, 19, 59, 27, 49, 17, 57, 25],
  112.         [15, 47,  7, 39, 13, 45,  5, 37],
  113.         [63, 31, 55, 23, 61, 29, 53, 21],
  114.     ])
  115.  
  116.     BAYER_16 = np.array([
  117.         [  0, 128,  32, 160,   8, 136,  40, 168,   2, 130,  34, 162,  10, 138,  42, 170],
  118.         [192,  64, 224,  96, 200,  72, 232, 104, 194,  66, 226,  98, 202,  74, 234, 106],
  119.         [ 48, 176,  16, 144,  56, 184,  24, 152,  50, 178,  18, 146,  58, 186,  26, 154],
  120.         [240, 112, 208,  80, 248, 120, 216,  88, 242, 114, 210,  82, 250, 122, 218,  90],
  121.         [ 12, 140,  44, 172,   4, 132,  36, 164,  14, 142,  46, 174,   6, 134,  38, 166],
  122.         [204,  76, 236, 108, 196,  68, 228, 100, 206,  78, 238, 110, 198,  70, 230, 102],
  123.         [ 60, 188,  28, 156,  52, 180,  20, 148,  62, 190,  30, 158,  54, 182,  22, 150],
  124.         [252, 124, 220,  92, 244, 116, 212,  84, 254, 126, 222,  94, 246, 118, 214,  86],
  125.         [  3, 131,  35, 163,  11, 139,  43, 171,   1, 129,  33, 161,   9, 137,  41, 169],
  126.         [195,  67, 227,  99, 203,  75, 235, 107, 193,  65, 225,  97, 201,  73, 233, 105],
  127.         [ 51, 179,  19, 147,  59, 187,  27, 155,  49, 177,  17, 145,  57, 185,  25, 153],
  128.         [243, 115, 211,  83, 251, 123, 219,  91, 241, 113, 209,  81, 249, 121, 217,  89],
  129.         [ 15, 143,  47, 175,   7, 135,  39, 167,  13, 141,  45, 173,   5, 133,  37, 165],
  130.         [207,  79, 239, 111, 199,  71, 231, 103, 205,  77, 237, 109, 197,  69, 229, 101],
  131.         [ 63, 191,  31, 159,  55, 183,  23, 151,  61, 189,  29, 157,  53, 181,  21, 149],
  132.         [255, 127, 223,  95, 247, 119, 215,  87, 253, 125, 221,  93, 245, 117, 213,  85],
  133.     ])
  134.  
  135.     AMOGUS_5 = np.array([
  136.         [ 0, 1, 1, 1, 0],
  137.         [ 1, 1, 0, 0, 0],
  138.         [ 1, 1, 1, 1, 0],
  139.         [ 0, 1, 0, 1, 0],
  140.         [ 0, 0, 0, 0, 0],
  141.     ])
  142.  
  143.     BAYER_N  = [BAYER_2, BAYER_4, BAYER_8, BAYER_16, AMOGUS_5]
  144.  
  145.     #gamma correct ok
  146.     BAYER_OK_N = [(bayer*bayer) / ((np.max(bayer)+1)**2) - (0.5 - 1e-8) for bayer in BAYER_N]
  147.  
  148.     #non-const
  149.     pixels = None #don't mutate after init
  150.     pixels_output = None #copy of pixels that can be modified
  151.  
  152.     size = None
  153.  
  154.     def __init__(self, input_path):
  155.         self.imgToOkPixels(input_path)
  156.  
  157.     #vals = reference to np.ndarray
  158.     def _quantize(self, vals, step_count: int):
  159.         vals[:] = np.round(vals*step_count)/step_count
  160.  
  161.     #public
  162.     def imgToOkPixels(self, img_path: str):
  163.         in_img = Image.open(img_path).convert("RGBA")
  164.         col_list = np.array(in_img, dtype=np.float64) / 255.0
  165.         col_list = col_list.reshape(-1, 4)
  166.         col_list[:,:3] = srgbToOklab(col_list[:,:3])
  167.  
  168.         self.pixels = col_list
  169.         self.pixels[:] = np.clip(self.pixels,[0,-0.5,-0.5,0.0],[1.0,0.5,0.5,1.0])
  170.         self.pixels_output = self.pixels.copy()
  171.         self.size = in_img.size
  172.  
  173.     def saveImage(self, output_path: str):
  174.         col_list = self.pixels_output.copy()
  175.         col_list[:,:3] = oklabToSrgb(col_list[:,:3])
  176.         rgba = np.clip(np.round(col_list * 255), 0, 255).astype(np.uint8)
  177.         rgba = rgba.reshape((self.size[1], self.size[0], 4))
  178.         img = Image.fromarray(rgba, "RGBA")
  179.         img.save(output_path)
  180.  
  181.     def quantizeAxes(self, col_list, step_count: int):
  182.         if not step_count:
  183.             return col_list
  184.         if col_list is None:
  185.             col_list = self.pixels_output
  186.         self._quantize(col_list,step_count)
  187.         return col_list
  188.  
  189.     def quantizeAlpha(self, alpha_count: int):
  190.         alpha = self.pixels_output[:,3]
  191.         if alpha_count == 0:
  192.             alpha[:] = np.zeros(len(alpha)) + 1.0
  193.         else:
  194.             self._quantize(alpha,alpha_count)
  195.    
  196.     def createUniqueList(self):
  197.         #strip dupes
  198.         unique_colors, unique_idxs, original_idxs = np.unique(self.pixels_output, axis=0, return_index=True, return_inverse=True)
  199.  
  200.         #area[original_index] = dupe_count, so area[0] is how many pixels are unique_color[0]
  201.         nontransp = self.pixels_output[:, 3] > (1.0 / 255.0) #exclude transparent
  202.         area = np.bincount(
  203.             original_idxs,
  204.         weights=nontransp,
  205.             minlength=len(unique_colors)
  206.       )
  207.  
  208.         self.unique_list = UniqueList(
  209.             unique_colors[:,:3],
  210.         unique_colors[:, 3],
  211.             area,
  212.             unique_idxs,
  213.             original_idxs
  214.         )
  215.  
  216.     #### palettize methods ###
  217.     def applyPalette(self, unique_palettized):
  218.         self.pixels_output[:,:3] = unique_palettized[self.unique_list.original_idxs]
  219.  
  220.     def ditherNone(self, palette_img):
  221.         pal_list = palette_img.unique_list
  222.         pixels = self.pixels_output[:,:3]
  223.  
  224.         tree = KDTree(pal_list.color)
  225.         _, idxs = tree.query(pixels, k=1, workers=-1)
  226.         self.pixels_output[:,:3] = pal_list.color[:,:3][idxs]
  227.  
  228.     def ditherOrdered(self, palette_img, matrix_size=1):
  229.         pal_list = palette_img.unique_list
  230.         pixels = self.pixels_output[:,:3].copy()
  231.  
  232.         matrix_size=np.clip( matrix_size, 0, len(self.BAYER_OK_N)-1 )
  233.         b_m=self.BAYER_OK_N[matrix_size]
  234.         b_h, b_w = b_m.shape
  235.  
  236.         y_idxs, x_idxs = np.divmod(np.arange(len(pixels)), self.size[0])
  237.         thresholds = b_m[y_idxs % b_h, x_idxs % b_w]
  238.  
  239.         #gap between two closest palette colors of current pixel
  240.         tree = KDTree(pal_list.color)
  241.         _, idxs = tree.query(pixels, k=2, workers=-1)
  242.         pixel_gaps = np.abs(pal_list.color[idxs[:,1]] - pal_list.color[idxs[:,0]])
  243.  
  244.         #apply mask
  245.         new_pixels = pixels + thresholds[:,None] * pixel_gaps
  246.         _, idxs = tree.query(new_pixels, k=1, workers=-1)
  247.         new_pixels = pal_list.color[idxs]
  248.  
  249.         self.pixels_output[:,:3] = new_pixels
  250.  
  251. def calcBucketScore(bucket_areas, col_dists, col_idxs, max_radius):
  252.     area_weight = max(8.0,4.0*max_radius)
  253.     max_bucket = max(1.0, max(bucket_areas))
  254.     return col_dists * (1.0 + area_weight * (bucket_areas[col_idxs]/max_bucket) )
  255.  
  256. #Map unique colors to palette, but avoid collapsing similar colors
  257. #Return unique_palettized[len(unique_list.color)] = [l,a,b,alpha]
  258. def createWeightedPalette(
  259.    src_img: OkImage,
  260.    palette_img: OkImage,
  261.    max_error: int = 1,
  262.     k_count = 13
  263.    ):
  264.     unique_list = src_img.unique_list
  265.     palette_list = palette_img.unique_list
  266.     pal_length = len(palette_list.color)
  267.  
  268.     max_radius = approxOkGap(pal_length) * max_error
  269.  
  270.     #accumulated area of colors in each palette bucket
  271.     bucket_areas = np.zeros(pal_length)
  272.  
  273.     #Closest palette colors
  274.     tree = KDTree(palette_list.color)
  275.     est_maxk = max_error * 12 #Sphere kissing number within 1 radius
  276.     k_count = min(max(2,k_count),pal_length,est_maxk)
  277.     dists, idxs = tree.query(unique_list.color, k = int(k_count), workers=-1)
  278.    
  279.     #choose palette index for each color
  280.     unique_count = len(unique_list.color)
  281.     unique_palettized = np.zeros((unique_count,3))
  282.  
  283.     #prioritize largest area
  284.     unique_sorted_idx = np.argsort(-1.0*unique_list.area)
  285.     for i in unique_sorted_idx:
  286.         #lowest dist and emptiest bucket ; lowest score = better
  287.         local_scores = calcBucketScore(bucket_areas, dists[i], idxs[i], max_radius)
  288.         mask = dists[i] <= max_radius
  289.    
  290.         if np.any(mask):
  291.             valid = np.where(mask)[0]
  292.             best_pos = valid[np.argmin(local_scores[valid])]
  293.             best_j = int(idxs[i][best_pos])
  294.         else:
  295.             #choose nearest if exceeds max_error
  296.             best_pos = 0
  297.             best_j = int(idxs[i][best_pos])
  298.    
  299.         unique_palettized[i] = palette_list.color[best_j]
  300.         bucket_areas[best_j] += unique_list.area[i]
  301.  
  302.     return unique_palettized
  303.    
  304.  
  305. ### Main functions ###
  306.  
  307. @dataclass
  308. class ConvertPreset:
  309.     image: str #file names
  310.     palette: str
  311.     output: str
  312.     alpha_count: int
  313.     max_error: float #radius that within neighboring palette colors can replace unique colors
  314.     merge_radius: float #quantize original image. >1.0 is lower quant than palette. May improve quality if you got thousands of unique colors and tiny palette
  315.     dither: int #only 0=None 1=ordered dither
  316.     dither_size: int #bayer matrix 0=2x2, 1=4x4, 2=8x8, 3=16x16
  317.  
  318. def palettizeImage(preset: ConvertPreset):
  319.    
  320.     palette_ok = OkImage(preset.palette)
  321.     palette_ok.quantizeAlpha(0)
  322.     palette_ok.createUniqueList()
  323.  
  324.     image_ok = OkImage(preset.image)
  325.     if preset.merge_radius:
  326.         axis_step_size = approxOkGap(len(palette_ok.unique_list.color)) * preset.merge_radius
  327.         axis_count = int(1.0/axis_step_size)
  328.         image_ok.quantizeAxes(None, axis_count)
  329.     image_ok.quantizeAlpha(preset.alpha_count)
  330.     image_ok.createUniqueList()
  331.  
  332.     #replace original img pixels with convert_dict
  333.     if preset.dither == 0:
  334.         if preset.max_error:
  335.             unique_palettized = createWeightedPalette(image_ok, palette_ok, preset.max_error)
  336.             image_ok.applyPalette(unique_palettized)
  337.         else:
  338.             image_ok.ditherNone(palette_ok)
  339.     if preset.dither == 1:
  340.         image_ok.ditherOrdered(palette_ok,preset.dither_size)
  341.  
  342.     image_ok.saveImage(preset.output)
  343.  
  344.  
  345.  
  346. if __name__ == '__main__':
  347.    
  348.     input_palette =  "./palettes/pal256.png"
  349.     preset_list = [
  350.         ConvertPreset(
  351.             image               = "./testImg/KaelygonLogo25.png",
  352.             palette         = input_palette,
  353.             output          = "./output/palettizedImg.png",
  354.             alpha_count     = 1,
  355.             max_error       = 2.0,
  356.             merge_radius    = 0.0,
  357.             dither          = 0,
  358.             dither_size     = 0,
  359.         ),
  360.         ConvertPreset(
  361.             image               = "./testImg/tienaPride.png",
  362.             palette         = input_palette,
  363.             output          = "./output/palettizedImg.png",
  364.             alpha_count     = 1,
  365.             max_error       = 1,
  366.             merge_radius    = 0.0,
  367.             dither          = 1,
  368.             dither_size     = 3,
  369.         ),
  370.         ConvertPreset(
  371.             image               = "./testImg/TienaPortrait.png",
  372.             palette         = input_palette,
  373.             output          = "./output/palettizedImg.png",
  374.             alpha_count     = 1,
  375.             max_error       = 1.0,
  376.             merge_radius    = 0.09,
  377.             dither          = 0,
  378.             dither_size     = 0,
  379.         ),
  380.         ConvertPreset(
  381.             image               = "./testImg/testChart.png",
  382.             palette         = input_palette,
  383.             output          = "./output/palettizedImg.png",
  384.             alpha_count     = 1,
  385.             max_error       = 1.0,
  386.             merge_radius    = 0.0,
  387.             dither          = 1,
  388.             dither_size     = 4,
  389.         ),
  390.         ConvertPreset(
  391.             image               = "./testImg/KaelygonSeawing.png",
  392.             palette         = input_palette,
  393.             output          = "./output/palettizedImg.png",
  394.             alpha_count     = 1,
  395.             max_error       = 0.0,
  396.             merge_radius    = 0.0,
  397.             dither          = 1,
  398.             dither_size     = 4,
  399.         )
  400.     ]
  401.    
  402.     preset_index = 4
  403.     palettizeImage( preset_list[preset_index]  )
Advertisement
Add Comment
Please, Sign In to add comment