Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- a = torch.rand(4, 4, 4)
- i = a[:, :, 0] > 0.5
- b = a[i]
- print(b.unsqueeze(dim=1))
- tensor([[[0.9476, 0.3862, 0.4544, 0.5905]],
- [[0.9413, 0.9987, 0.6411, 0.6876]],
- [[0.5807, 0.6687, 0.0952, 0.1582]],
- [[0.6057, 0.6513, 0.4329, 0.2501]],
- [[0.8998, 0.4524, 0.9219, 0.0447]]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement