Advertisement
gehtsiegarnixan

Image to Gaussian Mixture

Feb 5th, 2025 (edited)
172
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.44 KB | Photo | 0 0
  1. import numpy as np
  2. from skimage import io, color
  3. from skimage.transform import resize
  4. from sklearn.mixture import GaussianMixture
  5. from sklearn.mixture import BayesianGaussianMixture
  6.  
  7.  
  8. def prepare_image_data(image_path, new_width, mode='rgb'):
  9.     """
  10.    Loads an image, resizes it, and prepares the data for clustering.
  11.  
  12.    Args:
  13.    - image_path (str): Path to the input image.
  14.    - new_width (int): New width for the resized image, aspect ratio will be preserved.
  15.    - mode (str): Mode for processing the image data ('rgb', 'lab', or 'grey').
  16.  
  17.    Returns:
  18.    - data (numpy.ndarray): The prepared data containing normalized coordinates and either RGB, Lab, or brightness values.
  19.    """
  20.     # Load the image
  21.     image = io.imread(image_path)
  22.  
  23.     # Get original dimensions
  24.     orig_height, orig_width, _ = image.shape
  25.  
  26.     # Set new height to preserve aspect ratio
  27.     new_height = int((new_width / orig_width) * orig_height)
  28.  
  29.     # Resize the image while preserving aspect ratio
  30.     image_resized = resize(image, (new_height, new_width), anti_aliasing=True)
  31.  
  32.     if mode == 'rgb':
  33.         # Extract RGB values and normalize them to [0, 1]
  34.         values = image_resized.reshape(-1, 3)
  35.     elif mode == 'lab':
  36.         # Convert RGB to Lab color space
  37.         lab_values = color.rgb2lab(image_resized).reshape(-1, 3)
  38.  
  39.         # Normalize to [0,1] range
  40.         lab_values[:, 0] = lab_values[:, 0] / 100.0  # Scale L to [0,1]
  41.         lab_values[:, 1] = (lab_values[:, 1] + 128) / 255.0  # Scale a to [0,1]
  42.         lab_values[:, 2] = (lab_values[:, 2] + 128) / 255.0  # Scale b to [0,1]
  43.  
  44.         values = lab_values
  45.     elif mode == 'grey':
  46.         # Luminosity method: Y = 0.299 * R + 0.587 * G + 0.114 * B
  47.         values = np.dot(image_resized[..., :3], [0.299, 0.587, 0.114]).flatten().reshape(-1, 1)
  48.     else:
  49.         raise ValueError("Invalid mode. Use 'rgb', 'lab', or 'grey'.")
  50.  
  51.     # Create coordinate grid
  52.     y_indices, x_indices = np.meshgrid(
  53.         np.linspace(1, 0, new_height), np.linspace(0, 1, new_width), indexing='ij'
  54.     )
  55.  
  56.     # Reorder data so that coordinates come first (x, y, values)
  57.     data = np.column_stack((
  58.         x_indices.flatten(),  # Normalized x-coordinates
  59.         y_indices.flatten(),  # Normalized y-coordinates
  60.         values  # RGB, Lab, or grayscale values
  61.     ))
  62.  
  63.     return data
  64.  
  65.  
  66. def prepare_image_data_with_mask(image_path, mask_path, num_samples, mode='rgb', random_seed=None):
  67.     """
  68.    Loads an image and a focus mask, then selects pixels randomly based on the mask brightness.
  69.  
  70.    Args:
  71.    - image_path (str): Path to the input image.
  72.    - mask_path (str): Path to the grayscale focus mask. The brighter, the more important.
  73.    - num_samples (int): Total number of pixels to sample.
  74.    - mode (str): Mode for processing the image data ('rgb', 'lab', or 'grey').
  75.    - random_seed (int, optional): Seed for reproducible sampling. If None, results are random.
  76.  
  77.    Returns:
  78.    - data (numpy.ndarray): The selected data containing coordinates, color values, and focus weights.
  79.    """
  80.     # Load images
  81.     image = io.imread(image_path) / 255.0  # Normalize image to [0,1]
  82.     mask = io.imread(mask_path) / 255.0
  83.  
  84.     # If the mask has multiple channels, take only the first (red channel)
  85.     if mask.ndim == 3:
  86.         mask = mask[..., 0]  # Select the first channel (Red)
  87.  
  88.     # Get image dimensions
  89.     height, width = image.shape[:2]
  90.  
  91.     # Create coordinate grid
  92.     y_coords, x_coords = np.meshgrid(np.linspace(1, 0, height), np.linspace(0, 1, width), indexing='ij')
  93.  
  94.     # Flatten data
  95.     x_coords = x_coords.flatten()
  96.     y_coords = y_coords.flatten()
  97.     mask_values = mask.flatten()
  98.  
  99.     if mode == 'rgb':
  100.         values = image.reshape(-1, 3)  # RGB values
  101.     elif mode == 'lab':
  102.         # Convert RGB to Lab color space
  103.         lab_values = color.rgb2lab(image).reshape(-1, 3)
  104.  
  105.         # Normalize to [0,1] range
  106.         lab_values[:, 0] = lab_values[:, 0] / 100.0  # Scale L to [0,1]
  107.         lab_values[:, 1] = (lab_values[:, 1] + 128) / 255.0  # Scale a to [0,1]
  108.         lab_values[:, 2] = (lab_values[:, 2] + 128) / 255.0  # Scale b to [0,1]
  109.  
  110.         values = lab_values
  111.     elif mode == 'grey':
  112.         # Luminosity method: Y = 0.299 * R + 0.587 * G + 0.114 * B
  113.         values = np.dot(image[..., :3], [0.299, 0.587, 0.114]).flatten().reshape(-1, 1)
  114.     else:
  115.         raise ValueError("Invalid mode. Use 'rgb' or 'grey'.")
  116.  
  117.     # Normalize mask to use as probabilities
  118.     mask_values += 1e-6  # Avoid division by zero
  119.     mask_values /= np.sum(mask_values)
  120.  
  121.     # Set random seed if provided
  122.     if random_seed is not None:
  123.         np.random.seed(random_seed)
  124.  
  125.     # Sample pixels based on mask probabilities
  126.     indices = np.random.choice(len(mask_values), size=num_samples, p=mask_values)
  127.  
  128.     # Select data points
  129.     sampled_x = x_coords[indices]
  130.     sampled_y = y_coords[indices]
  131.     sampled_values = values[indices]
  132.  
  133.     # Combine into final data array
  134.     data = np.column_stack((sampled_x, sampled_y, sampled_values))
  135.  
  136.     return data
  137.  
  138.  
  139. def find_scaling_factor(values, target_digits=3):
  140.     """
  141.    Finds the scaling factor for a given dataset to ensure values are represented as
  142.    integers with the specified number of digits. This scaling factor will normalize
  143.    the data such that the largest value is scaled to a 3-digit integer.
  144.  
  145.    Args:
  146.        values (numpy.ndarray): The dataset to find the scaling factor for.
  147.        target_digits (int, optional): The number of digits to scale the values to.
  148.                                       Defaults to 3.
  149.  
  150.    Returns:
  151.        float: The scaling factor to scale the values to the target number of digits.
  152.    """
  153.     # Find the largest value in the dataset
  154.     max_val = np.max(np.abs(values))
  155.     if max_val == 0:
  156.         return 1  # Avoid division by zero if values are zero
  157.  
  158.     # Find the scaling factor (targeting a 3-digit integer)
  159.     scale_factor = 10 ** (np.floor(np.log10(max_val)) - (target_digits - 1))
  160.     return scale_factor
  161.  
  162.  
  163. def generate_glsl_data(gmm):
  164.     """
  165.    Converts the clustering results from a Gaussian Mixture Model (GMM) into GLSL-ready data definitions.
  166.  
  167.    Args:
  168.        gmm (sklearn.mixture.GaussianMixture): The trained Gaussian Mixture Model object containing:
  169.            - `weights_`: The mixture component weights.
  170.            - `means_`: The cluster means (centroids) of the GMM.
  171.            - `covariances_`: The covariance matrices of the GMM.
  172.  
  173.    Returns:
  174.        str: A GLSL-compatible string that includes:
  175.            - `COUNT`: The number of Gaussian components.
  176.            - `SCALES`: A `vec4` containing scaling factors for weights, positions, colors, and covariances.
  177.            - `WEIGHTS`: A GLSL `int` array representing the scaled weights of the components.
  178.            - `COLORS`: A GLSL `vec3` array representing the scaled RGB or brightness values.
  179.            - `POSITIONS`: A GLSL `vec2` array representing the scaled positions (x, y) of the components.
  180.            - `COVARIANCES`: A GLSL `mat2` array representing the scaled 2x2 covariance matrices.
  181.    """
  182.     # Extracting cluster means and covariances from the GMM
  183.     weights = gmm.weights_
  184.     cluster_means = gmm.means_
  185.     covariances = gmm.covariances_
  186.  
  187.     # Inferred number of components
  188.     n_components = len(cluster_means)
  189.  
  190.     # Determine if using RGB or brightness mode
  191.     use_rgb = cluster_means.shape[1] == 5  # Assuming [x, y, r, g, b] for RGB mode
  192.  
  193.     # Extract the first 4 covariance elements affecting the position only
  194.     cov_selected = np.array([cov[:2, :2].flatten() for cov in covariances])
  195.  
  196.     # Compute scaling factors for each type of value
  197.     weight_scale = find_scaling_factor(weights)
  198.     position_scale = find_scaling_factor(cluster_means[:, :2])
  199.     if use_rgb:
  200.         color_scale = find_scaling_factor(cluster_means[:, 2:5])
  201.     else:
  202.         color_scale = find_scaling_factor(cluster_means[:, 2:3])
  203.     cov_scale = find_scaling_factor(cov_selected, target_digits=4)
  204.  
  205.     # Prepare GLSL output
  206.     output = f"\n#define COUNT {n_components}\n"
  207.     output += f"#define SCALES vec4({weight_scale:.1e},{position_scale:.1e},{color_scale:.1e},{cov_scale:.1e})\n"
  208.     output = output.replace("1.0e-0", "1e-")  # Safe some characters
  209.  
  210.     # Weights (scaled to 3 digits)
  211.     output += "const int WEIGHTS[COUNT] = int[](" + ",".join(f"{int(w / weight_scale):d}" for w in weights) + ");\n"
  212.  
  213.     # Colors (scaled to 3 digits)
  214.     output += "const vec3 COLORS[COUNT] = vec3[]("
  215.     output += ",".join(f"v({int(r / color_scale):d},{int(g / color_scale):d},{int(b / color_scale):d})"
  216.                        if use_rgb else f"vec3({int(r / color_scale):d})"
  217.                        for r, g, b in (cluster_means[:, 2:5]
  218.                                        if use_rgb else np.c_[cluster_means[:, 2], cluster_means[:, 2], cluster_means[:, 2]]))
  219.     output += ");\n"
  220.  
  221.     # Positions (scaled to 3 digits)
  222.     output += "const vec2 POSITIONS[COUNT] = vec2[]("
  223.     output += ",".join(f"u({int(x / position_scale):d},{int(y / position_scale):d})"
  224.                        for x, y in cluster_means[:, :2])
  225.     output += ");\n"
  226.  
  227.     # Covariances (2x2 sampled, scaled to 3 digits)
  228.     output += "const mat2 COVARIANCES[COUNT] = mat2[]("
  229.     output += ",".join(f"m({int(c[0] / cov_scale):d},{int(c[1] / cov_scale):d},{int(c[2] / cov_scale):d},{int(c[3] / cov_scale):d})"
  230.                        for c in cov_selected)
  231.     output += ");"
  232.  
  233.     return output
  234.  
  235.  
  236. # Path to image to convert and down-sampling and number of clusters
  237. # example images can be found here https://imgur.com/a/zQYTfFQ
  238. image_path = 'Textures/Lenna.png'
  239. mask_path = 'Textures/Lenna_Importance.png'  # optional importance mask made with MS Paint
  240. new_width = 128  # increase for better quality results
  241. clusters = 512
  242.  
  243. # data = prepare_image_data(image_path, new_width, mode='lab')
  244. data = prepare_image_data_with_mask(image_path, mask_path, new_width * new_width, mode='lab')
  245.  
  246. # Fit the Gaussian Mixture Model
  247. model = GaussianMixture(n_components=clusters, max_iter=100, covariance_type='full', init_params='k-means++', verbose=2)
  248.  
  249. # Claims it's a better model but takes 3x as long and makes very blurry results
  250. # model = BayesianGaussianMixture(n_components=clusters, max_iter=300, covariance_type='full', init_params='k-means++', verbose=2)
  251.  
  252. model.fit(data)
  253.  
  254. # Print the GLSL code to use in your shader
  255. print(generate_glsl_data(model))
  256.  
  257. # See Shader code at https://www.shadertoy.com/view/43Gfzt
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement