Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Threading.Tasks;
- namespace Classification
- {
- public class Knn
- {
- private int NeighborhoodSize { get; set; }
- private bool NormalizationOn { get; set; }
- private List<Func<double, double>> Normalization = new List<Func<double, double>>();
- private List<ClassifiedVector> TrainingSet;
- public Knn(int neighborhoodSize, bool normalizationOn)
- {
- NeighborhoodSize = neighborhoodSize;
- NormalizationOn = normalizationOn;
- }
- public void FeedWithTrainingSet(IEnumerable<ClassifiedVector> trainingSet)
- {
- int dimensions = trainingSet.First().Vector.DimensionCount;
- Normalization.Clear();
- if (NormalizationOn)
- {
- for (int i = 0; i < dimensions; ++i)
- {
- double avg = trainingSet.Average(p => p.Vector[i]);
- double dev = trainingSet.Select(p => p.Vector[i]).StandardDeviation();
- Normalization.Add(x => (x - avg) / dev);
- }
- TrainingSet = trainingSet.Select(x => new ClassifiedVector()
- {
- Vector = Normalize(x.Vector),
- Classification = x.Classification,
- }).ToList();
- }
- else
- {
- for(int i = 0; i < dimensions; ++i)
- {
- Normalization.Add(x => x);
- }
- TrainingSet = trainingSet.ToList(); // copy
- ;
- }
- }
- public int Classify(Vector arg)
- {
- Vector v = Normalize(arg);
- return TrainingSet
- .OrderBy(t => (t.Vector - v).LenghtSquared)
- .Take(NeighborhoodSize)
- .GroupBy(t => t.Classification)
- .OrderByDescending(g => g.Count())
- .ThenBy(g => g.Min(t => (t.Vector - v).LenghtSquared))
- .First().First().Classification;
- }
- private Vector Normalize(Vector arg)
- {
- List<double> normalizedValues = new List<double>();
- for (int i = 0; i < arg.DimensionCount; ++i)
- {
- normalizedValues.Add(Normalization[i](arg[i]));
- }
- return new Vector(normalizedValues);
- }
- public override string ToString()
- {
- return $"k={NeighborhoodSize}, normalized={NormalizationOn}";
- }
- }
- public struct ClassifiedVector
- {
- public Vector Vector;
- public int Classification;
- }
- public static class Extension
- {
- public static double StandardDeviation(this IEnumerable<double> arg)
- {
- double avg = arg.Average();
- double sum = arg.Sum(x => (x - avg) * (x - avg));
- return Math.Sqrt(sum / arg.Count());
- }
- }
- public class Vector
- {
- public int DimensionCount
- {
- get { return _data.Count; }
- }
- public double this[int idx]
- {
- get { return _data[idx]; }
- }
- private List<double> _data;
- public Vector(IEnumerable<double> values)
- {
- _data = values.ToList();
- }
- public Vector(double x, double y)
- {
- _data = new List<double>() { x, y };
- }
- public double LenghtSquared
- {
- get { return _data.Sum(x => x * x); }
- }
- public double Lenght
- {
- get { return Math.Sqrt(LenghtSquared); }
- }
- public static Vector operator +(Vector left, Vector right)
- {
- return new Vector(left._data.Zip(right._data, (a, b) => a + b));
- }
- public static Vector operator -(Vector left, Vector right)
- {
- return new Vector(left._data.Zip(right._data, (a, b) => a - b));
- }
- public static Vector operator *(Vector left, double right)
- {
- return new Vector(left._data.Select(x => x * right));
- }
- public static Vector operator *(double left, Vector right)
- {
- return right * left;
- }
- public static Vector operator /(Vector left, double right)
- {
- return new Vector(left._data.Select(x => x / right));
- }
- }
- }
- //------------------
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement