Advertisement
Guest User

Untitled

a guest
Oct 9th, 2019
151
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.29 KB | None | 0 0
  1. def create_wavelet_plot(
  2. run_mode, points, fe_title, rowstart, power, period, levels, sig95,
  3. coi, feature_idx, cut_start, fe_type, index, filename, data_source,
  4. result
  5. ):
  6. """Create wavelet plots."""
  7. open_file = '/home/erics/Documents/test.json'
  8.  
  9. with open(open_file) as json_file:
  10. data = json.load(json_file)
  11.  
  12. points = data['points']
  13. fe_title = data['fe_title']
  14. rowstart = data['rowstart']
  15. power = np.asarray(data['power'])
  16. period = data['period']
  17. levels = data['levels']
  18. sig95 = data['sig95']
  19. coi = data['coi']
  20. feature_idx = data['feature_idx']
  21. cut_start = data['cut_start']
  22. fe_type = data['fe_type']
  23. index = data['index']
  24. filename = '/home/erics/Documents/test-wavelet.png'
  25. data_source = data['data_source']
  26. result = data['result']
  27.  
  28. feature_idx = [2000, 4500, 6000, 10000, 3000]
  29.  
  30. width, xlim_scale, period_min, period_max, time, shift =\
  31. _get_values_derived_from_points(
  32. points, cut_start, index, rowstart, period
  33. )
  34.  
  35. fig = plt.figure(figsize=(12, 5))
  36.  
  37. current_ax = fig.add_subplot(1, 2, 1)
  38. overlay_ax = fig.add_subplot(1, 2, 2)
  39.  
  40. current_ax.set_position([0.1, 0.1, 0.8, 0.8])
  41. overlay_ax.set_position([0.1, 0.1, 0.8, 0.8])
  42.  
  43. current_ax.set_title(fe_title)
  44. current_ax.set_xlabel('Points')
  45. current_ax.set_ylabel('Period')
  46. current_ax.set_xlim(xlim_scale)
  47. overlay_ax.set_xlim(xlim_scale)
  48. current_ax.set_ylim([
  49. period_min,
  50. period_max
  51. ])
  52.  
  53. power[power == 0] = sys.float_info.epsilon
  54. cs = current_ax.contourf(
  55. time, period, np.log2(power), len(levels))
  56. im = current_ax.contourf(
  57. cs, levels=np.log2(levels), cmap=cm.jet)
  58.  
  59. current_ax.contour(time, period, sig95, [-99, 1], colors='k')
  60. current_ax.plot(time, coi, 'k')
  61.  
  62. if feature_idx:
  63. ranked_x, ranked_y =\
  64. _add_feature_idx_info(
  65. run_mode, overlay_ax, feature_idx, time,
  66. period, cut_start, shift, fe_type, period_max,
  67. period_min, rowstart, width, data_source, result
  68. )
  69.  
  70. current_ax.set_yscale('log', basey=10, subsy=None)
  71. current_ax.invert_yaxis()
  72. overlay_ax.invert_yaxis()
  73. overlay_ax.set_axis_off()
  74.  
  75. divider = make_axes_locatable(current_ax)
  76. cax = divider.append_axes('bottom', size='5%', pad=0.5)
  77. bf_axis = divider.append_axes('bottom', size='5%', pad=0.5)
  78. bf_axis.set_axis_off()
  79. fig.colorbar(im, cax=cax, orientation='horizontal')
  80.  
  81. divider2 = make_axes_locatable(overlay_ax)
  82. cax2 = divider2.append_axes('bottom', size='5%', pad=0.5)
  83. bf_axis2 = divider2.append_axes('bottom', size='5%', pad=0.5)
  84. cax2.set_axis_off()
  85. bf_axis2.set_axis_off()
  86.  
  87. if ranked_x is not None:
  88. _add_feature_idx_details(
  89. time, period, ranked_x, ranked_y, bf_axis)
  90.  
  91. plt.tight_layout(pad=6, h_pad=5)
  92.  
  93. save_to_file(filename)
  94.  
  95.  
  96. def _get_values_derived_from_points(
  97. points, cut_start, index, rowstart, period
  98. ):
  99. time_scale = np.arange(points)
  100. xlim_scale = ([0, points - 1])
  101.  
  102. shift = cut_start[index] + 1
  103.  
  104. time_scale = time_scale + shift
  105. xlim_scale = ([
  106. xlim_scale[0] + shift,
  107. xlim_scale[1] + shift
  108. ])
  109.  
  110. width = (xlim_scale[1] - xlim_scale[0])/20
  111.  
  112. period_min = np.min(period)
  113. period_max = np.max(period)
  114.  
  115. xlim_scale = [x + rowstart for x in xlim_scale]
  116. time_scale = [t + rowstart for t in time_scale]
  117.  
  118. return (width, xlim_scale, period_min,
  119. period_max, time_scale, shift)
  120.  
  121.  
  122. def _add_feature_idx_info(
  123. run_mode, ax, feature_idx, time, period, cut_start,
  124. shift, fe_type, period_max, period_min, rowstart,
  125. width, data_source, result
  126. ):
  127. if (
  128. run_mode == RunMode.MODELING.value or
  129. run_mode == RunMode.TESTING.value and
  130. result in ('base', 'ok')
  131. ):
  132. x_values = np.array(feature_idx) % len(time)
  133. y_values = np.array(feature_idx) // len(time)
  134. y_values = [period[y_val] for y_val in y_values]
  135. y_values = [10.12323, 11.434134, 23.4352, 50.43424, 68.10023]
  136.  
  137. height = _get_height(y_values)
  138.  
  139. points = zip(x_values, y_values)
  140. ranked = list(zip(range(1, len(x_values)+1), points))
  141. ranks = [item[0] for item in ranked]
  142.  
  143. ranked_x = [item[1][0] for item in ranked]
  144. ranked_x = [x + shift for x in ranked_x]
  145. ranked_x = [x + rowstart for x in ranked_x]
  146.  
  147. ranked_y = [item[1][1] for item in ranked]
  148. all_points = sorted(zip(ranked_x, ranked_y, ranks),
  149. key=lambda x: (x[0], x[1]))
  150.  
  151. # This is a Wavelet file so this is always True
  152. isWavelet = True
  153. # Default Value for Multi Column Flag
  154. is_multi_column = False
  155. # Check if Multi Column Flag is true
  156. if data_source == DataSource.MULTI_COLUMN.value:
  157. is_multi_column = True
  158.  
  159. draw_cluster_rect(
  160. ax, all_points, width, height,
  161. period_max, period_min,
  162. isWavelet, is_multi_column
  163. )
  164.  
  165. return ranked_x, ranked_y
  166.  
  167. return None, None
  168.  
  169.  
  170. def _add_feature_idx_details(
  171. time, period, ranked_x, ranked_y, bf_axis
  172. ):
  173. best_feature = ''
  174. for i, (time, period) in enumerate(zip(ranked_x, ranked_y)):
  175. best_feature += (
  176. '[' + str(i+1) + '] ( ' +
  177. str(time) + ', ' +
  178. str('{0:.4f}'.format(period)) + ' )'
  179. )
  180. if i != len(ranked_x)-1:
  181. best_feature += ' '
  182.  
  183. bf_axis.text(
  184. 0.5, 0.0, best_feature,
  185. size=12, ha='center', transform=bf_axis.transAxes
  186. )
  187.  
  188.  
  189. def _get_height(y_values):
  190. """Return rectangle height."""
  191. max_val = int(max(y_values))
  192.  
  193. if max_val < 10:
  194. return max_val * 0.10
  195. else:
  196. result = []
  197.  
  198. result.append(math.floor(math.log(max_val, 10)))
  199. result.append(math.ceil(math.log(max_val, 10)))
  200. temp = result[:]
  201.  
  202. for i in range(len(result)):
  203. result[i] = pow(10, result[i])
  204. temp[i] = abs(max_val - result[i])
  205.  
  206. height = result[np.argmin(temp)] * 0.10
  207.  
  208. return 10 if height < 10 else height
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement