Advertisement
Guest User

Untitled

a guest
Jul 16th, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.47 KB | None | 0 0
  1. import torch.nn
  2. import geoopt
  3. # package.nn.modules.py
  4. def create_ball(ball=None, c=None):
  5. if ball is None:
  6. assert c is not None, "curvature of the ball should be explicitly specified"
  7. ball = geoopt.PoincareBall(c)
  8. elif not isinstance(ball, geoopt.PoincareBall):
  9. raise ValueError("ball should be an instance of PoncareMall")
  10. return ball
  11.  
  12. class MobiusLinear(torch.nn.Linear):
  13. def __init__(self, *args, nonlin=None, ball=None, c=1.0, **kwargs):
  14. super().__init__(*args, **kwargs)
  15. self.ball = create_ball(ball, c)
  16. if self.bias is not None:
  17. self.bias = geoopt.ManifoldParameter(self.bias, manifold=self.ball)
  18. self.nonlin = nonlin
  19. self.reset_parameters()
  20.  
  21. def forward(self, input):
  22. return mobius_linear(
  23. input,
  24. weight=self.weight,
  25. bias=self.bias,
  26. nonlin=self.nonlin,
  27. ball=self.ball,
  28. )
  29.  
  30. @torch.no_grad()
  31. def reset_parameters(self):
  32. torch.nn.init.eye_(self.weight)
  33. self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))
  34. if self.bias is not None:
  35. self.bias.zero_()
  36.  
  37. # package.nn.functional.py
  38. def mobius_linear(input, weight, bias=None, nonlin=None, *, ball: geoopt.PoincareBall):
  39. output = ball.mobius_matvec(weight, input)
  40. if bias is not None:
  41. output = ball.mobius_add(output, bias)
  42. if nonlin is not None:
  43. output = ball.mobius_fn_apply(nonlin, output)
  44. return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement