Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import time
- import glob
- import torch
- import random
- import imageio
- from itertools import count
- from multiprocess_env import *
- from constants import CORRECT_OBJECT_REWARD
- torch.autograd.set_detect_anomaly(True)
- LEAVE_PRINT_EVERY_N_SECS = 30
- ERASE_LINE = '\x1b[2K'
- class A2C():
- def __init__(self,
- ac_model_fn,
- ac_model_max_grad_norm=1.0,
- ac_optimizer_fn=None,
- ac_optimizer_lr=None,
- policy_loss_weight=None,
- value_loss_weight=None,
- entropy_loss_weight=None,
- n_workers=None,
- tau=None):
- assert n_workers > 1
- self.ac_model_fn = ac_model_fn
- self.ac_model_max_grad_norm = ac_model_max_grad_norm
- self.ac_optimizer_fn = ac_optimizer_fn
- self.ac_optimizer_lr = ac_optimizer_lr
- self.policy_loss_weight = policy_loss_weight
- self.value_loss_weight = value_loss_weight
- self.entropy_loss_weight = entropy_loss_weight
- self.n_workers = n_workers
- self.tau = tau
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
- def get_ac_model(self, max_episode_length, vocab_size):
- return self.ac_model_fn(max_episode_length, vocab_size)
- def get_instruction_matrix(self, instructions, word_to_idx):
- ''''
- inputs:
- word_to_idx: a dictionary which contains the mapping
- from lowercase words to their indices
- outputs:
- instr_matrix: np.ndarray(n_workers x max_instruction_length) which contains indices for each word
- in lower case.
- '''
- ## we need to pad the instructions to longest instruction length
- ## then tokenise it and then convert it to a matrix
- max_instr_len = -1
- num_instructions = len(instructions)
- for inst in instructions:
- max_instr_len = max(max_instr_len, len(inst.split(' ')))
- instr_matrix = np.zeros((num_instructions, max_instr_len), dtype=int)
- for i, instr in enumerate(instructions):
- instr = instr.lower().split(' ')
- for j, word in enumerate(instr):
- instr_matrix[i, j] = word_to_idx[word]
- return instr_matrix
- def optimize_model(self):
- logpas = torch.stack(self.logpas).squeeze() ## size: (N_steps x n_workers)
- entropies = torch.stack(self.entropies).squeeze() ## size: (N_steps x n_workers)
- values = torch.stack(self.values).squeeze() ## size: (N_steps x n_workers)
- T = len(self.rewards)
- # print(type(self.rewards), len(self.rewards), self.rewards)
- discounts = np.logspace(0, T, num=T, base=self.gamma, endpoint=False)
- rewards = np.array(self.rewards).squeeze()
- # print(rewards.shape, rewards)
- returns = np.array([[np.sum(discounts[:T-t]*rewards[t:,w]) for t in range(T)]
- for w in range(self.n_workers)])
- ## Check GAE calculation once, looks wrong
- np_values = values.data.numpy()
- tau_discounts = np.logspace(0, T-1, num=T-1, base=self.gamma*self.tau, endpoint=False)
- advs = rewards[:-1] + self.gamma * np_values[1:] -np_values[:-1]
- gaes = np.array([[np.sum(tau_discounts[:T-1-t] * advs[t:, w]) for t in range(T-1)]
- for w in range(self.n_workers)])
- discounted_gaes = discounts[:-1] * gaes
- values = values[:-1, ...].reshape(-1, 1)
- logpas = logpas.reshape(-1, 1)
- entropies = entropies.reshape(-1, 1)
- returns = torch.FloatTensor(returns.T[:-1]).to(self.device).reshape(-1, 1)
- discounted_gaes = torch.FloatTensor(discounted_gaes.T).to(self.device).reshape(-1, 1)
- T = T -1
- T = T *self.n_workers ## T= (N_steps-1) * n_workers
- assert returns.size() == (T, 1)
- assert values.size() == (T, 1)
- assert logpas.size() == (T, 1)
- assert entropies.size() == (T, 1)
- # values_copy = new_values.detach().clone()
- value_error = returns.detach() - values
- value_loss = value_error.pow(2).mul(0.5).mean()
- policy_loss = -(discounted_gaes.detach() * logpas).mean()
- entropy_loss = -entropies.mean()
- loss = self.policy_loss_weight * policy_loss + \
- self.value_loss_weight * value_loss + \
- self.entropy_loss_weight * entropy_loss
- self.ac_optimizer.zero_grad()
- loss.backward()
- torch.nn.utils.clip_grad_norm_(self.ac_model.parameters(),
- self.ac_model_max_grad_norm)
- self.ac_optimizer.step()
- def interaction_step(self, states, envs):
- '''
- inputs:
- states: contains the img_stack, instruction_mat, hx, cx, tx all of which are torch tensor of device type.
- envs: multiple-environment objects
- outputs:
- new_img_frames: np.ndarray
- is_terminals: np.ndarray of float values
- hx: torch tensor of self.device type
- cx: torch tensor of self.device type
- '''
- ## Need to add variables for (hx, cx, tx)
- actions, is_exploratory, logpas, entropies, values, (hx, cx) = self.ac_model.full_pass(states)
- new_img_frames, rewards, is_terminals = envs.step(actions)
- self.logpas.append(logpas) ; self.entropies.append(entropies)
- self.rewards.append(rewards) ; self.values.append(values)
- self.running_reward = self.running_reward +rewards
- print(rewards)
- self.running_timestep = self.running_timestep + 1
- self.running_exploration = self.running_exploration+is_exploratory[:,np.newaxis].astype(int)
- return new_img_frames, is_terminals, (hx, cx)
- def train(self, make_envs_fn, make_env_kargs, seed, gamma,
- max_episodes):
- training_start, last_debug_time = time.time(), float('-inf')
- self.checkpoint_dir = make_env_kargs.dump_location
- self.make_envs_fn = make_envs_fn
- self.make_env_fn = grounding_env.GroundingEnv
- self.make_env_kargs = make_env_kargs
- self.seed = seed
- self.gamma = gamma
- self.max_n_steps = make_env_kargs.max_episode_length
- instruction_of_worker = ['<PAD>'] * self.n_workers
- ## Check this once
- env = grounding_env.GroundingEnv(make_env_kargs)
- env.game_init()
- envs = self.make_envs_fn( make_env_kargs, self.seed, self.n_workers)
- torch.manual_seed(self.seed)
- np.random.seed(self.seed)
- random.seed(self.seed)
- word_to_idx = env.word_to_idx
- self.running_timestep = np.array([[0.],]*self.n_workers)
- self.running_reward = np.array([[0.],]*self.n_workers)
- self.running_exploration = np.array([[0.],]*self.n_workers)
- self.running_seconds = np.array([[time.time()],]*self.n_workers)
- self.episode_timestep, self.episode_reward = [], []
- self.episode_seconds, self.evaluation_scores = [], []
- self.episode_exploration = []
- ## the make_env_kargs contains all the args needed for environment
- ## and model creation
- self.ac_model = self.get_ac_model(make_env_kargs.max_episode_length, len(word_to_idx))
- self.ac_model = self.ac_model.to(self.device)
- self.ac_optimizer = self.ac_optimizer_fn(self.ac_model,
- self.ac_optimizer_lr)
- result = np.empty((max_episodes, 5))
- result[:] = np.nan
- training_time = 0
- img_stack, instruction_list = envs.reset()
- img_stack = torch.from_numpy(img_stack).float()/255.0
- for rank in range(self.n_workers):
- instruction_of_worker[rank] = instruction_list[rank]
- instruction_mat = self.get_instruction_matrix(instruction_of_worker, word_to_idx)
- instruction_mat = torch.from_numpy(instruction_mat).long()
- ## Check size , size might differ because n_workers>1 is used.
- cx = torch.zeros(self.n_workers, 256)
- hx = torch.zeros(self.n_workers, 256)
- # conver into tensors of particualr device
- img_stack = img_stack.to(self.device)
- instruction_mat = instruction_mat.to(self.device)
- hx = hx.to(self.device)
- cx = cx.to(self.device)
- # Collect n_steps rollout
- episode, n_steps_start = 0, 0
- self.logpas, self.entropies, self.rewards, self.values = [], [], [], []
- for step in count(start=1):
- print(step)
- tx = torch.from_numpy(np.array([step - n_steps_start]*self.n_workers)).long().to(self.device)
- states = (img_stack, instruction_mat, (tx, hx, cx))
- img_stack, is_terminals, (hx, cx) = self.interaction_step(states, envs)
- img_stack = torch.from_numpy(img_stack).float()/255.0
- img_stack = img_stack.to(self.device)
- if is_terminals.sum() or step - n_steps_start == self.max_n_steps:
- print(is_terminals.sum() ,step - n_steps_start == self.max_n_steps)
- past_limits_enforced = envs._past_limit()
- is_failure = np.logical_and(is_terminals, np.logical_not(past_limits_enforced))
- next_values, (new_hx, new_cx) = self.ac_model.evaluate_state((img_stack, instruction_mat, (tx, hx, cx)))
- hx = new_hx
- cx = new_cx
- next_values = next_values.detach().cpu().numpy() * (1 - is_failure)
- self.rewards.append(next_values) ; self.values.append(torch.Tensor(next_values).to(self.device))
- self.optimize_model()
- self.logpas, self.entropies, self.rewards, self.values = [], [], [], []
- n_steps_start = step
- # stats
- if is_terminals.sum():
- episode_done = time.time()
- evaluation_score, _, _, _, _ = self.evaluate(self.ac_model, env)
- self.save_checkpoint(episode, self.ac_model)
- reset_info = {}
- for i in range(self.n_workers):
- if is_terminals[i]:
- # hx = hx.clone()
- # hx[i] = torch.zeros(256).to(self.device)
- # cx = cx.clone()
- # cx[i] = torch.zeros(256).to(self.device)
- reset_info[i], instruction_of_worker[i] = envs.reset(rank=i)
- self.episode_timestep.append(self.running_timestep[i][0])
- self.episode_reward.append(self.running_reward[i][0])
- self.episode_exploration.append(self.running_exploration[i][0]/self.running_timestep[i][0])
- self.episode_seconds.append(episode_done - self.running_seconds[i][0])
- training_time = training_time +self.episode_seconds[-1]
- self.evaluation_scores.append(evaluation_score)
- episode = episode+1
- mean_10_reward = np.mean(self.episode_reward[-10:])
- std_10_reward = np.std(self.episode_reward[-10:])
- mean_100_reward = np.mean(self.episode_reward[-100:])
- std_100_reward = np.std(self.episode_reward[-100:])
- mean_100_eval_score = np.mean(self.evaluation_scores[-100:])
- std_100_eval_score = np.std(self.evaluation_scores[-100:])
- mean_100_exp_rat = np.mean(self.episode_exploration[-100:])
- std_100_exp_rat = np.std(self.episode_exploration[-100:])
- total_step = int(np.sum(self.episode_timestep))
- wallclock_elapsed = time.time() - training_start
- result[episode-1] = total_step, mean_100_reward, \
- mean_100_eval_score, training_time, wallclock_elapsed
- # img_stack = img_stack.clone()
- # hx = hx.clone().detach()
- # cx = cx.clone().detach()
- for i, img in reset_info.items():
- # zeros_h = torch.zeros_like(hx).to(self.device)
- # zeros_c = torch.zeros_like(cx).to(self.device)
- # hx = torch.where(torch.arange(self.n_workers) == i, zeros_h, hx)
- # cx = torch.where(torch.arange(self.n_workers) == i, zeros_c, cx)
- hx[i]=0.0
- cx[i]=0.0
- img_stack[i] = torch.from_numpy(img).float()/255.0
- img_stack = img_stack.to(self.device)
- instruction_mat = self.get_instruction_matrix(instruction_of_worker, word_to_idx)
- instruction_mat = torch.from_numpy(instruction_mat).long().to(self.device)
- # debug stuff
- reached_debug_time = (time.time() - last_debug_time) >= LEAVE_PRINT_EVERY_N_SECS
- reached_max_episodes = (episode + self.n_workers) >= max_episodes
- training_is_over = reached_max_episodes
- elapsed_str = time.strftime("%H:%M:%S", time.gmtime(time.time() - training_start))
- debug_message = 'el {}, ep {:04}, ts {:06}, '
- debug_message = debug_message +'ar 10 {:05.1f}\u00B1{:05.1f}, '
- debug_message = debug_message +'100 {:05.1f}\u00B1{:05.1f}, '
- debug_message = debug_message +'ex 100 {:02.1f}\u00B1{:02.1f}, '
- debug_message = debug_message +'ev {:05.1f}\u00B1{:05.1f}'
- debug_message = debug_message.format(
- elapsed_str, episode-1, total_step, mean_10_reward, std_10_reward,
- mean_100_reward, std_100_reward, mean_100_exp_rat, std_100_exp_rat,
- mean_100_eval_score, std_100_eval_score)
- print('Hey')
- print(debug_message, end='\r', flush=True)
- if reached_debug_time or training_is_over:
- print(ERASE_LINE + 'Yo'+debug_message, flush=True)
- last_debug_time = time.time()
- if training_is_over:
- if reached_max_episodes: print(u'--> reached_max_episodes \u2715')
- break
- # reset running variables for next time around
- self.running_timestep = self.running_timestep *(1 - is_terminals)
- self.running_reward = self.running_exploration * (1 - is_terminals)
- self.running_exploration = self.running_seconds * (1 - is_terminals)
- self.running_seconds[is_terminals.astype(bool)] = time.time()
- final_eval_score, score_std, acc, _, _ = self.evaluate(self.ac_model, env, n_episodes=100)
- wallclock_time = time.time() - training_start
- instruction_of_worker = ['<PAD>'] * self.n_workers
- print('Training complete.')
- print('Final evaluation score {:.2f}\u00B1{:.2f} with {:.2f}% Accuracy in {:.2f}s training time,'
- ' {:.2f}s wall-clock time.\n'.format(
- final_eval_score, score_std, acc*100, training_time, wallclock_time))
- env.close() ; del env
- envs.close() ; del envs
- self.get_cleaned_checkpoints()
- return result, final_eval_score, training_time, wallclock_time
- def evaluate(self, eval_policy_model, eval_env, save_video=False, n_episodes=1, greedy=True):
- rs = []
- video_frames = []
- accuracy_list = []
- instruction_list = []
- cur_instruction = ['<PAD>']
- eval_policy_model.eval()
- eval_policy_model = eval_policy_model.to(self.device)
- with torch.no_grad():
- for _ in range(n_episodes):
- ## convert each input to torch tensor
- cur_episode_video_frames = []
- (img, instruction), _, _, _ = eval_env.reset()
- cur_instruction[0] = instruction
- instruction_list.append(instruction)
- if save_video: cur_episode_video_frames.append(img)
- img = torch.from_numpy(img).float()/255.0
- img = img.unsqueeze(0)
- img = img.to(self.device)
- instruction_mat = self.get_instruction_matrix(cur_instruction, eval_env.word_to_idx)
- instruction_mat = torch.from_numpy(instruction_mat).long()
- instruction_mat = instruction_mat.to(self.device)
- hx = torch.zeros((1, 256))
- cx = torch.zeros((1, 256))
- hx = hx.to(self.device)
- cx = cx.to(self.device)
- episode_timestep = 0
- rs.append(0)
- for _ in count():
- tx = torch.from_numpy(np.array([episode_timestep])).long().to(self.device)
- s = (img, instruction_mat, (tx, hx, cx))
- if greedy:
- a, (hx, cx) = eval_policy_model.select_greedy_action(s)
- else:
- a, (hx, cx) = eval_policy_model.select_action(s)
- (img, _), r, d, _ = eval_env.step(a)
- if save_video: cur_episode_video_frames.append(img)
- img = torch.from_numpy(img).float()/255.0
- img = img.unsqueeze(0)
- img = img.to(self.device)
- rs[-1] = rs[-1] +r
- if d :
- if r == CORRECT_OBJECT_REWARD:
- accuracy_list.append(1)
- else:
- accuracy_list.append(0)
- break
- episode_timestep = episode_timestep + 1
- video_frames.append(cur_episode_video_frames)
- return np.mean(rs), np.std(rs), np.mean(accuracy_list), video_frames, instruction_list
- def get_cleaned_checkpoints(self, n_checkpoints=5):
- try:
- return self.checkpoint_paths
- except AttributeError:
- self.checkpoint_paths = {}
- paths = glob.glob(os.path.join(self.checkpoint_dir, '*.tar'))
- paths_dic = {int(path.split('.')[-2]):path for path in paths}
- last_ep = max(paths_dic.keys())
- # checkpoint_idxs = np.geomspace(1, last_ep+1, n_checkpoints, endpoint=True, dtype=np.int)-1
- checkpoint_idxs = np.linspace(1, last_ep+1, n_checkpoints, endpoint=True, dtype=np.int)-1
- for idx, path in paths_dic.items():
- if idx in checkpoint_idxs:
- self.checget_cleaned_checkpointskpoint_paths[idx] = path
- else:
- os.unlink(path)
- return self.checkpoint_paths
- def demo_last(self, args, title='Fully_Trained_Agent', save_video=False):
- env = grounding_env.GroundingEnv(args)
- checkpoint_paths = self.get_cleaned_checkpoints()
- last_ep = max(checkpoint_paths.keys())
- self.ac_model.load_state_dict(torch.load(checkpoint_paths[last_ep]))
- _, _, _, video_frames, _ =self.evaluate(self.ac_model, env, save_video=save_video, n_episodes=1)
- env.close()
- imageio.mimsave(title+'.gif', video_frames[0])
- del env
- def demo_progression(self, title='Episode_{}_Agent_progression', max_n_videos=5, save_video=False):
- env = self.make_env_fn(**self.make_env_kargs, monitor_mode='evaluation', render=True, record=True)
- checkpoint_paths = self.get_cleaned_checkpoints()
- for i in sorted(checkpoint_paths.keys()):
- self.ac_model.load_state_dict(torch.load(checkpoint_paths[i]))
- _, _, _, video_frames, _ =self.evaluate(self.ac_model, env, save_video=save_video, n_episodes=1)
- imageio.mimsave(title.format(i)+'.gif', video_frames[0])
- env.close()
- del env
- def save_checkpoint(self, episode_idx, model):
- torch.save(model.state_dict(),
- os.path.join(self.checkpoint_dir, 'model.{}.tar'.format(episode_idx)))
Add Comment
Please, Sign In to add comment