Guest User

Untitled

a guest
Jan 22nd, 2019
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.69 KB | None | 0 0
  1. tokens = newData()
  2. train_Y_ = W2V.One_hot(train_Y)
  3. train_X_ = W2V.Convert2Vec("Data/my_model",train_X)
  4. Batch_size = 32
  5. Total_size = len(train_X)
  6. Vector_size = 300
  7. seq_length = [len(x) for x in train_X]
  8. Maxseq_length = max(seq_length)
  9. learning_rate = 0.001
  10. lstm_units = 128
  11. num_class = 2
  12. training_epochs = 5
  13. keep_prob = 0.75
  14.  
  15. X = tf.placeholder(tf.float32, shape = [None, Maxseq_length, Vector_size], name = 'X')
  16. Y = tf.placeholder(tf.float32, shape = [None, num_class], name = 'Y')
  17. seq_len = tf.placeholder(tf.int32, shape = [None])
  18.  
  19. BiLSTM = Bi_LSTM.Bi_LSTM(lstm_units, num_class, keep_prob)
  20.  
  21. with tf.variable_scope("loss", reuse = tf.AUTO_REUSE):
  22. logits = BiLSTM.logits(X, BiLSTM.W, BiLSTM.b, seq_len)
  23. loss, optimizer = BiLSTM.model_build(logits, Y, learning_rate)
  24.  
  25. prediction = tf.nn.softmax(logits)
  26. correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
  27. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
  28.  
  29. init = tf.global_variables_initializer()
  30.  
  31. total_batch = int(Total_size / Batch_size)
  32.  
  33. print("Start training!")
  34.  
  35. modelName = "BiLSTM.ckpt"
  36. saver = tf.train.Saver()
  37.  
  38.  
  39.  
  40. with tf.Session() as sess:
  41.  
  42. start_time = time.time()
  43. ckpt = tf.train.get_checkpoint_state('Data')
  44. if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  45. saver.restore(sess,ckpt.model_checkpoint_path)
  46. else:
  47. sess.run(init)
  48. train_writer = tf.summary.FileWriter('Bidirectional_LSTM', sess.graph)
  49. i = 0
  50. for epoch in range(training_epochs):
  51. avg_acc, avg_loss = 0. , 0.
  52. for step in range(total_batch):
  53. train_batch_X = train_X_[step*Batch_size : step*Batch_size+Batch_size]
  54. train_batch_Y = train_Y_[step*Batch_size : step*Batch_size+Batch_size]
  55. batch_seq_length = seq_length[step*Batch_size : step*Batch_size+Batch_size]
  56.  
  57. train_batch_X = W2V.Zero_padding(train_batch_X, Batch_size, Maxseq_length, Vector_size)
  58.  
  59. sess.run(optimizer, feed_dict={X: train_batch_X, Y: train_batch_Y, seq_len: batch_seq_length})
  60. # Compute average loss
  61. loss_ = sess.run(loss, feed_dict={X: train_batch_X, Y: train_batch_Y, seq_len: batch_seq_length})
  62. avg_loss += loss_ / total_batch
  63.  
  64. acc = sess.run(accuracy , feed_dict={X: train_batch_X, Y: train_batch_Y, seq_len: batch_seq_length})
  65. avg_acc += acc / total_batch
  66. print("epoch : {:02d} step : {:04d} loss = {:.6f} accuracy= {:.6f}".format(epoch+1, step+1, loss_, acc))
  67.  
  68. summary = sess.run(BiLSTM.graph_build(avg_loss, avg_acc))
  69. train_writer.add_summary(summary, i)
  70. i += 1
  71.  
  72. duration = time.time() - start_time
  73. minute = int(duration / 60)
  74. second = int(duration) % 60
  75. print("%dminutes %dseconds" % (minute,second))
  76. save_path = saver.save(sess, os.getcwd())
  77.  
  78. train_writer.close()
  79. print('save_path',save_path)
Add Comment
Please, Sign In to add comment