Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn.functional as F
- movement_filters = torch.Tensor([
- [
- [0, 1, 0],
- [0, 0, 0],
- [0, 0, 0],
- ],
- [
- [0, 0, 0],
- [0, 0, 1],
- [0, 0, 0],
- ],
- [
- [0, 0, 0],
- [0, 0, 0],
- [0, 1, 0],
- ],
- [
- [0, 0, 0],
- [1, 0, 0],
- [0, 0, 0],
- ],
- ]).unsqueeze(1).float()
- heads = torch.zeros((4, 1, 5, 5))
- # Place heads in center of 4 5x5 environments
- heads[:, 0, 2, 2] = 1
- # Each moves in a different cardinal direction
- actions_onehot = torch.zeros((4, 4))
- actions_onehot[torch.arange(4), torch.arange(4)] = 1
- intermediate = F.conv2d(heads, movement_filters, padding=1)
- heads = torch.einsum('bchw,bc->bhw', [intermediate, actions_onehot]).unsqueeze(1)
- print(heads[:, 0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement