daily pastebin goal
2%
SHARE
TWEET

Untitled

a guest Jan 23rd, 2019 64 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # -*- coding: utf-8 -*-
  2. from collections import defaultdict
  3. import numpy as np
  4. from pae.numerics.numpy import idx_, regularize_symmetric_matrix
  5.  
  6. # MEANS = [0., 0., 0.]
  7. # MEANS = [2., 4., 6.]
  8. EPS = 1e-5
  9.  
  10.  
  11. def compute_unconditional_moment(data_dim, means, variances, corrs, **kwargs):
  12.   def _floatify(iterable):
  13.     return list(map(float, iterable))
  14.   if data_dim == 3:
  15.     mu1, mu2, mu3 = _floatify(means)
  16.     mu = np.array([mu1, mu2, mu3])
  17.  
  18.     sigma1, sigma2, sigma3 = _floatify(variances)
  19.     rho12, rho13, rho23 = _floatify(corrs)
  20.  
  21.     sigma = np.array(
  22.           [
  23.             [sigma1, rho12, rho13],
  24.             [rho12, sigma2, rho23],
  25.             [rho13, rho23, sigma3]
  26.           ]
  27.       ).astype("float32")
  28.     sigma = regularize_symmetric_matrix(sigma, eps=EPS)
  29.     lam = np.linalg.pinv(sigma)
  30.     return {"mu": mu, "sigma": sigma, "lam": lam}
  31.   elif data_dim == 2:
  32.     raise NotImplementedError
  33.   else:
  34.     raise ValueError
  35.  
  36.  
  37. def partition_moments(maska, maskr, mu, sigma, lam):
  38.  
  39.   ar = ac = idx_(maska)
  40.   rr = rc = idx_(maskr)
  41.  
  42.   mu_a = mu[ar]
  43.   mu_r = mu[rr]
  44.  
  45.   sigma_aa = sigma[ar[:, None], ac]
  46.   sigma_ar = sigma[ar[:, None], rc]
  47.   sigma_ra = sigma[rr[:, None], ac]
  48.   sigma_rr = sigma[rr[:, None], rc]
  49.  
  50.   lam_aa = lam[ar[:, None], ac]
  51.   lam_ar = lam[ar[:, None], rc]
  52.   lam_ra = lam[rr[:, None], ac]
  53.   lam_rr = lam[rr[:, None], rc]
  54.  
  55.   return dict(
  56.     mu={"a": mu_a, "r": mu_r},
  57.     sigma={"aa": sigma_aa, "ar": sigma_ar, "ra": sigma_ra, "rr": sigma_rr},
  58.     lam={"aa": lam_aa, "ar": lam_ar, "ra": lam_ra, "rr": lam_rr},
  59.   )
  60.  
  61.  
  62. def compute_conditional_moments(xa, part_moms):
  63.   # mu_r_a = part_moms["mu"]["r"] - np.linalg.pinv(part_moms["lam"]["rr"]).dot(
  64.   #   part_moms["lam"]["ra"]
  65.   # ).dot(xa - part_moms["mu"]["a"])
  66.   # sigma_r_a = np.linalg.pinv(part_moms["lam"]["rr"])
  67.  
  68.   mu_r, mu_a = part_moms['mu']['r'], part_moms['mu']['a']
  69.   sigma_ar = part_moms['sigma']['ar']
  70.   sigma_rr = part_moms['sigma']['rr']
  71.   sigma_ra = part_moms['sigma']['ra']
  72.   sigma_aa = part_moms['sigma']['aa']
  73.  
  74.   sigma_aa = regularize_symmetric_matrix(sigma_aa, eps=EPS)
  75.   sigma_aa_inv = np.linalg.pinv(sigma_aa)
  76.   mu_r_a = mu_r + sigma_ra.dot(sigma_aa_inv).dot(xa - mu_a)
  77.   sigma_r_a = sigma_rr - sigma_ra.dot(sigma_aa_inv).dot(sigma_ar)
  78.  
  79.   return {"mu": {"ra": mu_r_a}, "sigma": {"ra": sigma_r_a}}
  80.  
  81.  
  82. def make_conditional_operators(all_masks, mu, sigma, lam):
  83.   all_masks_dict = defaultdict(set)
  84.   for maska, maskr in all_masks.tolist():
  85.     if sum(maska) != 0: #and sum(maskr) != 3:
  86.       all_masks_dict[tuple(maskr)].add(tuple(maska))
  87.   cond_operators = {}
  88.   for maskr, masksa in all_masks_dict.items():
  89.     cond_operators[maskr] = {}
  90.     for maska in masksa:
  91.       part_moms = partition_moments(maska, maskr, mu, sigma, lam)
  92.       mu_r, mu_a = part_moms['mu']['r'], part_moms['mu']['a']
  93.       sigma_ar = part_moms['sigma']['ar']
  94.       sigma_rr = part_moms['sigma']['rr']
  95.       sigma_ra = part_moms['sigma']['ra']
  96.       sigma_aa = part_moms['sigma']['aa']
  97.       sigma_aa = regularize_symmetric_matrix(sigma_aa, eps=EPS)
  98.       sigma_aa_inv = np.linalg.pinv(sigma_aa)
  99.  
  100.       # \Sigma_{R|A} = \Sigma_{RR} - \Sigma_{RA} \Sigma_{AA}^{-1} \Sigma_{AR}
  101.       sigma_r_a = sigma_rr - sigma_ra.dot(sigma_aa_inv).dot(sigma_ar)
  102.       sigma_r_a = regularize_symmetric_matrix(sigma_r_a, eps=EPS)
  103.  
  104.       _, sigma_r_a_log_det = np.linalg.slogdet(sigma_r_a)
  105.       sigma_r_a_inv = np.linalg.pinv(sigma_r_a)
  106.  
  107.       # \mu_{R|A} = \mu_R + \Sigma_{RA}\Sigma_{AA}^{-1}(x_A - mu_A)
  108.       # \mu_{R|A} = \mu_R - \Sigma_{RA}\Sigma_{AA}^{-1} \mu_A + \Sigma_{RA}\Sigma_{AA}^{-1}(x_A)
  109.       mu_r_a_const = mu_r - sigma_ra.dot(sigma_aa_inv).dot(mu_a)
  110.       mu_r_reg = sigma_ra.dot(sigma_aa_inv)
  111.       cond_operators[maskr][maska] = {
  112.         # Batch of columns
  113.         'mu_const': np.expand_dims(np.expand_dims(mu_r_a_const, axis=0), axis=-1),
  114.         # Batch of columns
  115.         'mu_reg': np.expand_dims(mu_r_reg, axis=0),
  116.         # Broadcast for batch of Matrix
  117.         'sigma': np.expand_dims(sigma_r_a, axis=0),
  118.         # Broadcast for a batch of matrix
  119.         'sigma_inv': np.expand_dims(sigma_r_a_inv, axis=0),
  120.         # Scalar
  121.         'sigma_log_det': sigma_r_a_log_det
  122.       }
  123.   return cond_operators
  124.  
  125.  
  126. if __name__ == '__main__':
  127.   pass
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top