Advertisement
weidai

Parity RNN

Jul 5th, 2015
511
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.17 KB | None | 0 0
  1. from __future__ import print_function
  2.  
  3. #!/usr/bin/env python
  4. """ A simple recurrent neural network that detects parity for arbitrary sequences. """
  5. """ Modified by Wei Dai from original PyBrain example """
  6.  
  7. __author__ = 'Tom Schaul (tom@idsia.ch)'
  8.  
  9. from pybrain.supervised.trainers.backprop import BackpropTrainer
  10. from pybrain.structure import RecurrentNetwork, LinearLayer, TanhLayer, BiasUnit, FullConnection
  11. from random import randint, seed
  12. from pybrain.datasets import SequentialDataSet
  13.  
  14. class ParityDataSet2(SequentialDataSet):
  15.     """ Determine whether the bitstring up to the current point conains a pair number of 1s or not."""
  16.     def __init__(self, n):
  17.         SequentialDataSet.__init__(self, 2,1)
  18.  
  19.         i = 0
  20.         y = 0
  21.         while i < 100:
  22.             if y == 0:
  23.                 self.newSequence()
  24.                 p = -1
  25.             x = randint(0,1)
  26.             if x == 1:
  27.                 p = -p
  28.             else:
  29.                 x = -1
  30.             y = randint(0,n)
  31.             if y == 0:
  32.                 self.addSample([1,x], [p])
  33.                 i = i+1
  34.             else:
  35.                 self.addSample([-1,x], [-1])
  36.  
  37. def buildParityNet():
  38.     net = RecurrentNetwork()
  39.     net.addInputModule(LinearLayer(2, name = 'i'))
  40.     net.addModule(TanhLayer(3, name = 'h'))
  41.     net.addModule(TanhLayer(2, name = '2'))
  42.     net.addModule(BiasUnit('bias'))
  43.     net.addOutputModule(TanhLayer(1, name = 'o'))
  44.     net.addConnection(FullConnection(net['i'], net['h'], inSliceFrom=1))
  45.     net.addConnection(FullConnection(net['i'], net['o'], inSliceTo=1))
  46.     net.addConnection(FullConnection(net['h'], net['2']))
  47.     net.addConnection(FullConnection(net['bias'], net['h']))
  48.     net.addConnection(FullConnection(net['bias'], net['2']))
  49.     net.addConnection(FullConnection(net['bias'], net['o']))
  50.     net.addConnection(FullConnection(net['2'], net['o']))
  51.     net.addRecurrentConnection(FullConnection(net['2'], net['h']))
  52.     net.sortModules()
  53.  
  54.     p = net.params
  55. #    p[:] = [-0.5, -1.5, 1, 1, -1, 1, 1, -1, 1]
  56.     p *= 10.
  57.  
  58.     return net
  59.  
  60. def evalRnnOnSeqDataset(net, DS, verbose = False, silent = False):
  61.     """ evaluate the network on all the sequences of a dataset. """
  62.     r = 0.
  63.     samples = 0.
  64.     for seq in DS:
  65.         net.reset()
  66.         for i, t in seq:
  67.             res = net.activate(i)
  68.             if verbose:
  69.                 print(t, res)
  70.             r += sum((t-res)**2)
  71.             samples += 1
  72.         if verbose:
  73.             print('-'*20)
  74.     r /= samples
  75.     if not silent:
  76.         print('MSE:', r)
  77.     return r
  78.  
  79. if __name__ == "__main__":
  80.     seed(1)
  81.     N = buildParityNet()
  82.     DS = ParityDataSet2(3)
  83.     DS2 = ParityDataSet2(10)
  84.  
  85.     # Backprop improves the network performance, and sometimes even finds the global optimum.
  86.     N.randomize()
  87.     N.params[:] = [x / 10 for x in N.params]
  88.     bp = BackpropTrainer(N, DS, verbose = True, learningrate=0.05, momentum=0.5, weightdecay=0.0005)
  89. #    do:
  90. #        bp.trainOnDataset(epochs=100, dataset=DS)
  91.     while bp.train() > 0.0007:
  92.         pass
  93.     evalRnnOnSeqDataset(N, DS2, verbose = True)
  94.     print('(backprop-trained weights)')
  95.     print('Final weights:', N.params)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement