Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.awt.List;
- import java.util.ArrayList;
- import javax.validation.constraints.Max;
- import org.apache.spark.SparkConf;
- import org.apache.spark.api.java.JavaPairRDD;
- import org.apache.spark.api.java.JavaRDD;
- import org.apache.spark.api.java.JavaSparkContext;
- import org.apache.spark.api.java.function.Function2;
- import org.apache.spark.ml.feature.LabeledPoint;
- import org.apache.spark.ml.linalg.Vector;
- import org.apache.spark.ml.linalg.Vectors;
- import scala.Tuple2;
- public class KmeansImpl {
- private String inputPath;
- private String outPath;
- private Integer numCentroids;
- private Integer dimensions;
- public KmeansImpl(String inputPath, String outPath, Integer numCentroids, Integer dimensions) {
- super();
- this.inputPath = inputPath;
- this.outPath = outPath;
- this.numCentroids = numCentroids;
- this.dimensions = dimensions;
- String appName = "kmeans";
- SparkConf conf = new SparkConf().setAppName(appName);
- JavaSparkContext sc = new JavaSparkContext(conf);
- JavaRDD<String> data = sc.textFile(inputPath);
- JavaRDD<Vector> all_points = data.map(line -> {
- String[] sarray = line.split(" ");
- double[] values = new double[sarray.length];
- for (int i = 0; i < sarray.length; i++) {
- values[i] = Double.parseDouble(sarray[i]);
- }
- return Vectors.dense(values);
- }) ;
- ArrayList<Vector> centroids = (ArrayList<Vector>) all_points.take(numCentroids) ;
- ArrayList<Integer> setCounts = new ArrayList<Integer>(centroids.size() + 1) ;
- while(true) {
- JavaPairRDD<Integer,Vector> points = all_points.mapToPair(point ->{
- int centroidAssigned = -1 ;
- double minDistance = Integer.MAX_VALUE ;
- for(int i = 0 ; i < centroids.size() ; i++) {
- double currentDistance = distance(centroids.get(i), point);
- if(currentDistance < minDistance) {
- minDistance = currentDistance ;
- centroidAssigned = i ;
- }
- }
- setCounts.set(centroidAssigned, setCounts.get(centroidAssigned) + 1);
- return new Tuple2<Integer,Vector>(centroidAssigned, point);
- });
- JavaPairRDD<Integer,Vector> calculated_centroids = points.reduceByKey(new Function2<Vector, Vector, Vector>() {
- /**
- *
- */
- private static final long serialVersionUID = 1L;
- @Override
- public Vector call(Vector point1, Vector point2) throws Exception {
- // TODO Auto-generated method stub
- double[] total = new double[point1.size()];
- for(int i = 0 ; i < point1.size() ; i++) {
- total[i] += point1.apply(i) - point2.apply(i);
- }
- return Vectors.dense(total);
- }
- });
- centroids = (ArrayList<Vector>) calculated_centroids.values().collect();
- }
- sc.close();
- }
- public double distance(Vector v1, Vector v2) {
- double totalSquare = 0 ;
- for(int i = 0 ; i < v1.size() ; i++) {
- totalSquare += Math.pow(v1.apply(i) - v2.apply(i), 2);
- }
- return Math.sqrt(totalSquare) ;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement