SHARE
TWEET

Untitled




Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
- @timeit
- def test(specs, vars_, nets, monitors, all_masks, cond_operators, **kwargs):
- nets['proj'].eval()
- all_masks_dict = defaultdict(set)
- for maska, maskr in all_masks.tolist():
- if sum(maska) != 0:
- all_masks_dict[tuple(maskr)].add(tuple(maska))
- all_results = []
- for maskr_ in all_masks_dict:
- for maska_ in all_masks_dict[maskr_]:
- for idx, x in enumerate(vars_['eval_data']):
- # if specs.one_example_per_batch:
- x = x.view(1, -1).repeat(specs.batch_size, 1)
- # else:
- # raise NotImplementedError
- # if specs.one_mask_per_batch:
- masksa = torch.Tensor(maska_).view(1,
- -1).float().to(specs.device).repeat(
- specs.batch_size, 1)
- masksr = torch.Tensor(maskr_).view(1,
- -1).float().to(specs.device).repeat(
- specs.batch_size, 1)
- # else:
- # raise NotImplementedError
- xa = x * masksa
- xp, z, zmu, logvar = nets['proj'](xa, masksa, masksr)
- maska = masksa[0].long()
- maskr = masksr[0].long()
- xp_ = xp[:, maskr.nonzero()] # Batch of columns
- xa_ = xa[:, maska.nonzero()] # Batch of columns
- # Fetching conditional moments
- cond_ops = cond_operators[maskr_][maska_]
- mu_const = cond_ops['mu_const']
- mu_reg = cond_ops['mu_reg']
- sigma = cond_ops['sigma']
- # Conditional likelihood
- # mu = mu_const + (xa_ @ mu_reg.t())
- mu = mu_const + (mu_reg @ xa_)
- # if specs.one_example_per_batch:
- mean_err = (xp_.mean(dim=0) - mu.mean(dim=0)).pow(2).sum().sqrt()
- # else:
- # raise NotImplementedError
- xp_mu_hat = xp_ - xp_.mean(dim=0, keepdim=True)
- sigma_hat = (xp_mu_hat.squeeze(-1).t().matmul(
- xp_mu_hat.squeeze(-1))).div_(xp_mu_hat.shape[0] - 1.)
- if specs.variance_correction:
- raise NotImplementedError
- cov_err = (sigma - sigma_hat).pow(2).sum().sqrt()
- errs = {
- 'maska': maska_, 'maskr': maskr_, 'idx': idx,
- 'mu': to_numpy(mean_err), 'sigma': to_numpy(cov_err)
- }
- all_results.append(errs)
- testdf = pd.DataFrame(all_results)
- # Saving data frame as csv
- if not specs.dry_run:
- testdf.to_csv(os.path.join(specs.results_dir, 'results.csv'))
- return testdf
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.