Advertisement
denisb413

Untitled

Nov 10th, 2017
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.04 KB | None | 0 0
  1.  
  2.  
  3. def load_corpus(dirname, tokens_only=False):
  4. """ This function receives a directory path,
  5. read files from this directory, load the data
  6. and return a python generator which contains
  7. tagged sentences"""
  8. labels = [f for f in listdir(dirname) if f.endswith('.txt')]
  9. for name in labels:
  10. fname = dirname + "/" + name
  11. with smart_open.smart_open(fname, encoding="iso-8859-1") as fin:
  12. line = fin.read()
  13. if tokens_only:
  14. yield gensim.utils.simple_preprocess(line)
  15. else:
  16. # For training data, add tags
  17. yield gensim.models.doc2vec.TaggedDocument(
  18. gensim.utils.simple_preprocess(line), [name])
  19.  
  20. def start_training(hyperparams, train_corpus):
  21. model = gensim.models.doc2vec.Doc2Vec(size=hyperparams['size'], min_count=hyperparams['min_count'],
  22. iter=hyperparams['iter'], workers=16, window=hyperparams['window'],
  23. alpha=hyperparams['alpha'], min_alpha=hyperparams['min_alpha'],
  24. dm=hyperparams['dm'])
  25.  
  26. print("Building vocabulary")
  27. model.random.seed(0)
  28. model.build_vocab(train_corpus)
  29. print("Training the model")
  30. print(model)
  31. model.train(train_corpus, total_examples=model.corpus_count, epochs=model.iter)
  32.  
  33. return model
  34.  
  35. def get_word_vec(filename):
  36. """" This function receives a previously preprocessed file,
  37. and return a words vector"""
  38. with open(filename) as data_file:
  39. words_vec = data_file.read().split()
  40. return words_vec
  41.  
  42. def decay_equation(position):
  43. return np.exp((np.log(0.5)/(10 - 1)) * (position - 1))
  44.  
  45. def eval_model(model, eval_dir, hyperparams):
  46. ranked_eval = {}
  47. correct = 0
  48.  
  49. eval_files_list = os.listdir(eval_dir)
  50. for file in eval_files_list:
  51. eval_file = eval_dir + file
  52. words_vec = get_word_vec(eval_file)
  53. model.random.seed(0)
  54. inferred_vector = model.infer_vector(words_vec, alpha=hyperparams['alpha'], min_alpha=hyperparams['min_alpha'], steps=(hyperparams['iter']))
  55. similars = model.docvecs.most_similar([inferred_vector], topn=len(model.docvecs))
  56.  
  57. target_ER = ''
  58. if len(file) == 21:
  59. target_ER = file[-17:]
  60. elif len(file) == 22:
  61. target_ER = file[-18:]
  62. elif len(file) == 23:
  63. target_ER = file[-19:]
  64. elif len(file) == 24:
  65. target_ER = file[-20:]
  66. elif len(file) == 25:
  67. target_ER = file[-21:]
  68. elif len(file) == 29:
  69. target_ER = file[-18:]
  70.  
  71. for i in range(len(similars)):
  72. sim = similars[i]
  73. if sim[0] == target_ER:
  74. print(file, "found in position", i)
  75. ranked_eval[file] = i
  76. correct += decay_equation(i)
  77.  
  78. break
  79.  
  80. accuracy_rate = (correct / len(eval_files_list)) * 100
  81.  
  82. return accuracy_rate, ranked_eval
  83.  
  84. #train the model
  85. hyperparams = {
  86. 'size': size,
  87. 'min_count': min_count,
  88. 'iter': iter,
  89. 'window': window,
  90. 'alpha': alpha,
  91. 'min_alpha': min_alpha,
  92. 'dm': dm,
  93. }
  94. print("Training with the hyperparams:")
  95. print(hyperparams, "\n")
  96.  
  97. print("Loading files")
  98. train_corpus = list(load_corpus(train_dir))
  99. model = start_training(hyperparams, train_corpus)
  100.  
  101. model_file = model_dir + preprocessed_input + '.model'
  102. model.save(model_file)
  103. print ("Trained model saved in:", model_file)
  104.  
  105. #evaluation
  106. accuracy_rate, ranked_eval = eval_model(model, eval_dir, hyperparams)
  107. hyperparams['accuracy_rate'] = accuracy_rate
  108. hyperparams['model_file'] = model_file
  109.  
  110. print("Accuracy rate: ", accuracy_rate)
  111.  
  112. #write a log file with hyperparams and its results
  113. results_file = model_dir + 'tests_result.txt'
  114. with open(results_file, "a") as f:
  115. f.write("%s" % (hyperparams))
  116. f.write("%s\n" % (ranked_eval))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement