Advertisement
Guest User

Untitled

a guest
Jan 23rd, 2019
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.47 KB | None | 0 0
  1. @timeit
  2. def test(specs, vars_, nets, monitors, all_masks, cond_operators, **kwargs):
  3.   nets['proj'].eval()
  4.   all_masks_dict = defaultdict(set)
  5.   for maska, maskr in all_masks.tolist():
  6.     if sum(maska) != 0:
  7.       all_masks_dict[tuple(maskr)].add(tuple(maska))
  8.  
  9.   all_results = []
  10.   for maskr_ in all_masks_dict:
  11.     for maska_ in all_masks_dict[maskr_]:
  12.       for idx, x in enumerate(vars_['eval_data']):
  13.         # if specs.one_example_per_batch:
  14.         x = x.view(1, -1).repeat(specs.batch_size, 1)
  15.         # else:
  16.         #   raise NotImplementedError
  17.  
  18.         # if specs.one_mask_per_batch:
  19.         masksa = torch.Tensor(maska_).view(1,
  20.                                            -1).float().to(specs.device).repeat(
  21.                                                specs.batch_size, 1)
  22.         masksr = torch.Tensor(maskr_).view(1,
  23.                                            -1).float().to(specs.device).repeat(
  24.                                                specs.batch_size, 1)
  25.         # else:
  26.         #   raise NotImplementedError
  27.  
  28.         xa = x * masksa
  29.         xp, z, zmu, logvar = nets['proj'](xa, masksa, masksr)
  30.  
  31.         maska = masksa[0].long()
  32.         maskr = masksr[0].long()
  33.  
  34.         xp_ = xp[:, maskr.nonzero()]  # Batch of columns
  35.         xa_ = xa[:, maska.nonzero()]  # Batch of columns
  36.  
  37.         # Fetching conditional moments
  38.         cond_ops = cond_operators[maskr_][maska_]
  39.         mu_const = cond_ops['mu_const']
  40.         mu_reg = cond_ops['mu_reg']
  41.         sigma = cond_ops['sigma']
  42.  
  43.         # Conditional likelihood
  44.         # mu = mu_const + (xa_ @ mu_reg.t())
  45.         mu = mu_const + (mu_reg @ xa_)
  46.         # if specs.one_example_per_batch:
  47.         mean_err = (xp_.mean(dim=0) - mu.mean(dim=0)).pow(2).sum().sqrt()
  48.         # else:
  49.         #   raise NotImplementedError
  50.  
  51.         xp_mu_hat = xp_ - xp_.mean(dim=0, keepdim=True)
  52.         sigma_hat = (xp_mu_hat.squeeze(-1).t().matmul(
  53.             xp_mu_hat.squeeze(-1))).div_(xp_mu_hat.shape[0] - 1.)
  54.         if specs.variance_correction:
  55.           raise NotImplementedError
  56.         cov_err = (sigma - sigma_hat).pow(2).sum().sqrt()
  57.  
  58.         errs = {
  59.             'maska': maska_, 'maskr': maskr_, 'idx': idx,
  60.             'mu': to_numpy(mean_err), 'sigma': to_numpy(cov_err)
  61.         }
  62.         all_results.append(errs)
  63.  
  64.   testdf = pd.DataFrame(all_results)
  65.   # Saving data frame as csv
  66.   if not specs.dry_run:
  67.     testdf.to_csv(os.path.join(specs.results_dir, 'results.csv'))
  68.   return testdf
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement