Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- @timeit
- def test(specs, vars_, nets, monitors, all_masks, cond_operators, **kwargs):
- nets['proj'].eval()
- all_masks_dict = defaultdict(set)
- for maska, maskr in all_masks.tolist():
- if sum(maska) != 0:
- all_masks_dict[tuple(maskr)].add(tuple(maska))
- all_results = []
- for maskr_ in all_masks_dict:
- for maska_ in all_masks_dict[maskr_]:
- for idx, x in enumerate(vars_['eval_data']):
- # if specs.one_example_per_batch:
- x = x.view(1, -1).repeat(specs.batch_size, 1)
- # else:
- # raise NotImplementedError
- # if specs.one_mask_per_batch:
- masksa = torch.Tensor(maska_).view(1,
- -1).float().to(specs.device).repeat(
- specs.batch_size, 1)
- masksr = torch.Tensor(maskr_).view(1,
- -1).float().to(specs.device).repeat(
- specs.batch_size, 1)
- # else:
- # raise NotImplementedError
- xa = x * masksa
- xp, z, zmu, logvar = nets['proj'](xa, masksa, masksr)
- maska = masksa[0].long()
- maskr = masksr[0].long()
- xp_ = xp[:, maskr.nonzero()] # Batch of columns
- xa_ = xa[:, maska.nonzero()] # Batch of columns
- # Fetching conditional moments
- cond_ops = cond_operators[maskr_][maska_]
- mu_const = cond_ops['mu_const']
- mu_reg = cond_ops['mu_reg']
- sigma = cond_ops['sigma']
- # Conditional likelihood
- # mu = mu_const + (xa_ @ mu_reg.t())
- mu = mu_const + (mu_reg @ xa_)
- # if specs.one_example_per_batch:
- mean_err = (xp_.mean(dim=0) - mu.mean(dim=0)).pow(2).sum().sqrt()
- # else:
- # raise NotImplementedError
- xp_mu_hat = xp_ - xp_.mean(dim=0, keepdim=True)
- sigma_hat = (xp_mu_hat.squeeze(-1).t().matmul(
- xp_mu_hat.squeeze(-1))).div_(xp_mu_hat.shape[0] - 1.)
- if specs.variance_correction:
- raise NotImplementedError
- cov_err = (sigma - sigma_hat).pow(2).sum().sqrt()
- errs = {
- 'maska': maska_, 'maskr': maskr_, 'idx': idx,
- 'mu': to_numpy(mean_err), 'sigma': to_numpy(cov_err)
- }
- all_results.append(errs)
- testdf = pd.DataFrame(all_results)
- # Saving data frame as csv
- if not specs.dry_run:
- testdf.to_csv(os.path.join(specs.results_dir, 'results.csv'))
- return testdf
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement