Advertisement
Guest User

Untitled

a guest
Aug 22nd, 2019
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.60 KB | None | 0 0
  1. """
  2. pytorch CrossEntropyLoss 用法
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import math
  7. loss = nn.CrossEntropyLoss()
  8. input = torch.randn(1, 5, requires_grad=True)
  9. target = torch.empty(1, dtype=torch.long).random_(5)
  10. output = loss(input, target)
  11.  
  12. print("输入为5类:")
  13. print(input)
  14. print("要计算loss的类别:")
  15. print(target)
  16. print("计算loss的结果:")
  17. print(output)
  18.  
  19. first = 0
  20. for i in range(1):
  21. first -= input[i][target[i]]
  22. second = 0
  23. for i in range(1):
  24. for j in range(5):
  25. second += math.exp(input[i][j])
  26. res = 0
  27. res += first +math.log(second)
  28. print("自己的计算结果:")
  29. print(res)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement