Advertisement
tryhardqaq

vgg11_transfer_to_vgg13

Apr 29th, 2023
760
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.99 KB | Source Code | 0 0
  1. # 建構各層轉換式
  2. def copy_weights(name, layer):
  3. if 'weight' in name:
  4. layer.weight.data.copy_(vgg11_state_dict[name + '.weight'])
  5. if 'bias' in name:
  6. layer.bias.data.copy_(vgg11_state_dict[name + '.bias'])
  7. # BN層除了weight和bias以外,還有running_mean、running_var、num_batches_tracked要複製
  8. if 'running_mean' in name:
  9. layer.running_mean.copy_(vgg11_state_dict[name + '.running_mean'])
  10. if 'running_var' in name:
  11. layer.running_var.copy_(vgg11_state_dict[name + '.running_var'])
  12. if 'num_batches_tracked' in name:
  13. layer.num_batches_tracked.copy_(vgg11_state_dict[name + '.num_batches_tracked'])
  14.  
  15. # VGG11和VGG13的結構(不含relu跟pooling)
  16. # 默認基本結構會是conv→BN→conv→BN不斷循環,最後以BN作結
  17. # 層數上,relu: +2 / maxpooling: +1,所以像B, 128中間就是隔了relu+max所以+3
  18. # 'VGG13': [ 64, B, 64, B, 128, B, 128, B, 256, B, 256, B, 512, B, 512, B, 512, B, 512, B]
  19. # cnn_13 : [ 0, 1, 3, 4, 7, 8, 10, 11, 14, 15, 17, 18, 21, 22, 24, 25, 28, 29, 31, 32]
  20. # 'VGG11': [ 64, B, 128, B, 256, B, 256, B, 512, B, 512, B, 512, B, 512, B]
  21. # cnn_11 : [ 0, 1, 4, 5, 8, 9, 11, 12, 15, 16, 18, 19, 22, 23, 25, 26]
  22. # 承上,13對11的對應關係為:
  23. # vgg_13 0 1 7 8 14 15 17 18 21 22 24 25 28 29 31 32
  24. # vgg_11 0 1 4 5 8 9 11 12 15 16 18 19 22 23 25 26
  25.  
  26. for name, layer in new_model.named_modules():
  27. # features的部分
  28. if 'features.0' in name:
  29. copy_weights('features.0',layer=layer)
  30. if 'features.1' in name:
  31. copy_weights('features.1',layer=layer)
  32. if 'features.7' in name:
  33. copy_weights('features.4',layer=layer)
  34. if 'features.8' in name:
  35. copy_weights('features.5',layer=layer)
  36. if 'features.14' in name:
  37. copy_weights('features.8',layer=layer)
  38. if 'features.15' in name:
  39. copy_weights('features.9',layer=layer)
  40. if 'features.17' in name:
  41. copy_weights('features.11',layer=layer)
  42. if 'features.18' in name:
  43. copy_weights('features.12',layer=layer)
  44. if 'features.21' in name:
  45. copy_weights('features.15',layer=layer)
  46. if 'features.22' in name:
  47. copy_weights('features.16',layer=layer)
  48. if 'features.24' in name:
  49. copy_weights('features.18',layer=layer)
  50. if 'features.25' in name:
  51. copy_weights('features.19',layer=layer)
  52. if 'features.28' in name:
  53. copy_weights('features.22',layer=layer)
  54. if 'features.29' in name:
  55. copy_weights('features.23',layer=layer)
  56. if 'features.31' in name:
  57. copy_weights('features.25',layer=layer)
  58. if 'features.32' in name:
  59. copy_weights('features.26',layer=layer)
  60. # classifier的部分
  61. if 'classifier.0' in name:
  62. copy_weights('classifier.0',layer=layer)
  63. if 'classifier.3' in name:
  64. copy_weights('classifier.3',layer=layer)
  65. if 'classifier.6' in name:
  66. copy_weights('classifier.6',layer=layer)
  67.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement