Advertisement
Guest User

Untitled

a guest
May 26th, 2019
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.57 KB | None | 0 0
  1. import math
  2. import torch
  3. import torch.nn as nn
  4.  
  5.  
  6. class SparseLinear(nn.Module):
  7. def __init__(self, in_features, out_features, bias=True):
  8. super(SparseLinear, self).__init__()
  9. self.in_features = in_features
  10. self.out_features = out_features
  11. self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
  12. if bias:
  13. self.bias = nn.Parameter(torch.Tensor(out_features))
  14. else:
  15. self.register_parameter('bias', None)
  16. self.reset_parameters()
  17.  
  18. def reset_parameters(self):
  19. nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  20. if self.bias is not None:
  21. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
  22. bound = 1 / math.sqrt(fan_in)
  23. nn.init.uniform_(self.bias, -bound, bound)
  24.  
  25. def forward(self, input_sparse_tensor):
  26. return torch.sparse.mm(input_sparse_tensor, self.weight) + self.bias
  27.  
  28. def extra_repr(self):
  29. return 'in_features={}, out_features={}, bias={}'.format(
  30. self.in_features, self.out_features, self.bias is not None
  31. )
  32.  
  33.  
  34. if __name__ == '__main__':
  35. i = torch.LongTensor([[0, 1, 1],
  36. [2, 0, 2]])
  37. v = torch.FloatTensor([3, 4, 5])
  38. input_tensor = torch.sparse.FloatTensor(i, v, torch.Size([2, 3])).requires_grad_(True)
  39.  
  40. layer = SparseLinear(3, 5)
  41.  
  42. output_tensor = layer(input_tensor)
  43. print('Size: {}'.format(output_tensor.size()))
  44.  
  45. target = output_tensor.sum()
  46. target.backward()
  47.  
  48. print('Gradients: {}'.format(input_tensor.grad))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement