Advertisement
Guest User

uhuhu

a guest
May 24th, 2019
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.74 KB | None | 0 0
  1. from random import seed
  2.  
  3. from math import exp
  4. from random import random
  5.  
  6.  
  7. # Initialize a network
  8. # Put fixed weights to 0.5 on code.finki.ukim.mk if there is a problem with random()
  9. def initialize_network(n_inputs, n_hidden, n_outputs):
  10. """Build the network and initialize the weights
  11.  
  12. :param n_inputs: number of neurons in the input layer
  13. :type n_inputs: int
  14. :param n_hidden: number of neurons in the hidden layer
  15. :type n_hidden: int
  16. :param n_outputs: number of neurons in the output layer
  17. (number of classes)
  18. :type n_outputs: int
  19. :return: the network as a list of the layers, where each
  20. layer is a dictionary with key 'weights' and their
  21. values
  22. :rtype: list(list(dict(str, list)))
  23. """
  24. network = list()
  25. hidden_layer = [{'weights': [random() for _ in range(n_inputs + 1)]}
  26. for _ in range(n_hidden)]
  27. network.append(hidden_layer)
  28. output_layer = [{'weights': [random() for _ in range(n_hidden + 1)]}
  29. for _ in range(n_outputs)]
  30. network.append(output_layer)
  31. return network
  32.  
  33.  
  34. def neuron_calculate(weights, inputs):
  35. """Calculate neuron activation value
  36.  
  37. :param weights: given vector (list) of weights
  38. :type weights: list(float)
  39. :param inputs: given vector (list) of inputs
  40. :type inputs: list(float)
  41. :return: neuron calculation
  42. :rtype: float
  43. """
  44. activation = weights[-1]
  45. for i in range(len(weights) - 1):
  46. activation += weights[i] * inputs[i]
  47. return activation
  48.  
  49.  
  50. def sigmoid_activation(activation):
  51. """Sigmoid activation function
  52.  
  53. :param activation: value for the activation function
  54. :type activation: float
  55. :return: activation function value
  56. :rtype: float
  57. """
  58. return 1.0 / (1.0 + exp(-activation))
  59.  
  60.  
  61. def forward_propagate(network, row):
  62. """Forward propagate input to a network output
  63.  
  64. :param network: the network
  65. :param row: current data instance
  66. :return: list of outputs from the last layer
  67. """
  68. inputs = row
  69. for layer in network:
  70. new_inputs = []
  71. for neuron in layer:
  72. activation = neuron_calculate(neuron['weights'], inputs)
  73. neuron['output'] = sigmoid_activation(activation)
  74. new_inputs.append(neuron['output'])
  75. inputs = new_inputs
  76. return inputs
  77.  
  78.  
  79. def sigmoid_activation_derivative(output):
  80. """Calculates the derivative of an neuron output
  81.  
  82. :param output: the output values
  83. :return: value of the derivative
  84. """
  85. return output * (1.0 - output)
  86.  
  87.  
  88. def backward_propagate_error(network, expected):
  89. """Backpropagate error and store in neurons
  90.  
  91. :param network: the network
  92. :type network: list(list(dict(str, list)))
  93. :param expected: expected output values
  94. :type expected: list(int)
  95. :return: None
  96. """
  97. for i in reversed(range(len(network))):
  98. layer = network[i]
  99. errors = list()
  100. if i != len(network) - 1:
  101. for j in range(len(layer)):
  102. error = 0.0
  103. for neuron in network[i + 1]:
  104. error += (neuron['weights'][j] * neuron['delta'])
  105. errors.append(error)
  106. else:
  107. for j in range(len(layer)):
  108. neuron = layer[j]
  109. errors.append(expected[j] - neuron['output'])
  110. for j in range(len(layer)):
  111. neuron = layer[j]
  112. neuron['delta'] = errors[j] * sigmoid_activation_derivative(neuron['output'])
  113.  
  114.  
  115. def update_weights(network, row, l_rate):
  116. """Update network weights with error
  117.  
  118. :param network: the network
  119. :type network: list(list(dict(str, list)))
  120. :param row: one data instance
  121. :type row: list
  122. :param l_rate: learning rate
  123. :type l_rate: float
  124. :return: None
  125. """
  126. for i in range(len(network)):
  127. inputs = row[:-1]
  128. if i != 0:
  129. inputs = [neuron['output'] for neuron in network[i - 1]]
  130. for neuron in network[i]:
  131. for j in range(len(inputs)):
  132. neuron['weights'][j] += l_rate * neuron['delta'] * inputs[j]
  133. neuron['weights'][-1] += l_rate * neuron['delta']
  134.  
  135.  
  136. def train_network(network, train, l_rate, n_epoch, n_outputs, verbose=True):
  137. """Train a network for a fixed number of epochs
  138.  
  139. :param network: the network
  140. :type network: list(list(dict(str, list)))
  141. :param train: train dataset
  142. :type train: list
  143. :param l_rate: learning rate
  144. :type l_rate: float
  145. :param n_epoch: number of epochs
  146. :type n_epoch: int
  147. :param n_outputs: number of neurons (classes) in output layer
  148. :type n_outputs: int
  149. :param verbose: True for printing log, else False
  150. :type: verbose: bool
  151. :return: None
  152. """
  153. for epoch in range(n_epoch):
  154. sum_error = 0
  155. for row in train:
  156. outputs = forward_propagate(network, row)
  157. expected = [0] * n_outputs
  158. expected[row[-1]] = 1
  159. sum_error += sum([(expected[i] - outputs[i]) ** 2 for i in range(len(expected))])
  160. backward_propagate_error(network, expected)
  161. update_weights(network, row, l_rate)
  162. if verbose:
  163. print('>epoch=%d, lrate=%.3f, error=%.3f' % (epoch, l_rate, sum_error))
  164.  
  165.  
  166. def predict(network, row):
  167. """Make a prediction with a network
  168.  
  169. :param network: the network
  170. :type network: list(list(dict(str, list)))
  171. :param row: one data instance
  172. :type row: list
  173. :return: predicted classes
  174. """
  175. outputs = forward_propagate(network, row)
  176. return outputs.index(max(outputs))
  177.  
  178.  
  179. dataset = [
  180. [6.3, 2.9, 5.6, 1.8, 0],
  181. [6.5, 3.0, 5.8, 2.2, 0],
  182. [7.6, 3.0, 6.6, 2.1, 0],
  183. [4.9, 2.5, 4.5, 1.7, 0],
  184. [7.3, 2.9, 6.3, 1.8, 0],
  185. [6.7, 2.5, 5.8, 1.8, 0],
  186. [7.2, 3.6, 6.1, 2.5, 0],
  187. [6.5, 3.2, 5.1, 2.0, 0],
  188. [6.4, 2.7, 5.3, 1.9, 0],
  189. [6.8, 3.0, 5.5, 2.1, 0],
  190. [5.7, 2.5, 5.0, 2.0, 0],
  191. [5.8, 2.8, 5.1, 2.4, 0],
  192. [6.4, 3.2, 5.3, 2.3, 0],
  193. [6.5, 3.0, 5.5, 1.8, 0],
  194. [7.7, 3.8, 6.7, 2.2, 0],
  195. [7.7, 2.6, 6.9, 2.3, 0],
  196. [6.0, 2.2, 5.0, 1.5, 0],
  197. [6.9, 3.2, 5.7, 2.3, 0],
  198. [5.6, 2.8, 4.9, 2.0, 0],
  199. [7.7, 2.8, 6.7, 2.0, 0],
  200. [6.3, 2.7, 4.9, 1.8, 0],
  201. [6.7, 3.3, 5.7, 2.1, 0],
  202. [7.2, 3.2, 6.0, 1.8, 0],
  203. [6.2, 2.8, 4.8, 1.8, 0],
  204. [6.1, 3.0, 4.9, 1.8, 0],
  205. [6.4, 2.8, 5.6, 2.1, 0],
  206. [7.2, 3.0, 5.8, 1.6, 0],
  207. [7.4, 2.8, 6.1, 1.9, 0],
  208. [7.9, 3.8, 6.4, 2.0, 0],
  209. [6.4, 2.8, 5.6, 2.2, 0],
  210. [6.3, 2.8, 5.1, 1.5, 0],
  211. [6.1, 2.6, 5.6, 1.4, 0],
  212. [7.7, 3.0, 6.1, 2.3, 0],
  213. [6.3, 3.4, 5.6, 2.4, 0],
  214. [5.1, 3.5, 1.4, 0.2, 1],
  215. [4.9, 3.0, 1.4, 0.2, 1],
  216. [4.7, 3.2, 1.3, 0.2, 1],
  217. [4.6, 3.1, 1.5, 0.2, 1],
  218. [5.0, 3.6, 1.4, 0.2, 1],
  219. [5.4, 3.9, 1.7, 0.4, 1],
  220. [4.6, 3.4, 1.4, 0.3, 1],
  221. [5.0, 3.4, 1.5, 0.2, 1],
  222. [4.4, 2.9, 1.4, 0.2, 1],
  223. [4.9, 3.1, 1.5, 0.1, 1],
  224. [5.4, 3.7, 1.5, 0.2, 1],
  225. [4.8, 3.4, 1.6, 0.2, 1],
  226. [4.8, 3.0, 1.4, 0.1, 1],
  227. [4.3, 3.0, 1.1, 0.1, 1],
  228. [5.8, 4.0, 1.2, 0.2, 1],
  229. [5.7, 4.4, 1.5, 0.4, 1],
  230. [5.4, 3.9, 1.3, 0.4, 1],
  231. [5.1, 3.5, 1.4, 0.3, 1],
  232. [5.7, 3.8, 1.7, 0.3, 1],
  233. [5.1, 3.8, 1.5, 0.3, 1],
  234. [5.4, 3.4, 1.7, 0.2, 1],
  235. [5.1, 3.7, 1.5, 0.4, 1],
  236. [4.6, 3.6, 1.0, 0.2, 1],
  237. [5.1, 3.3, 1.7, 0.5, 1],
  238. [4.8, 3.4, 1.9, 0.2, 1],
  239. [5.0, 3.0, 1.6, 0.2, 1],
  240. [5.0, 3.4, 1.6, 0.4, 1],
  241. [5.2, 3.5, 1.5, 0.2, 1],
  242. [5.2, 3.4, 1.4, 0.2, 1],
  243. [5.5, 2.3, 4.0, 1.3, 2],
  244. [6.5, 2.8, 4.6, 1.5, 2],
  245. [5.7, 2.8, 4.5, 1.3, 2],
  246. [6.3, 3.3, 4.7, 1.6, 2],
  247. [4.9, 2.4, 3.3, 1.0, 2],
  248. [6.6, 2.9, 4.6, 1.3, 2],
  249. [5.2, 2.7, 3.9, 1.4, 2],
  250. [5.0, 2.0, 3.5, 1.0, 2],
  251. [5.9, 3.0, 4.2, 1.5, 2],
  252. [6.0, 2.2, 4.0, 1.0, 2],
  253. [6.1, 2.9, 4.7, 1.4, 2],
  254. [5.6, 2.9, 3.6, 1.3, 2],
  255. [6.7, 3.1, 4.4, 1.4, 2],
  256. [5.6, 3.0, 4.5, 1.5, 2],
  257. [5.8, 2.7, 4.1, 1.0, 2],
  258. [6.2, 2.2, 4.5, 1.5, 2],
  259. [5.6, 2.5, 3.9, 1.1, 2],
  260. [5.9, 3.2, 4.8, 1.8, 2],
  261. [6.1, 2.8, 4.0, 1.3, 2],
  262. [6.3, 2.5, 4.9, 1.5, 2],
  263. [6.1, 2.8, 4.7, 1.2, 2],
  264. [6.4, 2.9, 4.3, 1.3, 2],
  265. [6.6, 3.0, 4.4, 1.4, 2],
  266. [6.8, 2.8, 4.8, 1.4, 2],
  267. [6.7, 3.0, 5.0, 1.7, 2],
  268. [6.0, 2.9, 4.5, 1.5, 2],
  269. [5.7, 2.6, 3.5, 1.0, 2],
  270. [5.5, 2.4, 3.8, 1.1, 2],
  271. [5.4, 3.0, 4.5, 1.5, 2],
  272. [6.0, 3.4, 4.5, 1.6, 2],
  273. [6.7, 3.1, 4.7, 1.5, 2],
  274. [6.3, 2.3, 4.4, 1.3, 2],
  275. [5.6, 3.0, 4.1, 1.3, 2],
  276. [5.5, 2.5, 4.0, 1.3, 2],
  277. [5.5, 2.6, 4.4, 1.2, 2],
  278. [6.1, 3.0, 4.6, 1.4, 2],
  279. [5.8, 2.6, 4.0, 1.2, 2],
  280. [5.0, 2.3, 3.3, 1.0, 2],
  281. [5.6, 2.7, 4.2, 1.3, 2],
  282. [5.7, 3.0, 4.2, 1.2, 2],
  283. [5.7, 2.9, 4.2, 1.3, 2],
  284. [6.2, 2.9, 4.3, 1.3, 2],
  285. [5.1, 2.5, 3.0, 1.1, 2],
  286. [5.7, 2.8, 4.1, 1.3, 2],
  287. [6.4, 3.1, 5.5, 1.8, 0],
  288. [6.0, 3.0, 4.8, 1.8, 0],
  289. [6.9, 3.1, 5.4, 2.1, 0],
  290. [6.8, 3.2, 5.9, 2.3, 0],
  291. [6.7, 3.3, 5.7, 2.5, 0],
  292. [6.7, 3.0, 5.2, 2.3, 0],
  293. [6.3, 2.5, 5.0, 1.9, 0],
  294. [6.5, 3.0, 5.2, 2.0, 0],
  295. [6.2, 3.4, 5.4, 2.3, 0],
  296. [4.7, 3.2, 1.6, 0.2, 1],
  297. [4.8, 3.1, 1.6, 0.2, 1],
  298. [5.4, 3.4, 1.5, 0.4, 1],
  299. [5.2, 4.1, 1.5, 0.1, 1],
  300. [5.5, 4.2, 1.4, 0.2, 1],
  301. [4.9, 3.1, 1.5, 0.2, 1],
  302. [5.0, 3.2, 1.2, 0.2, 1],
  303. [5.5, 3.5, 1.3, 0.2, 1],
  304. [4.9, 3.6, 1.4, 0.1, 1],
  305. [4.4, 3.0, 1.3, 0.2, 1],
  306. [5.1, 3.4, 1.5, 0.2, 1],
  307. [5.0, 3.5, 1.3, 0.3, 1],
  308. [4.5, 2.3, 1.3, 0.3, 1],
  309. [4.4, 3.2, 1.3, 0.2, 1],
  310. [5.0, 3.5, 1.6, 0.6, 1],
  311. [5.9, 3.0, 5.1, 1.8, 0],
  312. [5.1, 3.8, 1.9, 0.4, 1],
  313. [4.8, 3.0, 1.4, 0.3, 1],
  314. [5.1, 3.8, 1.6, 0.2, 1],
  315. [5.5, 2.4, 3.7, 1.0, 2],
  316. [5.8, 2.7, 3.9, 1.2, 2],
  317. [6.0, 2.7, 5.1, 1.6, 2],
  318. [6.7, 3.1, 5.6, 2.4, 0],
  319. [6.9, 3.1, 5.1, 2.3, 0],
  320. [5.8, 2.7, 5.1, 1.9, 0],
  321. ]
  322.  
  323. if __name__ == "__main__":
  324. # ne menuvaj
  325. seed(1)
  326.  
  327. att1 = float(input())
  328. att2 = float(input())
  329. att3 = float(input())
  330. att4 = float(input())
  331. planttype = int(input())
  332. testCase = [att1, att2, att3, att4, planttype]
  333.  
  334. # vasiot kod ovde
  335.  
  336. lr = [0.3,0.5,0.7]
  337.  
  338. train_set = dataset[0:-10]
  339. validation_set = dataset[-10:]
  340.  
  341. size_inputs = len(dataset[0]) - 1
  342. size_outputs = len(set([row[-1] for row in dataset]))
  343.  
  344. best = 0
  345.  
  346. for learning_rate in lr:
  347. network = initialize_network(size_inputs, 3, size_outputs)
  348. train_network(network, train_set, learning_rate, 20, size_outputs,False)
  349.  
  350. score = 0
  351.  
  352. for i in validation_set:
  353. prediction = predict(network, i)
  354. if prediction == i[-1]:
  355. score = score + 1
  356.  
  357.  
  358. if score >= best:
  359. best = score
  360. best_network = network
  361.  
  362. print(predict(best_network, testCase))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement