Guest User

Untitled

a guest
Aug 22nd, 2019
82
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. """
  2. pytorch CrossEntropyLoss 用法
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import math
  7. loss = nn.CrossEntropyLoss()
  8. input = torch.randn(1, 5, requires_grad=True)
  9. target = torch.empty(1, dtype=torch.long).random_(5)
  10. output = loss(input, target)
  11.  
  12. print("输入为5类:")
  13. print(input)
  14. print("要计算loss的类别:")
  15. print(target)
  16. print("计算loss的结果:")
  17. print(output)
  18.  
  19. first = 0
  20. for i in range(1):
  21. first -= input[i][target[i]]
  22. second = 0
  23. for i in range(1):
  24. for j in range(5):
  25. second += math.exp(input[i][j])
  26. res = 0
  27. res += first +math.log(second)
  28. print("自己的计算结果:")
  29. print(res)
RAW Paste Data