Advertisement
gehtsiegarnixan

Lenna Gaussian Mixture

Feb 5th, 2025 (edited)
35
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.16 KB | Photo | 0 0
  1. import numpy as np
  2. from skimage import io
  3. from skimage.transform import resize
  4. from sklearn.mixture import GaussianMixture
  5. from sklearn.mixture import BayesianGaussianMixture
  6.  
  7.  
  8. def prepare_image_data(image_path, new_width, mode='rgb'):
  9.     """
  10.    Loads an image, resizes it, and prepares the data for clustering.
  11.  
  12.    Args:
  13.    - image_path (str): Path to the input image.
  14.    - new_width (int): New width for the resized image, aspect ratio will be preserved.
  15.    - mode (str): Mode for processing the image data ('rgb' or 'brightness').
  16.  
  17.    Returns:
  18.    - data (numpy.ndarray): The prepared data containing normalized coordinates and either RGB values or brightness values.
  19.    """
  20.     # Load the image
  21.     image = io.imread(image_path)
  22.  
  23.     # Get original dimensions
  24.     orig_height, orig_width, _ = image.shape
  25.  
  26.     # Set new height to preserve aspect ratio
  27.     new_height = int((new_width / orig_width) * orig_height)
  28.  
  29.     # Resize the image while preserving aspect ratio
  30.     image_resized = resize(image, (new_height, new_width), anti_aliasing=True)
  31.  
  32.     if mode == 'rgb':
  33.         # Extract RGB values and normalize them to [0, 1]
  34.         values = image_resized.reshape(-1, 3)
  35.     elif mode == 'brightness':
  36.         # Calculate brightness as max(r, g, b) and normalize to [0, 1]
  37.         values = np.max(image_resized, axis=2).flatten().reshape(-1, 1)
  38.  
  39.     # Create coordinate grid
  40.     y_indices, x_indices = np.meshgrid(
  41.         np.linspace(1, 0, new_height), np.linspace(0, 1, new_width), indexing='ij'
  42.     )
  43.  
  44.     # Reorder data so that coordinates come first (x, y, values)
  45.     data = np.column_stack((
  46.         x_indices.flatten(),  # Normalized x-coordinates
  47.         y_indices.flatten(),  # Normalized y-coordinates
  48.         values  # Flattened values (either RGB or brightness)
  49.     ))
  50.  
  51.     return data
  52.  
  53.  
  54. def generate_glsl_data(gmm):
  55.     """
  56.    Converts clustering results from a Gaussian Mixture Model into GLSL-ready data definitions.
  57.  
  58.    Args:
  59.    - gmm (GaussianMixture): The fitted GMM object.
  60.  
  61.    Returns:
  62.    - str: GLSL code string containing the cluster information.
  63.    """
  64.     # Extracting cluster means and covariances from the GMM
  65.     weights = gmm.weights_
  66.     cluster_means = gmm.means_
  67.     covariances = gmm.covariances_
  68.     covariance_type = gmm.covariance_type  # Get the covariance type
  69.  
  70.     # Inferred n_components from the length of cluster_means
  71.     n_components = len(cluster_means)
  72.  
  73.     # Determine if using RGB or brightness mode
  74.     use_rgb = cluster_means.shape[1] == 5  # Assuming [x, y, r, g, b] for RGB mode
  75.  
  76.     # Prepare GLSL output
  77.     output = f"#define COUNT {n_components}\n\n"
  78.     # Example Python code to generate GLSL code for weights
  79.  
  80.     # Printing weight of a cluster definitions
  81.     output += "    const float weights[COUNT] = float[]("
  82.     for i, weight in enumerate(weights):
  83.         output += f"{weight:.2e}"
  84.         if i < len(weights) - 1:
  85.             output += ","
  86.     output += ");\n"
  87.  
  88.     # Printing color definitions
  89.     output += "    const vec3 colors[COUNT] = vec3[]("
  90.     for i, mean in enumerate(cluster_means):
  91.         if use_rgb:
  92.             r, g, b = mean[2], mean[3], mean[4]
  93.             output += f"vec3({r:.3f},{g:.3f},{b:.3f})"
  94.         else:
  95.             r = mean[2]  # Use the third value for grayscale
  96.             output += f"vec3({r:.3f})"
  97.         if i < len(cluster_means) - 1:
  98.             output += ","
  99.     output += ");\n"
  100.  
  101.     # Printing coordinate center definitions
  102.     output += "    const vec2 positions[COUNT] = vec2[]("
  103.     for i, mean in enumerate(cluster_means):
  104.         x, y = mean[0], mean[1]
  105.         output += f"vec2({x:.3f},{y:.3f})"
  106.         if i < len(cluster_means) - 1:
  107.             output += ","
  108.     output += ");\n"
  109.  
  110.     # Handle different covariance types
  111.     output += "    const mat2 covariances[COUNT] = mat2[]("
  112.     for i, covariance in enumerate(covariances):
  113.         if covariance_type == 'spherical':
  114.             variance = covariance  # Spherical covariance is a single variance value
  115.             output += f"mat2({variance:.2e},0,0,{variance:.2e})"
  116.         elif covariance_type == 'full':
  117.             c00, c01 = covariance[0, 0], covariance[0, 1]
  118.             c10, c11 = covariance[1, 0], covariance[1, 1]
  119.             output += f"mat2({c00:.2e},{c01:.2e},{c10:.2e},{c11:.2e})"
  120.         if i < len(covariances) - 1:
  121.             output += ","
  122.     output += ");\n"
  123.  
  124.     # Replace "0." with "."
  125.     output = output.replace("0.", ".")
  126.  
  127.     return output
  128.  
  129.  
  130. # Path to image to convert and down-sampling and number of clusters
  131. image_path = 'Textures/Lenna.png'
  132. new_width = 128
  133. clusters = 512
  134.  
  135. data = prepare_image_data(image_path, new_width, mode='rgb')
  136.  
  137. # Fit the Gaussian Mixture Model
  138. model = GaussianMixture(n_components=clusters, covariance_type='full', verbose=2)
  139.  
  140. # Fit the Bayesian Gaussian Mixture Model (claims it's better but looks blurry and is 2x slower)
  141. # model = BayesianGaussianMixture(n_components=clusters, covariance_type='full', verbose=2)
  142.  
  143. # Run the model
  144. model.fit(data)
  145.  
  146. # Print the GLSL code to use in your shader
  147. print(generate_glsl_data(model))
  148.  
  149. # See Shader code at https://www.shadertoy.com/view/43Gfzt
  150.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement