Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- template <class _RealType = double>
- class mvnorm_distribution
- {
- public:
- // types
- typedef arma::mat result_type;
- class param_type
- {
- size_t dims_;
- result_type means_;
- result_type covs_;
- public:
- typedef mvnorm_distribution distribution_type;
- explicit param_type(arma::mat means, arma::mat covs)
- : means_(means), covs_(covs) {
- dims_ = means.n_rows;
- }
- size_t ndims() const {return dims_;}
- result_type means() const {return means_;}
- result_type covs() const {return covs_;}
- friend
- bool operator==(const param_type& x, const param_type& y)
- {return arma::approx_equal(x.means_, y.means_, "absdiff", 0.001)
- && arma::approx_equal(x.covs_, y.covs_, "absdiff", 0.001); }
- friend
- bool operator!=(const param_type& x, const param_type& y)
- {return !(x == y); }
- };
- private:
- arma::gmm_full gf_model_;
- arma::gmm_diag gd_model_;
- param_type p_;
- result_type v_;
- public:
- // constructor and reset functions
- explicit mvnorm_distribution(result_type means, result_type covs)
- : p_(param_type(means, covs)) {
- // check if it's diagonal, initiate the diag model
- gf_model_.reset(p_.ndims(), 1);
- // I need to redesign the params...
- arma::mat m(p_.ndims(), 1);
- arma::cube c(p_.ndims(), p_.ndims(), 1);
- m.col(0) = p_.means();
- c.slice(0) = p_.covs();
- // gf_model_.set_means(p_.means());
- // gf_model_.set_fcovs(p_.covs());
- gf_model_.set_means(m);
- gf_model_.set_fcovs(c);
- }
- explicit mvnorm_distribution(const param_type& p)
- : p_(p) {}
- void reset() {};
- // generating functions
- template<class URNG>
- result_type operator()(URNG& g)
- {return (*this)(g, p_);}
- template<class URNG> result_type operator()(URNG& g, const param_type& parm);
- // property functions
- result_type means() const {return p_.means();}
- result_type covs() const {return p_.covs();}
- param_type param() const {return p_;};
- void param(const param_type& params) { p_ = params;}
- result_type min() const {return -std::numeric_limits<_RealType>::infinity();}
- result_type max() const {return std::numeric_limits<_RealType>::infinity();}
- friend bool operator==(const mvnorm_distribution& x,
- const mvnorm_distribution& y)
- {return x.p_ == y.p_;}
- friend bool operator!=(const mvnorm_distribution& x,
- const mvnorm_distribution& y)
- {return !(x == y);}
- template <class charT, class traits>
- friend
- std::basic_ostream<charT, traits>&
- operator<<(std::basic_ostream<charT, traits>& os,
- const mvnorm_distribution& means);
- template <class charT, class traits>
- friend
- std::basic_istream<charT, traits>&
- operator>>(std::basic_istream<charT, traits>& is,
- mvnorm_distribution& means);
- };
- template <class _RealType>
- template <class _URNG>
- mvnorm_distribution<double>::result_type
- mvnorm_distribution<_RealType>::operator()(_URNG &g, const mvnorm_distribution<_RealType>::param_type &parm) {
- return gf_model_.generate();
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement