Advertisement
Guest User

Untitled

a guest
Feb 14th, 2017
486
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.46 KB | None | 0 0
  1. import torch
  2. from torch.autograd import Variable
  3. import torch.nn as nn
  4. import numpy as np
  5.  
  6. N, D_in, H, D_out, num_class = 64, 1000, 100, 10, 4
  7. dtype = torch.FloatTensor
  8.  
  9. class Net(nn.Module):
  10.     def __init__(self):
  11.         super(Net, self).__init__()
  12.         self.linear1 = torch.nn.Linear(D_in, H)
  13.         self.linear2 = torch.nn.Linear(H, D_out)
  14.  
  15.     def forward(self, x):
  16.         h_relu = self.linear1(x).clamp(min=0)
  17.         y_pred = self.linear2(h_relu)
  18.         return y_pred, h_relu
  19.    
  20. # loss_function
  21. def CenterLoss(y, y_pred, centers):
  22.     centers_pred = center.index_select(0, y.long())
  23.     difference   = y_pred - centers_pred
  24.     loss         = difference.pow(2).sum() / (2*  y.size()[0])
  25.     return loss
  26.  
  27. # Input
  28. x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
  29. # Output
  30. y = Variable(torch.Tensor(np.random.randint(0, num_class-1, size=N)).type(dtype), requires_grad=False)
  31. # Center-Variable
  32. center = Variable(torch.randn(num_class, H).type(dtype), requires_grad=True)
  33. # Network
  34. net = Net()
  35. # Classification Criterion
  36. criterion = nn.CrossEntropyLoss()
  37.  
  38. # Optimizer
  39. optimizer = torch.optim.Adam([
  40.                 {'params': net.parameters()},
  41.                 {'params': [center]}
  42.             ], lr=1e-2)
  43.  
  44. # Forward Pass
  45. y_pred, features  = net(x)
  46. loss = CenterLoss(y, features, center) + criterion(y_pred, y.long())
  47.  
  48. # compute gradient and do SGD step
  49. optimizer.zero_grad()
  50. loss.backward()
  51. optimizer.step()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement