Want more features on Pastebin? Sign Up, it's FREE!
Guest

gz_nn_minibatch.py

By: a guest on Dec 30th, 2013  |  syntax: Python  |  size: 1.65 KB  |  views: 661  |  expires: Never
download  |  raw  |  embed  |  report abuse  |  print
Text below is selected. Please press Ctrl+C to copy to your clipboard. (⌘+C on Mac)
  1. from pybrain.datasets import ClassificationDataSet
  2. from pybrain.tools.shortcuts import buildNetwork
  3. from pybrain.supervised.trainers import BackpropTrainer
  4. from pybrain.structure.modules import SoftmaxLayer
  5. from sklearn import datasets
  6. import random
  7.  
  8. # load the data from sklearn
  9. iris = datasets.load_iris()
  10. X = iris['data']
  11. y = iris['target']
  12.  
  13. # set up the dataset. 4 input features, 3 output classes
  14. all_inds = range(X.shape[0])
  15.  
  16. # build the network
  17. fnn = buildNetwork(4, 10, 3, outclass=SoftmaxLayer, bias=True)
  18. trainer = BackpropTrainer(fnn, momentum=0.1, verbose=True, weightdecay=0.01, learningrate=0.01)
  19.  
  20. # repeat the batch training several times
  21. for i in xrange(200):
  22.     # get a random order for the training examples for batch gradient descent
  23.     random.shuffle(all_inds)
  24.     # split the indexes into lists with the indices for each batch
  25.     batch_inds = [all_inds[i:i+10] for i in xrange(0, len(all_inds), 10)]
  26.  
  27.     # train on each batch
  28.     for inds in batch_inds:
  29.         # rebuild the dataset
  30.         ds = ClassificationDataSet(4, nb_classes=3)
  31.         for x_i, y_i in zip(X[inds, :], y[inds]):
  32.             ds.appendLinked(x_i, y_i)
  33.         ds._convertToOneOfMany()
  34.         # train on the current batch
  35.         trainer.trainOnDataset(ds)
  36.  
  37. # make a dataset with all the iris data
  38. ds_all = ClassificationDataSet(4, nb_classes=3)
  39. for x_i, y_i in zip(X, y):
  40.     ds_all.appendLinked(x_i, y_i)
  41. ds_all._convertToOneOfMany()
  42.  
  43. # test the result
  44. # Note that we are testing on our training data, which is bad practice,
  45. # but it does demonstrate the network is trained
  46. print sum(fnn.activateOnDataset(ds_all).argmax(axis=1) == y)/float(len(y))
clone this paste RAW Paste Data