Guest User

Untitled

a guest
Nov 20th, 2018
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.42 KB | None | 0 0
  1. """
  2. Small code snippet with a SUPER SLOW AND INEFFICIENT MLP IMPLEMENTATION.
  3. """
  4. from math import exp, tanh, sin, pi
  5. import numpy as np
  6. from random import random
  7.  
  8.  
  9. class MLP:
  10. """
  11. Classe que cria a rede neural artificial do tipo "MLP"
  12. """
  13. def __init__(self, nodes_in, nodes_out, nodes_per_hid_layer=[],
  14. train_data=[], learn_rate=0.5, iteration_limit=10000,
  15. target_error=0.01, classification=False, output_function=0,
  16. hidden_function=1):
  17. self.hidden_function = hidden_function
  18. self.output_function = output_function
  19. self.layers = len(nodes_per_hid_layer)
  20. self.learn_rate = learn_rate
  21. self.epoch = iteration_limit
  22. self.nodes_out = nodes_out
  23. self.train_data = train_data
  24. self.error = []
  25. self.target_error = target_error
  26. self.classification = classification
  27. self.nodes_in = nodes_in
  28. self.nodes_per_hid_layer = nodes_per_hid_layer
  29. self.nodes_out = nodes_out
  30. self.network = list()
  31. self.label_dict = dict()
  32. self.resetWeights()
  33.  
  34. # Atualiza os dados de treinamento
  35. def setTrainData(self, new_data):
  36. self.train_data = new_data
  37.  
  38. # Atualiza o número de iterações
  39. def setEpochs(self, p):
  40. self.epoch = p
  41.  
  42. # Atualiza a taxa de aprendizado
  43. def setLearnRate(self, lr):
  44. self.learn_rate = lr
  45.  
  46. # Atualiza os dados de treinamento
  47. def resetWeights(self):
  48. """
  49. Resets weights
  50. """
  51. self.network = []
  52. previous_nodes = self.nodes_in
  53. for nodes in self.nodes_per_hid_layer:
  54. hidden_layers = [{'weights':[random() for i in range(previous_nodes + 1)]} for i in range(nodes)]
  55. self.network.append(hidden_layers)
  56. previous_nodes = nodes
  57. output_layer = [{'weights':[random() for i in range(previous_nodes + 1)]} for i in range(self.nodes_out)]
  58. self.network.append(output_layer)
  59.  
  60. # Propaga para frente ao longo da RNA dada uma entrada
  61. def forward_propagate(self, inputs):
  62. """
  63. Performs forward propagation
  64. """
  65. layer_count = 0
  66. for layer in self.network:
  67. new_inputs = []
  68. for neuron in layer:
  69. activation = self.activate(neuron['weights'], inputs)
  70. neuron['output'] = self.transfer(activation, layer_count)
  71. new_inputs.append(neuron['output'])
  72. inputs = new_inputs
  73. layer_count += 1
  74. return inputs
  75.  
  76. # Calcula a ativação do neurônio para uma dada entrada
  77. @staticmethod
  78. def activate(weights, inputs):
  79. """
  80. Node Linear Activation
  81. """
  82. activation = weights[-1]
  83. for i in range(len(weights)-1):
  84. activation += weights[i] * inputs[i]
  85. return activation
  86.  
  87. # Transfere para a Funções de Ativação do nó
  88. def transfer(self, activation, layer):
  89. """
  90. Activation Function
  91. """
  92. # Seleciona o tipo de ativação do neuron
  93. if not layer == self.layers:
  94. if self.hidden_function == 1:
  95. return 1.0 / (1.0 + exp(-activation))
  96. elif self.hidden_function == 2:
  97. return tanh(activation)
  98. else:
  99. return activation
  100. else:
  101. if self.output_function == 1:
  102. return 1.0 / (1.0 + exp(-activation))
  103. elif self.output_function == 2:
  104. return tanh(activation)
  105. else:
  106. return activation
  107.  
  108. def transfer_derivative(self, output, layer):
  109. """
  110. Derivative of the Activation Function
  111. """
  112. # Seleciona o tipo de ativação do neuron
  113. if not layer == self.layers:
  114. if self.hidden_function == 1:
  115. return output * (1.0 - output)
  116. elif self.hidden_function == 2:
  117. return 1 - (output) ** 2
  118. else:
  119. return 1.0
  120. else:
  121. if self.output_function == 1:
  122. return output * (1.0 - output)
  123. elif self.output_function == 2:
  124. return 1 - (output) ** 2
  125. else:
  126. return 1.0
  127.  
  128. # Backpropagate o erro e armazena nos próprios neurons
  129. def backward_propagate_error(self, expected):
  130. """
  131. Performs backpropagation
  132. """
  133. for i in reversed(range(len(self.network))):
  134. layer = self.network[i]
  135. errors = list()
  136. if i != len(self.network) - 1:
  137. for j in range(len(layer)):
  138. error = 0.0
  139. for neuron in self.network[i + 1]:
  140. error += (neuron['weights'][j] * neuron['delta'])
  141. errors.append(error)
  142. else:
  143. for j in range(len(layer)):
  144. neuron = layer[j]
  145. errors.append(expected[j] - neuron['output'])
  146. for j in range(len(layer)):
  147. neuron = layer[j]
  148. neuron['delta'] = errors[j] * self.transfer_derivative(neuron['output'], i)
  149.  
  150. # Atualiza os pesos dos neurons
  151. def update_weights(self, row, l_rate):
  152. """
  153. Random weights initializer
  154. """
  155. for i in range(len(self.network)):
  156. inputs = row[:-1]
  157. if i != 0:
  158. inputs = [neuron['output'] for neuron in self.network[i - 1]]
  159. for neuron in self.network[i]:
  160. for j in range(len(inputs)):
  161. neuron['weights'][j] += l_rate * neuron['delta'] * inputs[j]
  162. neuron['weights'][-1] += l_rate * neuron['delta']
  163.  
  164. # Representação da classificação
  165. def reshape_output(self):
  166. """
  167. Reshapes the output
  168. """
  169. self.label_dict = dict()
  170. itens_set = set()
  171. for item in self.train_data:
  172. for out in item[-self.nodes_out:]:
  173. itens_set.add(out)
  174. # Atualiza o dicionário
  175. i = 0
  176. for item in sorted(itens_set):
  177. self.label_dict[item] = i
  178. i += 1
  179.  
  180. # Train a network for a fixed number of epochs
  181. def train_network(self):
  182. """
  183. Trains the Network
  184. """
  185. self.reshape_output()
  186. self.error = []
  187. train = self.train_data
  188. n_epoch = self.epoch
  189. l_rate = self.learn_rate
  190. n_outputs = self.nodes_out
  191. # Reorganiza a saídas para classificação
  192. error_count = 0
  193. for epoch in range(n_epoch):
  194. sum_error = 0
  195. for row in train:
  196. outputs = self.forward_propagate(row)
  197. if self.classification:
  198. expected = [0 for i in range(n_outputs)]
  199. expected[self.label_dict[row[-1]]] = 1
  200. else:
  201. expected = row[-n_outputs:]
  202. sum_error += sum([(expected[i] - outputs[i])**2 for i in range(len(expected))])
  203. self.backward_propagate_error(expected)
  204. self.update_weights(row, l_rate)
  205. total_error = sum_error / (len(expected) * len(train))
  206. self.error.append(total_error)
  207. if (total_error <= self.target_error and error_count > 5) or epoch == n_epoch:
  208. print('>epoch=%d, Taxa de Apredizado=%.3f, error=%.3f' % (epoch, l_rate, total_error))
  209. error_count += 1
  210. break
  211. # print('>epoch=%d, Taxa de Apredizado=%.3f, error=%.3f' % (epoch, l_rate, total_error))
  212.  
  213. # Usa a RNA para fazer uma previsão
  214. def predict(self, row):
  215. outputs = self.forward_propagate(row)
  216. if not self.classification:
  217. return outputs
  218. else:
  219. return outputs.index(max(outputs))
  220.  
  221. ############### EXEMPLO
  222. # Arquitetura da rede
  223. # Variáveis de controle
  224. n_inputs = 2 # Numero de entradas
  225. n_outputs = 2 # Numnero de saídas
  226. n_hidden = [10] # Número de neurons na(s) camada(s) escondidas
  227. epochs = 5000 # Épocas
  228. learn_rate = 0.1 # Taxa de aprendizado
  229. target_error = 0.01 # Treina até alcançar esse erro ou o número de épocas
  230. classification = False # Se o problema é de Classificação ou não
  231.  
  232. # Dados
  233. x = np.random.uniform(-2.5 * pi, 2.5 * pi, 50) # random values
  234. dataset = [[item, sin(item)] for item in x]
  235.  
  236. # Modelo
  237. RNA = MLP(n_inputs, n_outputs, n_hidden, dataset, learn_rate, epochs, target_error, classification, output_function=2)
  238.  
  239. # Treinamento
  240. RNA.train_network()
Add Comment
Please, Sign In to add comment