Guest User

Untitled

a guest
Feb 17th, 2019
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.32 KB | None | 0 0
  1. Make sure you have local versions of tensorflow.js available. p5 js is also useful for the continuous predictive draw loop
  2.  
  3. Put the following in your index.html file
  4.  
  5. ```html
  6. <html>
  7. <head>
  8. <meta name="viewport" width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=0>
  9. <style> body {padding: 0; margin: 0;} </style>
  10. <script src="tf.min.js"></script>
  11. <script src="p5.min.js"></script>
  12. <script src="main.js"></script>
  13. </head>
  14. <body>
  15. <canvas id="canvasId" style="display: none"></canvas>
  16. <video id="videoId" autoplay></video>
  17. <img id="imageId"/>
  18. <br>
  19. <button style="margin-top: 15px; margin-left: 15px" onclick="categorize('RED')">Red</button>
  20. <button style="margin-top: 15px; margin-left: 15px" onclick="categorize('GREEN')">Green</button>
  21. <button style="margin-top: 15px; margin-left: 15px" onclick="categorize('BLUE')">Blue</button>
  22. <br>
  23. <button style="margin-top: 15px; margin-left: 15px" onclick="trainModel()">Train Model</button>
  24. <br>
  25. <p style="margin-top: 15px; margin-left: 15px" id="paragraphId"></p>
  26. </body>
  27. </html>
  28. ```
  29.  
  30. Add the following to your main javascript file
  31.  
  32. ```javascript
  33. var categories = {
  34. RED: 0,
  35. GREEN: 1,
  36. BLUE: 2
  37. };
  38.  
  39. var categoriesInverse = {
  40. 0: 'RED',
  41. 1: 'GREEN',
  42. 2: 'BLUE'
  43. };
  44.  
  45. var data = [];
  46.  
  47. var model;
  48. var isTraining = false;
  49. var isTrained = false;
  50.  
  51. var imgWidth = 100;
  52. var imgHeight = 75;
  53.  
  54. var retainColor = true;
  55.  
  56. var predictionFrequency = 30;
  57.  
  58. function setup() {
  59. loadVideoPreview();
  60. }
  61.  
  62. function draw() {
  63. if (isTrained && !isTraining){
  64. if (frameCount % predictionFrequency === 0){
  65. predict();
  66. }
  67. }
  68. }
  69.  
  70. function loadVideoPreview () {
  71. var videoElement = document.querySelector("#videoId");
  72. var canvasElement = document.querySelector("#canvasId");
  73. var imageElement = document.querySelector("#imageId");
  74.  
  75. videoElement.width = imgWidth;
  76. canvasElement.width = imgWidth;
  77. imageElement.width = imgWidth;
  78. videoElement.height = imgHeight;
  79. canvasElement.height = imgHeight;
  80. imageElement.height = imgHeight;
  81.  
  82. navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia || navigator.oGetUserMedia;
  83.  
  84. if (navigator.getUserMedia) {
  85. navigator.getUserMedia({ video: true }, handleVideo, videoError);
  86. }
  87.  
  88. function handleVideo(stream) {
  89. videoElement.srcObject = stream
  90. videoStream = stream;
  91. }
  92.  
  93. function videoError(e) {
  94. console.log(e);
  95. }
  96. };
  97.  
  98. function captureImage() {
  99. var videoElement = document.querySelector("#videoId");
  100. var canvasElement = document.querySelector("#canvasId");
  101. var imageElement = document.querySelector("#imageId");
  102.  
  103. var canvasContext = canvasElement.getContext('2d');
  104.  
  105. canvasContext.drawImage(videoElement, 0, 0, canvasElement.width, canvasElement.height);
  106.  
  107. var image = canvasContext.getImageData(0, 0, canvasElement.width, canvasElement.height);
  108.  
  109. var pixels = [];
  110. if (retainColor){
  111. pixels = convertToRgb(image.data);
  112.  
  113. var px = 0;
  114. for (var i = 0; i < image.data.length; i += 4) {
  115. image.data[i+0] = pixels[px];
  116. image.data[i+1] = pixels[px+1];
  117. image.data[i+2] = pixels[px+2];
  118. image.data[i+3] = 255;
  119. px+=3;
  120. }
  121. canvasContext.putImageData(image, 0, 0);
  122. }
  123. else {
  124. pixels = convertToGrayscale(image.data);
  125.  
  126. var px = 0;
  127. for (var i = 0; i < image.data.length; i += 4) {
  128. image.data[i+0] = pixels[px];
  129. image.data[i+1] = pixels[px];
  130. image.data[i+2] = pixels[px];
  131. image.data[i+3] = 255;
  132. px++;
  133. }
  134. canvasContext.putImageData(image, 0, 0);
  135. }
  136.  
  137. var photo = canvasElement.toDataURL();
  138.  
  139. imageElement.src = photo;
  140.  
  141. return pixels;
  142. }
  143.  
  144. function categorize (categoryName) {
  145. var pixels = captureImage();
  146.  
  147. var datum = {
  148. Image: pixels,
  149. Category: categoryName
  150. }
  151.  
  152. data.push(datum);
  153. }
  154.  
  155. //Remove Alpha from color
  156. function convertToRgb(input) {
  157. var output = [];
  158. for (var i = 0; i < input.length; i+=4){
  159. output.push(input[i]);
  160. output.push(input[i+1]);
  161. output.push(input[i+2]);
  162. }
  163. return output;
  164. }
  165.  
  166. function convertToGrayscale(input) {
  167. var output = [];
  168. for (var i = 0; i < input.length; i+=4){
  169. var r = input[i];
  170. var g = input[i+1];
  171. var b = input[i+2];
  172.  
  173. output.push(floor((r + g + b) / 3));
  174. }
  175. return output;
  176. }
  177.  
  178. function buildModel() {
  179. let md = tf.sequential();
  180.  
  181. var colorSpace = retainColor ? 3 : 1;
  182.  
  183. const hidden = tf.layers.dense({
  184. units: 15,
  185. inputShape: [imgWidth * imgHeight * colorSpace],
  186. activation: 'sigmoid'
  187. });
  188.  
  189. const output = tf.layers.dense({
  190. units: 3,
  191. activation: 'softmax'
  192. });
  193. md.add(hidden);
  194. md.add(output);
  195.  
  196. const LEARNING_RATE = 0.25;
  197. const optimizer = tf.train.sgd(LEARNING_RATE);
  198.  
  199. md.compile({
  200. optimizer: optimizer,
  201. loss: 'categoricalCrossentropy',
  202. metrics: ['accuracy'],
  203. });
  204.  
  205. return md
  206. }
  207.  
  208. function trainModel() {
  209. if (data == null || data.length == 0){
  210. console.error('No Training Data Found');
  211. return;
  212. }
  213.  
  214. let images = [];
  215. let cats = [];
  216. for (let datum of data) {
  217. //Normalize color 0-255 to 0-1
  218. var normalized = datum.Image.map(function (p) { return p / 255 });
  219. images.push(normalized);
  220. cats.push(categories[datum.Category]);
  221. }
  222.  
  223. var xs = tf.tensor2d(images);
  224. let categoryTensor = tf.tensor1d(cats, 'int32');
  225.  
  226. var ys = tf.oneHot(categoryTensor, 3).cast('float32');
  227. categoryTensor.dispose();
  228.  
  229. model = buildModel();
  230.  
  231. isTraining = true;
  232.  
  233. model.fit(xs, ys, {
  234. shuffle: true,
  235. validationSplit: 0.1,
  236. epochs: 10,
  237. callbacks: {
  238. onEpochEnd: (epoch, logs) => {
  239. console.log('EPOCH: ' + epoch);
  240. },
  241. onBatchEnd: async (batch, logs) => {
  242. tf.nextFrame();
  243. },
  244. onTrainEnd: () => {
  245. isTraining = false;
  246. isTrained = true;
  247. console.log('finished');
  248. },
  249. },
  250. });
  251. }
  252.  
  253. function predict() {
  254. if (isTraining){
  255. console.log('Still Training');
  256. return;
  257. }
  258.  
  259. var pixels = captureImage();
  260.  
  261. //Normalize color 0-255 to 0-1
  262. var normalized = pixels.map(function (p) { return p / 255 });
  263.  
  264. var input = tf.tensor2d(normalized, [1, normalized.length]);
  265.  
  266. let results = model.predict(input);
  267.  
  268. let argMax = results.argMax(1);
  269. let index = argMax.dataSync()[0];
  270.  
  271. var prediction = categoriesInverse[index];
  272.  
  273. var paragraphElement = document.getElementById("paragraphId");
  274.  
  275. paragraphElement.innerText = 'I predict ' + prediction;
  276. }
  277. ```
Add Comment
Please, Sign In to add comment