Advertisement
Guest User

Untitled

a guest
Jun 24th, 2019
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.20 KB | None | 0 0
  1. template <class _RealType = double>
  2. class mvnorm_distribution
  3. {
  4. public:
  5. // types
  6. typedef arma::mat result_type;
  7.  
  8. class param_type
  9. {
  10. size_t dims_;
  11. result_type means_;
  12. result_type covs_;
  13. public:
  14. typedef mvnorm_distribution distribution_type;
  15.  
  16. explicit param_type(arma::mat means, arma::mat covs)
  17. : means_(means), covs_(covs) {
  18. dims_ = means.n_rows;
  19. }
  20.  
  21. size_t ndims() const {return dims_;}
  22. result_type means() const {return means_;}
  23. result_type covs() const {return covs_;}
  24.  
  25. friend
  26. bool operator==(const param_type& x, const param_type& y)
  27. {return arma::approx_equal(x.means_, y.means_, "absdiff", 0.001)
  28. && arma::approx_equal(x.covs_, y.covs_, "absdiff", 0.001); }
  29. friend
  30. bool operator!=(const param_type& x, const param_type& y)
  31. {return !(x == y); }
  32. };
  33.  
  34. private:
  35. arma::gmm_full gf_model_;
  36. arma::gmm_diag gd_model_;
  37. param_type p_;
  38. result_type v_;
  39.  
  40. public:
  41. // constructor and reset functions
  42. explicit mvnorm_distribution(result_type means, result_type covs)
  43. : p_(param_type(means, covs)) {
  44. // check if it's diagonal, initiate the diag model
  45. gf_model_.reset(p_.ndims(), 1);
  46.  
  47. // I need to redesign the params...
  48. arma::mat m(p_.ndims(), 1);
  49. arma::cube c(p_.ndims(), p_.ndims(), 1);
  50. m.col(0) = p_.means();
  51. c.slice(0) = p_.covs();
  52.  
  53. // gf_model_.set_means(p_.means());
  54. // gf_model_.set_fcovs(p_.covs());
  55.  
  56. gf_model_.set_means(m);
  57. gf_model_.set_fcovs(c);
  58. }
  59.  
  60. explicit mvnorm_distribution(const param_type& p)
  61. : p_(p) {}
  62. void reset() {};
  63.  
  64. // generating functions
  65. template<class URNG>
  66. result_type operator()(URNG& g)
  67. {return (*this)(g, p_);}
  68. template<class URNG> result_type operator()(URNG& g, const param_type& parm);
  69.  
  70. // property functions
  71. result_type means() const {return p_.means();}
  72. result_type covs() const {return p_.covs();}
  73.  
  74. param_type param() const {return p_;};
  75. void param(const param_type& params) { p_ = params;}
  76.  
  77. result_type min() const {return -std::numeric_limits<_RealType>::infinity();}
  78. result_type max() const {return std::numeric_limits<_RealType>::infinity();}
  79.  
  80. friend bool operator==(const mvnorm_distribution& x,
  81. const mvnorm_distribution& y)
  82. {return x.p_ == y.p_;}
  83. friend bool operator!=(const mvnorm_distribution& x,
  84. const mvnorm_distribution& y)
  85. {return !(x == y);}
  86.  
  87. template <class charT, class traits>
  88. friend
  89. std::basic_ostream<charT, traits>&
  90. operator<<(std::basic_ostream<charT, traits>& os,
  91. const mvnorm_distribution& means);
  92.  
  93. template <class charT, class traits>
  94. friend
  95. std::basic_istream<charT, traits>&
  96. operator>>(std::basic_istream<charT, traits>& is,
  97. mvnorm_distribution& means);
  98.  
  99. };
  100.  
  101. template <class _RealType>
  102. template <class _URNG>
  103. mvnorm_distribution<double>::result_type
  104. mvnorm_distribution<_RealType>::operator()(_URNG &g, const mvnorm_distribution<_RealType>::param_type &parm) {
  105. return gf_model_.generate();
  106. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement