Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class lstm_h_time(nn.Module):
- def __init__(self, input_f=12, hidden_size = 1024, lstm_history=48):
- super().__init__()
- self.hidden_size = hidden_size
- self.lstmcell = nn.LSTMCell(input_f, hidden_size) # input_feature, hidden_f
- # self.h2a = nn.Linear(hidden_size, hidden_size)
- self.h2y = nn.Linear(hidden_size,144)
- def forward(self, x):
- '''
- :param x: (batch_size, lstm_history, input_f)
- :return: (batch_size, 48, 3 or 2) depend on the city
- '''
- # print(x.size())
- batch_size, T, input_f = x.size()
- # detech hidden
- self.init_hidden(batch_size)
- hx = self.hx
- cx = self.cx
- for t in range(T):
- hx, cx = self.lstmcell(x[:, t, :], [hx, cx]) # (batch_size, hidden_size)
- y = self.h2y(hx) # (n, 144)
- return y
- def init_hidden(self, batch_size = 1):
- self.hx = Variable(torch.randn(batch_size, self.hidden_size)).cuda()
- self.cx = Variable(torch.randn(batch_size, self.hidden_size)).cuda()
- def predict(self, x):
- return self.forward(x)
Add Comment
Please, Sign In to add comment