SHARE
TWEET

Untitled

a guest Feb 16th, 2019 72 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class DQN(nn.Module):
  2.     def __init__(self, input_shape, n_actions):
  3.         super(DQN, self).__init__()
  4.  
  5.         self.conv = nn.Sequential(
  6.             nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
  7.             nn.ReLU(),
  8.             nn.Conv2d(32, 64, kernel_size=4, stride=2),
  9.             nn.ReLU(),
  10.             nn.Conv2d(64, 64, kernel_size=3, stride=1),
  11.             nn.ReLU()
  12.         )
  13.  
  14.         conv_out_size = self._get_conv_out(input_shape)
  15.         self.fc = nn.Sequential(
  16.             nn.Linear(conv_out_size, 512),
  17.             nn.ReLU(),
  18.             nn.Linear(512, n_actions)
  19.         )
  20.  
  21.     def _get_conv_out(self, shape):
  22.         o = self.conv(torch.zeros(1, *shape))
  23.         return int(np.prod(o.size()))
  24.  
  25.     def forward(self, x):
  26.         conv_out = self.conv(x).view(x.size()[0], -1)
  27.         return self.fc(conv_out)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top