Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch.nn
- import geoopt
- # package.nn.modules.py
- def create_ball(ball=None, c=None):
- if ball is None:
- assert c is not None, "curvature of the ball should be explicitly specified"
- ball = geoopt.PoincareBall(c)
- elif not isinstance(ball, geoopt.PoincareBall):
- raise ValueError("ball should be an instance of PoncareMall")
- return ball
- class MobiusLinear(torch.nn.Linear):
- def __init__(self, *args, nonlin=None, ball=None, c=1.0, **kwargs):
- super().__init__(*args, **kwargs)
- self.ball = create_ball(ball, c)
- if self.bias is not None:
- self.bias = geoopt.ManifoldParameter(self.bias, manifold=self.ball)
- self.nonlin = nonlin
- self.reset_parameters()
- def forward(self, input):
- return mobius_linear(
- input,
- weight=self.weight,
- bias=self.bias,
- nonlin=self.nonlin,
- ball=self.ball,
- )
- @torch.no_grad()
- def reset_parameters(self):
- torch.nn.init.eye_(self.weight)
- self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))
- if self.bias is not None:
- self.bias.zero_()
- # package.nn.functional.py
- def mobius_linear(input, weight, bias=None, nonlin=None, *, ball: geoopt.PoincareBall):
- output = ball.mobius_matvec(weight, input)
- if bias is not None:
- output = ball.mobius_add(output, bias)
- if nonlin is not None:
- output = ball.mobius_fn_apply(nonlin, output)
- return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement