Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2019
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.97 KB | None | 0 0
  1.  
  2. from matplotlib.backends.backend_pdf import PdfPages
  3.  
  4. for game in ['Breakout']:
  5.     fig1, axes = plt.subplots(14, 5, figsize = (20, 50), sharex=True, sharey=True)
  6.     fig1.tight_layout()
  7.  
  8.     for network_num, network in enumerate(networks_ordered_table[:-1]):
  9.         for seed_num, seed in enumerate(train_seeds):
  10.             cur_data = fls_breakout_infinite[fls_data['game'] == game][fls_data['network'] == network][fls_data['train_seed'] == seed]
  11.  
  12.             lengths = np.ravel(cur_data['lengths'][0])
  13.             rewards = np.ravel(cur_data['rewards'][0])
  14.  
  15.             xy = np.vstack([lengths, rewards])
  16.             try:
  17.                 density = gaussian_kde(xy)(xy)
  18.             except:
  19.                 density = np.zeros(8192)
  20.                 print(game, network, seed)
  21.                
  22. #             if network_num == 13:
  23. #                 axes[network_num, seed_num].xaxis.set_tick_params(which='x', labelbottom=True)
  24.                
  25.             axes[network_num, seed_num].set_xlim(left=1800, right=2400)
  26.             axes[network_num, seed_num].set_ylim(bottom=500, top=3000)
  27.  
  28.             axes[network_num, seed_num].scatter(lengths, rewards, c=density, alpha=0.5, cmap='jet', rasterized=True)
  29.  
  30.             if network_num == 0:
  31.                 axes[network_num, seed_num].set_title('Seed {}'.format(seed), fontsize=12)
  32.             if seed_num == 0:
  33.                 axes[network_num, seed_num].set_ylabel(network_nicknames[network], fontsize=12)
  34.  
  35.             axes[network_num, seed_num].grid()
  36.             #axes[network_num, seed_num].set_colorbar()
  37.  
  38.     pp = PdfPages('23_08/Seaquest_to_merge.pdf')
  39.     pp.savefig(fig1, dpi=100, bbox_inches='tight')
  40.  
  41.    
  42.     fig2, axes = plt.subplots(1, 5, figsize = (20, 4), sharex=True, sharey=True)
  43.     fig2.tight_layout()
  44.  
  45.     for network_num, network in enumerate([networks_ordered_table[-1]]):
  46.         for seed_num, seed in enumerate(train_seeds):
  47.             cur_data = fls_data[fls_data['game'] == game][fls_data['network'] == network][fls_data['train_seed'] == seed]
  48.  
  49.             lengths = np.ravel(cur_data['lengths'][0])
  50.             rewards = np.ravel(cur_data['rewards'][0])
  51.  
  52.             xy = np.vstack([lengths, rewards])
  53.             try:
  54.                 density = gaussian_kde(xy)(xy)
  55.             except:
  56.                 density = np.zeros(8192)
  57.                 print(game, network, seed)
  58.  
  59. #             axes[network_num, seed_num].set_xlim(left=0, right=8000)
  60. #             axes[network_num, seed_num].set_ylim(bottom=0, top=60000)
  61.  
  62. #             sns.kdeplot(lengths, rewards, ax=axes[seed_num],
  63. #                         shade=True, shade_lowest=False, cmap='jet', kernel='biw')
  64.             axes[seed_num].scatter(lengths, rewards, c=density, alpha=0.5, cmap='jet', rasterized=True)
  65.  
  66.             if seed_num == 0:
  67.                 axes[seed_num].set_ylabel(network_nicknames[network], fontsize=12)
  68.  
  69.             axes[seed_num].grid()
  70.    
  71.     pp.savefig(fig2, dpi=100, bbox_inches='tight')
  72.     pp.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement