Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # 建構各層轉換式
- def copy_weights(name, layer):
- if 'weight' in name:
- layer.weight.data.copy_(vgg11_state_dict[name + '.weight'])
- if 'bias' in name:
- layer.bias.data.copy_(vgg11_state_dict[name + '.bias'])
- # BN層除了weight和bias以外,還有running_mean、running_var、num_batches_tracked要複製
- if 'running_mean' in name:
- layer.running_mean.copy_(vgg11_state_dict[name + '.running_mean'])
- if 'running_var' in name:
- layer.running_var.copy_(vgg11_state_dict[name + '.running_var'])
- if 'num_batches_tracked' in name:
- layer.num_batches_tracked.copy_(vgg11_state_dict[name + '.num_batches_tracked'])
- # VGG11和VGG13的結構(不含relu跟pooling)
- # 默認基本結構會是conv→BN→conv→BN不斷循環,最後以BN作結
- # 層數上,relu: +2 / maxpooling: +1,所以像B, 128中間就是隔了relu+max所以+3
- # 'VGG13': [ 64, B, 64, B, 128, B, 128, B, 256, B, 256, B, 512, B, 512, B, 512, B, 512, B]
- # cnn_13 : [ 0, 1, 3, 4, 7, 8, 10, 11, 14, 15, 17, 18, 21, 22, 24, 25, 28, 29, 31, 32]
- # 'VGG11': [ 64, B, 128, B, 256, B, 256, B, 512, B, 512, B, 512, B, 512, B]
- # cnn_11 : [ 0, 1, 4, 5, 8, 9, 11, 12, 15, 16, 18, 19, 22, 23, 25, 26]
- # 承上,13對11的對應關係為:
- # vgg_13 0 1 7 8 14 15 17 18 21 22 24 25 28 29 31 32
- # vgg_11 0 1 4 5 8 9 11 12 15 16 18 19 22 23 25 26
- for name, layer in new_model.named_modules():
- # features的部分
- if 'features.0' in name:
- copy_weights('features.0',layer=layer)
- if 'features.1' in name:
- copy_weights('features.1',layer=layer)
- if 'features.7' in name:
- copy_weights('features.4',layer=layer)
- if 'features.8' in name:
- copy_weights('features.5',layer=layer)
- if 'features.14' in name:
- copy_weights('features.8',layer=layer)
- if 'features.15' in name:
- copy_weights('features.9',layer=layer)
- if 'features.17' in name:
- copy_weights('features.11',layer=layer)
- if 'features.18' in name:
- copy_weights('features.12',layer=layer)
- if 'features.21' in name:
- copy_weights('features.15',layer=layer)
- if 'features.22' in name:
- copy_weights('features.16',layer=layer)
- if 'features.24' in name:
- copy_weights('features.18',layer=layer)
- if 'features.25' in name:
- copy_weights('features.19',layer=layer)
- if 'features.28' in name:
- copy_weights('features.22',layer=layer)
- if 'features.29' in name:
- copy_weights('features.23',layer=layer)
- if 'features.31' in name:
- copy_weights('features.25',layer=layer)
- if 'features.32' in name:
- copy_weights('features.26',layer=layer)
- # classifier的部分
- if 'classifier.0' in name:
- copy_weights('classifier.0',layer=layer)
- if 'classifier.3' in name:
- copy_weights('classifier.3',layer=layer)
- if 'classifier.6' in name:
- copy_weights('classifier.6',layer=layer)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement