Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Threading.Tasks;
- using System.IO;
- namespace Kaggle_OCR_KNN
- {
- class Program
- {
- static void Main(string[] args)
- {
- string[] trainColumnLabels;
- int[] trainRowLabels;
- float[][] trainRows;
- string[] testColumnLabels;
- int[] testRowLabels;
- float[][] testRows;
- LoadTrainData(@"C:\Users\Admin\Desktop\stash\kagg\train.csv", out trainColumnLabels, out trainRowLabels, out trainRows);
- LoadTestData(@"C:\Users\Admin\Desktop\stash\kagg\test.csv", out testColumnLabels, out testRowLabels, out testRows);
- KNN(10, trainRowLabels, trainRows, out testRowLabels, testRows);
- SaveTestResults(@"C:\Users\Admin\Desktop\stash\kagg\KNN_K10.csv", testRowLabels);
- Console.WriteLine("done, press enter to exit");
- Console.ReadLine();
- }
- static void ReadFile(string filename, out string fileContents)
- {
- Console.WriteLine("reading " + filename);
- using (StreamReader streamReader = File.OpenText(filename))
- {
- fileContents = streamReader.ReadToEnd();
- }
- return;
- }
- static void ParseTrainFile(string trainFileContents, out string[] trainColumnLabels, out int[] trainRowLabels, out float[][] trainRows)
- {
- string[] trainFileContentsLines = trainFileContents.Split('\n');
- trainColumnLabels = trainFileContentsLines[0].Split(',');
- int totalRows = trainFileContentsLines.Length - 2; // less 1 for the headers and 1 for the empty line at the end (2 total)
- trainRowLabels = new int[totalRows];
- trainRows = new float[totalRows][];
- for (int rowIdx = 1; rowIdx < trainFileContentsLines.Length - 1; rowIdx++)
- {
- string[] trainFileContentsCurrentRowParts = trainFileContentsLines[rowIdx].Split(',');
- trainRowLabels[rowIdx - 1] = int.Parse(trainFileContentsCurrentRowParts[0]);
- int totalNonLabelColumns = trainFileContentsCurrentRowParts.Length - 1;
- trainRows[rowIdx - 1] = new float[totalNonLabelColumns];
- for (int columnIdx = 1; columnIdx < trainFileContentsCurrentRowParts.Length; columnIdx++)
- {
- trainRows[rowIdx - 1][columnIdx - 1] = float.Parse(trainFileContentsCurrentRowParts[columnIdx]);
- }
- }
- return;
- }
- static void ParseTestFile(string testFileContents, out string[] testColumnLabels, out int[] testRowLabels, out float[][] testRows)
- {
- string[] testFileContentsLines = testFileContents.Split('\n');
- testColumnLabels = testFileContentsLines[0].Split(',');
- int totalRows = testFileContentsLines.Length - 2; // less 1 for the headers and 1 for the empty line at the end (2 total)
- testRowLabels = new int[totalRows];
- testRows = new float[totalRows][];
- for (int rowIdx = 1; rowIdx < testFileContentsLines.Length - 1; rowIdx++)
- {
- string[] testFileContentsCurrentRowParts = testFileContentsLines[rowIdx].Split(',');
- testRowLabels[rowIdx - 1] = -1;
- int totalNonLabelColumns = testFileContentsCurrentRowParts.Length;
- testRows[rowIdx - 1] = new float[totalNonLabelColumns];
- for (int columnIdx = 0; columnIdx < testFileContentsCurrentRowParts.Length; columnIdx++)
- {
- testRows[rowIdx - 1][columnIdx] = float.Parse(testFileContentsCurrentRowParts[columnIdx]);
- }
- }
- return;
- }
- static void LoadTrainData(string filename, out string[] trainColumnLabels, out int[] trainRowLabels, out float[][] trainRows)
- {
- string trainFileContents;
- ReadFile(filename, out trainFileContents);
- ParseTrainFile(trainFileContents, out trainColumnLabels, out trainRowLabels, out trainRows);
- }
- static void LoadTestData(string filename, out string[] testColumnLabels, out int[] testRowLabels, out float[][] testRows)
- {
- string testFileContents;
- ReadFile(filename, out testFileContents);
- ParseTestFile(testFileContents, out testColumnLabels, out testRowLabels, out testRows);
- }
- static void Distance(float[] a, float[] b, out float distance)
- {
- distance = 0f;
- for (int componentIdx = 0; componentIdx < a.Length; componentIdx++)
- {
- float delta = b[componentIdx] - a[componentIdx];
- distance += delta * delta;
- }
- distance = (float)Math.Sqrt(distance);
- return;
- }
- static void FindNearestNeighboursLabels(int K, int[] trainRowLabels, float[][] trainRows, float[] testRow, out int[] nearestNeighboursLabels)
- {
- List<KeyValuePair<int, float>> nearestNeighbours = new List<KeyValuePair<int, float>>();
- for (int trainRowIdx = 0; trainRowIdx < trainRows.Length; trainRowIdx++)
- {
- float distance;
- Distance(trainRows[trainRowIdx], testRow, out distance);
- if (nearestNeighbours.Count == 0 || distance < nearestNeighbours.Last().Value)
- {
- nearestNeighbours.Add(new KeyValuePair<int, float>(trainRowLabels[trainRowIdx], distance));
- if (nearestNeighbours.Count > K)
- {
- nearestNeighbours.Sort((a, b) => { return a.Value.CompareTo(b.Value); });
- nearestNeighbours.RemoveAt(nearestNeighbours.Count - 1);
- }
- }
- }
- nearestNeighboursLabels = new int[K];
- for (int nearestNeighboursIdx = 0; nearestNeighboursIdx < nearestNeighbours.Count; nearestNeighboursIdx++)
- {
- nearestNeighboursLabels[nearestNeighboursIdx] = nearestNeighbours[nearestNeighboursIdx].Key;
- }
- return;
- }
- static void Mode(int[] nearestNeighboursLabels, out int modeNearestNeighbourLabel)
- {
- Dictionary<int, int> nearestNeighboursLabelsHistogram = new Dictionary<int, int>();
- for (int nearestNeighboursLabelsIdx = 0; nearestNeighboursLabelsIdx < nearestNeighboursLabels.Length; nearestNeighboursLabelsIdx++)
- {
- if (!nearestNeighboursLabelsHistogram.ContainsKey(nearestNeighboursLabels[nearestNeighboursLabelsIdx]))
- {
- nearestNeighboursLabelsHistogram[nearestNeighboursLabels[nearestNeighboursLabelsIdx]] = 1;
- }
- else
- {
- nearestNeighboursLabelsHistogram[nearestNeighboursLabels[nearestNeighboursLabelsIdx]]++;
- }
- }
- List<KeyValuePair<int, int>> nearestNeighboursLabelsHistogramList = nearestNeighboursLabelsHistogram.ToList();
- nearestNeighboursLabelsHistogramList.Sort((a, b) => { return -a.Value.CompareTo(b.Value); });
- modeNearestNeighbourLabel = nearestNeighboursLabelsHistogramList[0].Key;
- return;
- }
- static void KNN(int K, int[] trainRowLabels, float[][] trainRows, out int[] testRowLabels, float[][] testRows)
- {
- testRowLabels = new int[testRows.Length];
- for (int testRowIdx = 0; testRowIdx < testRows.Length; testRowIdx++)
- {
- int[] nearestNeighboursLabels;
- int modeNearestNeighbourLabel;
- FindNearestNeighboursLabels(K, trainRowLabels, trainRows, testRows[testRowIdx], out nearestNeighboursLabels);
- Mode(nearestNeighboursLabels, out modeNearestNeighbourLabel);
- testRowLabels[testRowIdx] = modeNearestNeighbourLabel;
- Console.Write(("\r" + (testRowIdx + 1) + "/" + testRows.Length).PadRight(10));
- }
- Console.WriteLine(("\r" + testRows.Length + "/" + testRows.Length).PadRight(10));
- return;
- }
- static void SaveTestResults(string filename, int[] testRowLabels)
- {
- Console.WriteLine("saving test results");
- using (StreamWriter streamWriter = File.CreateText(filename))
- {
- streamWriter.WriteLine("ImageId,Label");
- for (int testRowLabelsIdx = 0; testRowLabelsIdx < testRowLabels.Length; testRowLabelsIdx++)
- {
- streamWriter.WriteLine((testRowLabelsIdx + 1).ToString() + "," + testRowLabels[testRowLabelsIdx].ToString());
- }
- }
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement