Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class DQN(nn.Module):
- def __init__(self, input_shape, n_actions):
- super(DQN, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
- nn.ReLU(),
- nn.Conv2d(32, 64, kernel_size=4, stride=2),
- nn.ReLU(),
- nn.Conv2d(64, 64, kernel_size=3, stride=1),
- nn.ReLU()
- )
- conv_out_size = self._get_conv_out(input_shape)
- self.fc = nn.Sequential(
- nn.Linear(conv_out_size, 512),
- nn.ReLU(),
- nn.Linear(512, n_actions)
- )
- def _get_conv_out(self, shape):
- o = self.conv(torch.zeros(1, *shape))
- return int(np.prod(o.size()))
- def forward(self, x):
- conv_out = self.conv(x).view(x.size()[0], -1)
- return self.fc(conv_out)
Add Comment
Please, Sign In to add comment