Guest User

Untitled

a guest
Jun 24th, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.26 KB | None | 0 0
  1. train_data,word2index, unigram_table,vocab = SkipGram(data = None, columns = 'comment', cut = True, WINDOW_SIZE = 2, dataname = 'newdata_pytorch.csv')
  2.  
  3. #np.save('word2index.npy', word2index)
  4. #np.save('unigram_table.npy', unigram_table)
  5. #np.save('vocab.npy', vocab)
  6. ##word2index = np.load(word2index_name).item()
  7. #torch.save(train_data, 'train_data.pth')
  8.  
  9. EMBEDDING_SIZE = 300
  10. BATCH_SIZE = 256
  11. EPOCH = 100
  12. NEG = 10 # Num of Negative Sampling
  13. losses = []
  14. model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)
  15. if USE_CUDA:
  16. model = model.cuda()
  17.  
  18. optimizer = optim.Adam(model.parameters(), lr=0.001)
  19.  
  20. for epoch in range(EPOCH):
  21. for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):
  22.  
  23. inputs, targets = zip(*batch)
  24.  
  25. inputs = torch.cat(inputs) # B x 1
  26. targets = torch.cat(targets) # B x 1
  27. negs = negative_sampling(targets, unigram_table, NEG)
  28. model.zero_grad()
  29.  
  30. loss = model(inputs, targets, negs)
  31.  
  32. loss.backward()
  33. optimizer.step()
  34.  
  35. losses.append(loss.data.tolist())
  36.  
  37. if epoch % 5 == 0:
  38. print("Epoch : %d, mean_loss : %.02f" % (epoch, np.mean(losses)))
  39. losses = []
  40. # torch.save(model, 'skipgram.pt')
  41. torch.cuda.empty_cache()
Add Comment
Please, Sign In to add comment