G2A Many GEOs
SHARE
TWEET

Parity RNN

weidai Jul 5th, 2015 262 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from __future__ import print_function
  2.  
  3. #!/usr/bin/env python
  4. """ A simple recurrent neural network that detects parity for arbitrary sequences. """
  5. """ Modified by Wei Dai from original PyBrain example """
  6.  
  7. __author__ = 'Tom Schaul (tom@idsia.ch)'
  8.  
  9. from pybrain.supervised.trainers.backprop import BackpropTrainer
  10. from pybrain.structure import RecurrentNetwork, LinearLayer, TanhLayer, BiasUnit, FullConnection
  11. from random import randint, seed
  12. from pybrain.datasets import SequentialDataSet
  13.  
  14. class ParityDataSet2(SequentialDataSet):
  15.     """ Determine whether the bitstring up to the current point conains a pair number of 1s or not."""
  16.     def __init__(self, n):
  17.         SequentialDataSet.__init__(self, 2,1)
  18.  
  19.         i = 0
  20.         y = 0
  21.         while i < 100:
  22.             if y == 0:
  23.                 self.newSequence()
  24.                 p = -1
  25.             x = randint(0,1)
  26.             if x == 1:
  27.                 p = -p
  28.             else:
  29.                 x = -1
  30.             y = randint(0,n)
  31.             if y == 0:
  32.                 self.addSample([1,x], [p])
  33.                 i = i+1
  34.             else:
  35.                 self.addSample([-1,x], [-1])
  36.  
  37. def buildParityNet():
  38.     net = RecurrentNetwork()
  39.     net.addInputModule(LinearLayer(2, name = 'i'))
  40.     net.addModule(TanhLayer(3, name = 'h'))
  41.     net.addModule(TanhLayer(2, name = '2'))
  42.     net.addModule(BiasUnit('bias'))
  43.     net.addOutputModule(TanhLayer(1, name = 'o'))
  44.     net.addConnection(FullConnection(net['i'], net['h'], inSliceFrom=1))
  45.     net.addConnection(FullConnection(net['i'], net['o'], inSliceTo=1))
  46.     net.addConnection(FullConnection(net['h'], net['2']))
  47.     net.addConnection(FullConnection(net['bias'], net['h']))
  48.     net.addConnection(FullConnection(net['bias'], net['2']))
  49.     net.addConnection(FullConnection(net['bias'], net['o']))
  50.     net.addConnection(FullConnection(net['2'], net['o']))
  51.     net.addRecurrentConnection(FullConnection(net['2'], net['h']))
  52.     net.sortModules()
  53.  
  54.     p = net.params
  55. #    p[:] = [-0.5, -1.5, 1, 1, -1, 1, 1, -1, 1]
  56.     p *= 10.
  57.  
  58.     return net
  59.  
  60. def evalRnnOnSeqDataset(net, DS, verbose = False, silent = False):
  61.     """ evaluate the network on all the sequences of a dataset. """
  62.     r = 0.
  63.     samples = 0.
  64.     for seq in DS:
  65.         net.reset()
  66.         for i, t in seq:
  67.             res = net.activate(i)
  68.             if verbose:
  69.                 print(t, res)
  70.             r += sum((t-res)**2)
  71.             samples += 1
  72.         if verbose:
  73.             print('-'*20)
  74.     r /= samples
  75.     if not silent:
  76.         print('MSE:', r)
  77.     return r
  78.  
  79. if __name__ == "__main__":
  80.     seed(1)
  81.     N = buildParityNet()
  82.     DS = ParityDataSet2(3)
  83.     DS2 = ParityDataSet2(10)
  84.  
  85.     # Backprop improves the network performance, and sometimes even finds the global optimum.
  86.     N.randomize()
  87.     N.params[:] = [x / 10 for x in N.params]
  88.     bp = BackpropTrainer(N, DS, verbose = True, learningrate=0.05, momentum=0.5, weightdecay=0.0005)
  89. #    do:
  90. #        bp.trainOnDataset(epochs=100, dataset=DS)
  91.     while bp.train() > 0.0007:
  92.         pass
  93.     evalRnnOnSeqDataset(N, DS2, verbose = True)
  94.     print('(backprop-trained weights)')
  95.     print('Final weights:', N.params)
RAW Paste Data
Ledger Nano X - The secure hardware wallet
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top