SHARE
TWEET

Untitled

a guest Jan 23rd, 2019 75 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. @timeit
  2. def test(specs, vars_, nets, monitors, all_masks, cond_operators, **kwargs):
  3.   nets['proj'].eval()
  4.   all_masks_dict = defaultdict(set)
  5.   for maska, maskr in all_masks.tolist():
  6.     if sum(maska) != 0:
  7.       all_masks_dict[tuple(maskr)].add(tuple(maska))
  8.  
  9.   all_results = []
  10.   for maskr_ in all_masks_dict:
  11.     for maska_ in all_masks_dict[maskr_]:
  12.       for idx, x in enumerate(vars_['eval_data']):
  13.         # if specs.one_example_per_batch:
  14.         x = x.view(1, -1).repeat(specs.batch_size, 1)
  15.         # else:
  16.         #   raise NotImplementedError
  17.  
  18.         # if specs.one_mask_per_batch:
  19.         masksa = torch.Tensor(maska_).view(1,
  20.                                            -1).float().to(specs.device).repeat(
  21.                                                specs.batch_size, 1)
  22.         masksr = torch.Tensor(maskr_).view(1,
  23.                                            -1).float().to(specs.device).repeat(
  24.                                                specs.batch_size, 1)
  25.         # else:
  26.         #   raise NotImplementedError
  27.  
  28.         xa = x * masksa
  29.         xp, z, zmu, logvar = nets['proj'](xa, masksa, masksr)
  30.  
  31.         maska = masksa[0].long()
  32.         maskr = masksr[0].long()
  33.  
  34.         xp_ = xp[:, maskr.nonzero()]  # Batch of columns
  35.         xa_ = xa[:, maska.nonzero()]  # Batch of columns
  36.  
  37.         # Fetching conditional moments
  38.         cond_ops = cond_operators[maskr_][maska_]
  39.         mu_const = cond_ops['mu_const']
  40.         mu_reg = cond_ops['mu_reg']
  41.         sigma = cond_ops['sigma']
  42.  
  43.         # Conditional likelihood
  44.         # mu = mu_const + (xa_ @ mu_reg.t())
  45.         mu = mu_const + (mu_reg @ xa_)
  46.         # if specs.one_example_per_batch:
  47.         mean_err = (xp_.mean(dim=0) - mu.mean(dim=0)).pow(2).sum().sqrt()
  48.         # else:
  49.         #   raise NotImplementedError
  50.  
  51.         xp_mu_hat = xp_ - xp_.mean(dim=0, keepdim=True)
  52.         sigma_hat = (xp_mu_hat.squeeze(-1).t().matmul(
  53.             xp_mu_hat.squeeze(-1))).div_(xp_mu_hat.shape[0] - 1.)
  54.         if specs.variance_correction:
  55.           raise NotImplementedError
  56.         cov_err = (sigma - sigma_hat).pow(2).sum().sqrt()
  57.  
  58.         errs = {
  59.             'maska': maska_, 'maskr': maskr_, 'idx': idx,
  60.             'mu': to_numpy(mean_err), 'sigma': to_numpy(cov_err)
  61.         }
  62.         all_results.append(errs)
  63.  
  64.   testdf = pd.DataFrame(all_results)
  65.   # Saving data frame as csv
  66.   if not specs.dry_run:
  67.     testdf.to_csv(os.path.join(specs.results_dir, 'results.csv'))
  68.   return testdf
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top