Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.util.ArrayList;
- import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
- import org.apache.mahout.classifier.sgd.L1;
- import org.apache.mahout.classifier.sgd.L2;
- import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
- import org.apache.mahout.math.RandomAccessSparseVector;
- import org.apache.mahout.math.Vector;
- import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
- import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
- import com.tdunning.ch16.CategoryFeatureEncoder;
- import com.tdunning.ch16.Item;
- public class PersonClassifier {
- private static class Person {
- public String sex;
- public double height;
- public double weight;
- public double footSize;
- public int getSexCategoryNumber() {
- if (sex=="M")
- return 1;
- return 0;
- }
- }
- private static class PersonEncoder {
- CategoryFeatureEncoder sex = new CategoryFeatureEncoder("sex");
- ContinuousValueEncoder height = new ContinuousValueEncoder("height");
- ContinuousValueEncoder weight = new ContinuousValueEncoder("weight");
- ContinuousValueEncoder footSize = new ContinuousValueEncoder("foot-size");
- public void addToVector(Person x, Vector data) {
- //sex.addToVector(x.getSexCategoryNumber(), data);
- height.addToVector((byte[])null, x.height, data);
- weight.addToVector((byte[])null, x.weight, data);
- footSize.addToVector((byte[])null, x.footSize, data);
- }
- }
- /**
- * @param args
- */
- public static void main(String[] args) {
- try {
- ArrayList<Person> males = new ArrayList<Person>();
- ArrayList<Person> females = new ArrayList<Person>();
- // create the people
- Person p1 = new Person();
- p1.sex = "M";
- p1.height = 6;
- p1.weight = 180;
- p1.footSize = 12;
- males.add(p1);
- p1 = new Person();
- p1.sex = "M";
- p1.height = 6;
- p1.weight = 180;
- p1.footSize = 12;
- males.add(p1);
- p1 = new Person();
- p1.sex = "M";
- p1.height = 6;
- p1.weight = 180;
- p1.footSize = 12;
- males.add(p1);
- p1 = new Person();
- p1.sex = "M";
- p1.height = 6;
- p1.weight = 180;
- p1.footSize = 12;
- males.add(p1);
- p1 = new Person();
- p1.sex = "M";
- p1.height = 6;
- p1.weight = 180;
- p1.footSize = 12;
- males.add(p1);
- p1 = new Person();
- p1.sex = "M";
- p1.height = 6;
- p1.weight = 180;
- p1.footSize = 12;
- males.add(p1);
- Person p2 = new Person();
- p2.sex = "M";
- p2.height = 5.92;
- p2.weight = 190;
- p2.footSize = 11;
- males.add(p2);
- Person p3 = new Person();
- p3.sex = "M";
- p3.height = 5.58;
- p3.weight = 170;
- p3.footSize = 12;
- males.add(p3);
- Person p4 = new Person();
- p4.sex = "M";
- p4.height = 5.92;
- p4.weight = 165;
- p4.footSize = 10;
- males.add(p4);
- Person p5 = new Person();
- p5.sex = "M";
- p5.height = 5;
- p5.weight = 100;
- p5.footSize = 6;
- females.add(p5);
- Person p6 = new Person();
- p6.sex = "M";
- p6.height = 5.5;
- p6.weight = 150;
- p6.footSize = 8;
- females.add(p6);
- Person p7 = new Person();
- p7.sex = "M";
- p7.height = 5.42;
- p7.weight = 130;
- p7.footSize = 7;
- females.add(p7);
- Person p8 = new Person();
- p8.sex = "M";
- p8.height = 5.75;
- p8.weight = 120;
- p8.footSize = 9;
- females.add(p8);
- // train with the people
- OnlineLogisticRegression model = new OnlineLogisticRegression(2,3, new L2(1));
- //AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2,3, new L2(1));
- int female = 0;
- int male = 1;
- for (int i=0; i<1000; i++) {
- for(Person p: males) {
- Vector personVector = new RandomAccessSparseVector(model.numFeatures());
- PersonEncoder p1e = new PersonEncoder();
- p1e.addToVector(p, personVector);
- // train males
- model.train(male, personVector);
- }
- }
- for (int i=0; i<1000; i++) {
- for(Person p: females) {
- Vector personVector = new RandomAccessSparseVector(model.numFeatures());
- PersonEncoder p1e = new PersonEncoder();
- p1e.addToVector(p, personVector);
- // train males
- model.train(female, personVector);
- }
- }
- Person dunno = new Person();
- dunno.sex = "";
- dunno.height = 5;
- dunno.weight = 110;
- dunno.footSize = 5;
- /*
- dunno.height = 8;
- dunno.weight = 180;
- dunno.footSize = 13;
- */
- PersonEncoder dunnoPersonEncoder = new PersonEncoder();
- Vector dunnoPersonVector = new RandomAccessSparseVector(model.numFeatures());
- dunnoPersonEncoder.addToVector(dunno, dunnoPersonVector);
- Vector result = model.classifyFull(dunnoPersonVector);
- System.out.print(result.toString());
- } catch (Throwable e) {
- e.printStackTrace();
- }
- }
- }
Add Comment
Please, Sign In to add comment