Advertisement
readix

tp5.py

Nov 8th, 2022 (edited)
820
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.99 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import CrossEntropyLoss
  4. from torch.utils.data import DataLoader
  5. from torch.utils.tensorboard import SummaryWriter
  6. try:
  7.     from textloader import *
  8. except:
  9.     from student_tp5.src.textloader import *
  10.  
  11. try:
  12.     from generate import *
  13. except:
  14.     from student_tp5.src.generate import *
  15. import torch.nn.functional as F
  16. from datetime import datetime
  17. from pathlib import Path
  18. # from student_tp4.src.exo4 import TrumpDataset
  19. #  TODO:
  20.  
  21.  
  22. class TrumpDataset(Dataset):
  23.     def __init__(self,text,maxsent=None,maxlen=None):
  24.         """  Dataset pour les tweets de Trump
  25.            * text : texte brut
  26.            * maxsent : nombre maximum de phrases.
  27.            * maxlen : longueur maximale des phrases.
  28.        """
  29.         maxlen = maxlen or sys.maxsize
  30.         full_text = normalize(text)
  31.         self.phrases = [p[:maxlen].strip()+"." for p in full_text.split(".") if len(p)>0]
  32.         if maxsent is not None:
  33.             self.phrases=self.phrases[:maxsent]
  34.         self.MAX_LEN = max([len(p) for p in self.phrases])
  35.  
  36.     def __len__(self):
  37.         return len(self.phrases)
  38.     def __getitem__(self,i):
  39.         t = string2code(self.phrases[i])
  40.         t = torch.cat([torch.zeros(self.MAX_LEN-t.size(0),dtype=torch.long),t])
  41.         return t[:-1],t[1:]
  42.  
  43. cle = CrossEntropyLoss(reduction='none')
  44.  
  45. def maskedCrossEntropy(output: torch.Tensor, target: torch.LongTensor, padcar: int):
  46.     """
  47.    :param output: Tenseur length x batch x output_dim,
  48.    :param target: Tenseur length x batch
  49.    :param padcar: index du caractere de padding
  50.    """
  51.     #  TODO:  Implémenter maskedCrossEntropy sans aucune boucle, la CrossEntropy qui ne prend pas en compte les caractères de padding.
  52.     mask = (target!=padcar).contiguous().view(-1)
  53.     output = output.view(-1,output.size(2))
  54.     target = target.contiguous().view(-1).long()
  55.  
  56.     return (cle(output, target)*mask).sum()/mask.sum()
  57.  
  58.  
  59. class RNN(nn.Module):
  60.     #  TODO:  Recopier l'implémentation du RNN (TP 4)
  61.     #  TODO:  Implémenter comme décrit dans la question 1
  62.     def __init__(self, dimH, dimX, dimY):
  63.         super().__init__()
  64.         self.dimH=dimH
  65.         self.dimX = dimX
  66.         self.w1 = nn.Linear(dimX, dimH)
  67.         self.w2 = nn.Linear(dimH, dimH)
  68.         self.wd = nn.Linear(dimH, dimY)
  69.         self.device=None
  70.         self.batchDim = 0
  71.         self.lengthDim=1
  72.        
  73.     def hzero(self, batch,dimH):
  74.         return torch.rand(batch, dimH)
  75.        
  76.     def one_step(self, xi, hi):
  77.         """ batch * L """
  78.         return torch.tanh(self.w1(xi)+self.w2(hi))
  79.    
  80.     def forward(self, seq, h=None):
  81.         """
  82.           batch * length * dimx (2)
  83.        """
  84.         # print(seq, seq.shape)
  85.         allH = torch.empty(seq.size(self.batchDim), seq.size(self.lengthDim), self.dimH, device=self.device)
  86.        
  87.         if h==None:
  88.             h = self.hzero(seq.size(self.batchDim), self.dimH)
  89.             h.to(self.device)
  90.            
  91.         for i in range(seq.size()[self.lengthDim]):
  92.             x = seq[:,i,:].view(seq.size(self.batchDim), self.dimX)
  93.             h = self.one_step(x, h)
  94.             allH[:,i,:] = h
  95.         return allH
  96.    
  97.     def decode(self, h):
  98.         return F.softmax(self.wd(h), 1)
  99.  
  100. class LSTM(RNN):
  101.     #  TODO:  Implémenter un LSTM
  102.     pass
  103.  
  104. class GRU(nn.Module):
  105.     #  TODO:  Implémenter un GRU
  106.     pass
  107. class State:
  108.     def __init__(self, model, optim, writePath):
  109.         self.model = model
  110.         self.optim = optim
  111.         self.epoch, self.iteration = 0,0
  112.         self.writePath = writePath
  113.     def save(self,path):
  114.         torch.save(self.model.state_dict(), path)
  115.  
  116.     @staticmethod
  117.     def load(path):
  118.         with path.open("rb") as fp:
  119.             state = torch.load(fp) #on recommence depuis le modele sauvegarde
  120.             return state
  121.  
  122. class RNN_gen(torch.nn.Module):
  123.     def __init__(self, reseau, nbChar, dimEmb, dimH, dimX, dimY,device):
  124.         super().__init__()
  125.         self.emb = nn.Embedding(nbChar, dimEmb)
  126.         self.rnn = reseau( dimH, dimX, dimY)
  127.         self.rnn.device = device
  128.        
  129.     def forward(self, x):
  130.         x_enc = self.emb(x.long())
  131.         res = self.rnn(x_enc)
  132.         return self.rnn.decode(res)
  133.  
  134.  
  135.  
  136. BATCH_SIZE = 32
  137. LENGTH = 10
  138.  
  139. #  TODO:
  140. PATH = "data/"
  141. with open(PATH+'trump_full_speech.txt') as f:
  142.     text = f.read()
  143.  
  144. ds = TrumpDataset(text,200,LENGTH)
  145. data_train = DataLoader(ds,batch_size=len(ds),drop_last=True, collate_fn = pad_collate_fn)
  146. # data_train_l = DataLoader(ds,batch_size=len(ds),shuffle=True,drop_last=True)
  147. len_emb = len(id2lettre)
  148.  
  149.  
  150.  
  151. DIM_EMB = 50
  152. DIM_H = 10
  153. device=torch.device('cpu')
  154.  
  155.  
  156.  
  157. model = RNN_gen(RNN,len_emb,DIM_EMB,DIM_H,DIM_EMB,len_emb,device)
  158.  
  159. lossfunc = lambda yhat,y : maskedCrossEntropy(yhat,y,PAD_IX)
  160. alpha=0.0001
  161. epoch=200
  162. nameState='student_tp5/src/rnn7'
  163. dataLoader = data_train
  164. optmizer=torch.optim.SGD
  165.  
  166. writePath = "student_tp5/runs/trump"+datetime.now().strftime("%Y%m%d-%H%M%S")
  167.  
  168. savepath = Path(nameState+".pch")
  169. if savepath.is_file():
  170.     state = State.load(savepath)
  171.     state.model = state.model.to(device)
  172. else:
  173.     model = model.to(device)
  174.     optim = optmizer(params=model.parameters(),lr=alpha) ## on optimise selon w et b, lr : pas de gradient
  175.     state = State(model, optim, writePath)
  176.  
  177. for epoch in range (state.epoch, epoch):
  178.    
  179.     state.iter = 0
  180.     for x,y in dataLoader:
  181.         x = x.to(device)
  182.         y = y.to(device)
  183.         state.optim.zero_grad()
  184.         predict = state.model(x)
  185.         l = lossfunc(predict,y.long())
  186.         l.backward()
  187.  
  188.         state.optim.step()
  189.         state.iter += 1
  190.  
  191.     with savepath.open("wb") as fp:
  192.         state.epoch = epoch + 1
  193.         torch.save(state, fp)
  194.    
  195.     #affichage
  196.     with torch.no_grad():
  197.         x,y = next(iter(data_train))
  198.         x = x.to(device)
  199.         y = y.to(device)
  200.         predict = state.model(x)
  201.         l_train = lossfunc(predict,y.long()).item()
  202.        
  203.         print('epoch',epoch,'loss train',l_train)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement