Guest User

Untitled

a guest
Dec 19th, 2018
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.50 KB | None | 0 0
  1. class lstm(nn.Module):
  2.  
  3. def __init__(self,
  4. input_dim=None,
  5. hidden_dim=None,
  6. output_dim=None):
  7.  
  8. super(lstm, self).__init__()
  9. self.input_dim = input_dim
  10. self.hidden_dim = hidden_dim
  11. self.output_dim = output_dim
  12. self.initial_hidden = (None, None)
  13. self.lstmcell = nn.LSTM(input_size=input_dim,
  14. hidden_size=hidden_dim,
  15. batch_first=True)
  16. self.fc = nn.Linear(hidden_dim, output_dim)
  17.  
  18. def forward(self, x):
  19. h0 = Variable(torch.zeros(1, x.size(0), self.hidden_dim).cuda())
  20. c0 = Variable(torch.zeros(1, x.size(0), self.hidden_dim).cuda())
  21. self.initial_hidden = (h0, c0)
  22. output, _ = self.lstmcell(x, self.initial_hidden)
  23. y = self.fc(output[:, -1, :])
  24. return F.log_softmax(y)
  25.  
  26.  
  27. train_batch_size = 128
  28. valid_batch_size = 1000
  29. test_batch_size = 1000
  30.  
  31.  
  32. train_dataloader = DataLoader(dataset=train_dataset,
  33. batch_size=train_batch_size,
  34. shuffle=True,
  35. num_workers=1,
  36. drop_last=True)
  37.  
  38.  
  39. model = lstm(input_dim=28,
  40. hidden_dim=128,
  41. output_dim=28)
  42. model = nn.DataParallel(model)
  43.  
  44.  
  45. for epoch_index in range(5):
  46. model.train()
  47. error_sum = 0.
  48. for batch_index, (inputs, labels) in enumerate(train_dataloader):
  49. inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
  50. optimizer.zero_grad()
  51. outputs = model(inputs)
  52. error = F.nll_loss(outputs, labels)
  53. error_sum += batch_average_error.data[0]
  54. error.backward()
  55. optimizer.step()
Add Comment
Please, Sign In to add comment