Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class netCompressed( Net):
- def __init__(self, net, r):
- super().__init__()
- layer_weight = self.compressing('fc1.weight')
- self.fc1 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
- self.fc1.weight = layer_weight
- layer_weight = self.compressing('fc2.weight')
- self.fc2 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
- self.fc2.weight = layer_weight
- layer_weight = self.compressing('fc3.weight')
- self.fc3 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
- self.fc3.weight = layer_weight
- layer_weight = self.compressing('fc4.weight')
- self.fc4 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
- self.fc4.weight = layer_weight
- layer_weight = self.compressing('fc5.weight')
- self.fc5 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
- self.fc5.weight = layer_weight
- layer_weight = self.compressing('fc6.weight')
- self.fc6 = nn.Linear(layer_weight.shape[0],layer_weight.shape[1],bias = False)
- self.fc6.weight = layer_weight
- self.ReLu = torch.ReLu
- def forward(self, x):
- x = self.fc1(x.view(-1, 3 * 32*32))
- x = self.ReLU(x)
- x = self.fc2(x)
- x = self.ReLU(x)
- x = self.fc3(x)
- x = self.ReLU(x)
- x = self.fc4(x)
- x = self.ReLU(x)
- x = self.fc5(x)
- x = self.ReLU(x)
- x = self.fc6(x)
- return F.log_softmax(x, dim=1)
- def compressing(key):
- u,s,v = torch.svd(net.state_dict()[key])
- main_s = s [:r]
- main_u = u[:r, : r]
- main_v = v[:r, :r]
- compressed_weight = torch.matmul(torch.matmul(main_u, torch.diag_embed(main_s)), main_v.transpose())
- return compressed_weight
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement