Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def create_wavelet_plot(
- run_mode, points, fe_title, rowstart, power, period, levels, sig95,
- coi, feature_idx, cut_start, fe_type, index, filename, data_source,
- result
- ):
- """Create wavelet plots."""
- open_file = '/home/erics/Documents/test.json'
- with open(open_file) as json_file:
- data = json.load(json_file)
- points = data['points']
- fe_title = data['fe_title']
- rowstart = data['rowstart']
- power = np.asarray(data['power'])
- period = data['period']
- levels = data['levels']
- sig95 = data['sig95']
- coi = data['coi']
- feature_idx = data['feature_idx']
- cut_start = data['cut_start']
- fe_type = data['fe_type']
- index = data['index']
- filename = '/home/erics/Documents/test-wavelet.png'
- data_source = data['data_source']
- result = data['result']
- feature_idx = [2000, 4500, 6000, 10000, 3000]
- width, xlim_scale, period_min, period_max, time, shift =\
- _get_values_derived_from_points(
- points, cut_start, index, rowstart, period
- )
- fig = plt.figure(figsize=(12, 5))
- current_ax = fig.add_subplot(1, 2, 1)
- overlay_ax = fig.add_subplot(1, 2, 2)
- current_ax.set_position([0.1, 0.1, 0.8, 0.8])
- overlay_ax.set_position([0.1, 0.1, 0.8, 0.8])
- current_ax.set_title(fe_title)
- current_ax.set_xlabel('Points')
- current_ax.set_ylabel('Period')
- current_ax.set_xlim(xlim_scale)
- overlay_ax.set_xlim(xlim_scale)
- current_ax.set_ylim([
- period_min,
- period_max
- ])
- power[power == 0] = sys.float_info.epsilon
- cs = current_ax.contourf(
- time, period, np.log2(power), len(levels))
- im = current_ax.contourf(
- cs, levels=np.log2(levels), cmap=cm.jet)
- current_ax.contour(time, period, sig95, [-99, 1], colors='k')
- current_ax.plot(time, coi, 'k')
- if feature_idx:
- ranked_x, ranked_y =\
- _add_feature_idx_info(
- run_mode, overlay_ax, feature_idx, time,
- period, cut_start, shift, fe_type, period_max,
- period_min, rowstart, width, data_source, result
- )
- current_ax.set_yscale('log', basey=10, subsy=None)
- current_ax.invert_yaxis()
- overlay_ax.invert_yaxis()
- overlay_ax.set_axis_off()
- divider = make_axes_locatable(current_ax)
- cax = divider.append_axes('bottom', size='5%', pad=0.5)
- bf_axis = divider.append_axes('bottom', size='5%', pad=0.5)
- bf_axis.set_axis_off()
- fig.colorbar(im, cax=cax, orientation='horizontal')
- divider2 = make_axes_locatable(overlay_ax)
- cax2 = divider2.append_axes('bottom', size='5%', pad=0.5)
- bf_axis2 = divider2.append_axes('bottom', size='5%', pad=0.5)
- cax2.set_axis_off()
- bf_axis2.set_axis_off()
- if ranked_x is not None:
- _add_feature_idx_details(
- time, period, ranked_x, ranked_y, bf_axis)
- plt.tight_layout(pad=6, h_pad=5)
- save_to_file(filename)
- def _get_values_derived_from_points(
- points, cut_start, index, rowstart, period
- ):
- time_scale = np.arange(points)
- xlim_scale = ([0, points - 1])
- shift = cut_start[index] + 1
- time_scale = time_scale + shift
- xlim_scale = ([
- xlim_scale[0] + shift,
- xlim_scale[1] + shift
- ])
- width = (xlim_scale[1] - xlim_scale[0])/20
- period_min = np.min(period)
- period_max = np.max(period)
- xlim_scale = [x + rowstart for x in xlim_scale]
- time_scale = [t + rowstart for t in time_scale]
- return (width, xlim_scale, period_min,
- period_max, time_scale, shift)
- def _add_feature_idx_info(
- run_mode, ax, feature_idx, time, period, cut_start,
- shift, fe_type, period_max, period_min, rowstart,
- width, data_source, result
- ):
- if (
- run_mode == RunMode.MODELING.value or
- run_mode == RunMode.TESTING.value and
- result in ('base', 'ok')
- ):
- x_values = np.array(feature_idx) % len(time)
- y_values = np.array(feature_idx) // len(time)
- y_values = [period[y_val] for y_val in y_values]
- y_values = [10.12323, 11.434134, 23.4352, 50.43424, 68.10023]
- height = _get_height(y_values)
- points = zip(x_values, y_values)
- ranked = list(zip(range(1, len(x_values)+1), points))
- ranks = [item[0] for item in ranked]
- ranked_x = [item[1][0] for item in ranked]
- ranked_x = [x + shift for x in ranked_x]
- ranked_x = [x + rowstart for x in ranked_x]
- ranked_y = [item[1][1] for item in ranked]
- all_points = sorted(zip(ranked_x, ranked_y, ranks),
- key=lambda x: (x[0], x[1]))
- # This is a Wavelet file so this is always True
- isWavelet = True
- # Default Value for Multi Column Flag
- is_multi_column = False
- # Check if Multi Column Flag is true
- if data_source == DataSource.MULTI_COLUMN.value:
- is_multi_column = True
- draw_cluster_rect(
- ax, all_points, width, height,
- period_max, period_min,
- isWavelet, is_multi_column
- )
- return ranked_x, ranked_y
- return None, None
- def _add_feature_idx_details(
- time, period, ranked_x, ranked_y, bf_axis
- ):
- best_feature = ''
- for i, (time, period) in enumerate(zip(ranked_x, ranked_y)):
- best_feature += (
- '[' + str(i+1) + '] ( ' +
- str(time) + ', ' +
- str('{0:.4f}'.format(period)) + ' )'
- )
- if i != len(ranked_x)-1:
- best_feature += ' '
- bf_axis.text(
- 0.5, 0.0, best_feature,
- size=12, ha='center', transform=bf_axis.transAxes
- )
- def _get_height(y_values):
- """Return rectangle height."""
- max_val = int(max(y_values))
- if max_val < 10:
- return max_val * 0.10
- else:
- result = []
- result.append(math.floor(math.log(max_val, 10)))
- result.append(math.ceil(math.log(max_val, 10)))
- temp = result[:]
- for i in range(len(result)):
- result[i] = pow(10, result[i])
- temp[i] = abs(max_val - result[i])
- height = result[np.argmin(temp)] * 0.10
- return 10 if height < 10 else height
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement