Guest User

a2c_model

a guest
Mar 4th, 2024
49
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 20.48 KB | Source Code | 0 0
  1. import os
  2. import time
  3. import glob
  4. import torch
  5. import random
  6. import imageio
  7.  
  8. from itertools import count
  9. from multiprocess_env import *
  10.  
  11. from constants import CORRECT_OBJECT_REWARD
  12.  
  13. torch.autograd.set_detect_anomaly(True)
  14.  
  15. LEAVE_PRINT_EVERY_N_SECS = 30
  16. ERASE_LINE = '\x1b[2K'
  17.    
  18.  
  19.  
  20. class A2C():
  21.     def __init__(self,
  22.                  ac_model_fn,
  23.                  ac_model_max_grad_norm=1.0,
  24.                  ac_optimizer_fn=None,
  25.                  ac_optimizer_lr=None,
  26.                  policy_loss_weight=None,
  27.                  value_loss_weight=None,
  28.                  entropy_loss_weight=None,
  29.                  n_workers=None,
  30.                  tau=None):
  31.         assert n_workers > 1
  32.         self.ac_model_fn = ac_model_fn
  33.         self.ac_model_max_grad_norm = ac_model_max_grad_norm
  34.         self.ac_optimizer_fn = ac_optimizer_fn
  35.         self.ac_optimizer_lr = ac_optimizer_lr
  36.  
  37.         self.policy_loss_weight = policy_loss_weight
  38.         self.value_loss_weight = value_loss_weight
  39.         self.entropy_loss_weight = entropy_loss_weight
  40.  
  41.         self.n_workers = n_workers
  42.         self.tau = tau
  43.  
  44.         self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  45.        
  46.     def get_ac_model(self, max_episode_length, vocab_size):
  47.         return self.ac_model_fn(max_episode_length, vocab_size)
  48.  
  49.     def get_instruction_matrix(self, instructions, word_to_idx):
  50.         ''''
  51.        inputs:
  52.            word_to_idx: a dictionary which contains the mapping
  53.              from lowercase words to their indices
  54.        outputs:
  55.            instr_matrix: np.ndarray(n_workers x max_instruction_length) which contains indices for each word
  56.             in lower case.
  57.        '''
  58.         ## we need to pad the instructions to longest instruction length
  59.         ## then tokenise it and then convert it to a matrix
  60.  
  61.         max_instr_len = -1
  62.         num_instructions = len(instructions)
  63.  
  64.         for inst in instructions:
  65.             max_instr_len = max(max_instr_len, len(inst.split(' ')))
  66.        
  67.  
  68.         instr_matrix = np.zeros((num_instructions, max_instr_len), dtype=int)
  69.  
  70.         for i, instr in enumerate(instructions):
  71.             instr = instr.lower().split(' ')
  72.             for j, word in enumerate(instr):
  73.                 instr_matrix[i, j] = word_to_idx[word]
  74.  
  75.         return instr_matrix
  76.        
  77.  
  78.     def optimize_model(self):
  79.  
  80.         logpas = torch.stack(self.logpas).squeeze() ## size: (N_steps x n_workers)
  81.         entropies = torch.stack(self.entropies).squeeze() ## size: (N_steps x n_workers)
  82.         values = torch.stack(self.values).squeeze() ## size: (N_steps x n_workers)
  83.  
  84.         T = len(self.rewards)
  85.         # print(type(self.rewards), len(self.rewards), self.rewards)
  86.         discounts = np.logspace(0, T, num=T, base=self.gamma, endpoint=False)
  87.         rewards = np.array(self.rewards).squeeze()
  88.         # print(rewards.shape, rewards)
  89.         returns = np.array([[np.sum(discounts[:T-t]*rewards[t:,w]) for t in range(T)]
  90.                             for w in range(self.n_workers)])
  91.        
  92.         ## Check GAE calculation once, looks wrong
  93.         np_values = values.data.numpy()
  94.         tau_discounts = np.logspace(0, T-1, num=T-1, base=self.gamma*self.tau, endpoint=False)
  95.         advs = rewards[:-1] + self.gamma * np_values[1:] -np_values[:-1]
  96.         gaes = np.array([[np.sum(tau_discounts[:T-1-t] * advs[t:, w]) for t in range(T-1)]
  97.                              for w in range(self.n_workers)])
  98.         discounted_gaes = discounts[:-1] * gaes
  99.  
  100.         values = values[:-1, ...].reshape(-1, 1)
  101.         logpas = logpas.reshape(-1, 1)
  102.         entropies = entropies.reshape(-1, 1)
  103.         returns = torch.FloatTensor(returns.T[:-1]).to(self.device).reshape(-1, 1)
  104.         discounted_gaes = torch.FloatTensor(discounted_gaes.T).to(self.device).reshape(-1, 1)
  105.  
  106.         T = T -1
  107.         T = T *self.n_workers ## T= (N_steps-1) * n_workers
  108.  
  109.         assert returns.size() == (T, 1)
  110.         assert values.size() == (T, 1)
  111.         assert logpas.size() == (T, 1)
  112.         assert entropies.size() == (T, 1)
  113.  
  114.         # values_copy = new_values.detach().clone()
  115.         value_error = returns.detach() - values
  116.         value_loss = value_error.pow(2).mul(0.5).mean()
  117.         policy_loss = -(discounted_gaes.detach() * logpas).mean()
  118.         entropy_loss = -entropies.mean()
  119.         loss = self.policy_loss_weight * policy_loss + \
  120.                 self.value_loss_weight * value_loss + \
  121.                 self.entropy_loss_weight * entropy_loss
  122.        
  123.         self.ac_optimizer.zero_grad()
  124.         loss.backward()
  125.         torch.nn.utils.clip_grad_norm_(self.ac_model.parameters(),
  126.                                        self.ac_model_max_grad_norm)
  127.         self.ac_optimizer.step()
  128.  
  129.     def interaction_step(self, states, envs):
  130.         '''
  131.        inputs:
  132.            states: contains the img_stack, instruction_mat, hx, cx, tx all of which are torch tensor of device type.
  133.            envs: multiple-environment objects
  134.        outputs:
  135.            new_img_frames: np.ndarray
  136.            is_terminals: np.ndarray of float values
  137.            hx: torch tensor of self.device type
  138.            cx: torch tensor of self.device type
  139.        '''
  140.         ## Need to add variables for (hx, cx, tx)
  141.         actions, is_exploratory, logpas, entropies, values, (hx, cx) = self.ac_model.full_pass(states)
  142.         new_img_frames, rewards, is_terminals = envs.step(actions)
  143.  
  144.         self.logpas.append(logpas) ; self.entropies.append(entropies)
  145.         self.rewards.append(rewards) ; self.values.append(values)
  146.        
  147.         self.running_reward = self.running_reward +rewards
  148.         print(rewards)
  149.         self.running_timestep = self.running_timestep + 1
  150.         self.running_exploration = self.running_exploration+is_exploratory[:,np.newaxis].astype(int)
  151.  
  152.         return new_img_frames, is_terminals, (hx, cx)
  153.  
  154.     def train(self, make_envs_fn, make_env_kargs, seed, gamma,
  155.                max_episodes):
  156.         training_start, last_debug_time = time.time(), float('-inf')
  157.        
  158.         self.checkpoint_dir = make_env_kargs.dump_location
  159.         self.make_envs_fn = make_envs_fn
  160.         self.make_env_fn = grounding_env.GroundingEnv
  161.         self.make_env_kargs = make_env_kargs
  162.         self.seed = seed
  163.         self.gamma = gamma
  164.         self.max_n_steps = make_env_kargs.max_episode_length
  165.  
  166.         instruction_of_worker = ['<PAD>'] * self.n_workers
  167.        
  168.         ## Check this once
  169.         env = grounding_env.GroundingEnv(make_env_kargs)
  170.         env.game_init()
  171.  
  172.         envs = self.make_envs_fn( make_env_kargs, self.seed, self.n_workers)
  173.        
  174.         torch.manual_seed(self.seed)
  175.         np.random.seed(self.seed)
  176.         random.seed(self.seed)
  177.  
  178.         word_to_idx = env.word_to_idx
  179.  
  180.         self.running_timestep = np.array([[0.],]*self.n_workers)
  181.         self.running_reward = np.array([[0.],]*self.n_workers)
  182.         self.running_exploration = np.array([[0.],]*self.n_workers)
  183.         self.running_seconds = np.array([[time.time()],]*self.n_workers)
  184.         self.episode_timestep, self.episode_reward = [], []
  185.         self.episode_seconds, self.evaluation_scores = [], []
  186.         self.episode_exploration = []
  187.  
  188.         ## the make_env_kargs contains all the args needed for environment
  189.         ## and model creation
  190.         self.ac_model = self.get_ac_model(make_env_kargs.max_episode_length, len(word_to_idx))
  191.  
  192.         self.ac_model = self.ac_model.to(self.device)
  193.  
  194.         self.ac_optimizer = self.ac_optimizer_fn(self.ac_model,
  195.                                                  self.ac_optimizer_lr)
  196.        
  197.         result = np.empty((max_episodes, 5))
  198.         result[:] = np.nan
  199.         training_time = 0
  200.         img_stack, instruction_list = envs.reset()
  201.  
  202.         img_stack = torch.from_numpy(img_stack).float()/255.0
  203.  
  204.         for rank in range(self.n_workers):
  205.             instruction_of_worker[rank] = instruction_list[rank]
  206.  
  207.         instruction_mat = self.get_instruction_matrix(instruction_of_worker, word_to_idx)
  208.         instruction_mat = torch.from_numpy(instruction_mat).long()
  209.  
  210.         ## Check size , size might differ because n_workers>1 is used.
  211.         cx = torch.zeros(self.n_workers, 256)
  212.         hx = torch.zeros(self.n_workers, 256)
  213.        
  214.  
  215.         # conver into tensors of particualr device
  216.         img_stack = img_stack.to(self.device)
  217.         instruction_mat = instruction_mat.to(self.device)
  218.         hx = hx.to(self.device)
  219.         cx = cx.to(self.device)
  220.  
  221.  
  222.         # Collect n_steps rollout
  223.  
  224.         episode, n_steps_start = 0, 0
  225.         self.logpas, self.entropies, self.rewards, self.values = [], [], [], []
  226.  
  227.         for step in count(start=1):
  228.             print(step)
  229.             tx = torch.from_numpy(np.array([step - n_steps_start]*self.n_workers)).long().to(self.device)
  230.             states = (img_stack, instruction_mat, (tx, hx, cx))
  231.             img_stack, is_terminals, (hx, cx) = self.interaction_step(states, envs)
  232.  
  233.             img_stack = torch.from_numpy(img_stack).float()/255.0
  234.             img_stack = img_stack.to(self.device)
  235.  
  236.             if is_terminals.sum() or step - n_steps_start == self.max_n_steps:
  237.                 print(is_terminals.sum() ,step - n_steps_start == self.max_n_steps)
  238.                 past_limits_enforced = envs._past_limit()
  239.                 is_failure = np.logical_and(is_terminals, np.logical_not(past_limits_enforced))
  240.                 next_values, (new_hx, new_cx) = self.ac_model.evaluate_state((img_stack, instruction_mat, (tx, hx, cx)))
  241.                 hx = new_hx
  242.                 cx = new_cx
  243.                
  244.                 next_values = next_values.detach().cpu().numpy() * (1 - is_failure)
  245.  
  246.                 self.rewards.append(next_values) ; self.values.append(torch.Tensor(next_values).to(self.device))
  247.                 self.optimize_model()
  248.                 self.logpas, self.entropies, self.rewards, self.values = [], [], [], []
  249.                 n_steps_start = step
  250.                
  251.             # stats
  252.             if is_terminals.sum():
  253.                 episode_done = time.time()
  254.                 evaluation_score, _, _, _, _ = self.evaluate(self.ac_model, env)
  255.                 self.save_checkpoint(episode, self.ac_model)
  256.  
  257.                 reset_info = {}
  258.  
  259.                 for i in range(self.n_workers):
  260.                     if is_terminals[i]:
  261.                         # hx = hx.clone()
  262.                         # hx[i] = torch.zeros(256).to(self.device)
  263.                         # cx = cx.clone()
  264.                         # cx[i] = torch.zeros(256).to(self.device)
  265.  
  266.                         reset_info[i], instruction_of_worker[i] = envs.reset(rank=i)                        
  267.                        
  268.                         self.episode_timestep.append(self.running_timestep[i][0])
  269.                         self.episode_reward.append(self.running_reward[i][0])
  270.                         self.episode_exploration.append(self.running_exploration[i][0]/self.running_timestep[i][0])
  271.                         self.episode_seconds.append(episode_done - self.running_seconds[i][0])
  272.                         training_time = training_time +self.episode_seconds[-1]
  273.                         self.evaluation_scores.append(evaluation_score)
  274.                         episode = episode+1
  275.  
  276.                         mean_10_reward = np.mean(self.episode_reward[-10:])
  277.                         std_10_reward = np.std(self.episode_reward[-10:])
  278.                         mean_100_reward = np.mean(self.episode_reward[-100:])
  279.                         std_100_reward = np.std(self.episode_reward[-100:])
  280.                         mean_100_eval_score = np.mean(self.evaluation_scores[-100:])
  281.                         std_100_eval_score = np.std(self.evaluation_scores[-100:])
  282.                         mean_100_exp_rat = np.mean(self.episode_exploration[-100:])
  283.                         std_100_exp_rat = np.std(self.episode_exploration[-100:])
  284.                        
  285.                         total_step = int(np.sum(self.episode_timestep))
  286.                         wallclock_elapsed = time.time() - training_start
  287.                         result[episode-1] = total_step, mean_100_reward, \
  288.                             mean_100_eval_score, training_time, wallclock_elapsed
  289.                    
  290.                     # img_stack = img_stack.clone()
  291.                     # hx = hx.clone().detach()
  292.                     # cx = cx.clone().detach()
  293.  
  294.                     for i, img in reset_info.items():
  295.                         # zeros_h = torch.zeros_like(hx).to(self.device)
  296.                         # zeros_c = torch.zeros_like(cx).to(self.device)
  297.  
  298.                         # hx = torch.where(torch.arange(self.n_workers) == i, zeros_h, hx)
  299.                         # cx = torch.where(torch.arange(self.n_workers) == i, zeros_c, cx)
  300.                         hx[i]=0.0
  301.                         cx[i]=0.0
  302.                         img_stack[i] = torch.from_numpy(img).float()/255.0
  303.  
  304.                     img_stack = img_stack.to(self.device)
  305.                     instruction_mat = self.get_instruction_matrix(instruction_of_worker, word_to_idx)
  306.                     instruction_mat = torch.from_numpy(instruction_mat).long().to(self.device)
  307.  
  308.                 # debug stuff
  309.                 reached_debug_time = (time.time() - last_debug_time) >= LEAVE_PRINT_EVERY_N_SECS
  310.                 reached_max_episodes = (episode + self.n_workers) >= max_episodes
  311.                 training_is_over = reached_max_episodes
  312.  
  313.                 elapsed_str = time.strftime("%H:%M:%S", time.gmtime(time.time() - training_start))
  314.                 debug_message = 'el {}, ep {:04}, ts {:06}, '
  315.                 debug_message = debug_message +'ar 10 {:05.1f}\u00B1{:05.1f}, '
  316.                 debug_message = debug_message +'100 {:05.1f}\u00B1{:05.1f}, '
  317.                 debug_message = debug_message +'ex 100 {:02.1f}\u00B1{:02.1f}, '
  318.                 debug_message = debug_message +'ev {:05.1f}\u00B1{:05.1f}'
  319.                 debug_message = debug_message.format(
  320.                     elapsed_str, episode-1, total_step, mean_10_reward, std_10_reward,
  321.                     mean_100_reward, std_100_reward, mean_100_exp_rat, std_100_exp_rat,
  322.                     mean_100_eval_score, std_100_eval_score)
  323.                 print('Hey')
  324.                 print(debug_message, end='\r', flush=True)
  325.                 if reached_debug_time or training_is_over:
  326.                     print(ERASE_LINE + 'Yo'+debug_message, flush=True)
  327.                     last_debug_time = time.time()
  328.                 if training_is_over:
  329.                     if reached_max_episodes: print(u'--> reached_max_episodes \u2715')                    
  330.                     break
  331.  
  332.                 # reset running variables for next time around
  333.                 self.running_timestep = self.running_timestep *(1 - is_terminals)
  334.                 self.running_reward = self.running_exploration * (1 - is_terminals)
  335.                 self.running_exploration = self.running_seconds * (1 - is_terminals)
  336.                 self.running_seconds[is_terminals.astype(bool)] = time.time()
  337.  
  338.         final_eval_score, score_std, acc, _, _ = self.evaluate(self.ac_model, env, n_episodes=100)
  339.         wallclock_time = time.time() - training_start
  340.        
  341.         instruction_of_worker = ['<PAD>'] * self.n_workers
  342.  
  343.         print('Training complete.')
  344.         print('Final evaluation score {:.2f}\u00B1{:.2f} with {:.2f}% Accuracy in {:.2f}s training time,'
  345.               ' {:.2f}s wall-clock time.\n'.format(
  346.                   final_eval_score, score_std, acc*100, training_time, wallclock_time))
  347.         env.close() ; del env
  348.         envs.close() ; del envs
  349.         self.get_cleaned_checkpoints()
  350.         return result, final_eval_score, training_time, wallclock_time
  351.  
  352.     def evaluate(self, eval_policy_model, eval_env, save_video=False, n_episodes=1, greedy=True):
  353.         rs = []
  354.         video_frames = []
  355.         accuracy_list = []
  356.         instruction_list = []
  357.         cur_instruction = ['<PAD>']
  358.  
  359.         eval_policy_model.eval()
  360.         eval_policy_model = eval_policy_model.to(self.device)
  361.  
  362.         with torch.no_grad():
  363.             for _ in range(n_episodes):
  364.                 ## convert each input to torch tensor
  365.                 cur_episode_video_frames = []
  366.                 (img, instruction), _, _, _ = eval_env.reset()
  367.                
  368.                 cur_instruction[0] = instruction
  369.                
  370.                 instruction_list.append(instruction)
  371.  
  372.                 if save_video: cur_episode_video_frames.append(img)
  373.  
  374.                 img = torch.from_numpy(img).float()/255.0
  375.                 img = img.unsqueeze(0)
  376.                 img = img.to(self.device)
  377.                
  378.                 instruction_mat = self.get_instruction_matrix(cur_instruction, eval_env.word_to_idx)
  379.                 instruction_mat = torch.from_numpy(instruction_mat).long()
  380.                 instruction_mat = instruction_mat.to(self.device)
  381.  
  382.                 hx = torch.zeros((1, 256))
  383.                 cx = torch.zeros((1, 256))            
  384.  
  385.                 hx = hx.to(self.device)
  386.                 cx = cx.to(self.device)
  387.  
  388.                 episode_timestep = 0
  389.                 rs.append(0)
  390.                 for _ in count():
  391.                     tx = torch.from_numpy(np.array([episode_timestep])).long().to(self.device)
  392.  
  393.                     s = (img, instruction_mat, (tx, hx, cx))
  394.  
  395.                     if greedy:
  396.                         a, (hx, cx) = eval_policy_model.select_greedy_action(s)
  397.                     else:
  398.                         a, (hx, cx) = eval_policy_model.select_action(s)
  399.                     (img, _), r, d, _ = eval_env.step(a)                
  400.                    
  401.                     if save_video: cur_episode_video_frames.append(img)
  402.                    
  403.                     img = torch.from_numpy(img).float()/255.0
  404.                     img = img.unsqueeze(0)
  405.                     img = img.to(self.device)
  406.  
  407.                     rs[-1] = rs[-1] +r
  408.                    
  409.                     if d :
  410.                        
  411.                         if r == CORRECT_OBJECT_REWARD:
  412.                             accuracy_list.append(1)
  413.                         else:
  414.                             accuracy_list.append(0)
  415.  
  416.                         break
  417.                    
  418.                     episode_timestep = episode_timestep + 1
  419.                 video_frames.append(cur_episode_video_frames)
  420.  
  421.         return np.mean(rs), np.std(rs), np.mean(accuracy_list), video_frames, instruction_list
  422.  
  423.     def get_cleaned_checkpoints(self, n_checkpoints=5):
  424.         try:
  425.             return self.checkpoint_paths
  426.         except AttributeError:
  427.             self.checkpoint_paths = {}
  428.  
  429.         paths = glob.glob(os.path.join(self.checkpoint_dir, '*.tar'))
  430.         paths_dic = {int(path.split('.')[-2]):path for path in paths}
  431.         last_ep = max(paths_dic.keys())
  432.         # checkpoint_idxs = np.geomspace(1, last_ep+1, n_checkpoints, endpoint=True, dtype=np.int)-1
  433.         checkpoint_idxs = np.linspace(1, last_ep+1, n_checkpoints, endpoint=True, dtype=np.int)-1
  434.  
  435.         for idx, path in paths_dic.items():
  436.             if idx in checkpoint_idxs:
  437.                 self.checget_cleaned_checkpointskpoint_paths[idx] = path
  438.             else:
  439.                 os.unlink(path)
  440.  
  441.         return self.checkpoint_paths
  442.  
  443.     def demo_last(self, args, title='Fully_Trained_Agent', save_video=False):
  444.         env = grounding_env.GroundingEnv(args)
  445.        
  446.         checkpoint_paths = self.get_cleaned_checkpoints()
  447.         last_ep = max(checkpoint_paths.keys())
  448.         self.ac_model.load_state_dict(torch.load(checkpoint_paths[last_ep]))
  449.  
  450.         _, _, _, video_frames, _ =self.evaluate(self.ac_model, env, save_video=save_video, n_episodes=1)
  451.         env.close()
  452.  
  453.         imageio.mimsave(title+'.gif', video_frames[0])
  454.  
  455.         del env
  456.        
  457.     def demo_progression(self, title='Episode_{}_Agent_progression', max_n_videos=5, save_video=False):
  458.         env = self.make_env_fn(**self.make_env_kargs, monitor_mode='evaluation', render=True, record=True)
  459.  
  460.         checkpoint_paths = self.get_cleaned_checkpoints()
  461.         for i in sorted(checkpoint_paths.keys()):
  462.             self.ac_model.load_state_dict(torch.load(checkpoint_paths[i]))
  463.             _, _, _, video_frames, _ =self.evaluate(self.ac_model, env, save_video=save_video, n_episodes=1)
  464.             imageio.mimsave(title.format(i)+'.gif', video_frames[0])
  465.  
  466.         env.close()
  467.         del env
  468.        
  469.     def save_checkpoint(self, episode_idx, model):
  470.         torch.save(model.state_dict(),
  471.                    os.path.join(self.checkpoint_dir, 'model.{}.tar'.format(episode_idx)))
  472.  
Add Comment
Please, Sign In to add comment