Guest User

Untitled

a guest
Feb 16th, 2019
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.82 KB | None | 0 0
  1. class DQN(nn.Module):
  2. def __init__(self, input_shape, n_actions):
  3. super(DQN, self).__init__()
  4.  
  5. self.conv = nn.Sequential(
  6. nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
  7. nn.ReLU(),
  8. nn.Conv2d(32, 64, kernel_size=4, stride=2),
  9. nn.ReLU(),
  10. nn.Conv2d(64, 64, kernel_size=3, stride=1),
  11. nn.ReLU()
  12. )
  13.  
  14. conv_out_size = self._get_conv_out(input_shape)
  15. self.fc = nn.Sequential(
  16. nn.Linear(conv_out_size, 512),
  17. nn.ReLU(),
  18. nn.Linear(512, n_actions)
  19. )
  20.  
  21. def _get_conv_out(self, shape):
  22. o = self.conv(torch.zeros(1, *shape))
  23. return int(np.prod(o.size()))
  24.  
  25. def forward(self, x):
  26. conv_out = self.conv(x).view(x.size()[0], -1)
  27. return self.fc(conv_out)
Add Comment
Please, Sign In to add comment