Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import numpy as np
- from collections import defaultdict
- from pae.experiments.toy.gaussians.masking import make_possible_masks
- from pae.experiments.toy.gaussians.data import make_conditional_operators, compute_unconditional_moment
- from pae.utils import to_numpy
- import pandas as pd
- import os
- def test(specs, vars_, nets, monitors, all_masks, cond_operators, **kwargs):
- nets['proj'].eval()
- all_masks_dict = defaultdict(set)
- for maska, maskr in all_masks.tolist():
- if sum(maska) != 0:
- all_masks_dict[tuple(maskr)].add(tuple(maska))
- all_results = []
- for maskr_ in all_masks_dict:
- for maska_ in all_masks_dict[maskr_]:
- for idx, x in enumerate(vars_['eval_data']):
- # if specs.one_example_per_batch:
- x = x.view(1, -1).repeat(specs.batch_size, 1)
- # else:
- # raise NotImplementedError
- # if specs.one_mask_per_batch:
- masksa = torch.Tensor(maska_).view(1,
- -1).float().to(specs.device).repeat(
- specs.batch_size, 1)
- masksr = torch.Tensor(maskr_).view(1,
- -1).float().to(specs.device).repeat(
- specs.batch_size, 1)
- # else:
- # raise NotImplementedError
- xa = x * masksa
- # Probly will have to change this line
- xp, z, zmu, logvar = nets['proj'](xa, masksa, masksr)
- maska = masksa[0].long()
- maskr = masksr[0].long()
- xp_ = xp[:, maskr.nonzero()] # Batch of columns
- xa_ = xa[:, maska.nonzero()] # Batch of columns
- # Fetching conditional moments
- cond_ops = cond_operators[maskr_][maska_]
- mu_const = cond_ops['mu_const']
- mu_reg = cond_ops['mu_reg']
- sigma = cond_ops['sigma']
- # Conditional likelihood
- # mu = mu_const + (xa_ @ mu_reg.t())
- mu = mu_const + (mu_reg @ xa_)
- # if specs.one_example_per_batch:
- mean_err = (xp_.mean(dim=0) - mu.mean(dim=0)).pow(2).sum().sqrt()
- # else:
- # raise NotImplementedError
- xp_mu_hat = xp_ - xp_.mean(dim=0, keepdim=True)
- sigma_hat = (xp_mu_hat.squeeze(-1).t().matmul(
- xp_mu_hat.squeeze(-1))).div_(xp_mu_hat.shape[0] - 1.)
- if specs.variance_correction:
- raise NotImplementedError
- cov_err = (sigma - sigma_hat).pow(2).sum().sqrt()
- errs = {
- 'maska': maska_, 'maskr': maskr_, 'idx': idx,
- 'mu': to_numpy(mean_err), 'sigma': to_numpy(cov_err)
- }
- all_results.append(errs)
- testdf = pd.DataFrame(all_results)
- # Saving data frame as csv
- if not specs.dry_run:
- testdf.to_csv(os.path.join(specs.results_dir, 'results.csv'))
- return testdf
- if __name__ == '__main__':
- data_dim = 3
- means = [2., 4., 6.]
- variances = [1., 1., 1.]
- corrs = [0.5, 0.25, 0.]
- all_masks = make_possible_masks(
- data_dim, complementary_masks=True, no_overlap=True, dtype='float32')
- moms = compute_unconditional_moment(
- data_dim, means=means, variances=variances, corrs=corrs)
- cond_operators = make_conditional_operators(all_masks, moms['mu'],
- moms['sigma'], moms['lam'])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement