Guest User

script.js

a guest
Apr 11th, 2021
157
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import {MnistData} from './data.js';
  2.  
  3. async function showExamples(data) {
  4.   // Create a container in the visor
  5.   const surface =
  6.     tfvis.visor().surface({ name: 'Input Data Examples', tab: 'Input Data'});  
  7.  
  8.   // Get the examples
  9.   const examples = data.nextTestBatch(20);
  10.   const numExamples = examples.xs.shape[0];
  11.  
  12.   // Create a canvas element to render each example
  13.   for (let i = 0; i < numExamples; i++) {
  14.     const imageTensor = tf.tidy(() => {
  15.       // Reshape the image to 28x28 px
  16.       return examples.xs
  17.         .slice([i, 0], [1, examples.xs.shape[1]])
  18.         .reshape([28, 28, 1]);
  19.     });
  20.    
  21.     const canvas = document.createElement('canvas');
  22.     canvas.width = 28;
  23.     canvas.height = 28;
  24.     canvas.style = 'margin: 4px;';
  25.     await tf.browser.toPixels(imageTensor, canvas);
  26.     surface.drawArea.appendChild(canvas);
  27.  
  28.     imageTensor.dispose();
  29.   }
  30. }
  31.  
  32. async function run() {  
  33.   const data = new MnistData();
  34.   await data.load();
  35.   //await showExamples(data);
  36.   const model = getModel();
  37.   tfvis.show.modelSummary({name: 'Model Architecture', tab: 'Model'}, model);
  38.  
  39.   await train(model, data);
  40. }
  41.  
  42. document.addEventListener('DOMContentLoaded', run);
  43.  
  44. function getModel() {
  45.     const model = tf.sequential();
  46.    
  47.     const IMAGE_WIDTH = 28;
  48.     const IMAGE_HEIGHT = 28;
  49.     const IMAGE_CHANNELS = 1;  
  50.    
  51.     // In the first layer of our convolutional neural network we have
  52.     // to specify the input shape. Then we specify some parameters for
  53.     // the convolution operation that takes place in this layer.
  54.     model.add(tf.layers.conv2d({inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],kernelSize: 5,filters: 8,strides: 1,activation: 'relu',kernelInitializer: 'varianceScaling'}));
  55.  
  56.     // The MaxPooling layer acts as a sort of downsampling using max values
  57.     // in a region instead of averaging.  
  58.     model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
  59.    
  60.     // Repeat another conv2d + maxPooling stack.
  61.     // Note that we have more filters in the convolution.
  62.     model.add(tf.layers.conv2d({kernelSize: 5,filters: 16,strides: 1,activation: 'relu',kernelInitializer: 'varianceScaling'
  63.     }));
  64.     model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
  65.    
  66.     // Now we flatten the output from the 2D filters into a 1D vector to prepare
  67.     // it for input into our last layer. This is common practice when feeding
  68.     // higher dimensional data to a final classification output layer.
  69.     model.add(tf.layers.flatten());
  70.  
  71.     // Our last layer is a dense layer which has 10 output units, one for each
  72.     // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
  73.     const NUM_OUTPUT_CLASSES = 10;
  74.     model.add(tf.layers.dense({units: NUM_OUTPUT_CLASSES,kernelInitializer: 'varianceScaling',activation: 'softmax'}));
  75.  
  76.    
  77.     // Choose an optimizer, loss function and accuracy metric,
  78.     // then compile and return the model
  79.     const optimizer = tf.train.adam();
  80.     model.compile({
  81.       optimizer: optimizer,
  82.       loss: 'categoricalCrossentropy',
  83.       metrics: ['accuracy'],
  84.     });
  85.  
  86.     return model;
  87.   }
  88.  
  89.   async function train(model, data) {
  90.     const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
  91.     const container = {
  92.       name: 'Model Training', tab: 'Model', styles: { height: '1000px' }
  93.     };
  94.     const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
  95.    
  96.     const BATCH_SIZE = 512;
  97.     const TRAIN_DATA_SIZE = 55000;
  98.     const TEST_DATA_SIZE = 10000;
  99.  
  100.     const [trainXs, trainYs] = tf.tidy(() => {
  101.       const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
  102.       return [
  103.         d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),d.labels];
  104.     });
  105.  
  106.     const [testXs, testYs] = tf.tidy(() => {
  107.       const d = data.nextTestBatch(TEST_DATA_SIZE);
  108.       return [
  109.         d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),d.labels];
  110.     });
  111.  
  112.     return model.fit(trainXs, trainYs, {
  113.       batchSize: BATCH_SIZE,
  114.       validationData: [testXs, testYs],
  115.       epochs: 10,
  116.       shuffle: true,
  117.       callbacks: fitCallbacks
  118.     });
  119.   }
  120.  
  121.   const classNames = ['Zero', 'One', 'Two', 'Three', 'Four', 'Five', 'Six', 'Seven', 'Eight', 'Nine'];
  122.  
  123. function doPrediction(model, data, testDataSize = 500) {
  124.   const IMAGE_WIDTH = 28;
  125.   const IMAGE_HEIGHT = 28;
  126.   const testData = data.nextTestBatch(testDataSize);
  127.   const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
  128.   const labels = testData.labels.argMax(-1);
  129.   const preds = model.predict(testxs).argMax(-1);
  130.   testxs.dispose();
  131.   return [preds, labels];
  132. }
  133.  
  134. async function showAccuracy(model, data) {
  135.   const [preds, labels] = doPrediction(model, data);
  136.   preds.print()
  137.   const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
  138.   const container = {name: 'Accuracy', tab: 'Evaluation'};
  139.   tfvis.show.perClassAccuracy(container, classAccuracy, classNames);
  140.   labels.dispose();
  141. }
  142.  
  143. async function showConfusion(model, data) {
  144.   const [preds, labels] = doPrediction(model, data);
  145.   const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
  146.   const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
  147.   tfvis.render.confusionMatrix(container, {values: confusionMatrix, tickLabels: classNames});
  148.  
  149.   labels.dispose();
  150. }
  151.  
  152. await showAccuracy(model, data);
  153.  
RAW Paste Data