Advertisement
Guest User

Untitled

a guest
Apr 9th, 2021
403
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. const IMAGE_SIZE = 784;
  2. const NUM_CLASSES = 10;
  3. const NUM_DATASET_ELEMENTS = 65000;
  4.  
  5. const NUM_TRAIN_ELEMENTS = 55000;
  6. const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
  7.  
  8. const MNIST_IMAGES_SPRITE_PATH =
  9.     'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
  10. const MNIST_LABELS_PATH =
  11.     'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
  12.  
  13. /**
  14.  * A class that fetches the sprited MNIST dataset and returns shuffled batches.
  15.  *
  16.  * NOTE: This will get much easier. For now, we do data fetching and
  17.  * manipulation manually.
  18.  */
  19.  
  20. export class MnistData {
  21.   constructor() {
  22.     this.shuffledTrainIndex = 0;
  23.     this.shuffledTestIndex = 0;
  24.   }
  25.  
  26.   async load() {
  27.     // Make a request for the MNIST sprited image.
  28.     const img = new Image();
  29.     const canvas = document.createElement('canvas');
  30.     const ctx = canvas.getContext('2d');
  31.     const imgRequest = new Promise((resolve, reject) => {
  32.       img.crossOrigin = '';
  33.       img.onload = () => {
  34.         img.width = img.naturalWidth;
  35.         img.height = img.naturalHeight;
  36.  
  37.         const datasetBytesBuffer =
  38.             new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
  39.  
  40.         const chunkSize = 5000;
  41.         canvas.width = img.width;
  42.         canvas.height = chunkSize;
  43.  
  44.         for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
  45.           const datasetBytesView = new Float32Array(
  46.               datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
  47.               IMAGE_SIZE * chunkSize);
  48.           ctx.drawImage(
  49.               img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
  50.               chunkSize);
  51.  
  52.           const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
  53.  
  54.           for (let j = 0; j < imageData.data.length / 4; j++) {
  55.             // All channels hold an equal value since the image is grayscale, so
  56.             // just read the red channel.
  57.             datasetBytesView[j] = imageData.data[j * 4] / 255;
  58.           }
  59.         }
  60.         this.datasetImages = new Float32Array(datasetBytesBuffer);
  61.  
  62.         resolve();
  63.       };
  64.       img.src = MNIST_IMAGES_SPRITE_PATH;
  65.     });
  66.  
  67.     const labelsRequest = fetch(MNIST_LABELS_PATH);
  68.     const [imgResponse, labelsResponse] =
  69.         await Promise.all([imgRequest, labelsRequest]);
  70.  
  71.     this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
  72.  
  73.     // Create shuffled indices into the train/test set for when we select a
  74.     // random dataset element for training / validation.
  75.     this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
  76.     this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
  77.  
  78.     // Slice the the images and labels into train and test sets.
  79.     this.trainImages =
  80.         this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
  81.     this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
  82.     this.trainLabels =
  83.         this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  84.     this.testLabels =
  85.         this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  86.   }
  87.  
  88.   nextTrainBatch(batchSize) {
  89.     return this.nextBatch(
  90.         batchSize, [this.trainImages, this.trainLabels], () => {
  91.           this.shuffledTrainIndex =
  92.               (this.shuffledTrainIndex + 1) % this.trainIndices.length;
  93.           return this.trainIndices[this.shuffledTrainIndex];
  94.         });
  95.   }
  96.  
  97.   nextTestBatch(batchSize) {
  98.     return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
  99.       this.shuffledTestIndex =
  100.           (this.shuffledTestIndex + 1) % this.testIndices.length;
  101.       return this.testIndices[this.shuffledTestIndex];
  102.     });
  103.   }
  104.  
  105.   nextBatch(batchSize, data, index) {
  106.     const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
  107.     const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
  108.  
  109.     for (let i = 0; i < batchSize; i++) {
  110.       const idx = index();
  111.  
  112.       const image =
  113.           data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
  114.       batchImagesArray.set(image, i * IMAGE_SIZE);
  115.  
  116.       const label =
  117.           data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
  118.       batchLabelsArray.set(label, i * NUM_CLASSES);
  119.     }
  120.  
  121.     const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
  122.     const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
  123.  
  124.     return {xs, labels};
  125.   }
  126. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement