Advertisement
Guest User

Untitled

a guest
Apr 28th, 2017
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.71 KB | None | 0 0
  1. import org.apache.spark.api.java.JavaPairRDD;
  2. import org.apache.spark.api.java.JavaRDD;
  3. import org.apache.spark.mllib.linalg.Vector;
  4. import org.apache.spark.mllib.linalg.Vectors;
  5. import org.apache.spark.mllib.linalg.distributed.*;
  6. import org.slf4j.Logger;
  7. import org.slf4j.LoggerFactory;
  8. import scala.Tuple2;
  9.  
  10. import java.io.Serializable;
  11. import java.util.*;
  12.  
  13.  
  14. public class SimilarityDimSum {
  15. private static final Logger LOGGER = LoggerFactory.getLogger(SimilarityDimSum.class);
  16.  
  17. public static class Relation implements Comparable<Relation>{
  18. private Double score;
  19. private String first;
  20. private String second;
  21.  
  22. public void swap(){
  23. new Relation(second, first, score);
  24. }
  25. public Relation(String first, String second, Double similarity){
  26. this.first = first;
  27. this.second = second;
  28. this.score = similarity;
  29. }
  30. public Relation(Long first, Long second, Double similarity){
  31. this.first = first.toString();
  32. this.second = second.toString();
  33. this.score = similarity;
  34. }
  35. public int compareTo(Relation other){
  36. if(!this.first.equals(other.first)){
  37. //TODO ERROR .. what to do??
  38. }
  39. return this.score.compareTo(other.score);
  40. }
  41.  
  42. public Double getScore() {
  43. return score;
  44. }
  45.  
  46. public void setScore(Double score) {
  47. this.score = score;
  48. }
  49.  
  50. public String getFirst() {
  51. return first;
  52. }
  53.  
  54. public void setFirst(String first) {
  55. this.first = first;
  56. }
  57.  
  58. public String getSecond() {
  59. return second;
  60. }
  61.  
  62. public void setSecond(String second) {
  63. this.second = second;
  64. }
  65.  
  66. @Override
  67. public String toString(){
  68. return String.format("[\"%s\",\"%s\",\"%s\"]", first, second, score.toString());
  69. }
  70. }
  71. public static JavaRDD<Relation> computeItemItemSimilarities(JavaRDD<DmRating> ratingJavaRDD, double threshold, int k){
  72. RowMatrix ratingsMatrix = getUserVectors(ratingJavaRDD);
  73. return computeSimilarities(ratingsMatrix, threshold, k);
  74. }
  75.  
  76. public static JavaRDD<Relation> computeUserUserSimilarities(JavaRDD<DmRating> ratingJavaRDD, double threshold, int k){
  77. RowMatrix ratingsMatrix = getItemVectors(ratingJavaRDD);
  78. return computeSimilarities(ratingsMatrix, threshold, k);
  79. }
  80. public static JavaRDD<Relation> computeSimilarities(RowMatrix ratingsMatrix, double threshold, int k){
  81. CoordinateMatrix similaritiesMatrix = ratingsMatrix.columnSimilarities(threshold);
  82. JavaRDD<MatrixEntry> entries = similaritiesMatrix.entries().toJavaRDD().cache();
  83. LOGGER.info("Number of entries: {}", entries.count());
  84.  
  85. JavaPairRDD<String, Relation> similaritiesRdd = entries.mapPartitionsToPair((Iterator<MatrixEntry> entriesIterator) -> {
  86. ArrayList<Tuple2<String,Relation>> list = new ArrayList<>();
  87. while(entriesIterator.hasNext()){
  88. MatrixEntry matrixEntry = entriesIterator.next();
  89. Relation sim = new Relation(matrixEntry.i(), matrixEntry.j(), matrixEntry.value());
  90. list.add(new Tuple2<>(sim.getFirst(), sim));
  91.  
  92. }
  93. return list;
  94. }).cache();
  95. LOGGER.info("Number of similarities: {}", similaritiesRdd.count());
  96. JavaPairRDD<String,Iterable<Relation>> grouped = similaritiesRdd.groupByKey().cache();
  97. JavaRDD<Relation> topSimilarities = grouped.flatMap((Tuple2<String, Iterable<Relation>> item) -> {
  98. Iterator<Relation> sim = item._2().iterator();
  99. ArrayList<Relation> simList = new ArrayList<>();
  100. while(sim.hasNext()){
  101. simList.add(sim.next());
  102. }
  103. return getTopSimilarities(simList, k);
  104. });
  105.  
  106. LOGGER.info("Number of top similarities: {}", topSimilarities.count());
  107. return topSimilarities;
  108. }
  109. public static double cosineSimilarity(double dotProduct, double norm1, double norm2){
  110. return dotProduct/(norm1 * norm2);
  111. }
  112.  
  113. public static List<Relation> getTopSimilarities(ArrayList<Relation> grouped, int k){
  114. Comparator<Relation> reverse = Collections.reverseOrder();
  115. grouped.sort(reverse);
  116. if(grouped.size()>k) {
  117. return grouped.subList(0, k);
  118. }else{
  119. return grouped;
  120. }
  121. }
  122.  
  123. public static class DmRatingComparator implements Comparator<Tuple2<Integer, Iterable<DmRating>>>, Serializable {
  124. @Override
  125. public int compare(Tuple2<Integer, Iterable<DmRating>> o1, Tuple2<Integer, Iterable<DmRating>> o2) {
  126. return o1._1().compareTo(o2._1());
  127. }
  128. }
  129. public static RowMatrix getUserVectors(JavaRDD<DmRating> ratingJavaRDD){
  130. JavaPairRDD<Integer, Iterable<DmRating>> groupedByItem = ratingJavaRDD.groupBy( rating -> rating.modVideoId);
  131. int maxItemId = groupedByItem.max(new DmRatingComparator())._1();
  132. JavaPairRDD<Integer, Iterable<DmRating>> groupedByUser = ratingJavaRDD.groupBy(rating -> rating.modUserId);
  133.  
  134. JavaRDD<Vector> vectorJavaRDD = groupedByUser.mapPartitions(itemIter ->{
  135. ArrayList<Vector> vectors = new ArrayList<>();
  136. while(itemIter.hasNext()){
  137. Tuple2<Integer, Iterable<DmRating>> item = itemIter.next();
  138. HashMap<Integer, Tuple2<Integer, Double>> videoIdRating = new HashMap<>();
  139. Iterator<DmRating> videos = item._2().iterator();
  140. while(videos.hasNext()){
  141. DmRating rating = videos.next();
  142. // make sure we only use one rating per video
  143. videoIdRating.put(rating.modVideoId, new Tuple2<>(rating.modVideoId, rating.rating));
  144. }
  145. vectors.add(Vectors.sparse(maxItemId+1, videoIdRating.values()));
  146. }
  147. return vectors;
  148. });
  149. return new RowMatrix(vectorJavaRDD.rdd());
  150. }
  151.  
  152. public static RowMatrix getItemVectors(JavaRDD<DmRating> ratingJavaRDD){
  153. JavaPairRDD<Integer, Iterable<DmRating>> groupedByUser = ratingJavaRDD.groupBy(rating -> rating.modUserId);
  154. int maxUserId = groupedByUser.max(new DmRatingComparator())._1();
  155. JavaPairRDD<Integer, Iterable<DmRating>> groupedByItem = ratingJavaRDD.groupBy(rating -> rating.modVideoId);
  156.  
  157. /*
  158. * videoID 1 : [ userID, userId, userId .. ]
  159. * videoID 2 : [ userID, userId, userId .. ]
  160. *
  161. * videoId1 videoId2
  162. * userId 1
  163. * userId 1
  164. */
  165. JavaRDD<Vector> vectorJavaRDD = groupedByItem.mapPartitions((Iterator<Tuple2<Integer, Iterable<DmRating>>> tuple2Iterator) -> {
  166. ArrayList<Vector> vectors = new ArrayList<>();
  167. while(tuple2Iterator.hasNext()){
  168. Tuple2<Integer, Iterable<DmRating>> item = tuple2Iterator.next();
  169. HashMap<Integer, Tuple2<Integer, Double>> userIdRating = new HashMap<>();
  170. Iterator<DmRating> videos = item._2().iterator();
  171. while(videos.hasNext()){
  172. DmRating rating = videos.next();
  173. // make sure we only use one rating per user
  174. userIdRating.put(rating.modUserId, new Tuple2<>(rating.modUserId, rating.rating));
  175. }
  176. vectors.add(Vectors.sparse(maxUserId+1, userIdRating.values()));
  177. }
  178. return vectors;
  179. });
  180. return new RowMatrix(vectorJavaRDD.rdd());
  181. }
  182.  
  183. public static JavaRDD<Relation> computeAlsFeatureSimilarity(JavaRDD<Tuple2<Object, double[]>> features, double threshold, int k){
  184. //TODO figure out how to do similarity computation.
  185. return computeAlsFeatureSimilarityDimSum(features, threshold, k);
  186. }
  187. public static JavaRDD<Relation> computeAlsFeatureSimilarityDimSum(JavaRDD<Tuple2<Object, double[]>> features, double threshold, int k){
  188. JavaRDD<MatrixEntry> alsRows = features.mapPartitions(ti2 -> {
  189. ArrayList<MatrixEntry> list = new ArrayList<>();
  190. while(ti2.hasNext()){
  191. Tuple2<Object, double[]> t2 = ti2.next();
  192. double[] vals = t2._2();
  193. Integer id = (Integer) t2._1();
  194. for(int i=0; i<vals.length; i++){
  195. list.add(new MatrixEntry(i, id, vals[i]));
  196. }
  197. }
  198. return list;
  199. });
  200. JavaRDD<MatrixEntry> cachedProductAlsRows = alsRows.cache();
  201. CoordinateMatrix alsCoordinateRowMatrix = new CoordinateMatrix(cachedProductAlsRows.rdd());
  202. RowMatrix alsRowMatrix = alsCoordinateRowMatrix.toRowMatrix();
  203. return computeSimilarities(alsRowMatrix, threshold, k);
  204. }
  205. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement