Guest User

Untitled

a guest
Dec 14th, 2017
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.30 KB | None | 0 0
  1. diff --git a/onmt/Models.py b/onmt/Models.py
  2. index c52cfdc..241e052 100644
  3. --- a/onmt/Models.py
  4. +++ b/onmt/Models.py
  5. @@ -443,9 +443,11 @@ class DecoderState(object):
  6. Detaches all Variables from the graph
  7. that created it, making it a leaf.
  8. """
  9. - for h in self._all:
  10. + def f(h):
  11. if h is not None:
  12. - h.detach_()
  13. + return h.detach()
  14. + return self.__new__(
  15. + type(self), f(self.input_feed), self.hidden_size, tuple(f(x) for x in self.hidden))
  16.  
  17. def beam_update(self, idx, positions, beam_size):
  18. """ Update when beam advances. """
  19. @@ -473,6 +475,7 @@ class RNNDecoderState(DecoderState):
  20. else:
  21. self.hidden = rnnstate
  22. self.coverage = None
  23. + self.hidden_size = hidden_size
  24.  
  25. # Init the input feed.
  26. batch_size = context.size(1)
  27. diff --git a/onmt/Trainer.py b/onmt/Trainer.py
  28. index 78bd439..c810194 100644
  29. --- a/onmt/Trainer.py
  30. +++ b/onmt/Trainer.py
  31. @@ -130,7 +130,7 @@ class Trainer(object):
  32.  
  33. # If truncated, don't backprop fully.
  34. if dec_state is not None:
  35. - dec_state.detach()
  36. + dec_state = dec_state.detach()
  37.  
  38. if report_func is not None:
  39. report_stats = report_func(
Add Comment
Please, Sign In to add comment