Advertisement
Guest User

Untitled

a guest
Nov 14th, 2019
150
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.21 KB | None | 0 0
  1. class netCompressed:
  2.     def __init__(self, net, r):
  3.         super().__init__()
  4.         self.r = r
  5.         layer_weight = self.compressing(net, 'fc1.weight')
  6.         self.fc1 = layer_weight
  7.         self.fc1_bias = net.state_dict()['fc1.bias']
  8.        
  9.                              
  10.         layer_weight = self.compressing(net, 'fc2.weight')
  11.         self.fc2 = layer_weight
  12.         self.fc2_bias = net.state_dict()['fc2.bias']
  13.        
  14.        
  15.         layer_weight = self.compressing(net,'fc3.weight')
  16.         self.fc3 = layer_weight
  17.         self.fc3_bias = net.state_dict()['fc3.bias']
  18.        
  19.        
  20.         layer_weight = self.compressing(net, 'fc4.weight')
  21.         self.fc4 = layer_weight
  22.         self.fc4_bias = net.state_dict()['fc4.bias']
  23.        
  24.        
  25.         layer_weight = self.compressing(net, 'fc5.weight')
  26.         self.fc5 = layer_weight
  27.         self.fc5_bias = net.state_dict()['fc5.bias']
  28.        
  29.        
  30.         layer_weight = self.compressing(net, 'fc6.weight')
  31.         self.fc6 = layer_weight
  32.         self.fc6_bias = net.state_dict()['fc6.bias']
  33.                                          
  34.     def forward(self, x):
  35.         x = x.view(-1, 3 * 32*32)
  36.        
  37.         for layer, b in zip([self.fc1, self.fc2, self.fc3, self.fc4, self.fc5, self.fc6],
  38.                            [self.fc1_bias, self.fc2_bias, self.fc3_bias, self.fc4_bias, self.fc5_bias, self.fc6_bias]):
  39.             u, s, v = layer
  40. #             print(u.shape, s.shape, v.shape)
  41.             x =  x@v
  42.             x = s.view(1, -1)*x
  43. #             print(x.shape)
  44.             x= x @ u.T
  45.             x = x + b
  46.             if layer is self.fc6:
  47.                 x = F.log_softmax(x, dim=1)
  48.             else:
  49.                 x = torch.relu(x)
  50.        
  51.        
  52.         return x
  53.    
  54.     def compressing(self, net, key):
  55.         weight = net.state_dict()[key]
  56. #         print(key, weight.shape)
  57.         u,s,v  = torch.svd(weight)
  58. #         print(u.shape, s.shape, v.shape)
  59.         main_s = s [:self.r]
  60.         main_u = u[:, :self.r]
  61.         main_v = v[:, :self.r]
  62. #         compressed_weight = torch.matmul(torch.matmul(main_u, torch.diag_embed(main_s)), main_v.transpose())
  63.         return main_u, main_s, main_v
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement