• API
• FAQ
• Tools
• Archive
SHARE
TWEET

# Parity RNN weidai  Jul 5th, 2015 262 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. from __future__ import print_function
2.
3. #!/usr/bin/env python
4. """ A simple recurrent neural network that detects parity for arbitrary sequences. """
5. """ Modified by Wei Dai from original PyBrain example """
6.
7. __author__ = 'Tom Schaul (tom@idsia.ch)'
8.
9. from pybrain.supervised.trainers.backprop import BackpropTrainer
10. from pybrain.structure import RecurrentNetwork, LinearLayer, TanhLayer, BiasUnit, FullConnection
11. from random import randint, seed
12. from pybrain.datasets import SequentialDataSet
13.
14. class ParityDataSet2(SequentialDataSet):
15.     """ Determine whether the bitstring up to the current point conains a pair number of 1s or not."""
16.     def __init__(self, n):
17.         SequentialDataSet.__init__(self, 2,1)
18.
19.         i = 0
20.         y = 0
21.         while i < 100:
22.             if y == 0:
23.                 self.newSequence()
24.                 p = -1
25.             x = randint(0,1)
26.             if x == 1:
27.                 p = -p
28.             else:
29.                 x = -1
30.             y = randint(0,n)
31.             if y == 0:
33.                 i = i+1
34.             else:
36.
37. def buildParityNet():
38.     net = RecurrentNetwork()
52.     net.sortModules()
53.
54.     p = net.params
55. #    p[:] = [-0.5, -1.5, 1, 1, -1, 1, 1, -1, 1]
56.     p *= 10.
57.
58.     return net
59.
60. def evalRnnOnSeqDataset(net, DS, verbose = False, silent = False):
61.     """ evaluate the network on all the sequences of a dataset. """
62.     r = 0.
63.     samples = 0.
64.     for seq in DS:
65.         net.reset()
66.         for i, t in seq:
67.             res = net.activate(i)
68.             if verbose:
69.                 print(t, res)
70.             r += sum((t-res)**2)
71.             samples += 1
72.         if verbose:
73.             print('-'*20)
74.     r /= samples
75.     if not silent:
76.         print('MSE:', r)
77.     return r
78.
79. if __name__ == "__main__":
80.     seed(1)
81.     N = buildParityNet()
82.     DS = ParityDataSet2(3)
83.     DS2 = ParityDataSet2(10)
84.
85.     # Backprop improves the network performance, and sometimes even finds the global optimum.
86.     N.randomize()
87.     N.params[:] = [x / 10 for x in N.params]
88.     bp = BackpropTrainer(N, DS, verbose = True, learningrate=0.05, momentum=0.5, weightdecay=0.0005)
89. #    do:
90. #        bp.trainOnDataset(epochs=100, dataset=DS)
91.     while bp.train() > 0.0007:
92.         pass
93.     evalRnnOnSeqDataset(N, DS2, verbose = True)
94.     print('(backprop-trained weights)')
95.     print('Final weights:', N.params)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.
Top