Advertisement
Guest User

Untitled

a guest
Jul 19th, 2019
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.33 KB | None | 0 0
  1. import torch
  2.  
  3. a = torch.rand(4, 4, 4)
  4.  
  5. i = a[:, :, 0] > 0.5
  6. b = a[i]
  7. print(b.unsqueeze(dim=1))
  8.  
  9. tensor([[[0.9476, 0.3862, 0.4544, 0.5905]],
  10.  
  11. [[0.9413, 0.9987, 0.6411, 0.6876]],
  12.  
  13. [[0.5807, 0.6687, 0.0952, 0.1582]],
  14.  
  15. [[0.6057, 0.6513, 0.4329, 0.2501]],
  16.  
  17. [[0.8998, 0.4524, 0.9219, 0.0447]]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement