Advertisement
Guest User

Cosine Distance Visualization

a guest
Mar 10th, 2025
24
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.99 KB | Source Code | 0 0
  1. from matplotlib import pyplot as plt
  2. import numpy as np
  3. from matplotlib import cm
  4.  
  5. def hotcold(lutsize=256, neutral=1/3, interp=None):
  6. # From https://github.com/endolith/bipolar-colormap/blob/master/bipolar.py
  7. """
  8. Bipolar hot/cold colormap, with neutral central color.
  9.  
  10. This colormap is meant for visualizing diverging data; positive
  11. and negative deviations from a central value. It is similar to a "hot"
  12. blackbody colormap for positive values, but with a complementary
  13. "cold" colormap for negative values.
  14.  
  15. Parameters
  16. ----------
  17. lutsize : int
  18. The number of elements in the colormap lookup table. (Default is 256.)
  19. neutral : float
  20. The gray value for the neutral middle of the colormap. (Default is
  21. 1/3.)
  22. The colormap goes from cyan-blue-neutral-red-yellow if neutral
  23. is < 0.5, and from blue-cyan-neutral-yellow-red if `neutral` > 0.5.
  24. For shaded 3D surfaces, a `neutral` near 0.5 is better, because it
  25. minimizes luminance changes that would otherwise obscure shading cues
  26. for determining 3D structure.
  27. For 2D heat maps, a `neutral` near the 0 or 1 extremes is better, for
  28. maximizing luminance change and showing details of the data.
  29. interp : str or int, optional
  30. Specifies the type of interpolation.
  31. ('linear', 'nearest', 'zero', 'slinear', 'quadratic, 'cubic')
  32. or as an integer specifying the order of the spline interpolator
  33. to use. Default is 'linear' for dark neutral and 'cubic' for light
  34. neutral. See `scipy.interpolate.interp1d`.
  35.  
  36. Returns
  37. -------
  38. out : matplotlib.colors.LinearSegmentedColormap
  39. The resulting colormap object
  40.  
  41. Examples
  42. --------
  43. >>> from mpl_toolkits.mplot3d import Axes3D
  44. >>> import matplotlib.pyplot as plt
  45. >>> import numpy as np
  46. >>> from bipolar import hotcold
  47.  
  48. >>> x = y = np.arange(-4, 4, 0.15)
  49. >>> x, y = np.meshgrid(x, y)
  50. >>> z = (1 - x/2 + x**5 + y**3) * np.exp(-x**2 - y**2)
  51.  
  52. >>> fig, axs = plt.subplots(2, 2, figsize=(12, 8),
  53. ... subplot_kw={'projection': '3d'})
  54. >>> for ax, neutral in (((0, 0), 1/3), # Default
  55. ... ((0, 1), 0.1), # Dark gray as neutral
  56. ... ((1, 0), 0.9), # Light gray as neutral
  57. ... ((1, 1), 2/3),
  58. ... ):
  59. ... surf = axs[ax].plot_surface(x, y, z, rstride=1, cstride=1,
  60. ... vmax=abs(z).max(), vmin=-abs(z).max(),
  61. ... cmap=hotcold(neutral=neutral))
  62. >>> axs[ax].set_title(f'{neutral:.3f}')
  63. ... fig.colorbar(surf, ax=axs[ax])
  64. >>> plt.show()
  65.  
  66. References
  67. ----------
  68. .. [1] Lehmann Manja, Crutch SJ, Ridgway GR et al. "Cortical thickness
  69. and voxel-based morphometry in posterior cortical atrophy and typical
  70. Alzheimer's disease", Neurobiology of Aging, 2009,
  71. doi:10.1016/j.neurobiolaging.2009.08.017
  72.  
  73. """
  74. n = neutral
  75. if 0 <= n <= 0.5:
  76. if interp is None:
  77. # Seems to work well with dark neutral colors
  78. interp = 'linear'
  79.  
  80. data = (
  81. (0, 1, 1), # cyan
  82. (0, 0, 1), # blue
  83. (n, n, n), # dark neutral
  84. (1, 0, 0), # red
  85. (1, 1, 0), # yellow
  86. )
  87. elif 0.5 < n <= 1:
  88. if interp is None:
  89. # Seems to work better with bright neutral colors
  90. # Produces bright yellow or cyan rings otherwise
  91. interp = 'cubic'
  92.  
  93. data = (
  94. (0, 0, 1), # blue
  95. (0, 1, 1), # cyan
  96. (n, n, n), # light neutral
  97. (1, 1, 0), # yellow
  98. (1, 0, 0), # red
  99. )
  100. else:
  101. raise ValueError('n must be 0.0 < n < 1.0')
  102.  
  103. t = np.linspace(0, 1, lutsize//2)
  104.  
  105. # Super ugly Bezier curve
  106. # Do 2, one for each half, from nnn to 100 and from 001 to nnn
  107.  
  108. x1 = data[2][0]
  109. y1 = data[2][1]
  110. z1 = data[2][2]
  111.  
  112. xc = data[1][0]
  113. yc = data[1][1]
  114. zc = data[1][2]
  115.  
  116. x2 = data[0][0]
  117. y2 = data[0][1]
  118. z2 = data[0][2]
  119.  
  120. w = 1 # weight
  121.  
  122. r1 = (((1 - t)**2*x1 + 2*(1 - t)*t*w*xc + t**2*x2) /
  123. ((1 - t)**2 + 2*(1 - t)*t*w + t**2))
  124. g1 = (((1 - t)**2*y1 + 2*(1 - t)*t*w*yc + t**2*y2) /
  125. ((1 - t)**2 + 2*(1 - t)*t*w + t**2))
  126. b1 = (((1 - t)**2*z1 + 2*(1 - t)*t*w*zc + t**2*z2) /
  127. ((1 - t)**2 + 2*(1 - t)*t*w + t**2))
  128.  
  129. x1 = data[2][0]
  130. y1 = data[2][1]
  131. z1 = data[2][2]
  132.  
  133. xc = data[3][0]
  134. yc = data[3][1]
  135. zc = data[3][2]
  136.  
  137. x2 = data[4][0]
  138. y2 = data[4][1]
  139. z2 = data[4][2]
  140.  
  141. r2 = (((1 - t)**2*x1 + 2*(1 - t)*t*w*xc + t**2*x2) /
  142. ((1 - t)**2 + 2*(1 - t)*t*w + t**2))
  143. g2 = (((1 - t)**2*y1 + 2*(1 - t)*t*w*yc + t**2*y2) /
  144. ((1 - t)**2 + 2*(1 - t)*t*w + t**2))
  145. b2 = (((1 - t)**2*z1 + 2*(1 - t)*t*w*zc + t**2*z2) /
  146. ((1 - t)**2 + 2*(1 - t)*t*w + t**2))
  147.  
  148. rgb1 = np.dstack((r1, g1, b1))[0]
  149. rgb2 = np.dstack((r2, g2, b2))[0]
  150.  
  151. ynew = np.concatenate((rgb1[1:][::-1], rgb2))
  152.  
  153. return cm.colors.LinearSegmentedColormap.from_list('hotcold', ynew,
  154. lutsize)
  155.  
  156.  
  157. resolution = 1000
  158.  
  159. # Function to update the heatmap based on the new origin
  160. def update_heatmap(event):
  161. global heatmap, origin, center, cmap
  162. if event.inaxes:
  163. # Clear the previous heatmap
  164. ax.clear()
  165.  
  166.  
  167. origin = np.array([-(event.ydata), (event.xdata)])
  168.  
  169. coords = np.indices((resolution, resolution)).reshape(2, -1).T
  170. coords = (coords - center) / resolution
  171. distances = np.dot(coords, origin) / (np.linalg.norm(coords, axis=1) * np.linalg.norm(origin))
  172. heatmap = distances.reshape((resolution, resolution))
  173. ax.imshow(heatmap, cmap=cmap, interpolation='nearest', extent=[-1, 1, -1, 1])
  174.  
  175. # Draw an arrow from the center to the origin
  176. ax.quiver(0, 0, origin[1], -origin[0], angles='xy', scale_units='xy', scale=1, color='black')
  177.  
  178. fig.canvas.draw_idle()
  179.  
  180.  
  181.  
  182. # Create the initial heatmap
  183. heatmap = np.random.rand(resolution, resolution)
  184. origin = np.array([0.5, 0.5])
  185. center = np.array([resolution/2, resolution/2]) # Center of the heatmap
  186. cmap = hotcold(neutral=0.0)
  187. coords = np.indices((resolution, resolution)).reshape(2, -1).T
  188. coords = (coords - center) / resolution
  189. distances = np.dot(coords, origin) / (np.linalg.norm(coords, axis=1) * np.linalg.norm(origin))
  190. heatmap = distances.reshape((resolution, resolution))
  191.  
  192. # Plot the heatmap
  193. fig, ax = plt.subplots()
  194. ax.imshow(heatmap, cmap=cmap, interpolation='nearest', extent=[0, 1, 0, 1])
  195.  
  196. plt.colorbar(ax.imshow(heatmap, cmap=cmap, interpolation='nearest', extent=[0, 1, 0, 1]))
  197.  
  198. # Connect the event handler to the figure
  199. fig.canvas.mpl_connect('motion_notify_event', update_heatmap)
  200.  
  201. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement