Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from matplotlib.backends.backend_pdf import PdfPages
- for game in ['Breakout']:
- fig1, axes = plt.subplots(14, 5, figsize = (20, 50), sharex=True, sharey=True)
- fig1.tight_layout()
- for network_num, network in enumerate(networks_ordered_table[:-1]):
- for seed_num, seed in enumerate(train_seeds):
- cur_data = fls_breakout_infinite[fls_data['game'] == game][fls_data['network'] == network][fls_data['train_seed'] == seed]
- lengths = np.ravel(cur_data['lengths'][0])
- rewards = np.ravel(cur_data['rewards'][0])
- xy = np.vstack([lengths, rewards])
- try:
- density = gaussian_kde(xy)(xy)
- except:
- density = np.zeros(8192)
- print(game, network, seed)
- # if network_num == 13:
- # axes[network_num, seed_num].xaxis.set_tick_params(which='x', labelbottom=True)
- axes[network_num, seed_num].set_xlim(left=1800, right=2400)
- axes[network_num, seed_num].set_ylim(bottom=500, top=3000)
- axes[network_num, seed_num].scatter(lengths, rewards, c=density, alpha=0.5, cmap='jet', rasterized=True)
- if network_num == 0:
- axes[network_num, seed_num].set_title('Seed {}'.format(seed), fontsize=12)
- if seed_num == 0:
- axes[network_num, seed_num].set_ylabel(network_nicknames[network], fontsize=12)
- axes[network_num, seed_num].grid()
- #axes[network_num, seed_num].set_colorbar()
- pp = PdfPages('23_08/Seaquest_to_merge.pdf')
- pp.savefig(fig1, dpi=100, bbox_inches='tight')
- fig2, axes = plt.subplots(1, 5, figsize = (20, 4), sharex=True, sharey=True)
- fig2.tight_layout()
- for network_num, network in enumerate([networks_ordered_table[-1]]):
- for seed_num, seed in enumerate(train_seeds):
- cur_data = fls_data[fls_data['game'] == game][fls_data['network'] == network][fls_data['train_seed'] == seed]
- lengths = np.ravel(cur_data['lengths'][0])
- rewards = np.ravel(cur_data['rewards'][0])
- xy = np.vstack([lengths, rewards])
- try:
- density = gaussian_kde(xy)(xy)
- except:
- density = np.zeros(8192)
- print(game, network, seed)
- # axes[network_num, seed_num].set_xlim(left=0, right=8000)
- # axes[network_num, seed_num].set_ylim(bottom=0, top=60000)
- # sns.kdeplot(lengths, rewards, ax=axes[seed_num],
- # shade=True, shade_lowest=False, cmap='jet', kernel='biw')
- axes[seed_num].scatter(lengths, rewards, c=density, alpha=0.5, cmap='jet', rasterized=True)
- if seed_num == 0:
- axes[seed_num].set_ylabel(network_nicknames[network], fontsize=12)
- axes[seed_num].grid()
- pp.savefig(fig2, dpi=100, bbox_inches='tight')
- pp.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement