Advertisement
Guest User

Untitled

a guest
Oct 19th, 2017
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.47 KB | None | 0 0
  1. import sys
  2. from analysis_plots.parse import parse_info
  3. from paper_plots import paper_conf as conf
  4.  
  5. import gflags
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from matplotlib import rc
  9.  
  10. from paper_plots.stat_util import pooled_sd
  11.  
  12. FLAGS = gflags.FLAGS
  13.  
  14. gflags.DEFINE_string('path', '/Users/Michael/Documents/deep_config/paper_results/baseline_run_1/,'
  15.                              '/Users/Michael/Documents/deep_config/paper_results/baseline_run_2/,'
  16.                              '/Users/Michael/Documents/deep_config/paper_results/baseline_run_3/',
  17.                      'path to input')
  18. gflags.DEFINE_string('prefix', 'dba_train_tiny_14400,'
  19.                                'default_train_tiny_14400,'
  20.                                'random_train_tiny_14400,'
  21.                                'online_train_simple_tiny_14400,'
  22.                                'no_pretrain_train_complex_tiny_14400,'
  23.                                'online_train_complex_tiny_14400'
  24.                      , 'experiment_prefix')
  25.  
  26. #RL-pre,RL-online,RL-pre-perturb
  27.  
  28. gflags.DEFINE_string('label', 'DBA,No indexing,Random,RL-pre,RL-online,RL-pre-perturb',
  29.                      'experiment label')
  30. gflags.DEFINE_string('export_path', '/Users/Michael/Documents/tf-paper/img/',
  31.                      'export dir')
  32.  
  33. gflags.DEFINE_integer('queries', 10,
  34.                       'experiment label')
  35. gflags.DEFINE_integer('clients', 10, 'clients')
  36.  
  37. rc('text', usetex=True)
  38. rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
  39. params = {'text.latex.preamble': [r'\usepackage{upgreek}', r'\usepackage{amsmath}']}
  40. plt.rcParams.update(params)
  41.  
  42.  
  43. def main(argv):
  44.     try:
  45.         argv = FLAGS(argv)
  46.     except gflags.FlagsError as e:
  47.         print('%s\\nUsage: %s ARGS\\n%s' % (e, sys.argv[0], FLAGS))
  48.  
  49.     path = FLAGS.path.split(',')
  50.     prefix = FLAGS.prefix.split(',')
  51.     label = FLAGS.label.split(',')
  52.     export_path = FLAGS.export_path
  53.  
  54.     clients = FLAGS.clients
  55.     queries = FLAGS.queries
  56.  
  57.     iterations = 3
  58.     # TODO das mit paths ist noch falsch
  59.     plot_queries(export_path, label, path, prefix, clients, queries, 0, 'train_', iterations)
  60.  
  61.  
  62. def plot_queries(export_path, label, path, experiment_variant, clients, queries, offset, result_prefix, iterations):
  63.     ind = np.arange(queries - 5)
  64.     width = 0.15
  65.     # Iterate over all prefixes. For each prefix, compute all queries
  66.  
  67.     for k in range(len(experiment_variant)):
  68.         iteration_means = np.zeros((iterations, (queries - 5)))
  69.         iteration_stds = np.zeros((iterations, (queries - 5)))
  70.         sample_sizes = np.zeros((iterations, (queries - 5)))
  71.  
  72.         for iteration in range(iterations):
  73.             update, query = parse_info(path[iteration] + '/' + experiment_variant[k], clients, queries, False, offset)
  74.  
  75.             for i in range(len(query) - 5):
  76.                 # print('Length of query content' + str(len(query[i])))
  77.                 scaled_up = np.asarray(query[i] * 1000)  # in ms
  78.                 iteration_means[iteration][i] = np.median(scaled_up)
  79.                 # print('Mean value for query {} is {}'.format( offset + i + 1, iteration_means[iteration][i]))
  80.                 iteration_stds[iteration][i] = np.std(scaled_up)
  81.                 sample_sizes[iteration][i] = len(scaled_up)
  82.  
  83.         means = np.mean(iteration_means, axis=0)
  84.         stds = []
  85.  
  86.         for query in range(queries - 5):
  87.             # Das hier sind jetzt die einzelnen stds fuer die query entlang der iteration achse
  88.             std_list = list(iteration_stds[:, query])
  89.             sizes = list(sample_sizes[:, query])
  90.  
  91.             stds.append(pooled_sd(sd_list=std_list, sample_sizes=sizes, num_samples=iterations))
  92.  
  93.         plt.bar(ind + k * width, means, yerr=stds, bottom=0, width=width,
  94.                 color=conf.greyscales[k], hatch=conf.hatches[k])
  95.  
  96.     plt.tick_params(axis='both', which='major', labelsize=conf.font_size)
  97.     plt.xlabel('Training query id', fontsize=conf.font_size)
  98.     plt.xticks(ind, ('Q1', 'Q2', 'Q3', 'Q4', 'Q5'))
  99.     plt.ylabel('Median execution time (ms)', fontsize=conf.font_size)
  100.     #plt.ylim(0, 6000)
  101.     plt.legend(label, loc='upper left')
  102.     plt.yscale('log')
  103.     plt.tight_layout()
  104.  
  105.     plt.savefig('/Users/Michael/Documents/deep_config/analysis/paper_img/train_query_bars_first.pdf', format="pdf", bbox_inches="tight")
  106.     plt.savefig(export_path + result_prefix + 'query_bars_first.pdf', format="pdf", bbox_inches="tight")
  107.     plt.clf()
  108.     plt.cla()
  109.     plt.close()
  110.  
  111.     for k in range(len(experiment_variant)):
  112.         iteration_means = np.zeros((iterations, (queries - 5)))
  113.         iteration_stds = np.zeros((iterations, (queries - 5)))
  114.         sample_sizes = np.zeros((iterations, (queries - 5)))
  115.  
  116.         for iteration in range(iterations):
  117.             update, query = parse_info(path[iteration] + '/' + experiment_variant[k], clients, queries, False, offset)
  118.  
  119.             for i in range(5):
  120.                 j = i + 5
  121.                 scaled_up = np.asarray(query[j] * 1000)  # in ms
  122.  
  123.                 #median or mean?
  124.                 iteration_means[iteration][i] = np.median(scaled_up)
  125.                 # print('Mean value for query {} is {}'.format(offset + j + 1, iteration_means[iteration][i]))
  126.                 iteration_stds[iteration][i] = np.std(scaled_up)
  127.                 sample_sizes[iteration][i] = len(scaled_up)
  128.  
  129.         means = np.mean(iteration_means, axis=0)
  130.         stds = []
  131.  
  132.         for query in range(queries - 5):
  133.             std_list = list(iteration_stds[:, query])
  134.             sizes = list(sample_sizes[:, query])
  135.  
  136.             stds.append(pooled_sd(sd_list=std_list, sample_sizes=sizes, num_samples=iterations))
  137.  
  138.         plt.bar(ind + k * width, means, yerr=stds, bottom=0, width=width,
  139.                 color=conf.greyscales[k], hatch=conf.hatches[k])
  140.  
  141.     plt.xlabel('Training query id', fontsize=conf.font_size)
  142.    # plt.ylim(0, 6000)
  143.     plt.tick_params(axis='both', which='major', labelsize=conf.font_size)
  144.     plt.xticks(ind, ('Q6', 'Q7', 'Q8', 'Q9', 'Q10'))
  145.     plt.ylabel('Median execution time (ms)', fontsize=conf.font_size)
  146.    # plt.legend(label, loc='upper left')
  147.     plt.yscale('log')
  148.     plt.tight_layout()
  149.  
  150.     plt.savefig('/Users/Michael/Documents/deep_config/analysis/paper_img/train_query_bars_second.pdf', format="pdf", bbox_inches="tight")
  151.  
  152.     plt.savefig(export_path + result_prefix + 'query_bars_second.pdf', format="pdf", bbox_inches="tight")
  153.     plt.clf()
  154.     plt.cla()
  155.     plt.close()
  156.  
  157.  
  158. if __name__ == '__main__':
  159.     main(sys.argv)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement