Advertisement
Guest User

Untitled

a guest
Nov 13th, 2019
101
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.95 KB | None | 0 0
  1. class netCompressed( Net):
  2.     def __init__(self, net, r):
  3.         super().__init__()
  4.         layer_weight = self.compressing('fc1.weight')
  5.         self.fc1 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
  6.         self.fc1.weight = layer_weight
  7.                              
  8.         layer_weight = self.compressing('fc2.weight')
  9.         self.fc2 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
  10.         self.fc2.weight = layer_weight
  11.        
  12.         layer_weight = self.compressing('fc3.weight')
  13.         self.fc3 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
  14.         self.fc3.weight = layer_weight
  15.        
  16.         layer_weight = self.compressing('fc4.weight')
  17.         self.fc4 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
  18.         self.fc4.weight = layer_weight
  19.        
  20.         layer_weight = self.compressing('fc5.weight')
  21.         self.fc5 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
  22.         self.fc5.weight = layer_weight
  23.        
  24.         layer_weight = self.compressing('fc6.weight')
  25.         self.fc6 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
  26.         self.fc6.weight = layer_weight
  27.        
  28.         self.ReLu = torch.ReLu
  29.        
  30.     def forward(self, x):
  31.         x = self.fc1(x.view(-1, 3 * 32*32))
  32.         x = self.ReLU(x)
  33.         x = self.fc2(x)
  34.         x = self.ReLU(x)
  35.         x = self.fc3(x)
  36.         x = self.ReLU(x)
  37.         x = self.fc4(x)
  38.         x = self.ReLU(x)
  39.         x = self.fc5(x)
  40.         x = self.ReLU(x)
  41.         x = self.fc6(x)
  42.         return F.log_softmax(x, dim=1)
  43.    
  44.     def compressing(key):
  45.         u,s,v  = torch.svd(net.state_dict()[key])
  46.         main_s = s [:r]
  47.         main_u = u[:r, : r]
  48.         main_v = v[:r, :r]
  49.         compressed_weight = torch.matmul(torch.matmul(main_u, torch.diag_embed(main_s)), main_v.transpose())
  50.         return compressed_weight
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement