Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class OutBlock(nn.Module):
- def __init__(self):
- super(OutBlock, self).__init__()
- self.conv = nn.Conv2d(128, 3, kernel_size=1) # value head
- self.bn = nn.BatchNorm2d(3)
- self.fc1 = nn.Linear(3*6*7, 32)
- self.fc2 = nn.Linear(32, 1)
- self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head
- self.bn1 = nn.BatchNorm2d(32)
- self.logsoftmax = nn.LogSoftmax(dim=1)
- self.fc = nn.Linear(6*7*32, 7)
- def forward(self,s):
- v = F.relu(self.bn(self.conv(s))) # value head
- v = v.view(-1, 3*6*7) # batch_size X channel X height X width
- v = F.relu(self.fc1(v))
- v = torch.tanh(self.fc2(v))
- p = F.relu(self.bn1(self.conv1(s))) # policy head
- p = p.view(-1, 6*7*32)
- p = self.fc(p)
- p = self.logsoftmax(p).exp()
- return p, v
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement