Guest User

Untitled

a guest
May 22nd, 2018
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.09 KB | None | 0 0
  1. class lstm_h_time(nn.Module):
  2. def __init__(self, input_f=12, hidden_size = 1024, lstm_history=48):
  3. super().__init__()
  4. self.hidden_size = hidden_size
  5. self.lstmcell = nn.LSTMCell(input_f, hidden_size) # input_feature, hidden_f
  6. # self.h2a = nn.Linear(hidden_size, hidden_size)
  7. self.h2y = nn.Linear(hidden_size,144)
  8.  
  9. def forward(self, x):
  10. '''
  11. :param x: (batch_size, lstm_history, input_f)
  12. :return: (batch_size, 48, 3 or 2) depend on the city
  13. '''
  14. # print(x.size())
  15. batch_size, T, input_f = x.size()
  16.  
  17. # detech hidden
  18. self.init_hidden(batch_size)
  19.  
  20. hx = self.hx
  21. cx = self.cx
  22. for t in range(T):
  23.  
  24. hx, cx = self.lstmcell(x[:, t, :], [hx, cx]) # (batch_size, hidden_size)
  25.  
  26.  
  27. y = self.h2y(hx) # (n, 144)
  28. return y
  29.  
  30. def init_hidden(self, batch_size = 1):
  31. self.hx = Variable(torch.randn(batch_size, self.hidden_size)).cuda()
  32. self.cx = Variable(torch.randn(batch_size, self.hidden_size)).cuda()
  33.  
  34. def predict(self, x):
  35. return self.forward(x)
Add Comment
Please, Sign In to add comment