Advertisement
NLinker

mnist.py

Feb 16th, 2019
494
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.53 KB | None | 0 0
  1. import sys
  2. import os
  3. import time
  4.  
  5. import numpy as np
  6.  
  7. __doc__="""taken from https://github.com/Lasagne/Lasagne/blob/master/examples/mnist.py"""
  8.  
  9. def load_dataset():
  10.     # We first define a download function, supporting both Python 2 and 3.
  11.     if sys.version_info[0] == 2:
  12.         from urllib import urlretrieve
  13.     else:
  14.         from urllib.request import urlretrieve
  15.  
  16.     def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
  17.         print("Downloading %s" % filename)
  18.         urlretrieve(source + filename, filename)
  19.  
  20.     # We then define functions for loading MNIST images and labels.
  21.     # For convenience, they also download the requested files if needed.
  22.     import gzip
  23.  
  24.     def load_mnist_images(filename):
  25.         if not os.path.exists(filename):
  26.             download(filename)
  27.         # Read the inputs in Yann LeCun's binary format.
  28.         with gzip.open(filename, 'rb') as f:
  29.             data = np.frombuffer(f.read(), np.uint8, offset=16)
  30.         # The inputs are vectors now, we reshape them to monochrome 2D images,
  31.         # following the shape convention: (examples, channels, rows, columns)
  32.         data = data.reshape(-1, 1, 28, 28)
  33.         # The inputs come as bytes, we convert them to float32 in range [0,1].
  34.         # (Actually to range [0, 255/256], for compatibility to the version
  35.         # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
  36.         return (data / np.float32(256)).squeeze()
  37.  
  38.     def load_mnist_labels(filename):
  39.         if not os.path.exists(filename):
  40.             download(filename)
  41.         # Read the labels in Yann LeCun's binary format.
  42.         with gzip.open(filename, 'rb') as f:
  43.             data = np.frombuffer(f.read(), np.uint8, offset=8)
  44.         # The labels are vectors of integers now, that's exactly what we want.
  45.         return data
  46.  
  47.     # We can now download and read the training and test set images and labels.
  48.     X_train = load_mnist_images('train-images-idx3-ubyte.gz')
  49.     y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
  50.     X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
  51.     y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')
  52.  
  53.     # We reserve the last 10000 training examples for validation.
  54.     X_train, X_val = X_train[:-10000], X_train[-10000:]
  55.     y_train, y_val = y_train[:-10000], y_train[-10000:]
  56.  
  57.     # We just return all the arrays in order, as expected in main().
  58.     # (It doesn't matter how we do this as long as we can read them again.)
  59.     return X_train, y_train, X_val, y_val, X_test, y_test
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement