Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import print_function
- #!/usr/bin/env python
- """ A simple recurrent neural network that detects parity for arbitrary sequences. """
- """ Modified by Wei Dai from original PyBrain example """
- __author__ = 'Tom Schaul (tom@idsia.ch)'
- from pybrain.supervised.trainers.backprop import BackpropTrainer
- from pybrain.structure import RecurrentNetwork, LinearLayer, TanhLayer, BiasUnit, FullConnection
- from random import randint, seed
- from pybrain.datasets import SequentialDataSet
- class ParityDataSet2(SequentialDataSet):
- """ Determine whether the bitstring up to the current point conains a pair number of 1s or not."""
- def __init__(self, n):
- SequentialDataSet.__init__(self, 2,1)
- i = 0
- y = 0
- while i < 100:
- if y == 0:
- self.newSequence()
- p = -1
- x = randint(0,1)
- if x == 1:
- p = -p
- else:
- x = -1
- y = randint(0,n)
- if y == 0:
- self.addSample([1,x], [p])
- i = i+1
- else:
- self.addSample([-1,x], [-1])
- def buildParityNet():
- net = RecurrentNetwork()
- net.addInputModule(LinearLayer(2, name = 'i'))
- net.addModule(TanhLayer(3, name = 'h'))
- net.addModule(TanhLayer(2, name = '2'))
- net.addModule(BiasUnit('bias'))
- net.addOutputModule(TanhLayer(1, name = 'o'))
- net.addConnection(FullConnection(net['i'], net['h'], inSliceFrom=1))
- net.addConnection(FullConnection(net['i'], net['o'], inSliceTo=1))
- net.addConnection(FullConnection(net['h'], net['2']))
- net.addConnection(FullConnection(net['bias'], net['h']))
- net.addConnection(FullConnection(net['bias'], net['2']))
- net.addConnection(FullConnection(net['bias'], net['o']))
- net.addConnection(FullConnection(net['2'], net['o']))
- net.addRecurrentConnection(FullConnection(net['2'], net['h']))
- net.sortModules()
- p = net.params
- # p[:] = [-0.5, -1.5, 1, 1, -1, 1, 1, -1, 1]
- p *= 10.
- return net
- def evalRnnOnSeqDataset(net, DS, verbose = False, silent = False):
- """ evaluate the network on all the sequences of a dataset. """
- r = 0.
- samples = 0.
- for seq in DS:
- net.reset()
- for i, t in seq:
- res = net.activate(i)
- if verbose:
- print(t, res)
- r += sum((t-res)**2)
- samples += 1
- if verbose:
- print('-'*20)
- r /= samples
- if not silent:
- print('MSE:', r)
- return r
- if __name__ == "__main__":
- seed(1)
- N = buildParityNet()
- DS = ParityDataSet2(3)
- DS2 = ParityDataSet2(10)
- # Backprop improves the network performance, and sometimes even finds the global optimum.
- N.randomize()
- N.params[:] = [x / 10 for x in N.params]
- bp = BackpropTrainer(N, DS, verbose = True, learningrate=0.05, momentum=0.5, weightdecay=0.0005)
- # do:
- # bp.trainOnDataset(epochs=100, dataset=DS)
- while bp.train() > 0.0007:
- pass
- evalRnnOnSeqDataset(N, DS2, verbose = True)
- print('(backprop-trained weights)')
- print('Final weights:', N.params)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement