Advertisement
Guest User

Untitled

a guest
Aug 20th, 2019
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.97 KB | None | 0 0
  1. class RNN(nn.Module):
  2. def __init__(self, input_size, hidden_size1, hidden_size2, final_layer, num_layers, output_size, batch_size):
  3. super(RNN, self).__init__()
  4. self.hidden_size1 = hidden_size1
  5. self.hidden_size2 = hidden_size2
  6. self.batch_size = batch_size
  7. self.num_layers = num_layers
  8. self.lstm = nn.LSTM(input_size, hidden_size1, num_layers, batch_first=True).cuda()
  9. self.dense1 = nn.Linear(hidden_size1, output_size).cuda()
  10. self.tanh1 = nn.Tanh()
  11. self.dense2 = nn.Linear(hidden_size2, final_layer).cuda()
  12. self.tanh2 = nn.Tanh()
  13. self.dense3 = nn.Linear(final_layer, 1).cuda()
  14. #self.fc = nn.Linear(hidden)
  15. def forward(self, x):
  16. # Set initial hidden and cell states
  17. h0 = torch.zeros(self.num_layers, self.batch_size, self.hidden_size1).to(device)
  18. c0 = torch.zeros(self.num_layers, self.batch_size, self.hidden_size1).to(device)
  19. # print(h0)
  20. # print(x.shape)
  21. # print(c0)
  22. #y = np.zeros((self.batch_size, input_size, input_size))
  23. x = x.view(self.batch_size, x.shape[1], 1)
  24. hidden = (h0, c0)
  25. # Forward propagate LSTM
  26.  
  27. #print("h0.shape and c0.shape" + str(h0.shape))
  28. #print("x.shape" + str(x.shape))
  29. out, _ = self.lstm(x, hidden) # out: tensor of shape (batch_size, seq_length, hidden_size)
  30. print("out.shape:")
  31. print(out.shape)
  32. # Decode the hidden state of the last time step
  33. #print("after lstm:" + str(out.shape))
  34. out = self.dense1(out)
  35. out = self.tanh1(out)
  36. #print("after first dense layer:" + str(out.shape))
  37. out = out.view(batch_size, x.shape[1])
  38. out = self.dense2(out)
  39. out = self.tanh2(out)
  40. #print("after second dense layer:" + str(out.shape))
  41. #print(out.view(-1).shape)
  42. out = self.dense3(out)
  43. out = out.view(batch_size)
  44. return out
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement