Guest User

Untitled

a guest
Oct 19th, 2017
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.55 KB | None | 0 0
  1. import re
  2. import os
  3. import glob
  4. import pandas as pd
  5.  
  6. import matplotlib.pyplot as plt
  7.  
  8. from collections import OrderedDict
  9.  
  10. fig_size = [12, 9]
  11. plt.rcParams["figure.figsize"] = fig_size
  12.  
  13.  
  14. def parse_filepath(string):
  15. head, tail = os.path.split(string)
  16. return re.search(r'_(.*),', tail).group(1)
  17.  
  18.  
  19. def load_csv_files(path, index_column='Step', wanted_columns=[1, 2], skip_rows=2):
  20. expanded_path = os.path.abspath(path)
  21. csv_filepaths = glob.glob(os.path.join(expanded_path, '*.csv'))
  22. dataframe_from_each_file = (pd.read_csv(filepath,
  23. index_col=index_column,
  24. names=['Wall time', 'Step', parse_filepath(filepath)],
  25. skiprows=skip_rows,
  26. header=None,
  27. usecols=wanted_columns)
  28. for filepath in csv_filepaths)
  29. concatenated_dataframe = pd.concat(dataframe_from_each_file, axis=1, join='inner')
  30. return concatenated_dataframe
  31.  
  32.  
  33. experiment_logs = load_csv_files('./experiment_results/mnist_losses')
  34.  
  35.  
  36. def add_rolling_average(dataframe, window=6):
  37. for column in dataframe:
  38. dataframe[f'{column}_rolling_average'] = dataframe[column].rolling(window=window).mean()
  39.  
  40.  
  41. add_rolling_average(experiment_logs)
  42.  
  43. colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
  44.  
  45.  
  46. def graph_experiment_losses(dataframe, number_of_logs=4, filters=('average')):
  47. fig = plt.figure()
  48. ax = fig.add_subplot()
  49. axes = experiment_logs.plot(ax=ax, color=colors[:number_of_logs], logy=True, sort_columns=True)
  50. axes.set_xlabel('Step', fontsize=16)
  51. axes.set_ylabel('Training Loss', fontsize=16)
  52. axes.spines["top"].set_visible(False)
  53. axes.spines["bottom"].set_visible(False)
  54. axes.spines["right"].set_visible(False)
  55. axes.spines["left"].set_visible(False)
  56. plt.xticks(fontsize=12)
  57. plt.yticks(fontsize=12)
  58. for i in range(number_of_logs):
  59. axes.lines[i].set_alpha(0.3)
  60. axes.lines[i].set_alpha(0.3)
  61. handles, labels = plt.gca().get_legend_handles_labels()
  62. by_label = OrderedDict((label, handel) for label, handel in zip(labels, handles) if not label.endswith(filters))
  63. legend = plt.legend(by_label.values(),by_label.keys(), frameon=False, prop={'size': 12})
  64. for leg in legend.legendHandles:
  65. leg.set_alpha(1)
  66.  
  67.  
  68. graph_experiment_losses(experiment_logs)
  69.  
  70. plt.savefig('losses.png', format='png', bbox_inches='tight')
Add Comment
Please, Sign In to add comment