Guest User

Untitled

a guest
Oct 17th, 2018
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.80 KB | None | 0 0
  1. @torch.jit.weak_module
  2. class Weak(torch.nn.Module):
  3. def __init__(self, in_features, out_features):
  4. super(Weak, self).__init__()
  5. self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
  6.  
  7. @torch.jit.weak_script_method
  8. def forward(self, x):
  9. return F.linear(x, self.weight)
  10.  
  11. class Strong(torch.jit.ScriptModule):
  12. def __init__(self, weak):
  13. super(Strong, self).__init__()
  14. self.weak = weak
  15.  
  16. @torch.jit.script_method
  17. def forward(self, x):
  18. return self.weak(x)
  19.  
  20. inp = torch.ones(5, 5) * 5
  21. weak_mod = Weak(5, 5)
  22. strong_mod = Strong(weak_mod)
  23.  
  24.  
  25. weak_mod.weight.data += torch.ones(5, 5) * 100
  26. strong_mod(inp).allclose(weak_mod(inp)) # True
  27.  
  28. weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
  29. strong_mod(inp).allclose(weak_mod(inp)) # False
Add Comment
Please, Sign In to add comment