SHARE
TWEET

Untitled

a guest Oct 9th, 2019 118 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def create_wavelet_plot(
  2.     run_mode, points, fe_title, rowstart, power, period, levels, sig95,
  3.     coi, feature_idx, cut_start, fe_type, index, filename, data_source,
  4.     result
  5. ):
  6.     """Create wavelet plots."""
  7.     open_file = '/home/erics/Documents/test.json'
  8.  
  9.     with open(open_file) as json_file:
  10.         data = json.load(json_file)
  11.  
  12.         points = data['points']
  13.         fe_title = data['fe_title']
  14.         rowstart = data['rowstart']
  15.         power = np.asarray(data['power'])
  16.         period = data['period']
  17.         levels = data['levels']
  18.         sig95 = data['sig95']
  19.         coi  = data['coi']
  20.         feature_idx = data['feature_idx']
  21.         cut_start = data['cut_start']
  22.         fe_type = data['fe_type']
  23.         index = data['index']
  24.         filename = '/home/erics/Documents/test-wavelet.png'
  25.         data_source = data['data_source']
  26.         result = data['result']
  27.  
  28.     feature_idx = [2000, 4500, 6000, 10000, 3000]
  29.  
  30.     width, xlim_scale, period_min, period_max, time, shift =\
  31.         _get_values_derived_from_points(
  32.             points, cut_start, index, rowstart, period
  33.         )
  34.  
  35.     fig = plt.figure(figsize=(12, 5))
  36.  
  37.     current_ax = fig.add_subplot(1, 2, 1)
  38.     overlay_ax = fig.add_subplot(1, 2, 2)
  39.  
  40.     current_ax.set_position([0.1, 0.1, 0.8, 0.8])
  41.     overlay_ax.set_position([0.1, 0.1, 0.8, 0.8])
  42.  
  43.     current_ax.set_title(fe_title)
  44.     current_ax.set_xlabel('Points')
  45.     current_ax.set_ylabel('Period')
  46.     current_ax.set_xlim(xlim_scale)
  47.     overlay_ax.set_xlim(xlim_scale)
  48.     current_ax.set_ylim([
  49.         period_min,
  50.         period_max
  51.     ])
  52.  
  53.     power[power == 0] = sys.float_info.epsilon
  54.     cs = current_ax.contourf(
  55.         time, period, np.log2(power), len(levels))
  56.     im = current_ax.contourf(
  57.         cs, levels=np.log2(levels), cmap=cm.jet)
  58.  
  59.     current_ax.contour(time, period, sig95, [-99, 1], colors='k')
  60.     current_ax.plot(time, coi, 'k')
  61.  
  62.     if feature_idx:
  63.         ranked_x, ranked_y =\
  64.             _add_feature_idx_info(
  65.                 run_mode, overlay_ax, feature_idx, time,
  66.                 period, cut_start, shift, fe_type, period_max,
  67.                 period_min, rowstart, width, data_source, result
  68.             )
  69.  
  70.     current_ax.set_yscale('log', basey=10, subsy=None)
  71.     current_ax.invert_yaxis()
  72.     overlay_ax.invert_yaxis()
  73.     overlay_ax.set_axis_off()
  74.  
  75.     divider = make_axes_locatable(current_ax)
  76.     cax = divider.append_axes('bottom', size='5%', pad=0.5)
  77.     bf_axis = divider.append_axes('bottom', size='5%', pad=0.5)
  78.     bf_axis.set_axis_off()
  79.     fig.colorbar(im, cax=cax, orientation='horizontal')
  80.  
  81.     divider2 = make_axes_locatable(overlay_ax)
  82.     cax2 = divider2.append_axes('bottom', size='5%', pad=0.5)
  83.     bf_axis2 = divider2.append_axes('bottom', size='5%', pad=0.5)
  84.     cax2.set_axis_off()
  85.     bf_axis2.set_axis_off()
  86.  
  87.     if ranked_x is not None:
  88.         _add_feature_idx_details(
  89.             time, period, ranked_x, ranked_y, bf_axis)
  90.  
  91.     plt.tight_layout(pad=6, h_pad=5)
  92.  
  93.     save_to_file(filename)
  94.  
  95.  
  96. def _get_values_derived_from_points(
  97.     points, cut_start, index, rowstart, period
  98. ):
  99.     time_scale = np.arange(points)
  100.     xlim_scale = ([0, points - 1])
  101.  
  102.     shift = cut_start[index] + 1
  103.  
  104.     time_scale = time_scale + shift
  105.     xlim_scale = ([
  106.         xlim_scale[0] + shift,
  107.         xlim_scale[1] + shift
  108.     ])
  109.  
  110.     width = (xlim_scale[1] - xlim_scale[0])/20
  111.  
  112.     period_min = np.min(period)
  113.     period_max = np.max(period)
  114.  
  115.     xlim_scale = [x + rowstart for x in xlim_scale]
  116.     time_scale = [t + rowstart for t in time_scale]
  117.  
  118.     return (width, xlim_scale, period_min,
  119.             period_max, time_scale, shift)
  120.  
  121.  
  122. def _add_feature_idx_info(
  123.     run_mode, ax, feature_idx, time, period, cut_start,
  124.     shift, fe_type, period_max, period_min, rowstart,
  125.     width, data_source, result
  126. ):
  127.     if (
  128.         run_mode == RunMode.MODELING.value or
  129.         run_mode == RunMode.TESTING.value and
  130.         result in ('base', 'ok')
  131.     ):
  132.         x_values = np.array(feature_idx) % len(time)
  133.         y_values = np.array(feature_idx) // len(time)
  134.         y_values = [period[y_val] for y_val in y_values]
  135.         y_values = [10.12323, 11.434134, 23.4352, 50.43424, 68.10023]
  136.  
  137.         height = _get_height(y_values)
  138.  
  139.         points = zip(x_values, y_values)
  140.         ranked = list(zip(range(1, len(x_values)+1), points))
  141.         ranks = [item[0] for item in ranked]
  142.  
  143.         ranked_x = [item[1][0] for item in ranked]
  144.         ranked_x = [x + shift for x in ranked_x]
  145.         ranked_x = [x + rowstart for x in ranked_x]
  146.  
  147.         ranked_y = [item[1][1] for item in ranked]
  148.         all_points = sorted(zip(ranked_x, ranked_y, ranks),
  149.                             key=lambda x: (x[0], x[1]))
  150.  
  151.         # This is a Wavelet file so this is always True
  152.         isWavelet = True
  153.         # Default Value for Multi Column Flag
  154.         is_multi_column = False
  155.         # Check if Multi Column Flag is true
  156.         if data_source == DataSource.MULTI_COLUMN.value:
  157.             is_multi_column = True
  158.  
  159.         draw_cluster_rect(
  160.             ax, all_points, width, height,
  161.             period_max, period_min,
  162.             isWavelet, is_multi_column
  163.         )
  164.  
  165.         return ranked_x, ranked_y
  166.  
  167.     return None, None
  168.  
  169.  
  170. def _add_feature_idx_details(
  171.     time, period, ranked_x, ranked_y, bf_axis
  172. ):
  173.     best_feature = ''
  174.     for i, (time, period) in enumerate(zip(ranked_x, ranked_y)):
  175.         best_feature += (
  176.             '[' + str(i+1) + '] ( ' +
  177.             str(time) + ', ' +
  178.             str('{0:.4f}'.format(period)) + ' )'
  179.         )
  180.         if i != len(ranked_x)-1:
  181.             best_feature += '   '
  182.  
  183.     bf_axis.text(
  184.         0.5, 0.0, best_feature,
  185.         size=12, ha='center', transform=bf_axis.transAxes
  186.     )
  187.  
  188.  
  189. def _get_height(y_values):
  190.     """Return rectangle height."""
  191.     max_val = int(max(y_values))
  192.  
  193.     if max_val < 10:
  194.         return max_val * 0.10
  195.     else:
  196.         result = []
  197.  
  198.         result.append(math.floor(math.log(max_val, 10)))
  199.         result.append(math.ceil(math.log(max_val, 10)))
  200.         temp = result[:]
  201.  
  202.         for i in range(len(result)):
  203.             result[i] = pow(10, result[i])
  204.             temp[i] = abs(max_val - result[i])
  205.  
  206.         height = result[np.argmin(temp)] * 0.10
  207.  
  208.         return 10 if height < 10 else height
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top