SHARE
TWEET

Untitled

a guest Jul 16th, 2019 56 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch.nn
  2. import geoopt
  3. # package.nn.modules.py
  4. def create_ball(ball=None, c=None):
  5.     if ball is None:
  6.         assert c is not None, "curvature of the ball should be explicitly specified"
  7.         ball = geoopt.PoincareBall(c)
  8.     elif not isinstance(ball, geoopt.PoincareBall):
  9.         raise ValueError("ball should be an instance of PoncareMall")
  10.     return ball
  11.  
  12. class MobiusLinear(torch.nn.Linear):
  13.     def __init__(self, *args, nonlin=None, ball=None, c=1.0, **kwargs):
  14.         super().__init__(*args, **kwargs)
  15.         self.ball = create_ball(ball, c)
  16.         if self.bias is not None:
  17.             self.bias = geoopt.ManifoldParameter(self.bias, manifold=self.ball)
  18.         self.nonlin = nonlin
  19.         self.reset_parameters()
  20.  
  21.     def forward(self, input):
  22.         return mobius_linear(
  23.             input,
  24.             weight=self.weight,
  25.             bias=self.bias,
  26.             nonlin=self.nonlin,
  27.             ball=self.ball,
  28.         )
  29.  
  30.     @torch.no_grad()
  31.     def reset_parameters(self):
  32.         torch.nn.init.eye_(self.weight)
  33.         self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))
  34.         if self.bias is not None:
  35.             self.bias.zero_()
  36.  
  37. # package.nn.functional.py
  38. def mobius_linear(input, weight, bias=None, nonlin=None, *, ball: geoopt.PoincareBall):
  39.     output = ball.mobius_matvec(weight, input)
  40.     if bias is not None:
  41.         output = ball.mobius_add(output, bias)
  42.     if nonlin is not None:
  43.         output = ball.mobius_fn_apply(nonlin, output)
  44.     return output
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top