Advertisement
Guest User

Untitled

a guest
Jan 23rd, 2019
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.30 KB | None | 0 0
  1. import torch
  2. import numpy as np
  3. from collections import defaultdict
  4. from pae.experiments.toy.gaussians.masking import make_possible_masks
  5. from pae.experiments.toy.gaussians.data import make_conditional_operators, compute_unconditional_moment
  6. from pae.utils import to_numpy
  7. import pandas as pd
  8. import os
  9.  
  10.  
  11. def test(specs, vars_, nets, monitors, all_masks, cond_operators, **kwargs):
  12.   nets['proj'].eval()
  13.   all_masks_dict = defaultdict(set)
  14.   for maska, maskr in all_masks.tolist():
  15.     if sum(maska) != 0:
  16.       all_masks_dict[tuple(maskr)].add(tuple(maska))
  17.  
  18.   all_results = []
  19.   for maskr_ in all_masks_dict:
  20.     for maska_ in all_masks_dict[maskr_]:
  21.       for idx, x in enumerate(vars_['eval_data']):
  22.         # if specs.one_example_per_batch:
  23.         x = x.view(1, -1).repeat(specs.batch_size, 1)
  24.         # else:
  25.         #   raise NotImplementedError
  26.  
  27.         # if specs.one_mask_per_batch:
  28.         masksa = torch.Tensor(maska_).view(1,
  29.                                            -1).float().to(specs.device).repeat(
  30.                                                specs.batch_size, 1)
  31.         masksr = torch.Tensor(maskr_).view(1,
  32.                                            -1).float().to(specs.device).repeat(
  33.                                                specs.batch_size, 1)
  34.         # else:
  35.         #   raise NotImplementedError
  36.  
  37.         xa = x * masksa
  38.         # Probly will have to change this line
  39.         xp, z, zmu, logvar = nets['proj'](xa, masksa, masksr)
  40.  
  41.         maska = masksa[0].long()
  42.         maskr = masksr[0].long()
  43.  
  44.         xp_ = xp[:, maskr.nonzero()]  # Batch of columns
  45.         xa_ = xa[:, maska.nonzero()]  # Batch of columns
  46.  
  47.         # Fetching conditional moments
  48.         cond_ops = cond_operators[maskr_][maska_]
  49.         mu_const = cond_ops['mu_const']
  50.         mu_reg = cond_ops['mu_reg']
  51.         sigma = cond_ops['sigma']
  52.  
  53.         # Conditional likelihood
  54.         # mu = mu_const + (xa_ @ mu_reg.t())
  55.         mu = mu_const + (mu_reg @ xa_)
  56.         # if specs.one_example_per_batch:
  57.         mean_err = (xp_.mean(dim=0) - mu.mean(dim=0)).pow(2).sum().sqrt()
  58.         # else:
  59.         #   raise NotImplementedError
  60.  
  61.         xp_mu_hat = xp_ - xp_.mean(dim=0, keepdim=True)
  62.         sigma_hat = (xp_mu_hat.squeeze(-1).t().matmul(
  63.             xp_mu_hat.squeeze(-1))).div_(xp_mu_hat.shape[0] - 1.)
  64.         if specs.variance_correction:
  65.           raise NotImplementedError
  66.         cov_err = (sigma - sigma_hat).pow(2).sum().sqrt()
  67.  
  68.         errs = {
  69.             'maska': maska_, 'maskr': maskr_, 'idx': idx,
  70.             'mu': to_numpy(mean_err), 'sigma': to_numpy(cov_err)
  71.         }
  72.         all_results.append(errs)
  73.  
  74.   testdf = pd.DataFrame(all_results)
  75.   # Saving data frame as csv
  76.   if not specs.dry_run:
  77.     testdf.to_csv(os.path.join(specs.results_dir, 'results.csv'))
  78.   return testdf
  79.  
  80.  
  81. if __name__ == '__main__':
  82.   data_dim = 3
  83.   means = [2., 4., 6.]
  84.   variances = [1., 1., 1.]
  85.   corrs = [0.5, 0.25, 0.]
  86.   all_masks = make_possible_masks(
  87.       data_dim, complementary_masks=True, no_overlap=True, dtype='float32')
  88.  
  89.   moms = compute_unconditional_moment(
  90.       data_dim, means=means, variances=variances, corrs=corrs)
  91.   cond_operators = make_conditional_operators(all_masks, moms['mu'],
  92.                                               moms['sigma'], moms['lam'])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement