Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2012
34
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 2.87 KB | None | 0 0
  1. #ifndef __LATENTSOMODEL_H__
  2. #define __LATENTSOMODEL_H__
  3.  
  4. #include <shogun/latent/LatentModel.h>
  5. #include <shogun/structure/StructuredModel.h>
  6.  
  7. namespace shogun
  8. {
  9.     /** @brief TODO
  10.      */
  11.     class CLatentSOModel: public CLatentModel
  12.     {
  13.         public:
  14.             /** default ctor */
  15.             CLatentSOModel();
  16.  
  17.             /** ctor
  18.              *
  19.              * @param feats Latent features
  20.              * @param labels Latent labels
  21.              */
  22.             CLatentSOModel(CLatentFeatures* feats, CLatentLabels* labels);
  23.  
  24.             virtual ~CLatentSOModel();
  25.  
  26.             /**
  27.              * get the dimension of PSI
  28.              *
  29.              * @return dimension of features, i.e. psi vector
  30.              */
  31.             virtual int32_t get_dim() const=0;
  32.  
  33.             /**
  34.              * Calculate the PSI vector for a given sample
  35.              *
  36.              * @param idx index of the sample
  37.              *
  38.              * @return PSI vector
  39.              */
  40.             virtual CDotFeatures* get_psi_feature_vectors()=0;
  41.  
  42.             virtual SGVector<float64_t> get_psi_feature_vector(index_t feat_idx, CStructuredData* ybar, CData* hbar)=0;
  43.  
  44.             virtual CData* infer_latent_variable(const SGVector<float64_t>& w, index_t idx)=0;
  45.  
  46.             /** Finds the most violated constraint
  47.              * \f[
  48.              *  argmax_{(ybar,hbar)} [<w,psi(x,ybar,hbar)> + loss(y,ybar,hbar)]
  49.              * \f]
  50.              *
  51.              * @param idx
  52.              * @param
  53.              */
  54.             virtual void get_most_violated_constraint(const SGVector<float64_t>& w, index_t feat_idx, CStructuredData* ybar, CData* hbar)=0;
  55.  
  56.             virtual void argmax_h(const SGVector<float64_t>& w);
  57.            
  58.             virtual float64_t delta(CStructuredData* y1, CStructuredData* y2, CData* h) = 0;
  59.  
  60.             virtual const char* get_name() const { return "LatentModel"; }
  61.  
  62.             float64_t get_sum_argmax_h() const;
  63.  
  64.         private:
  65.             void register_parameters();
  66.  
  67.         private:
  68.             float64_t m_sum_argmax_h;
  69.     };
  70.  
  71.     class CLSOModel: public CStructuredModel
  72.     {
  73.         public:
  74.             CLSOModel();
  75.  
  76.             CLSOModel(CLatentSOModel* latent_model);
  77.  
  78.             virtual ~CLSOModel();
  79.  
  80.             virtual int32_t get_dim() const;
  81.  
  82.             virtual CResultSet* argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training = true);
  83.  
  84.             virtual const char* get_name() const { return "LSOModel"; }
  85.  
  86.            
  87.             /** Calculates the risk function for Latent Structural SVM
  88.              *
  89.              * \f[
  90.              *  \sum_{i=1}^n \left( \max_{(\hat{y},\hat{h}) \in YxH} \left[ \mathbf{w}
  91.              *  \cdot \Psi(x_i, \hat{y}, \hat{h})+\Delta(y_i, \hat{y}, \hat{h}) \right ] \right )
  92.              *  - \sum_{i=1}^n \mathbf{w} \cdot \Psi(x_i, y_i, h^*_i)
  93.              * \f]
  94.              *
  95.              * For more details see [1]
  96.              * [1] C.-N. J. Yu and T. Joachims,
  97.              *     "Learning structural SVMs with latent variables"
  98.              *     presented at the Proceedings of the 26th Annual International Conference on Machine Learning,
  99.              *     New York, NY, USA, 2009, pp. 1169-1176.
  100.              * http://www.cs.cornell.edu/~cnyu/papers/icml09_latentssvm.pdf
  101.              *
  102.              */
  103.             virtual float64_t risk(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0);
  104.         private:
  105.             CLatentSOModel* m_latent_model;
  106.     };
  107. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement