Advertisement
Guest User

Untitled

a guest
Dec 11th, 2019
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.80 KB | None | 0 0
  1. import math
  2. import os
  3. import random
  4. import string
  5. import time
  6. import matplotlib.pyplot as plt
  7. import matplotlib.ticker as ticker
  8. import torch
  9. from torch import nn
  10.  
  11. #sva slova abecede koja koristimo za rijeci, broj tih slova, broj ukupnih kategorija
  12. all_letters = string.ascii_letters
  13. n_letters = len(all_letters)
  14.  
  15.  
  16.  
  17. class AnimalDataset():
  18.     def readLines(filename):
  19.         lines = open(filename, encoding='utf-8').read().strip().split('\n')
  20.         return [line for line in lines]
  21.  
  22.     def __init__(self, data_root):
  23.         self.data_root = data_root
  24.         self.animals = []
  25.         self.animal_words = {}
  26.  
  27.         for animal in os.listdir(self.data_root):  # otvori Data direktorij
  28.  
  29.             self.animals.append(animal)
  30.  
  31.             animal_folder = os.path.join(self.data_root, animal)  # put poddirektorija
  32.             filepath = animal_folder + "/" + str(os.listdir(animal_folder)[0])  # put seta podataka
  33.             lines = open(filepath, encoding='utf-8').read().strip().split('\n')
  34.             self.animal_words[animal] = lines
  35.             # dodajemo uzorke u mapu: Zivotinja - rijec za zivotinju
  36.  
  37.  
  38. # Find letter index from all_letters, e.g. "a" = 0
  39. def letterToIndex(letter):
  40.     return all_letters.find(letter)
  41.  
  42.  
  43. # Turn a line into a <line_length x 1 x n_letters>,
  44. # or an array of one-hot letter vectors
  45. def lineToTensor(line):
  46.     tensor = torch.zeros(len(line), 1, n_letters)
  47.     for li, letter in enumerate(line):
  48.         tensor[li][0][letterToIndex(letter)] = 1
  49.     return tensor
  50.  
  51.  
  52. dataset = AnimalDataset('/home/pero/Documents/Neuronske/Data/') #promjeniti na lokaciju dataseta
  53. all_categories = dataset.animals
  54. words = dataset.animal_words
  55. n_categories = len(all_categories)
  56.  
  57. class RNN(nn.Module):
  58.     def __init__(self, input_size, hidden_size, output_size):
  59.         super(RNN, self).__init__()
  60.  
  61.         self.hidden_size = hidden_size
  62.  
  63.         self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
  64.         self.i2o = nn.Linear(input_size + hidden_size, output_size)
  65.         self.softmax = nn.LogSoftmax(dim=1)
  66.  
  67.     def forward(self, input, hidden):
  68.         combined = torch.cat((input, hidden), 1)
  69.         hidden = self.i2h(combined)
  70.         output = self.i2o(combined)
  71.         output = self.softmax(output)
  72.         return output, hidden
  73.  
  74.     def initHidden(self):
  75.         return torch.zeros(1, self.hidden_size)
  76.  
  77.  
  78. def categoryFromOutput(output):
  79.     top_n, top_i = output.topk(1)
  80.     category_i = top_i[0].item()
  81.     return all_categories[category_i], category_i
  82.  
  83.  
  84. n_hidden = 500
  85. rnn = RNN(n_letters, n_hidden, n_categories)
  86.  
  87. inputs = lineToTensor('kamel')
  88. hidden = torch.zeros(1, n_hidden)
  89.  
  90. output, next_hidden = rnn(inputs[0], hidden)
  91. print(output) #test run za jedno ponavljanje mreže
  92.  
  93. print(categoryFromOutput(output))
  94.  
  95.  
  96. def randomChoice(l):
  97.     return l[random.randint(0, len(l) - 1)]
  98.  
  99.  
  100. def randomTrainingExample():
  101.     category = randomChoice(all_categories)
  102.     line = randomChoice(words[category])
  103.     category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
  104.     line_tensor = lineToTensor(line)
  105.     return category, line, category_tensor, line_tensor
  106.  
  107.  
  108. for i in range(10):
  109.     category, line, category_tensor, line_tensor = randomTrainingExample()
  110.     print('category =', category, '/ line =', line)
  111.  
  112. criterion = nn.NLLLoss()
  113. learning_rate = 0.00045 #PARAM 1
  114.  
  115.  
  116. def train(category_tensor, line_tensor):
  117.     hidden = rnn.initHidden()
  118.  
  119.     rnn.zero_grad()
  120.  
  121.     for i in range(line_tensor.size()[0]):
  122.         output, hidden = rnn(line_tensor[i], hidden)
  123.  
  124.     loss = criterion(output, category_tensor)
  125.     loss.backward()
  126.  
  127.     # Add parameters' gradients to their values, multiplied by learning rate
  128.     for p in rnn.parameters():
  129.         p.data.add_(-learning_rate, p.grad.data)
  130.  
  131.     return output, loss.item()
  132.  
  133.  
  134. n_iters = 70000 #PARAM 2
  135. print_every = 500
  136. plot_every = 100
  137.  
  138. # Keep track of losses for plotting
  139. current_loss = 0
  140. all_losses = []
  141.  
  142.  
  143. def timeSince(since):
  144.     now = time.time()
  145.     s = now - since
  146.     m = math.floor(s / 60)
  147.     s -= m * 60
  148.     return '%dm %ds' % (m, s)
  149.  
  150.  
  151. start = time.time()
  152.  
  153. for iter in range(1, n_iters):
  154.     category, line, category_tensor, line_tensor = randomTrainingExample()
  155.     output, loss = train(category_tensor, line_tensor)
  156.     current_loss += loss
  157.  
  158.     # Print iter number, loss, name and guess
  159.     if iter % print_every == 0:
  160.         guess, guess_i = categoryFromOutput(output)
  161.         correct = '✓' if guess == category else '✗ (%s)' % category
  162.         print(
  163.             '%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))
  164.  
  165.     # Add current loss avg to list of losses
  166.     if iter % plot_every == 0:
  167.         all_losses.append(current_loss / plot_every)
  168.         current_loss = 0
  169.  
  170. plt.figure()
  171. plt.plot(all_losses)
  172. # Keep track of correct guesses in a confusion matrix
  173. confusion = torch.zeros(n_categories, n_categories)
  174. n_confusion = 10000
  175.  
  176.  
  177. # Just return an output given a line
  178. def evaluate(line_tensor):
  179.     hidden = rnn.initHidden()
  180.  
  181.     for i in range(line_tensor.size()[0]):
  182.         output, hidden = rnn(line_tensor[i], hidden)
  183.  
  184.     return output
  185.  
  186.  
  187. # Go through a bunch of examples and record which are correctly guessed
  188. for i in range(n_confusion):
  189.     category, line, category_tensor, line_tensor = randomTrainingExample()
  190.     output = evaluate(line_tensor)
  191.     guess, guess_i = categoryFromOutput(output)
  192.     category_i = all_categories.index(category)
  193.     confusion[category_i][guess_i] += 1
  194.  
  195. # Normalize by dividing every row by its sum
  196. for i in range(n_categories):
  197.     confusion[i] = confusion[i] / confusion[i].sum()
  198.  
  199. # Set up plot
  200. fig = plt.figure()
  201. ax = fig.add_subplot(111)
  202. cax = ax.matshow(confusion.numpy())
  203. fig.colorbar(cax)
  204.  
  205. # Set up axes
  206. ax.set_xticklabels([''] + all_categories, rotation=90)
  207. ax.set_yticklabels([''] + all_categories)
  208.  
  209. # Force label at every tick
  210. ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  211. ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
  212.  
  213. # sphinx_gallery_thumbnail_number = 2
  214. plt.show()
  215.  
  216.  
  217. def predict(input_line, n_predictions=3):
  218.     print('\n> %s' % input_line)
  219.     with torch.no_grad():
  220.         output = evaluate(lineToTensor(input_line))
  221.  
  222.         # Get top N categories
  223.         topv, topi = output.topk(n_predictions, 1, True)
  224.         predictions = []
  225.  
  226.         for i in range(n_predictions):
  227.             value = topv[0][i].item()
  228.             category_index = topi[0][i].item()
  229.             print('(%.2f) %s' % (value, all_categories[category_index]))
  230.             predictions.append([value, all_categories[category_index]])
  231.  
  232.  
  233. predict('pes')
  234. predict('cucak')
  235. predict('dva')
  236. predict('treger')
  237. a = input("Daj zivinu: ")
  238.  
  239. while a != 'kraj':
  240.     predict(str(a))
  241.     a = input("Daj zivinu: ")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement