Advertisement
Guest User

Untitled

a guest
Sep 22nd, 2023
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.64 KB | None | 0 0
  1. from backpack import backpack, extend
  2. from backpack.extensions import BatchGrad
  3.  
  4. def iterate_dataset(dataset: Dataset, batch_size: int):
  5. loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  6. for i, (batch_X, batch_y) in enumerate(loader):
  7. yield batch_X.cuda(), batch_y.cuda()
  8.  
  9. def vectorize(tensor: Tensor):
  10. return tensor.reshape(tensor.shape[0], -1)
  11.  
  12. def trainable_parameters(network: nn.Module):
  13. for param in network.parameters():
  14. if param.requires_grad:
  15. yield param
  16.  
  17. def grad_batch2vec(network: nn.Module):
  18. vec = []
  19. for param in trainable_parameters(network):
  20. vec.append(vectorize(param.grad_batch).detach())
  21. result = torch.cat(vec, dim=1)
  22. del vec
  23. return result
  24.  
  25.  
  26. def compute_gauss_newton_vector_product(network: nn.Module, dataset: Dataset,
  27. vector: torch.tensor,
  28. batch_size: int = DEFAULT_BS, dtype=torch.float):
  29. prod = torch.zeros_like(vector, device='cpu', dtype=dtype)
  30. k = network(dataset[0][0].unsqueeze(0).cuda()).shape[1]
  31. vector = vector.cuda()
  32. network2 = extend(network)
  33. for (X, y) in iterate_dataset(dataset, batch_size):
  34. for c in range(k):
  35. batch_score = network2(X)[:, c].sum()
  36. with backpack(BatchGrad()):
  37. batch_score.backward()
  38. G = grad_batch2vec(network2).detach()
  39. prod += (G.t().mv(G.mv(vector)) / len(dataset)).detach().cpu()
  40. return prod.cuda()
  41.  
  42. def lanczos(matrix_vector, p: int, k: int, warm_start=None, dtype=torch.float, maxiter=None, tol=0):
  43. def mv(vec: np.ndarray):
  44. gpu_vec = torch.tensor(vec, dtype=dtype).cuda()
  45. return matrix_vector(gpu_vec).cpu().numpy()
  46.  
  47. # print(p, k, maxiter)
  48. np_dtype = np.float if dtype == torch.float else np.double
  49. operator = LinearOperator((p, p), matvec=mv, dtype=np_dtype)
  50. evals, evecs = eigsh(operator, k, v0=warm_start, maxiter=maxiter, tol=tol)
  51.  
  52.  
  53. if dtype == torch.float:
  54. return torch.from_numpy(np.ascontiguousarray(evals[::-1])).float(), \
  55. torch.from_numpy(np.ascontiguousarray(np.flip(evecs, -1))).float()
  56. elif dtype == torch.double:
  57. return torch.from_numpy(np.ascontiguousarray(evals[::-1])), \
  58. torch.from_numpy(np.ascontiguousarray(np.flip(evecs, -1)))
  59.  
  60.  
  61. def get_gn_eigenvalues(network: nn.Module, dataset, ncomponents=6):
  62. gn_vp_delta = lambda delta: compute_gauss_newton_vector_product(network, dataset, delta, batch_size=1000).detach()
  63. Lam, V = lanczos(gn_vp_delta, nparams(network), k=ncomponents)
  64. return Lam, V
  65.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement