Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/onmt/Models.py b/onmt/Models.py
- index c52cfdc..241e052 100644
- --- a/onmt/Models.py
- +++ b/onmt/Models.py
- @@ -443,9 +443,11 @@ class DecoderState(object):
- Detaches all Variables from the graph
- that created it, making it a leaf.
- """
- - for h in self._all:
- + def f(h):
- if h is not None:
- - h.detach_()
- + return h.detach()
- + return self.__new__(
- + type(self), f(self.input_feed), self.hidden_size, tuple(f(x) for x in self.hidden))
- def beam_update(self, idx, positions, beam_size):
- """ Update when beam advances. """
- @@ -473,6 +475,7 @@ class RNNDecoderState(DecoderState):
- else:
- self.hidden = rnnstate
- self.coverage = None
- + self.hidden_size = hidden_size
- # Init the input feed.
- batch_size = context.size(1)
- diff --git a/onmt/Trainer.py b/onmt/Trainer.py
- index 78bd439..c810194 100644
- --- a/onmt/Trainer.py
- +++ b/onmt/Trainer.py
- @@ -130,7 +130,7 @@ class Trainer(object):
- # If truncated, don't backprop fully.
- if dec_state is not None:
- - dec_state.detach()
- + dec_state = dec_state.detach()
- if report_func is not None:
- report_stats = report_func(
Add Comment
Please, Sign In to add comment