Advertisement
Guest User

Untitled

a guest
Oct 22nd, 2019
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.16 KB | None | 0 0
  1. for mname, m in model.named_modules():
  2.    
  3.     if isinstance(m, nn.BatchNorm2d):
  4.         # For main branches
  5.         prev_mname = mname.replace('bn','conv')
  6.        
  7.         # For residual connections
  8.         if prev_mname == mname:
  9.             prefix, suffix = mname.rsplit('.', 1)
  10.             prev_mname = prefix + '.' + str(int(suffix) - 1)
  11.        
  12.         prev_m = get_module(model, prev_mname)
  13.         print(mname, prev_mname)
  14.  
  15.         # If conv layer is decomposed
  16.         if isinstance(prev_m, nn.Sequential):
  17.             for lidx, (lname, l) in enumerate(prev_m.named_children()):
  18.                 if lidx < len(prev_m) - 1:
  19.                     # replace every conv exept last to ConvBN(conv = l, bn = None)
  20.                     print(lidx, lname, 'Wrap')
  21.                 else:
  22.                     # replace last conv to ConvBN(conv = l, bn = m)
  23.                     print(lidx, lname,  'Wrap & Merdge with ', mname)                
  24.  
  25.         # If conv layer is not decomposed
  26.         elif isinstance(prev_m, nn.Conv2d):
  27.             print(prev_m)
  28.            
  29.         else:
  30.             print("WTF?")
  31.            
  32.         # delete bn layer
  33.         print()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement