Guest User

Untitled

a guest
Feb 19th, 2018
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.61 KB | None | 0 0
  1.  
  2. import java.util.ArrayList;
  3.  
  4. import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
  5. import org.apache.mahout.classifier.sgd.L1;
  6. import org.apache.mahout.classifier.sgd.L2;
  7. import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
  8. import org.apache.mahout.math.RandomAccessSparseVector;
  9. import org.apache.mahout.math.Vector;
  10. import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
  11. import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
  12. import com.tdunning.ch16.CategoryFeatureEncoder;
  13. import com.tdunning.ch16.Item;
  14.  
  15. public class PersonClassifier {
  16.  
  17. private static class Person {
  18.  
  19. public String sex;
  20. public double height;
  21. public double weight;
  22. public double footSize;
  23.  
  24. public int getSexCategoryNumber() {
  25. if (sex=="M")
  26. return 1;
  27. return 0;
  28. }
  29. }
  30.  
  31. private static class PersonEncoder {
  32. CategoryFeatureEncoder sex = new CategoryFeatureEncoder("sex");
  33. ContinuousValueEncoder height = new ContinuousValueEncoder("height");
  34. ContinuousValueEncoder weight = new ContinuousValueEncoder("weight");
  35. ContinuousValueEncoder footSize = new ContinuousValueEncoder("foot-size");
  36.  
  37. public void addToVector(Person x, Vector data) {
  38. //sex.addToVector(x.getSexCategoryNumber(), data);
  39. height.addToVector((byte[])null, x.height, data);
  40. weight.addToVector((byte[])null, x.weight, data);
  41. footSize.addToVector((byte[])null, x.footSize, data);
  42. }
  43. }
  44.  
  45. /**
  46. * @param args
  47. */
  48. public static void main(String[] args) {
  49.  
  50. try {
  51.  
  52. ArrayList<Person> males = new ArrayList<Person>();
  53. ArrayList<Person> females = new ArrayList<Person>();
  54.  
  55. // create the people
  56.  
  57. Person p1 = new Person();
  58. p1.sex = "M";
  59. p1.height = 6;
  60. p1.weight = 180;
  61. p1.footSize = 12;
  62. males.add(p1);
  63. p1 = new Person();
  64. p1.sex = "M";
  65. p1.height = 6;
  66. p1.weight = 180;
  67. p1.footSize = 12;
  68. males.add(p1);
  69. p1 = new Person();
  70. p1.sex = "M";
  71. p1.height = 6;
  72. p1.weight = 180;
  73. p1.footSize = 12;
  74. males.add(p1);
  75. p1 = new Person();
  76. p1.sex = "M";
  77. p1.height = 6;
  78. p1.weight = 180;
  79. p1.footSize = 12;
  80. males.add(p1);
  81. p1 = new Person();
  82. p1.sex = "M";
  83. p1.height = 6;
  84. p1.weight = 180;
  85. p1.footSize = 12;
  86. males.add(p1);
  87. p1 = new Person();
  88. p1.sex = "M";
  89. p1.height = 6;
  90. p1.weight = 180;
  91. p1.footSize = 12;
  92. males.add(p1);
  93.  
  94. Person p2 = new Person();
  95. p2.sex = "M";
  96. p2.height = 5.92;
  97. p2.weight = 190;
  98. p2.footSize = 11;
  99. males.add(p2);
  100.  
  101. Person p3 = new Person();
  102. p3.sex = "M";
  103. p3.height = 5.58;
  104. p3.weight = 170;
  105. p3.footSize = 12;
  106. males.add(p3);
  107.  
  108. Person p4 = new Person();
  109. p4.sex = "M";
  110. p4.height = 5.92;
  111. p4.weight = 165;
  112. p4.footSize = 10;
  113. males.add(p4);
  114.  
  115. Person p5 = new Person();
  116. p5.sex = "M";
  117. p5.height = 5;
  118. p5.weight = 100;
  119. p5.footSize = 6;
  120. females.add(p5);
  121.  
  122. Person p6 = new Person();
  123. p6.sex = "M";
  124. p6.height = 5.5;
  125. p6.weight = 150;
  126. p6.footSize = 8;
  127. females.add(p6);
  128.  
  129. Person p7 = new Person();
  130. p7.sex = "M";
  131. p7.height = 5.42;
  132. p7.weight = 130;
  133. p7.footSize = 7;
  134. females.add(p7);
  135.  
  136. Person p8 = new Person();
  137. p8.sex = "M";
  138. p8.height = 5.75;
  139. p8.weight = 120;
  140. p8.footSize = 9;
  141. females.add(p8);
  142.  
  143.  
  144.  
  145. // train with the people
  146. OnlineLogisticRegression model = new OnlineLogisticRegression(2,3, new L2(1));
  147.  
  148. //AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2,3, new L2(1));
  149.  
  150. int female = 0;
  151. int male = 1;
  152.  
  153. for (int i=0; i<1000; i++) {
  154. for(Person p: males) {
  155. Vector personVector = new RandomAccessSparseVector(model.numFeatures());
  156. PersonEncoder p1e = new PersonEncoder();
  157. p1e.addToVector(p, personVector);
  158. // train males
  159. model.train(male, personVector);
  160. }
  161. }
  162.  
  163. for (int i=0; i<1000; i++) {
  164. for(Person p: females) {
  165. Vector personVector = new RandomAccessSparseVector(model.numFeatures());
  166. PersonEncoder p1e = new PersonEncoder();
  167. p1e.addToVector(p, personVector);
  168. // train males
  169. model.train(female, personVector);
  170. }
  171. }
  172.  
  173. Person dunno = new Person();
  174. dunno.sex = "";
  175. dunno.height = 5;
  176. dunno.weight = 110;
  177. dunno.footSize = 5;
  178.  
  179. /*
  180. dunno.height = 8;
  181. dunno.weight = 180;
  182. dunno.footSize = 13;
  183. */
  184.  
  185. PersonEncoder dunnoPersonEncoder = new PersonEncoder();
  186.  
  187. Vector dunnoPersonVector = new RandomAccessSparseVector(model.numFeatures());
  188. dunnoPersonEncoder.addToVector(dunno, dunnoPersonVector);
  189.  
  190.  
  191. Vector result = model.classifyFull(dunnoPersonVector);
  192. System.out.print(result.toString());
  193.  
  194.  
  195. } catch (Throwable e) {
  196. e.printStackTrace();
  197. }
  198.  
  199. }
  200.  
  201. }
Add Comment
Please, Sign In to add comment