Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- tokens = newData()
- train_Y_ = W2V.One_hot(train_Y)
- train_X_ = W2V.Convert2Vec("Data/my_model",train_X)
- Batch_size = 32
- Total_size = len(train_X)
- Vector_size = 300
- seq_length = [len(x) for x in train_X]
- Maxseq_length = max(seq_length)
- learning_rate = 0.001
- lstm_units = 128
- num_class = 2
- training_epochs = 5
- keep_prob = 0.75
- X = tf.placeholder(tf.float32, shape = [None, Maxseq_length, Vector_size], name = 'X')
- Y = tf.placeholder(tf.float32, shape = [None, num_class], name = 'Y')
- seq_len = tf.placeholder(tf.int32, shape = [None])
- BiLSTM = Bi_LSTM.Bi_LSTM(lstm_units, num_class, keep_prob)
- with tf.variable_scope("loss", reuse = tf.AUTO_REUSE):
- logits = BiLSTM.logits(X, BiLSTM.W, BiLSTM.b, seq_len)
- loss, optimizer = BiLSTM.model_build(logits, Y, learning_rate)
- prediction = tf.nn.softmax(logits)
- correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
- init = tf.global_variables_initializer()
- total_batch = int(Total_size / Batch_size)
- print("Start training!")
- modelName = "BiLSTM.ckpt"
- saver = tf.train.Saver()
- with tf.Session() as sess:
- start_time = time.time()
- ckpt = tf.train.get_checkpoint_state('Data')
- if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
- saver.restore(sess,ckpt.model_checkpoint_path)
- else:
- sess.run(init)
- train_writer = tf.summary.FileWriter('Bidirectional_LSTM', sess.graph)
- i = 0
- for epoch in range(training_epochs):
- avg_acc, avg_loss = 0. , 0.
- for step in range(total_batch):
- train_batch_X = train_X_[step*Batch_size : step*Batch_size+Batch_size]
- train_batch_Y = train_Y_[step*Batch_size : step*Batch_size+Batch_size]
- batch_seq_length = seq_length[step*Batch_size : step*Batch_size+Batch_size]
- train_batch_X = W2V.Zero_padding(train_batch_X, Batch_size, Maxseq_length, Vector_size)
- sess.run(optimizer, feed_dict={X: train_batch_X, Y: train_batch_Y, seq_len: batch_seq_length})
- # Compute average loss
- loss_ = sess.run(loss, feed_dict={X: train_batch_X, Y: train_batch_Y, seq_len: batch_seq_length})
- avg_loss += loss_ / total_batch
- acc = sess.run(accuracy , feed_dict={X: train_batch_X, Y: train_batch_Y, seq_len: batch_seq_length})
- avg_acc += acc / total_batch
- print("epoch : {:02d} step : {:04d} loss = {:.6f} accuracy= {:.6f}".format(epoch+1, step+1, loss_, acc))
- summary = sess.run(BiLSTM.graph_build(avg_loss, avg_acc))
- train_writer.add_summary(summary, i)
- i += 1
- duration = time.time() - start_time
- minute = int(duration / 60)
- second = int(duration) % 60
- print("%dminutes %dseconds" % (minute,second))
- save_path = saver.save(sess, os.getcwd())
- train_writer.close()
- print('save_path',save_path)
Add Comment
Please, Sign In to add comment