PO8

kmeans.rs

PO8
Aug 18th, 2019
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 4.65 KB | None | 0 0
  1. // https://www.reddit.com/r/learnrust/comments/crrp3g/beginner_feedback_request_dead_simple_kmeans/
  2.  
  3. //! Compute k-means in a standard way. Original author
  4. //! <http://reddit.com/u/i_lack_discipline>. Rewritten
  5. //! by <http://reddit.com/u/po8> August 2019.
  6.  
  7. use std::env;
  8. use std::fs::File;
  9. use std::io::prelude::*;
  10.  
  11. use ordered_float::OrderedFloat;
  12. use rand::prelude::*;
  13.  
  14. type Point = (f64, f64);
  15.  
  16. /// The Euclidean distance between two points.
  17. fn euc_dist((x1, y1): Point, (x2, y2): Point) -> f64 {
  18.     let dx = x2 - x1;
  19.     let dy = y2 - y1;
  20.     let radicand = dx * dx + dy * dy;
  21.     radicand.sqrt()
  22. }
  23.  
  24. /// Return a list of the nearest centroid index for each coord.
  25. fn assign_identities(
  26.     coords: &[Point],
  27.     centroids: &[Point],
  28. ) -> Vec<usize> {
  29.     coords
  30.         .iter()
  31.         .map(|&coord| {
  32.             centroids
  33.                 .iter()
  34.                 .enumerate()
  35.                 .min_by_key(|(_, &centroid)| {
  36.                     OrderedFloat(euc_dist(coord, centroid))
  37.                 })
  38.                 .unwrap()
  39.                 .0
  40.         })
  41.         .collect()
  42. }
  43.  
  44. /// Mean of input values. Must be at least one value.
  45. fn mean(vals: &[f64]) -> f64 {
  46.     let sum: f64 = vals.iter().cloned().sum();
  47.     sum / vals.len() as f64
  48. }
  49.  
  50. /// Return the adjusted centroids given by the reassignment.
  51. fn move_centroids(
  52.     coords: &[Point],
  53.     centroids: &[Point],
  54.     labels: &[usize],
  55. ) -> Vec<Point> {
  56.     (0..centroids.len())
  57.         .map(|cent_id| {
  58.             let (x_vals, y_vals): (Vec<f64>, Vec<f64>) = coords
  59.                 .iter()
  60.                 .enumerate()
  61.                 .filter(|(coord_id, _)| labels[*coord_id] == cent_id)
  62.                 .map(|(_, &coord)| coord)
  63.                 .unzip();
  64.             (mean(&x_vals), mean(&y_vals))
  65.         })
  66.         .collect()
  67. }
  68.  
  69. /// K means description.
  70. struct KMeans<'a> {
  71.    /// List of coordinates.
  72.    #[allow(unused)]
  73.    coords: &'a[Point],
  74.     /// List of k means.
  75.     #[allow(unused)]
  76.     centroids: Vec<Point>,
  77.     /// Centroid label for each coordinate.
  78.     labels: Vec<usize>,
  79. }
  80.  
  81. impl<'a> KMeans<'a> {
  82.     /// Iteratively find good centroids for a collection of points.
  83.     fn fit(coords: &'a [Point], k: usize) -> Self {
  84.        let mut rng = rand::thread_rng();
  85.  
  86.        // initialize centroids using Forgy method - randomly choose
  87.        // k samples and use these as the means
  88.        let init_centroids: Vec<Point> =
  89.            (0..k).map(|_| *coords.choose(&mut rng).unwrap()).collect();
  90.  
  91.        // init assignment
  92.        let mut labels = assign_identities(coords, &init_centroids);
  93.  
  94.        // calculate new centroids
  95.        let mut centroids =
  96.            move_centroids(coords, &init_centroids, &labels);
  97.  
  98.        // repeat until stable
  99.        loop {
  100.            let prior_labels = labels;
  101.            labels = assign_identities(coords, &centroids);
  102.            centroids = move_centroids(coords, &centroids, &labels);
  103.  
  104.            if prior_labels == labels {
  105.                return KMeans { coords, centroids, labels };
  106.            }
  107.        }
  108.    }
  109. }
  110.  
  111. /// Read a file of coordinates and print cluster data.
  112. fn main() {
  113.    let args: Vec<String> = env::args().collect();
  114.    let fp = &args[1];
  115.    let k: usize = args[2].parse().unwrap_or_else(|why| {
  116.        panic!("could not parse '{}' for k: {}", args[2], why);
  117.    });
  118.  
  119.    let mut file = File::open(fp).unwrap_or_else(|why| {
  120.        panic!("could not open {}: {}", fp, why);
  121.    });
  122.  
  123.    let mut str_in = String::new();
  124.    file.read_to_string(&mut str_in)
  125.        .expect("could not read file");
  126.  
  127.    let coords: Vec<Point> = str_in
  128.        .split('\n')
  129.        .enumerate()
  130.        .filter(|(_, line)| !line.is_empty())
  131.        .map(|(lineno, line)| {
  132.            let vals: Vec<&str> = line.split(',').collect();
  133.            if vals.len() != 2 {
  134.                panic!("line {}: invalid format", lineno + 1);
  135.            }
  136.  
  137.            let x_val: f64 = vals[0].parse().unwrap_or_else(|why| {
  138.                panic!(
  139.                    "line {}: could not parse '{}': {}",
  140.                    lineno + 1,
  141.                    vals[0],
  142.                    why
  143.                );
  144.            });
  145.            let y_val: f64 = vals[1].parse().unwrap_or_else(|why| {
  146.                panic!(
  147.                    "line {}: could not parse '{}': {}",
  148.                    lineno + 1,
  149.                    vals[0],
  150.                    why
  151.                );
  152.            });
  153.            (x_val, y_val)
  154.        })
  155.        .collect();
  156.  
  157.    let model = KMeans::fit(&coords, k);
  158.  
  159.    for ident in model.labels.iter() {
  160.        println!("{}", ident);
  161.    }
  162. }
Add Comment
Please, Sign In to add comment