Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // https://www.reddit.com/r/learnrust/comments/crrp3g/beginner_feedback_request_dead_simple_kmeans/
- //! Compute k-means in a standard way. Original author
- //! <http://reddit.com/u/i_lack_discipline>. Rewritten
- //! by <http://reddit.com/u/po8> August 2019.
- use std::env;
- use std::fs::File;
- use std::io::prelude::*;
- use ordered_float::OrderedFloat;
- use rand::prelude::*;
- type Point = (f64, f64);
- /// The Euclidean distance between two points.
- fn euc_dist((x1, y1): Point, (x2, y2): Point) -> f64 {
- let dx = x2 - x1;
- let dy = y2 - y1;
- let radicand = dx * dx + dy * dy;
- radicand.sqrt()
- }
- /// Return a list of the nearest centroid index for each coord.
- fn assign_identities(
- coords: &[Point],
- centroids: &[Point],
- ) -> Vec<usize> {
- coords
- .iter()
- .map(|&coord| {
- centroids
- .iter()
- .enumerate()
- .min_by_key(|(_, ¢roid)| {
- OrderedFloat(euc_dist(coord, centroid))
- })
- .unwrap()
- .0
- })
- .collect()
- }
- /// Mean of input values. Must be at least one value.
- fn mean(vals: &[f64]) -> f64 {
- let sum: f64 = vals.iter().cloned().sum();
- sum / vals.len() as f64
- }
- /// Return the adjusted centroids given by the reassignment.
- fn move_centroids(
- coords: &[Point],
- centroids: &[Point],
- labels: &[usize],
- ) -> Vec<Point> {
- (0..centroids.len())
- .map(|cent_id| {
- let (x_vals, y_vals): (Vec<f64>, Vec<f64>) = coords
- .iter()
- .enumerate()
- .filter(|(coord_id, _)| labels[*coord_id] == cent_id)
- .map(|(_, &coord)| coord)
- .unzip();
- (mean(&x_vals), mean(&y_vals))
- })
- .collect()
- }
- /// K means description.
- struct KMeans<'a> {
- /// List of coordinates.
- #[allow(unused)]
- coords: &'a[Point],
- /// List of k means.
- #[allow(unused)]
- centroids: Vec<Point>,
- /// Centroid label for each coordinate.
- labels: Vec<usize>,
- }
- impl<'a> KMeans<'a> {
- /// Iteratively find good centroids for a collection of points.
- fn fit(coords: &'a [Point], k: usize) -> Self {
- let mut rng = rand::thread_rng();
- // initialize centroids using Forgy method - randomly choose
- // k samples and use these as the means
- let init_centroids: Vec<Point> =
- (0..k).map(|_| *coords.choose(&mut rng).unwrap()).collect();
- // init assignment
- let mut labels = assign_identities(coords, &init_centroids);
- // calculate new centroids
- let mut centroids =
- move_centroids(coords, &init_centroids, &labels);
- // repeat until stable
- loop {
- let prior_labels = labels;
- labels = assign_identities(coords, ¢roids);
- centroids = move_centroids(coords, ¢roids, &labels);
- if prior_labels == labels {
- return KMeans { coords, centroids, labels };
- }
- }
- }
- }
- /// Read a file of coordinates and print cluster data.
- fn main() {
- let args: Vec<String> = env::args().collect();
- let fp = &args[1];
- let k: usize = args[2].parse().unwrap_or_else(|why| {
- panic!("could not parse '{}' for k: {}", args[2], why);
- });
- let mut file = File::open(fp).unwrap_or_else(|why| {
- panic!("could not open {}: {}", fp, why);
- });
- let mut str_in = String::new();
- file.read_to_string(&mut str_in)
- .expect("could not read file");
- let coords: Vec<Point> = str_in
- .split('\n')
- .enumerate()
- .filter(|(_, line)| !line.is_empty())
- .map(|(lineno, line)| {
- let vals: Vec<&str> = line.split(',').collect();
- if vals.len() != 2 {
- panic!("line {}: invalid format", lineno + 1);
- }
- let x_val: f64 = vals[0].parse().unwrap_or_else(|why| {
- panic!(
- "line {}: could not parse '{}': {}",
- lineno + 1,
- vals[0],
- why
- );
- });
- let y_val: f64 = vals[1].parse().unwrap_or_else(|why| {
- panic!(
- "line {}: could not parse '{}': {}",
- lineno + 1,
- vals[0],
- why
- );
- });
- (x_val, y_val)
- })
- .collect();
- let model = KMeans::fit(&coords, k);
- for ident in model.labels.iter() {
- println!("{}", ident);
- }
- }
Add Comment
Please, Sign In to add comment