Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.81 KB | None | 0 0
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from torch.utils.tensorboard import SummaryWriter
  5. import pandas as pd
  6. from sklearn.model_selection import train_test_split
  7. import matplotlib.pyplot as plt
  8. import string
  9. import unicodedata
  10.  
  11.  
  12. ## Création des data :
  13.  
  14. LETTRES = string.ascii_letters + string.punctuation + string.digits + ' '
  15. id2lettre = dict(zip(range(1,len(LETTRES)+1),LETTRES))
  16. id2lettre[0] = '' ##Null character
  17. lettre2id = dict(zip(id2lettre.values(), id2lettre.keys()))
  18.  
  19. def normalize(s):
  20. return ''.join(c for c in unicodedata.normalize('NFD', s) if c in LETTRES)
  21. def string2code(s):
  22. return torch.tensor([lettre2id[c] for c in normalize(s)])
  23. def code2string(t):
  24. if type(t) != list :
  25. t = t.tolist()
  26. return ''.join(id2lettre[i] for i in t)
  27.  
  28. File = open('trump_full_speech.txt','r')
  29. data_trump = string2code(File.read())
  30.  
  31.  
  32.  
  33.  
  34. # Modèle :
  35.  
  36. class RNN(nn.Module):
  37. def __init__(self, latent, in_, out_):
  38. super(RNN,self).__init__()
  39. self.latent = latent
  40. self.linear1 = nn.Linear(in_,self.latent,bias=False)
  41. self.linear2 = nn.Linear(self.latent,self.latent,bias=True)
  42. self.tanh = nn.Tanh()
  43. def one_step(self,x,h):
  44. res = (self.linear1(x.float()) + self.linear2(h.float()))
  45. return self.tanh(res)
  46. def forward(self,x,h):
  47. self.next_h = h
  48. k = 0
  49. #One step appliqué a chaque élements de la sequence (x_i.shape = (batch_size,1,embedDim))
  50. for x_i in x.transpose(0,1):
  51. self.next_h = self.one_step(x_i,self.next_h)
  52. if k==0:
  53. self.h_list = self.next_h
  54. else:
  55. self.h_list = torch.cat((self.h_list,self.next_h),1)
  56. k+=1
  57. return self.next_h, self.h_list
  58.  
  59.  
  60.  
  61. class SequenceGenerator(nn.Module):
  62. def __init__(self, inDim, embedDim, hidenDim, outDim):
  63. super().__init__()
  64. self.inDim, self.embedDim, self.hidenDim,self.outDim = inDim, embedDim, hidenDim, outDim
  65. self.embedding = nn.Embedding(self.inDim, self.embedDim)
  66. self.rnn = RNN(self.hidenDim, self.embedDim, self.outDim)
  67. self.h = - torch.ones(batch_size, 1, self.hidenDim)
  68. def forward(self, x):
  69. #Embedding + forward
  70. x = self.embedding(x)
  71. x = self.rnn.forward(x.float(), self.h.float())
  72. return x
  73.  
  74.  
  75.  
  76. # Paramètres
  77.  
  78. batch_size = 10
  79. seq_length_train = 5558 # length
  80. latent = 96
  81. nb_epoch = 5000
  82. lr = 0.001
  83. momentum = 1
  84. seq_length = 200
  85. #inDim = taille dict, embedDim = arbitraire
  86. inDim, embedDim, hidenDim, outDim = 96, 50,96,96
  87.  
  88.  
  89. #Initialisation
  90.  
  91. net = SequenceGenerator(inDim, embedDim, hidenDim, outDim)
  92. loss_function = nn.CrossEntropyLoss()
  93. optim2 = torch.optim.Adam(params=net.parameters() , lr=lr,weight_decay=0.95)
  94.  
  95. loss_train = []
  96. loss_test = []
  97.  
  98.  
  99.  
  100. for k in range(nb_epoch):
  101. if k%50 == 0 :
  102. print(k)
  103.  
  104. # Train
  105.  
  106. ## Creation du batch + target
  107. index = [np.random.randint(1341153-seq_length-1) for i in range (batch_size)]
  108. x = data_trump[index[0]:index[0]+seq_length]
  109. target = data_trump[index[0]+1:index[0]+seq_length+1]
  110. x = torch.stack((x,data_trump[index[1]:index[1]+seq_length]),0)
  111. target = torch.stack((target,data_trump[index[1]+1:index[1]+seq_length+1]),0)
  112. for i in index[2:]:
  113. x = torch.cat((x,data_trump[i:i+seq_length].unsqueeze(0)),0)
  114. target = torch.cat((target,data_trump[i+1:i+seq_length+1].unsqueeze(0)),0)
  115.  
  116. ## training du rnn
  117. optim2.zero_grad()
  118. h, h_seq = net.forward(x.unsqueeze(2))
  119. loss = loss_function(h_seq.view(batch_size,inDim,seq_length),target)
  120. loss.backward()
  121. optim2.step()
  122. loss_train.append(loss.item())
  123.  
  124. plt.plot(loss_train, label = 'loss_train')
  125. plt.legend()
  126. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement