Advertisement
Guest User

models

a guest
Mar 4th, 2024
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.49 KB | Source Code | 0 0
  1. import torch
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5.  
  6. def normalized_columns_initializer(weights, std=1.0):
  7.     out = torch.randn(weights.size())
  8.     out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True).expand_as(out))
  9.    
  10.     return out
  11.  
  12. def weights_init(m):
  13.     classname =m.__class__.__name__
  14.  
  15.     if classname.find('Conv') != -1:
  16.         weight_shape = list(m.weight.data.size())
  17.         fan_in = np.prod(weight_shape[1:4])
  18.         fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
  19.         w_bound = np.sqrt(6. / (fan_in  + fan_out))
  20.         m.weight.data.uniform_(-w_bound, w_bound)
  21.         m.bias.data.fill_(0)
  22.     elif classname.find('Linear') != -1:
  23.         weight_shape = list(m.weight.data.size())
  24.         fan_in = weight_shape[1]
  25.         fan_out = weight_shape[0]
  26.         w_bound = np.sqrt(6. / (fan_in + fan_out))
  27.         m.weight.data.uniform_(-w_bound, w_bound)
  28.         m.bias.data.fill_(0)
  29.  
  30.  
  31. class LSTM_GA(nn.Module):
  32.     def __init__(self, max_n_steps, vocab_size):
  33.         super(LSTM_GA, self).__init__()
  34.  
  35.         # Image Processing
  36.         self.conv1 = nn.Conv2d(3, 128, kernel_size=8, stride=4)
  37.         self.conv2 = nn.Conv2d(128, 64, kernel_size=4, stride=2)
  38.         self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=2)
  39.  
  40.         # Instruction Processing
  41.         self.gru_hidden_size = 256
  42.         self.input_size = vocab_size
  43.         self.embedding = nn.Embedding(self.input_size, 32)
  44.         self.gru = nn.GRU(32, self.gru_hidden_size, batch_first=True)
  45.  
  46.         # Gated-Attention Layers
  47.         self.attn_linear = nn.Linear(self.gru_hidden_size, 64)
  48.  
  49.         # Time-Embedding Layer, helps in stabilizing value prediction
  50.         self.time_emb_dim = 32
  51.         self.time_emb_layer = nn.Embedding(
  52.             max_n_steps+1,
  53.             self.time_emb_dim
  54.         )
  55.  
  56.         # A2C-LSTM layers
  57.         self.linear = nn.Linear(64*8*17, 256)
  58.         self.lstm = nn.LSTMCell(256, 256)
  59.         self.critic_linear = nn.Linear(256 + self.time_emb_dim, 1)
  60.         self.actor_linear = nn.Linear(256 + self.time_emb_dim, 3)
  61.  
  62.  
  63.         # Initializing weights
  64.         self.apply(weights_init)
  65.         self.actor_linear.weight.data = normalized_columns_initializer(
  66.             self.actor_linear.weight.data, 0.01
  67.         )
  68.         self.actor_linear.bias.data.fill_(0)
  69.         self.critic_linear.weight.data = normalized_columns_initializer(
  70.             self.critic_linear.weight.data, 1.0
  71.         )
  72.         self.critic_linear.bias.data.fill_(0)
  73.  
  74.         self.lstm.bias_ih.data.fill_(0)
  75.         self.lstm.bias_hh.data.fill_(0)
  76.         # self.train()
  77.        
  78.  
  79.     def _format(self, inputs):
  80.         pass
  81.  
  82.     def forward(self, inputs):
  83.         x, input_inst, (tx, hx, cx) = inputs
  84.         # print(x.size(),)
  85.         n_workers = x.size(0)
  86.  
  87.         # Get the image representation
  88.         x = F.relu(self.conv1(x), inplace=False)
  89.         x = F.relu(self.conv2(x), inplace=False)
  90.         x_image_rep = F.relu(self.conv3(x), inplace=False)
  91.  
  92.         # Get the instruction representation
  93.         # encoder_hidden = torch.zeros(1, 1, self.gru_hidden_size)
  94.         ## Check here, we will be sending the entire embedded instruction matrix
  95.         ## and not create it in a loop, we need to send padded instructions
  96.         encoder_hidden = torch.zeros(1, n_workers, self.gru_hidden_size)
  97.         embedded_instruction = self.embedding(input_inst)
  98.         _, encoder_hidden = self.gru(embedded_instruction, encoder_hidden)
  99.        
  100.         # x_instr_rep = encoder_hidden.view(encoder_hidden.size(1), -1)
  101.         x_instr_rep = encoder_hidden
  102.  
  103.         # Get the attention vector from the instruction representation
  104.         x_attention = F.sigmoid(self.attn_linear(x_instr_rep))
  105.  
  106.         # Gated-Attention
  107.         ## Need to change this as the number of instructions or environment has changed
  108.         ## increased to n_workers
  109.         x_attention = x_attention.squeeze(0).unsqueeze(2).unsqueeze(3)
  110.         x_attention = x_attention.expand(n_workers, 64, 8, 17)
  111.        
  112.         assert x_image_rep.size() == x_attention.size()
  113.  
  114.         x = x_image_rep * x_attention
  115.         x = x.view(x.size(0), -1)
  116.  
  117.         # A3C-LSTM
  118.         x = F.relu(self.linear(x), inplace=False)
  119.         new_hx, new_cx = self.lstm(x, (hx, cx))
  120.         time_emb = self.time_emb_layer(tx)
  121.         x = torch.cat((new_hx, time_emb.view(-1, self.time_emb_dim)), 1)
  122.  
  123.         return  self.actor_linear(x), self.critic_linear(x), (new_hx, new_cx)
  124.  
  125.     def full_pass(self, inputs):
  126.         '''
  127.        inputs: combination of images_batch, instruction_batch, hx, cx, tx;
  128.            every input is torch tensor of self.device type
  129.        outputs:
  130.            action: either scalar or numpy array depending on number of n_workers (Size: n_workers)
  131.            is_exploratory: boolean array (Size: n_workers)
  132.            logpa: torch tensor of self.device type
  133.            entropy: torch tensor of self.device type
  134.            value: torch tensor of self.device type
  135.            hx: torch tensor of self.device type
  136.            cx: torch tensor of self.device type
  137.        '''
  138.  
  139.         logits, value, (hx, cx) = self.forward(inputs)
  140.         dist = torch.distributions.Categorical(logits=logits)
  141.         action = dist.sample()
  142.         logpa = dist.log_prob(action).unsqueeze(-1)
  143.         entropy = dist.entropy().unsqueeze(-1)
  144.         action = action.item() if len(action) == 1 else action.detach().cpu().numpy()
  145.  
  146.         ## Check this once, look at axis
  147.         is_exploratory = action != np.argmax(logits.detach().cpu().numpy(), axis=1)
  148.         return action, is_exploratory, logpa, entropy, value, (hx, cx)
  149.  
  150.     def select_action(self, inputs):
  151.  
  152.         ## Check whether (hc, cx) is needed or not; looks like it is needed, but not sure
  153.         ## take the output from forward method
  154.         logits, _, (hx, cx) = self.forward(inputs)
  155.  
  156.         dist = torch.distributions.Categorical(logits=logits)
  157.         action = dist.sample()
  158.         action = action.item() if len(action) == 1 else action.detach().cpu().numpy()
  159.         return action, (hx, cx)
  160.  
  161.     def select_greedy_action(self, inputs):
  162.         ## Check whether (hc, cx) is needed or not; looks like it is needed, but not sure
  163.         ## take the output from forward method
  164.         logits, _, (hx, cx) = self.forward(inputs)
  165.  
  166.         return np.argmax(logits.detach().cpu().numpy()), (hx, cx)
  167.        
  168.     def evaluate_state(self, inputs):
  169.         _, value, (hx, cx) = self.forward(inputs)
  170.         return value, (hx, cx)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement