Advertisement
Guest User

Untitled

a guest
Apr 25th, 2019
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.85 KB | None | 0 0
  1. class OutBlock(nn.Module):
  2. def __init__(self):
  3. super(OutBlock, self).__init__()
  4. self.conv = nn.Conv2d(128, 3, kernel_size=1) # value head
  5. self.bn = nn.BatchNorm2d(3)
  6. self.fc1 = nn.Linear(3*6*7, 32)
  7. self.fc2 = nn.Linear(32, 1)
  8.  
  9. self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head
  10. self.bn1 = nn.BatchNorm2d(32)
  11. self.logsoftmax = nn.LogSoftmax(dim=1)
  12. self.fc = nn.Linear(6*7*32, 7)
  13.  
  14. def forward(self,s):
  15. v = F.relu(self.bn(self.conv(s))) # value head
  16. v = v.view(-1, 3*6*7) # batch_size X channel X height X width
  17. v = F.relu(self.fc1(v))
  18. v = torch.tanh(self.fc2(v))
  19.  
  20. p = F.relu(self.bn1(self.conv1(s))) # policy head
  21. p = p.view(-1, 6*7*32)
  22. p = self.fc(p)
  23. p = self.logsoftmax(p).exp()
  24. return p, v
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement