SHARE
TWEET

fuck you facebook

a guest May 20th, 2019 80 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def plot_grad_flow(model):
  2.     """
  3.     Plots the gradients flowing through different layers in the net during training.
  4.     Can be used for checking for possible gradient vanishing / exploding problems.
  5.     """
  6.     ave_grads = []
  7.     max_grads = []
  8.     layers = []
  9.     for n, p in model.named_parameters():
  10.         if (p.requires_grad) and ("bias" not in n and "norm" not in n):
  11.             layers.append(n[:n.find('.weight')])
  12.             ave_grads.append(p.grad.abs().mean())
  13.             max_grads.append(p.grad.abs().max())
  14.     plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
  15.     plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
  16.     plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
  17.     plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
  18.     plt.xlim(left=0, right=len(ave_grads))
  19.     plt.ylim(bottom=-0.001, top=0.02)  # zoom in on the lower gradient regions
  20.     plt.xlabel("Layers")
  21.     plt.ylabel("average gradient")
  22.     plt.title("Gradient flow")
  23.     plt.grid(True)
  24.     plt.legend([Line2D([0], [0], color="c", lw=4),
  25.                 Line2D([0], [0], color="b", lw=4),
  26.                 Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
  27.     plt.tight_layout()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top