4ever_bored

lasagne_test.py

Oct 6th, 2016
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.89 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2.  
  3. from __future__ import print_function
  4.  
  5. import time
  6.  
  7. import numpy as np
  8.  
  9. import theano
  10. import theano.tensor as T
  11.  
  12. import lasagne
  13.  
  14. import matplotlib.pyplot as plt
  15.  
  16.  
  17. def load_dataset():
  18.     np.random.seed(1337)
  19.     X_train = np.random.randn(10000, 16*16)
  20.     X_train = X_train.astype('float32')
  21.     Y_train = np.mean(X_train, 1)
  22.  
  23.     X_test = np.random.randn(1000, 16*16)
  24.     X_test = X_test.astype('float32')
  25.     Y_test = np.mean(X_test, 1)
  26.  
  27.     X_train = np.reshape(X_train, (10000, 1, 16, 16))
  28.     X_test = np.reshape(X_test, (1000, 1, 16, 16))
  29.  
  30.     return X_train, Y_train, X_test, Y_test
  31.  
  32.  
  33. def build_cnn(input_var=None):
  34.  
  35.     network = lasagne.layers.InputLayer(shape=(None, 1, 16, 16),
  36.                                         input_var=input_var)
  37.  
  38.     network = lasagne.layers.Conv2DLayer(
  39.             network, num_filters=20, filter_size=(3, 3),
  40.             nonlinearity=lasagne.nonlinearities.rectify,
  41.             W=lasagne.init.GlorotUniform())
  42.  
  43.     network = lasagne.layers.DenseLayer(
  44.             network,
  45.             num_units=10,
  46.             nonlinearity=lasagne.nonlinearities.sigmoid)
  47.  
  48.     network = lasagne.layers.DenseLayer(
  49.             network,
  50.             num_units=1,
  51.             nonlinearity=lasagne.nonlinearities.linear)
  52.  
  53.     return network
  54.  
  55.  
  56. def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
  57.     assert len(inputs) == len(targets)
  58.     if shuffle:
  59.         indices = np.arange(len(inputs))
  60.         np.random.shuffle(indices)
  61.     for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
  62.         if shuffle:
  63.             excerpt = indices[start_idx:start_idx + batchsize]
  64.         else:
  65.             excerpt = slice(start_idx, start_idx + batchsize)
  66.         yield inputs[excerpt], targets[excerpt]
  67.  
  68.  
  69. def main(model='cnn', num_epochs=10):
  70.  
  71.     print("Loading data...")
  72.     X_train, y_train, X_test, y_test = load_dataset()
  73.  
  74.     input_var = T.tensor4('inputs')
  75.     target_var = T.vector('targets')
  76.  
  77.     print("Building model and compiling functions...")
  78.     network = build_cnn(input_var)
  79.  
  80.  
  81.     prediction = lasagne.layers.get_output(network)
  82.     loss = lasagne.objectives.squared_error(prediction, target_var)
  83.     loss = loss.mean()
  84.  
  85.     params = lasagne.layers.get_all_params(network, trainable=True)
  86.     updates = lasagne.updates.nesterov_momentum(
  87.             loss, params, learning_rate=0.1, momentum=0.9)
  88. #    updates = lasagne.updates.adam(loss, params)
  89.  
  90.     test_prediction = lasagne.layers.get_output(network)
  91.     test_loss = lasagne.objectives.squared_error(test_prediction,
  92.                                                             target_var)
  93.     test_loss = test_loss.mean()
  94.  
  95.  
  96.     train_fn = theano.function([input_var, target_var], loss, updates=updates)
  97.  
  98.     val_fn = theano.function([input_var, target_var], test_loss)
  99.  
  100.     preds = theano.function([input_var], test_prediction)
  101.  
  102.     print("Starting training...")
  103.  
  104.     for epoch in range(num_epochs):
  105.  
  106.         train_err = 0.0
  107.         train_batches = 0
  108.         start_time = time.time()
  109.         for batch in iterate_minibatches(X_train, y_train, 500, shuffle=False):
  110.             inputs, targets = batch
  111.             train_err += train_fn(inputs, targets)
  112.             train_batches += 1
  113.  
  114.         test_err = 0.0
  115.         test_batches = 0
  116.         for batch in iterate_minibatches(X_test, y_test, 500, shuffle=False):
  117.             inputs, targets = batch
  118.             err = val_fn(inputs, targets)
  119.             test_err += err
  120.             test_batches += 1
  121.         print("Epoch {} of {} took {:.3f}s".format(
  122.             epoch + 1, num_epochs, time.time() - start_time))
  123.         print("  training loss:\t\t{:.6f}".format(train_err / train_batches))
  124.         print("  test loss:\t\t{:.6f}".format(test_err / test_batches))
  125.  
  126.     pds = preds(X_test)
  127.     plt.scatter(y_test, pds)
  128.     plt.show()
  129.  
  130.  
  131.  
  132. if __name__ == '__main__':
  133.  
  134.     main()
Add Comment
Please, Sign In to add comment