Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- private Kernel getKernel (KernelType kType) {
- Kernel k = null;
- double width = 1.4;
- int cache_size = 10;
- switch (kType) {
- case GAUSSIAN:
- k = new GaussianKernel (10, width);
- break;
- case HIK:
- double beta = 1.4;
- HistogramIntersectionKernel hik = new HistogramIntersectionKernel (cache_size);
- hik.set_beta (beta);
- k = hik;
- break;
- case POLY:
- int degree = 1;
- k = new PolyKernel (cache_size, degree);
- break;
- case CHI2:
- k = new Chi2Kernel (cache_size, width);
- break;
- default:
- System.err.println ("invalid kernel type!");
- System.exit (1);
- }
- return k;
- }
- public void train (DoubleMatrix trainSet, Labels trainLabels) {
- double eps = 1e-5;
- double C = 1;
- RealFeatures trainFeats = new RealFeatures (trainSet);
- boolean slow = false;
- if (slow) {
- kernel = getKernel (kernelType);
- kernel.init (trainFeats, trainFeats);
- } else {
- kernel = new HistogramIntersectionKernel (trainFeats, trainFeats, 1.4);
- }
- ml = new LaRank (C, kernel, trainLabels);
- ((LaRank)ml).set_epsilon (eps);
- ml.train ();
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement