Advertisement
Guest User

Untitled

a guest
Apr 21st, 2019
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 3.03 KB | None | 0 0
  1. import java.awt.List;
  2. import java.util.ArrayList;
  3.  
  4. import javax.validation.constraints.Max;
  5.  
  6. import org.apache.spark.SparkConf;
  7. import org.apache.spark.api.java.JavaPairRDD;
  8. import org.apache.spark.api.java.JavaRDD;
  9. import org.apache.spark.api.java.JavaSparkContext;
  10. import org.apache.spark.api.java.function.Function2;
  11. import org.apache.spark.ml.feature.LabeledPoint;
  12. import org.apache.spark.ml.linalg.Vector;
  13. import org.apache.spark.ml.linalg.Vectors;
  14.  
  15. import scala.Tuple2;
  16.  
  17. public class KmeansImpl {
  18.    
  19.     private String inputPath;
  20.     private String outPath;
  21.     private Integer numCentroids;
  22.     private Integer dimensions;
  23.    
  24.     public KmeansImpl(String inputPath, String outPath, Integer numCentroids, Integer dimensions) {
  25.         super();
  26.         this.inputPath = inputPath;
  27.         this.outPath = outPath;
  28.         this.numCentroids = numCentroids;
  29.         this.dimensions = dimensions;
  30.        
  31.         String appName = "kmeans";
  32.         SparkConf conf = new SparkConf().setAppName(appName);
  33.         JavaSparkContext sc = new JavaSparkContext(conf);
  34.        
  35.         JavaRDD<String> data  = sc.textFile(inputPath);
  36.        
  37.         JavaRDD<Vector> all_points = data.map(line -> {
  38.             String[] sarray = line.split(" ");
  39.             double[] values = new double[sarray.length];
  40.             for (int i = 0; i < sarray.length; i++) {
  41.                 values[i] = Double.parseDouble(sarray[i]);
  42.             }
  43.             return Vectors.dense(values);      
  44.         }) ;
  45.        
  46.         ArrayList<Vector> centroids = (ArrayList<Vector>) all_points.take(numCentroids) ;
  47.         ArrayList<Integer> setCounts = new ArrayList<Integer>(centroids.size() + 1) ;
  48.        
  49.         while(true) {          
  50.                        
  51.             JavaPairRDD<Integer,Vector> points = all_points.mapToPair(point ->{
  52.                 int centroidAssigned = -1 ;
  53.                 double minDistance = Integer.MAX_VALUE ;
  54.                 for(int i = 0 ; i < centroids.size() ; i++) {
  55.                     double currentDistance = distance(centroids.get(i), point);  
  56.                     if(currentDistance < minDistance) {
  57.                         minDistance = currentDistance ;
  58.                         centroidAssigned = i ;
  59.                     }
  60.                 }
  61.                
  62.                 setCounts.set(centroidAssigned, setCounts.get(centroidAssigned) + 1);
  63.                
  64.                 return new Tuple2<Integer,Vector>(centroidAssigned, point);
  65.             });
  66.            
  67.            
  68.             JavaPairRDD<Integer,Vector> calculated_centroids = points.reduceByKey(new Function2<Vector, Vector, Vector>() {
  69.                 /**
  70.                  *
  71.                  */
  72.                 private static final long serialVersionUID = 1L;
  73.                
  74.                 @Override
  75.                 public Vector call(Vector point1, Vector point2) throws Exception {
  76.                     // TODO Auto-generated method stub
  77.                     double[] total = new double[point1.size()];
  78.                     for(int i = 0 ; i < point1.size() ; i++) {
  79.                         total[i] += point1.apply(i) - point2.apply(i);
  80.                     }
  81.                     return Vectors.dense(total);
  82.                 }
  83.             });
  84.            
  85.            
  86.             centroids = (ArrayList<Vector>) calculated_centroids.values().collect();
  87.            
  88.            
  89.            
  90.         }
  91.        
  92.        
  93.        
  94.        
  95.        
  96.        
  97.        
  98.        
  99.        
  100.        
  101.        
  102.        
  103.        
  104.         sc.close();
  105.        
  106.        
  107.     }
  108.    
  109.     public double distance(Vector v1, Vector v2) {
  110.        
  111.         double totalSquare = 0 ;
  112.        
  113.         for(int i = 0 ; i < v1.size() ; i++) {
  114.             totalSquare += Math.pow(v1.apply(i) - v2.apply(i), 2);
  115.         }
  116.        
  117.         return Math.sqrt(totalSquare) ;
  118.     }
  119. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement