Guest User

Untitled

a guest
Jul 20th, 2018
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.50 KB | None | 0 0
  1. import math
  2. import random
  3.  
  4.  
  5.  
  6. class XORNeuralNetwork(object):
  7.  
  8.     def __init__(self, inputs, outputs, f=None, fp=None, beta=1.2):
  9.        
  10.         self.inputs = []
  11.         for i in inputs:
  12.             self.inputs.append(list(i) + [1.])
  13.         self.outputs = outputs
  14.        
  15.         if not f:
  16.             self.f = lambda u: 1./(1. + math.e**(-beta*u))
  17.         else:
  18.             self.f = f
  19.            
  20.         if not fp:
  21.             self.fp = lambda u: beta*self.f(u)*(1.-self.f(u))
  22.         else:
  23.             self.fp = fp
  24.            
  25.         self.s = [random.random() for i in range(3)]
  26.         self.w = [[random.random() for i in range(3)], [random.random() for i in range(3)]] #could be better
  27.        
  28.     def printParams(self):
  29.    
  30.         print "\ts_1  = ", self.s[0]
  31.         print "\ts_2  = ", self.s[1]
  32.         print "\ts_3  = ", self.s[2]
  33.         print "\tw_11 = ", self.w[0][0]
  34.         print "\tw_12 = ", self.w[0][1]
  35.         print "\tw_13 = ", self.w[0][2]
  36.         print "\tw_21 = ", self.w[1][0]
  37.         print "\tw_22 = ", self.w[1][1]
  38.         print "\tw_23 = ", self.w[1][2]
  39.        
  40.     def calcX_pi(self, input, i, derive=False):
  41.    
  42.         if i == 3:
  43.             return 1.
  44.            
  45.         if derive:
  46.             fout = self.fp
  47.         else:
  48.             fout = self.f
  49.    
  50.         return fout(self.w[i-1][0]*input[0] + self.w[i-1][1]*input[1] + self.w[i-1][2]*input[2])
  51.  
  52.     def calcOutput(self, input, derive=False):
  53.    
  54.         if len(input) == 2:
  55.             input = list(input) + [1.] # add u_p3
  56.    
  57.         if derive:
  58.             fout = self.fp
  59.         else:
  60.             fout = self.f
  61.    
  62.         x_pi = [self.calcX_pi(input, i+1) for i in range(3)]
  63.        
  64.         return fout(self.s[0]*x_pi[0] + self.s[1]*x_pi[1] + self.s[2]*x_pi[2])
  65.    
  66.     def learn(self, printflag=False, c=0.1, eps=0.001):
  67.    
  68.         network_err = float("inf")
  69.         n = 0
  70.        
  71.         while network_err > eps:
  72.        
  73.             new_s = [0,0,0]
  74.             new_w = [[0,0,0],[0,0,0]]
  75.             network_err = 0.
  76.            
  77.             for i in range(3):
  78.                 new_s[i] = self.s[i] - c * sum((self.calcOutput(self.inputs[p]) - self.outputs[p]) * self.calcOutput(self.inputs[p], True) * self.calcX_pi(self.inputs[p], i+1) for p in range(len(self.inputs)))
  79.                
  80.             for i in range(2):
  81.                 for j in range(3):
  82.                     new_w[i][j] = self.w[i][j] - c * sum((self.calcOutput(self.inputs[p]) - self.outputs[p]) * self.calcOutput(self.inputs[p], True) * self.s[i] * self.calcX_pi(self.inputs[p], i+1, True) * self.inputs[p][j] for p in range(len(self.inputs)))
  83.            
  84.             self.s = new_s[:]
  85.             self.w[0] = new_w[0][:]
  86.             self.w[1] = new_w[1][:]
  87.            
  88.             for p in range(len(self.inputs)):
  89.                 network_err += (self.calcOutput(self.inputs[p]) - self.outputs[p])**2
  90.            
  91.             n += 1
  92.             if printflag:
  93.                 if n%2000 ==0:
  94.                     print n, network_err
  95.            
  96.         return n
  97.            
  98.    
  99. if __name__ == "__main__":
  100.    
  101.     inputs = ((0,0), (0,1), (1,0), (1,1))
  102.     outputs = (0, 1, 1, 0)
  103.    
  104.     network = XORNeuralNetwork(inputs, outputs)
  105.     print 'learning...'
  106.     print '[iteration | current_error]'
  107.     network.learn(True)
  108.    
  109.     print "\nparameters:"
  110.     network.printParams()
  111.    
  112.     print "\noutputs:"
  113.     for i in range(4):
  114.         print inputs[i], '-->', network.calcOutput(inputs[i])
Add Comment
Please, Sign In to add comment