Advertisement
Guest User

Untitled

a guest
Apr 25th, 2015
206
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.72 KB | None | 0 0
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Threading.Tasks;
  6. using System.IO;
  7.  
  8. namespace Kaggle_OCR_KNN
  9. {
  10. class Program
  11. {
  12. static void Main(string[] args)
  13. {
  14. string[] trainColumnLabels;
  15. int[] trainRowLabels;
  16. float[][] trainRows;
  17. string[] testColumnLabels;
  18. int[] testRowLabels;
  19. float[][] testRows;
  20.  
  21. LoadTrainData(@"C:\Users\Admin\Desktop\stash\kagg\train.csv", out trainColumnLabels, out trainRowLabels, out trainRows);
  22. LoadTestData(@"C:\Users\Admin\Desktop\stash\kagg\test.csv", out testColumnLabels, out testRowLabels, out testRows);
  23. KNN(10, trainRowLabels, trainRows, out testRowLabels, testRows);
  24. SaveTestResults(@"C:\Users\Admin\Desktop\stash\kagg\KNN_K10.csv", testRowLabels);
  25.  
  26. Console.WriteLine("done, press enter to exit");
  27. Console.ReadLine();
  28. }
  29.  
  30. static void ReadFile(string filename, out string fileContents)
  31. {
  32. Console.WriteLine("reading " + filename);
  33. using (StreamReader streamReader = File.OpenText(filename))
  34. {
  35. fileContents = streamReader.ReadToEnd();
  36. }
  37. return;
  38. }
  39. static void ParseTrainFile(string trainFileContents, out string[] trainColumnLabels, out int[] trainRowLabels, out float[][] trainRows)
  40. {
  41. string[] trainFileContentsLines = trainFileContents.Split('\n');
  42. trainColumnLabels = trainFileContentsLines[0].Split(',');
  43. int totalRows = trainFileContentsLines.Length - 2; // less 1 for the headers and 1 for the empty line at the end (2 total)
  44. trainRowLabels = new int[totalRows];
  45. trainRows = new float[totalRows][];
  46. for (int rowIdx = 1; rowIdx < trainFileContentsLines.Length - 1; rowIdx++)
  47. {
  48. string[] trainFileContentsCurrentRowParts = trainFileContentsLines[rowIdx].Split(',');
  49. trainRowLabels[rowIdx - 1] = int.Parse(trainFileContentsCurrentRowParts[0]);
  50. int totalNonLabelColumns = trainFileContentsCurrentRowParts.Length - 1;
  51. trainRows[rowIdx - 1] = new float[totalNonLabelColumns];
  52. for (int columnIdx = 1; columnIdx < trainFileContentsCurrentRowParts.Length; columnIdx++)
  53. {
  54. trainRows[rowIdx - 1][columnIdx - 1] = float.Parse(trainFileContentsCurrentRowParts[columnIdx]);
  55. }
  56. }
  57. return;
  58. }
  59. static void ParseTestFile(string testFileContents, out string[] testColumnLabels, out int[] testRowLabels, out float[][] testRows)
  60. {
  61. string[] testFileContentsLines = testFileContents.Split('\n');
  62. testColumnLabels = testFileContentsLines[0].Split(',');
  63. int totalRows = testFileContentsLines.Length - 2; // less 1 for the headers and 1 for the empty line at the end (2 total)
  64. testRowLabels = new int[totalRows];
  65. testRows = new float[totalRows][];
  66. for (int rowIdx = 1; rowIdx < testFileContentsLines.Length - 1; rowIdx++)
  67. {
  68. string[] testFileContentsCurrentRowParts = testFileContentsLines[rowIdx].Split(',');
  69. testRowLabels[rowIdx - 1] = -1;
  70. int totalNonLabelColumns = testFileContentsCurrentRowParts.Length;
  71. testRows[rowIdx - 1] = new float[totalNonLabelColumns];
  72. for (int columnIdx = 0; columnIdx < testFileContentsCurrentRowParts.Length; columnIdx++)
  73. {
  74. testRows[rowIdx - 1][columnIdx] = float.Parse(testFileContentsCurrentRowParts[columnIdx]);
  75. }
  76. }
  77. return;
  78. }
  79. static void LoadTrainData(string filename, out string[] trainColumnLabels, out int[] trainRowLabels, out float[][] trainRows)
  80. {
  81. string trainFileContents;
  82. ReadFile(filename, out trainFileContents);
  83. ParseTrainFile(trainFileContents, out trainColumnLabels, out trainRowLabels, out trainRows);
  84. }
  85. static void LoadTestData(string filename, out string[] testColumnLabels, out int[] testRowLabels, out float[][] testRows)
  86. {
  87. string testFileContents;
  88. ReadFile(filename, out testFileContents);
  89. ParseTestFile(testFileContents, out testColumnLabels, out testRowLabels, out testRows);
  90. }
  91. static void Distance(float[] a, float[] b, out float distance)
  92. {
  93. distance = 0f;
  94. for (int componentIdx = 0; componentIdx < a.Length; componentIdx++)
  95. {
  96. float delta = b[componentIdx] - a[componentIdx];
  97. distance += delta * delta;
  98. }
  99. distance = (float)Math.Sqrt(distance);
  100. return;
  101. }
  102. static void FindNearestNeighboursLabels(int K, int[] trainRowLabels, float[][] trainRows, float[] testRow, out int[] nearestNeighboursLabels)
  103. {
  104. List<KeyValuePair<int, float>> nearestNeighbours = new List<KeyValuePair<int, float>>();
  105. for (int trainRowIdx = 0; trainRowIdx < trainRows.Length; trainRowIdx++)
  106. {
  107. float distance;
  108. Distance(trainRows[trainRowIdx], testRow, out distance);
  109. if (nearestNeighbours.Count == 0 || distance < nearestNeighbours.Last().Value)
  110. {
  111. nearestNeighbours.Add(new KeyValuePair<int, float>(trainRowLabels[trainRowIdx], distance));
  112. if (nearestNeighbours.Count > K)
  113. {
  114. nearestNeighbours.Sort((a, b) => { return a.Value.CompareTo(b.Value); });
  115. nearestNeighbours.RemoveAt(nearestNeighbours.Count - 1);
  116. }
  117. }
  118. }
  119. nearestNeighboursLabels = new int[K];
  120. for (int nearestNeighboursIdx = 0; nearestNeighboursIdx < nearestNeighbours.Count; nearestNeighboursIdx++)
  121. {
  122. nearestNeighboursLabels[nearestNeighboursIdx] = nearestNeighbours[nearestNeighboursIdx].Key;
  123. }
  124. return;
  125. }
  126. static void Mode(int[] nearestNeighboursLabels, out int modeNearestNeighbourLabel)
  127. {
  128. Dictionary<int, int> nearestNeighboursLabelsHistogram = new Dictionary<int, int>();
  129. for (int nearestNeighboursLabelsIdx = 0; nearestNeighboursLabelsIdx < nearestNeighboursLabels.Length; nearestNeighboursLabelsIdx++)
  130. {
  131. if (!nearestNeighboursLabelsHistogram.ContainsKey(nearestNeighboursLabels[nearestNeighboursLabelsIdx]))
  132. {
  133. nearestNeighboursLabelsHistogram[nearestNeighboursLabels[nearestNeighboursLabelsIdx]] = 1;
  134. }
  135. else
  136. {
  137. nearestNeighboursLabelsHistogram[nearestNeighboursLabels[nearestNeighboursLabelsIdx]]++;
  138. }
  139. }
  140. List<KeyValuePair<int, int>> nearestNeighboursLabelsHistogramList = nearestNeighboursLabelsHistogram.ToList();
  141. nearestNeighboursLabelsHistogramList.Sort((a, b) => { return -a.Value.CompareTo(b.Value); });
  142. modeNearestNeighbourLabel = nearestNeighboursLabelsHistogramList[0].Key;
  143. return;
  144. }
  145. static void KNN(int K, int[] trainRowLabels, float[][] trainRows, out int[] testRowLabels, float[][] testRows)
  146. {
  147. testRowLabels = new int[testRows.Length];
  148. for (int testRowIdx = 0; testRowIdx < testRows.Length; testRowIdx++)
  149. {
  150. int[] nearestNeighboursLabels;
  151. int modeNearestNeighbourLabel;
  152. FindNearestNeighboursLabels(K, trainRowLabels, trainRows, testRows[testRowIdx], out nearestNeighboursLabels);
  153. Mode(nearestNeighboursLabels, out modeNearestNeighbourLabel);
  154. testRowLabels[testRowIdx] = modeNearestNeighbourLabel;
  155. Console.Write(("\r" + (testRowIdx + 1) + "/" + testRows.Length).PadRight(10));
  156. }
  157. Console.WriteLine(("\r" + testRows.Length + "/" + testRows.Length).PadRight(10));
  158. return;
  159. }
  160. static void SaveTestResults(string filename, int[] testRowLabels)
  161. {
  162. Console.WriteLine("saving test results");
  163. using (StreamWriter streamWriter = File.CreateText(filename))
  164. {
  165. streamWriter.WriteLine("ImageId,Label");
  166. for (int testRowLabelsIdx = 0; testRowLabelsIdx < testRowLabels.Length; testRowLabelsIdx++)
  167. {
  168. streamWriter.WriteLine((testRowLabelsIdx + 1).ToString() + "," + testRowLabels[testRowLabelsIdx].ToString());
  169. }
  170. }
  171. }
  172. }
  173. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement