SHARE
TWEET

Untitled

a guest Aug 22nd, 2019 75 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. """
  2. pytorch CrossEntropyLoss 用法
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import math
  7. loss = nn.CrossEntropyLoss()
  8. input = torch.randn(1, 5, requires_grad=True)
  9. target = torch.empty(1, dtype=torch.long).random_(5)
  10. output = loss(input, target)
  11.  
  12. print("输入为5类:")
  13. print(input)
  14. print("要计算loss的类别:")
  15. print(target)
  16. print("计算loss的结果:")
  17. print(output)
  18.  
  19. first = 0
  20. for i in range(1):
  21.     first -= input[i][target[i]]
  22. second = 0
  23. for i in range(1):
  24.     for j in range(5):
  25.         second += math.exp(input[i][j])
  26. res = 0
  27. res += first +math.log(second)
  28. print("自己的计算结果:")
  29. print(res)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top