Th3NiKo

Szekspir

Dec 7th, 2019
151
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.04 KB | None | 0 0
  1. import torch
  2. import sys
  3. from torch import nn,optim
  4. #Szekspir na znakach
  5. '''
  6. Zadanie
  7.  
  8. ngram model jezyka na znakach neuronowy
  9. wziac kawal tekstu (obojetnie), wytrenowac model i zrobic generator
  10. '''
  11.  
  12. history_length = 32
  13. embedding_size = 10
  14. nb_of_char_codes = 128
  15. hidden_size = 100
  16.  
  17. history_encoded = [ord("\n")] * history_length
  18.  
  19. device = torch.device('cpu')
  20.  
  21.  
  22. def char_source():
  23.     for line in sys.stdin:
  24.         for char in line:
  25.             if ord(char) < nb_of_char_codes:
  26.                 yield ord(char)
  27.  
  28. class NGramLanguageModel(nn.Module):
  29.     def __init__(self, nb_of_char_codes ,history_length, embedding_size,hidden_size):
  30.         super(NGramLanguageModel, self).__init__()
  31.        
  32.         self.embeddings = nn.Embedding(nb_of_char_codes,embedding_size).to(device)
  33.         self.model = nn.Sequential(
  34.             nn.Linear(history_length * embedding_size, hidden_size),
  35.             nn.Linear(hidden_size, nb_of_char_codes),
  36.             nn.LogSoftmax()
  37.         ).to(device)
  38.  
  39.     def forward(self,inputs):
  40.         embedded_inputs = self.embeddings(inputs)
  41.         return self.model(embedded_inputs.view(-1))
  42.  
  43.     def generate(self,to_be_contiuned,n):
  44.         t = ((" ") * history_length + to_be_contiuned)[-history_length:]
  45.         history = [ord(c) for c in t] #zamienia tekst na historie (ciag znakow ciag liczb)
  46.  
  47.         with torch.no_grad():
  48.             for _ in range(n):
  49.                 x = torch.tensor(history,dtype=torch.long)
  50.                 y = torch.exp(model(x))
  51.  
  52.                 best = (sorted(range(nb_of_char_codes), key=lambda i: -y[i]))[0:4] #cztery najlepsze znaki
  53.  
  54.                 yb = torch.tensor([
  55.                     y[ix] if ix in best else 0.0
  56.                     for ix in range(nb_of_char_codes)]) #zmodyfikowane y, ktore dla czterych najlepszych ma przypisane prawd
  57.                 c = torch.multinomial(yb,1)[0].item()
  58.                 t += chr(c)
  59.  
  60.                 history.pop(0)
  61.                 history.append(c)
  62.                 #Losujemy z rozkładu
  63.                 '''
  64.                c = torch.multinomial(y,1)[0].item()
  65.  
  66.                t += chr(c)
  67.                history.pop(0)
  68.                history.append(c)
  69.                '''
  70.         return t
  71.  
  72. model = NGramLanguageModel(nb_of_char_codes,history_length,embedding_size,hidden_size)
  73.  
  74. counter = 0
  75. step  = 1000
  76. losses = []
  77.  
  78. criterion = nn.NLLLoss()
  79. optimizer = optim.Adam(model.parameters())
  80.  
  81. for c in char_source():
  82.     x = torch.tensor(history_encoded, dtype=torch.long, device=device)
  83.     model.zero_grad()
  84.     y = model(x)
  85.  
  86.     loss = criterion(y.view(1,-1), torch.tensor([c],dtype=torch.long,device=device))
  87.  
  88.     losses.append(loss.item())
  89.     if len(losses) > step:
  90.         losses.pop(0)
  91.  
  92.     if counter % step == 0:
  93.         avg_loss = sum(losses) / len(losses)
  94.         print(counter)
  95.         print(avg_loss)
  96.         print(y)
  97.         print(model.generate("Machine translation is",200))
  98.  
  99.     loss.backward()
  100.     optimizer.step()
  101.  
  102.     history_encoded.pop(0)
  103.     history_encoded.append(c)
  104.     counter += 1
Advertisement
Add Comment
Please, Sign In to add comment