Advertisement
Guest User

Untitled

a guest
Jan 29th, 2020
173
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.04 KB | None | 0 0
  1. from math import exp
  2. from random import random, seed
  3.  
  4.  
  5. # Иницијализација на мрежа
  6. # Ставете фиксни тежини од 0.5 на code.finki.ukim.mk ако постои проблем со random()
  7. def initialize_network(n_inputs, n_hidden, n_outputs):
  8. """Изградба на мрежата и иницијализација на тежините
  9.  
  10. :param n_inputs: број на неврони во влезниот слој
  11. :type n_inputs: int
  12. :param n_hidden: број на неврони во скриениот слој
  13. :type n_hidden: int
  14. :param n_outputs: број на неврони во излезниот слој
  15. (број на класи)
  16. :type n_outputs: int
  17. :return: мрежата како листа на слоеви, каде што секој
  18. слој е речник со клуч 'weights' и нивните вредности
  19. :rtype: list(list(dict(str, list)))
  20. """
  21. network = list()
  22. hidden_layer = [{'weights': [random() for _ in range(n_inputs + 1)]}
  23. for _ in range(n_hidden)]
  24. network.append(hidden_layer)
  25. output_layer = [{'weights': [random() for _ in range(n_hidden + 1)]}
  26. for _ in range(n_outputs)]
  27. network.append(output_layer)
  28. return network
  29.  
  30.  
  31. def neuron_calculate(weights, inputs):
  32. """Пресметување на вредноста за активација на неврон
  33.  
  34. :param weights: даден вектор (листа) на тежини
  35. :type weights: list(float)
  36. :param inputs: даден вектор (листа) на влезови
  37. :type inputs: list(float)
  38. :return: пресметка на невронот
  39. :rtype: float
  40. """
  41. activation = weights[-1]
  42. for i in range(len(weights) - 1):
  43. activation += weights[i] * inputs[i]
  44. return activation
  45.  
  46.  
  47. def sigmoid_activation(activation):
  48. """Sigmoid активациска функција
  49.  
  50. :param activation: вредност за активациската функција
  51. :type activation: float
  52. :return: вредност добиена од примена на активациската
  53. функција
  54. :rtype: float
  55. """
  56. return 1.0 / (1.0 + exp(-activation))
  57.  
  58.  
  59. def forward_propagate(network, row):
  60. """Пропагирање нанапред на влезот кон излезот на мрежата
  61.  
  62. :param network: дадената мрежа
  63. :param row: моменталната податочна инстаца
  64. :return: листа на излезите од последниот слој
  65. """
  66. inputs = row
  67. for layer in network:
  68. new_inputs = []
  69. for neuron in layer:
  70. activation = neuron_calculate(neuron['weights'], inputs)
  71. neuron['output'] = sigmoid_activation(activation)
  72. new_inputs.append(neuron['output'])
  73. inputs = new_inputs
  74. return inputs
  75.  
  76.  
  77. def sigmoid_activation_derivative(output):
  78. """Пресметување на изводот на излезот од невронот
  79.  
  80. :param output: излезни вредности
  81. :return: вредност на изводот
  82. """
  83. return output * (1.0 - output)
  84.  
  85.  
  86. def backward_propagate_error(network, expected):
  87. """Пропагирање на грешката наназад и сочувување во невроните
  88.  
  89. :param network: дадена мрежа
  90. :type network: list(list(dict(str, list)))
  91. :param expected: очекувани вредности за излезот
  92. :type expected: list(int)
  93. :return: None
  94. """
  95. for i in reversed(range(len(network))):
  96. layer = network[i]
  97. errors = list()
  98. if i != len(network) - 1:
  99. for j in range(len(layer)):
  100. error = 0.0
  101. for neuron in network[i + 1]:
  102. error += (neuron['weights'][j] * neuron['delta'])
  103. errors.append(error)
  104. else:
  105. for j in range(len(layer)):
  106. neuron = layer[j]
  107. errors.append(expected[j] - neuron['output'])
  108. for j in range(len(layer)):
  109. neuron = layer[j]
  110. neuron['delta'] = errors[j] * sigmoid_activation_derivative(neuron['output'])
  111.  
  112.  
  113. def update_weights(network, row, l_rate):
  114. """Ажурирање на тежините на мрежата со грешката
  115.  
  116. :param network: дадена мрежа
  117. :type network: list(list(dict(str, list)))
  118. :param row: една инстанца на податоци
  119. :type row: list
  120. :param l_rate: рата на учење
  121. :type l_rate: float
  122. :return: None
  123. """
  124. for i in range(len(network)):
  125. inputs = row[:-1]
  126. if i != 0:
  127. inputs = [neuron['output'] for neuron in network[i - 1]]
  128. for neuron in network[i]:
  129. for j in range(len(inputs)):
  130. neuron['weights'][j] += l_rate * neuron['delta'] * inputs[j]
  131. neuron['weights'][-1] += l_rate * neuron['delta']
  132.  
  133.  
  134. def train_network(network, train, l_rate, n_epoch, n_outputs, verbose=True):
  135. """Тренирање на мрежата за фиксен број на епохи
  136.  
  137. :param network: дадена мрежа
  138. :type network: list(list(dict(str, list)))
  139. :param train: тренирачко множество
  140. :type train: list
  141. :param l_rate: рата на учење
  142. :type l_rate: float
  143. :param n_epoch: број на епохи
  144. :type n_epoch: int
  145. :param n_outputs: број на неврони (класи) во излезниот слој
  146. :type n_outputs: int
  147. :param verbose: True за принтање на лог, инаку False
  148. :type: verbose: bool
  149. :return: None
  150. """
  151. for epoch in range(n_epoch):
  152. sum_error = 0
  153. for row in train:
  154. outputs = forward_propagate(network, row)
  155. expected = [0] * n_outputs
  156. expected[row[-1]] = 1
  157. sum_error += sum([(expected[i] - outputs[i]) ** 2 for i in range(len(expected))])
  158. backward_propagate_error(network, expected)
  159. update_weights(network, row, l_rate)
  160. if verbose:
  161. print('>epoch=%d, lrate=%.3f, error=%.3f' % (epoch, l_rate, sum_error))
  162.  
  163.  
  164. def predict(network, row):
  165. """Направи предвидување
  166.  
  167. :param network: дадена мрежа
  168. :type network: list(list(dict(str, list)))
  169. :param row: една податочна инстанца
  170. :type row: list
  171. :return: предвидени класи
  172. """
  173. outputs = forward_propagate(network, row)
  174. return outputs.index(max(outputs))
  175.  
  176.  
  177. dataset=[
  178. [6.3,2.9,5.6,1.8,0],
  179. [6.5,3.0,5.8,2.2,0],
  180. [7.6,3.0,6.6,2.1,0],
  181. [4.9,2.5,4.5,1.7,0],
  182. [7.3,2.9,6.3,1.8,0],
  183. [6.7,2.5,5.8,1.8,0],
  184. [7.2,3.6,6.1,2.5,0],
  185. [6.5,3.2,5.1,2.0,0],
  186. [6.4,2.7,5.3,1.9,0],
  187. [6.8,3.0,5.5,2.1,0],
  188. [5.7,2.5,5.0,2.0,0],
  189. [5.8,2.8,5.1,2.4,0],
  190. [6.4,3.2,5.3,2.3,0],
  191. [6.5,3.0,5.5,1.8,0],
  192. [7.7,3.8,6.7,2.2,0],
  193. [7.7,2.6,6.9,2.3,0],
  194. [6.0,2.2,5.0,1.5,0],
  195. [6.9,3.2,5.7,2.3,0],
  196. [5.6,2.8,4.9,2.0,0],
  197. [7.7,2.8,6.7,2.0,0],
  198. [6.3,2.7,4.9,1.8,0],
  199. [6.7,3.3,5.7,2.1,0],
  200. [7.2,3.2,6.0,1.8,0],
  201. [6.2,2.8,4.8,1.8,0],
  202. [6.1,3.0,4.9,1.8,0],
  203. [6.4,2.8,5.6,2.1,0],
  204. [7.2,3.0,5.8,1.6,0],
  205. [7.4,2.8,6.1,1.9,0],
  206. [7.9,3.8,6.4,2.0,0],
  207. [6.4,2.8,5.6,2.2,0],
  208. [6.3,2.8,5.1,1.5,0],
  209. [6.1,2.6,5.6,1.4,0],
  210. [7.7,3.0,6.1,2.3,0],
  211. [6.3,3.4,5.6,2.4,0],
  212. [5.1,3.5,1.4,0.2,1],
  213. [4.9,3.0,1.4,0.2,1],
  214. [4.7,3.2,1.3,0.2,1],
  215. [4.6,3.1,1.5,0.2,1],
  216. [5.0,3.6,1.4,0.2,1],
  217. [5.4,3.9,1.7,0.4,1],
  218. [4.6,3.4,1.4,0.3,1],
  219. [5.0,3.4,1.5,0.2,1],
  220. [4.4,2.9,1.4,0.2,1],
  221. [4.9,3.1,1.5,0.1,1],
  222. [5.4,3.7,1.5,0.2,1],
  223. [4.8,3.4,1.6,0.2,1],
  224. [4.8,3.0,1.4,0.1,1],
  225. [4.3,3.0,1.1,0.1,1],
  226. [5.8,4.0,1.2,0.2,1],
  227. [5.7,4.4,1.5,0.4,1],
  228. [5.4,3.9,1.3,0.4,1],
  229. [5.1,3.5,1.4,0.3,1],
  230. [5.7,3.8,1.7,0.3,1],
  231. [5.1,3.8,1.5,0.3,1],
  232. [5.4,3.4,1.7,0.2,1],
  233. [5.1,3.7,1.5,0.4,1],
  234. [4.6,3.6,1.0,0.2,1],
  235. [5.1,3.3,1.7,0.5,1],
  236. [4.8,3.4,1.9,0.2,1],
  237. [5.0,3.0,1.6,0.2,1],
  238. [5.0,3.4,1.6,0.4,1],
  239. [5.2,3.5,1.5,0.2,1],
  240. [5.2,3.4,1.4,0.2,1],
  241. [5.5,2.3,4.0,1.3,2],
  242. [6.5,2.8,4.6,1.5,2],
  243. [5.7,2.8,4.5,1.3,2],
  244. [6.3,3.3,4.7,1.6,2],
  245. [4.9,2.4,3.3,1.0,2],
  246. [6.6,2.9,4.6,1.3,2],
  247. [5.2,2.7,3.9,1.4,2],
  248. [5.0,2.0,3.5,1.0,2],
  249. [5.9,3.0,4.2,1.5,2],
  250. [6.0,2.2,4.0,1.0,2],
  251. [6.1,2.9,4.7,1.4,2],
  252. [5.6,2.9,3.6,1.3,2],
  253. [6.7,3.1,4.4,1.4,2],
  254. [5.6,3.0,4.5,1.5,2],
  255. [5.8,2.7,4.1,1.0,2],
  256. [6.2,2.2,4.5,1.5,2],
  257. [5.6,2.5,3.9,1.1,2],
  258. [5.9,3.2,4.8,1.8,2],
  259. [6.1,2.8,4.0,1.3,2],
  260. [6.3,2.5,4.9,1.5,2],
  261. [6.1,2.8,4.7,1.2,2],
  262. [6.4,2.9,4.3,1.3,2],
  263. [6.6,3.0,4.4,1.4,2],
  264. [6.8,2.8,4.8,1.4,2],
  265. [6.7,3.0,5.0,1.7,2],
  266. [6.0,2.9,4.5,1.5,2],
  267. [5.7,2.6,3.5,1.0,2],
  268. [5.5,2.4,3.8,1.1,2],
  269. [5.4,3.0,4.5,1.5,2],
  270. [6.0,3.4,4.5,1.6,2],
  271. [6.7,3.1,4.7,1.5,2],
  272. [6.3,2.3,4.4,1.3,2],
  273. [5.6,3.0,4.1,1.3,2],
  274. [5.5,2.5,4.0,1.3,2],
  275. [5.5,2.6,4.4,1.2,2],
  276. [6.1,3.0,4.6,1.4,2],
  277. [5.8,2.6,4.0,1.2,2],
  278. [5.0,2.3,3.3,1.0,2],
  279. [5.6,2.7,4.2,1.3,2],
  280. [5.7,3.0,4.2,1.2,2],
  281. [5.7,2.9,4.2,1.3,2],
  282. [6.2,2.9,4.3,1.3,2],
  283. [5.1,2.5,3.0,1.1,2],
  284. [5.7,2.8,4.1,1.3,2],
  285. [6.4,3.1,5.5,1.8,0],
  286. [6.0,3.0,4.8,1.8,0],
  287. [6.9,3.1,5.4,2.1,0],
  288. [6.8,3.2,5.9,2.3,0],
  289. [6.7,3.3,5.7,2.5,0],
  290. [6.7,3.0,5.2,2.3,0],
  291. [6.3,2.5,5.0,1.9,0],
  292. [6.5,3.0,5.2,2.0,0],
  293. [6.2,3.4,5.4,2.3,0],
  294. [4.7,3.2,1.6,0.2,1],
  295. [4.8,3.1,1.6,0.2,1],
  296. [5.4,3.4,1.5,0.4,1],
  297. [5.2,4.1,1.5,0.1,1],
  298. [5.5,4.2,1.4,0.2,1],
  299. [4.9,3.1,1.5,0.2,1],
  300. [5.0,3.2,1.2,0.2,1],
  301. [5.5,3.5,1.3,0.2,1],
  302. [4.9,3.6,1.4,0.1,1],
  303. [4.4,3.0,1.3,0.2,1],
  304. [5.1,3.4,1.5,0.2,1],
  305. [5.0,3.5,1.3,0.3,1],
  306. [4.5,2.3,1.3,0.3,1],
  307. [4.4,3.2,1.3,0.2,1],
  308. [5.0,3.5,1.6,0.6,1],
  309. [5.9,3.0,5.1,1.8,0],
  310. [5.1,3.8,1.9,0.4,1],
  311. [4.8,3.0,1.4,0.3,1],
  312. [5.1,3.8,1.6,0.2,1],
  313. [5.5,2.4,3.7,1.0,2],
  314. [5.8,2.7,3.9,1.2,2],
  315. [6.0,2.7,5.1,1.6,2],
  316. [6.7,3.1,5.6,2.4,0],
  317. [6.9,3.1,5.1,2.3,0],
  318. [5.8,2.7,5.1,1.9,0],
  319. ]
  320.  
  321. if __name__ == "__main__":
  322. # ne menuvaj
  323. seed(1)
  324.  
  325. att1 = float(input())
  326. att2 = float(input())
  327. att3 = float(input())
  328. att4 = float(input())
  329. planttype = int(input())
  330. testCase = [att1, att2, att3, att4, planttype]
  331.  
  332. # vasiot kod ovde
  333. train_set=dataset[:-10]
  334. val_set=dataset[-10:]
  335. network = initialize_network(4, 3, 3)
  336. train_network(network, train_set, 0.3, 20, 3, False)
  337. network2 = initialize_network(4, 3, 3)
  338. train_network(network2, train_set, 0.5, 20, 3, False)
  339. network3 = initialize_network(4, 3, 3)
  340. train_network(network3, train_set, 0.7, 20, 3, False)
  341.  
  342. pogodeni1 = pogodeni2 = pogodeni3 = 0
  343.  
  344. for x in val_set:
  345. if predict(network, x) == x[-1]:
  346. pogodeni1 += 1
  347. if predict(network2, x) == x[-1]:
  348. pogodeni2 += 1
  349. if predict(network3, x) == x[-1]:
  350. pogodeni3 += 1
  351.  
  352. bestNetwork = network
  353. if pogodeni1 > pogodeni2 and pogodeni1 > pogodeni3:
  354. bestNetwork = network
  355. if pogodeni2 > pogodeni1 and pogodeni2 > pogodeni3:
  356. bestNetwork = network2
  357. if pogodeni3 > pogodeni1 and pogodeni3 > pogodeni2:
  358. bestNetwork = network3
  359.  
  360. print(predict(bestNetwork, testCase))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement