Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- for mname, m in model.named_modules():
- if isinstance(m, nn.BatchNorm2d):
- # For main branches
- prev_mname = mname.replace('bn','conv')
- # For residual connections
- if prev_mname == mname:
- prefix, suffix = mname.rsplit('.', 1)
- prev_mname = prefix + '.' + str(int(suffix) - 1)
- prev_m = get_module(model, prev_mname)
- print(mname, prev_mname)
- # If conv layer is decomposed
- if isinstance(prev_m, nn.Sequential):
- for lidx, (lname, l) in enumerate(prev_m.named_children()):
- if lidx < len(prev_m) - 1:
- # replace every conv exept last to ConvBN(conv = l, bn = None)
- print(lidx, lname, 'Wrap')
- else:
- # replace last conv to ConvBN(conv = l, bn = m)
- print(lidx, lname, 'Wrap & Merdge with ', mname)
- # If conv layer is not decomposed
- elif isinstance(prev_m, nn.Conv2d):
- print(prev_m)
- else:
- print("WTF?")
- # delete bn layer
- print()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement