Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def sample(self, inputs, states=None, max_len=20):
- out = []
- for i in range(max_len):
- lstm_out, states = self.lstm(inputs, states)
- lstm_out = torch.squeeze(lstm_out, 1)
- linear_out = self.linear(lstm_out)
- word = linear_out.max(1)[1]
- out.append(word.item())
- inputs = torch.unsqueeze(self.embed(word), 1)
- return out
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement