Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- @torch.jit.weak_module
- class Weak(torch.nn.Module):
- def __init__(self, in_features, out_features):
- super(Weak, self).__init__()
- self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
- @torch.jit.weak_script_method
- def forward(self, x):
- return F.linear(x, self.weight)
- class Strong(torch.jit.ScriptModule):
- def __init__(self, weak):
- super(Strong, self).__init__()
- self.weak = weak
- @torch.jit.script_method
- def forward(self, x):
- return self.weak(x)
- inp = torch.ones(5, 5) * 5
- weak_mod = Weak(5, 5)
- strong_mod = Strong(weak_mod)
- weak_mod.weight.data += torch.ones(5, 5) * 100
- strong_mod(inp).allclose(weak_mod(inp)) # True
- weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
- strong_mod(inp).allclose(weak_mod(inp)) # False
Add Comment
Please, Sign In to add comment