Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from backpack import backpack, extend
- from backpack.extensions import BatchGrad
- def iterate_dataset(dataset: Dataset, batch_size: int):
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
- for i, (batch_X, batch_y) in enumerate(loader):
- yield batch_X.cuda(), batch_y.cuda()
- def vectorize(tensor: Tensor):
- return tensor.reshape(tensor.shape[0], -1)
- def trainable_parameters(network: nn.Module):
- for param in network.parameters():
- if param.requires_grad:
- yield param
- def grad_batch2vec(network: nn.Module):
- vec = []
- for param in trainable_parameters(network):
- vec.append(vectorize(param.grad_batch).detach())
- result = torch.cat(vec, dim=1)
- del vec
- return result
- def compute_gauss_newton_vector_product(network: nn.Module, dataset: Dataset,
- vector: torch.tensor,
- batch_size: int = DEFAULT_BS, dtype=torch.float):
- prod = torch.zeros_like(vector, device='cpu', dtype=dtype)
- k = network(dataset[0][0].unsqueeze(0).cuda()).shape[1]
- vector = vector.cuda()
- network2 = extend(network)
- for (X, y) in iterate_dataset(dataset, batch_size):
- for c in range(k):
- batch_score = network2(X)[:, c].sum()
- with backpack(BatchGrad()):
- batch_score.backward()
- G = grad_batch2vec(network2).detach()
- prod += (G.t().mv(G.mv(vector)) / len(dataset)).detach().cpu()
- return prod.cuda()
- def lanczos(matrix_vector, p: int, k: int, warm_start=None, dtype=torch.float, maxiter=None, tol=0):
- def mv(vec: np.ndarray):
- gpu_vec = torch.tensor(vec, dtype=dtype).cuda()
- return matrix_vector(gpu_vec).cpu().numpy()
- # print(p, k, maxiter)
- np_dtype = np.float if dtype == torch.float else np.double
- operator = LinearOperator((p, p), matvec=mv, dtype=np_dtype)
- evals, evecs = eigsh(operator, k, v0=warm_start, maxiter=maxiter, tol=tol)
- if dtype == torch.float:
- return torch.from_numpy(np.ascontiguousarray(evals[::-1])).float(), \
- torch.from_numpy(np.ascontiguousarray(np.flip(evecs, -1))).float()
- elif dtype == torch.double:
- return torch.from_numpy(np.ascontiguousarray(evals[::-1])), \
- torch.from_numpy(np.ascontiguousarray(np.flip(evecs, -1)))
- def get_gn_eigenvalues(network: nn.Module, dataset, ncomponents=6):
- gn_vp_delta = lambda delta: compute_gauss_newton_vector_product(network, dataset, delta, batch_size=1000).detach()
- Lam, V = lanczos(gn_vp_delta, nparams(network), k=ncomponents)
- return Lam, V
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement