Advertisement
Guest User

Untitled

a guest
May 25th, 2021
56
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.35 KB | None | 0 0
  1. from math import *
  2. import random
  3.  
  4. def sig(x):
  5.     return 1 / (1 + exp(-x))
  6.  
  7. def sig_prime(x): # x already sigmoided
  8.     return x * (1 - x)
  9.  
  10. def tanh_prime(x): # x already tanhed
  11.     return 1 - x * x
  12.  
  13. def relu(x):
  14.     return max(x,0)
  15.  
  16. def relu_prime(x): # x already relued
  17.     return 1 if x > 0 else 0
  18.  
  19.  
  20. class Network:
  21.     def __init__(self, input_size, hidden_size, alpha, mom):
  22.         self.hidden_size = hidden_size
  23.         self.input_size = input_size
  24.         self.input_weights = []
  25.         self.input_biases = []
  26.         self.output_weights = []
  27.         self.output_bias = random.uniform(-0.5,0.5)
  28.         self.input_momentum_weights = []
  29.         self.output_momentum_weights = []
  30.         self.alpha = alpha
  31.         self.mom = mom
  32.         for i in range(hidden_size):
  33.             input_weight_list = []
  34.             input_momentum_list = []
  35.            
  36.             for i in range(input_size):
  37.                 input_weight_list.append(random.uniform(-0.5,0.5))
  38.            
  39.             for i in range(input_size):
  40.                 input_momentum_list.append(0)
  41.                
  42.             self.input_weights.append(input_weight_list)
  43.             self.input_biases.append(random.uniform(-0.5,0.5))
  44.             self.output_weights.append(random.uniform(-0.5,0.5))
  45.             self.input_momentum_weights.append(input_momentum_list)
  46.             self.output_momentum_weights.append(0)
  47.  
  48.     def forward(self, inputs):
  49.         h = [] # hidden node value
  50.         for i in range(len(self.input_weights)):
  51.             result = self.input_biases[i]
  52.             for j in range(self.input_size):
  53.                 result += self.input_weights[i][j]*inputs[j]
  54.             h.append(relu(result))
  55.         o = 0 # output value
  56.         for i in range(len(self.output_weights)):
  57.             o += self.output_weights[i] * h[i]
  58.         o += self.output_bias
  59.         return tanh(o)
  60.  
  61.     def learn(self,inputs,t):
  62.         h = [] # hidden node value
  63.         for i in range(len(self.input_weights)):
  64.             result = self.input_biases[i]
  65.             for j in range(self.input_size):
  66.                 result += self.input_weights[i][j]*inputs[j]
  67.             h.append(relu(result))
  68.         o = 0 # output value
  69.         for i in range(len(self.output_weights)):
  70.             o += self.output_weights[i] * h[i]
  71.         o += self.output_bias
  72.         o = tanh(o)
  73.      
  74.         e = t - o
  75.  
  76.         do = e * tanh_prime(o)
  77.      
  78.         eh = []
  79.         for i in range(len(self.output_weights)):
  80.             eh.append(do * self.output_weights[i])
  81.      
  82.         dh = []
  83.         for i in range(len(self.output_weights)):
  84.             dh.append(eh[i] * relu_prime(h[i]))
  85.      
  86.         for i in range(len(self.output_weights)):
  87.             self.output_momentum_weights[i] = self.mom * self.output_momentum_weights[i] + self.alpha * h[i] * do
  88.             self.output_weights[i] += self.output_momentum_weights[i]
  89.         self.output_bias += self.alpha * do
  90.      
  91.         for i in range(len(self.input_weights)):
  92.             for j in range(self.input_size):
  93.                 self.input_momentum_weights[i][j] = self.mom * self.input_momentum_weights[i][j] + self.alpha * inputs[j] * dh[i]
  94.             for j in range(self.input_size):
  95.                 self.input_weights[i][j] += self.input_momentum_weights[i][j]
  96.             self.input_biases[i] += self.alpha * dh[i]
  97.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement