Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class netCompressed:
- def __init__(self, net, r):
- super().__init__()
- self.r = r
- layer_weight = self.compressing(net, 'fc1.weight')
- self.fc1 = layer_weight
- self.fc1_bias = net.state_dict()['fc1.bias']
- layer_weight = self.compressing(net, 'fc2.weight')
- self.fc2 = layer_weight
- self.fc2_bias = net.state_dict()['fc2.bias']
- layer_weight = self.compressing(net,'fc3.weight')
- self.fc3 = layer_weight
- self.fc3_bias = net.state_dict()['fc3.bias']
- layer_weight = self.compressing(net, 'fc4.weight')
- self.fc4 = layer_weight
- self.fc4_bias = net.state_dict()['fc4.bias']
- layer_weight = self.compressing(net, 'fc5.weight')
- self.fc5 = layer_weight
- self.fc5_bias = net.state_dict()['fc5.bias']
- layer_weight = self.compressing(net, 'fc6.weight')
- self.fc6 = layer_weight
- self.fc6_bias = net.state_dict()['fc6.bias']
- def forward(self, x):
- x = x.view(-1, 3 * 32*32)
- for layer, b in zip([self.fc1, self.fc2, self.fc3, self.fc4, self.fc5, self.fc6],
- [self.fc1_bias, self.fc2_bias, self.fc3_bias, self.fc4_bias, self.fc5_bias, self.fc6_bias]):
- u, s, v = layer
- # print(u.shape, s.shape, v.shape)
- x = x@v
- x = s.view(1, -1)*x
- # print(x.shape)
- x= x @ u.T
- x = x + b
- if layer is self.fc6:
- x = F.log_softmax(x, dim=1)
- else:
- x = torch.relu(x)
- return x
- def compressing(self, net, key):
- weight = net.state_dict()[key]
- # print(key, weight.shape)
- u,s,v = torch.svd(weight)
- # print(u.shape, s.shape, v.shape)
- main_s = s [:self.r]
- main_u = u[:, :self.r]
- main_v = v[:, :self.r]
- # compressed_weight = torch.matmul(torch.matmul(main_u, torch.diag_embed(main_s)), main_v.transpose())
- return main_u, main_s, main_v
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement