Advertisement
Guest User

Untitled

a guest
Dec 5th, 2016
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.29 KB | None | 0 0
  1. import os
  2. import pickle
  3. import tensorflow as tf
  4. from sklearn.model_selection import train_test_split
  5.  
  6. with open('djma_v3.pkl', 'rb') as input:
  7. data = pickle.load(input)
  8. X = data["X"]
  9. y = data["y"]
  10.  
  11. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
  12.  
  13. learning_rate = 0.01
  14. training_epochs = 1000
  15. batch_size = 100
  16. display_step = 1
  17.  
  18. # Architecture
  19. n_hidden_1 = 512
  20. n_hidden_2 = 512
  21. n_hidden_3 = 512
  22.  
  23.  
  24. def layer(input, weight_shape, bias_shape):
  25. weight_init = tf.random_normal_initializer(stddev=(2.0/weight_shape[0])**0.5)
  26. bias_init = tf.constant_initializer(value=0)
  27. W = tf.get_variable("W", weight_shape,
  28. initializer=weight_init)
  29. b = tf.get_variable("b", bias_shape,
  30. initializer=bias_init)
  31. return tf.nn.relu(tf.matmul(input, W) + b)
  32.  
  33. def inference(x):
  34. with tf.variable_scope("hidden_1"):
  35. hidden_1 = layer(x, [363, n_hidden_1], [n_hidden_1])
  36.  
  37. with tf.variable_scope("hidden_2"):
  38. hidden_2 = layer(hidden_1, [n_hidden_1, n_hidden_2], [n_hidden_2])
  39. with tf.variable_scope("hidden_3"):
  40. hidden_3 = layer(hidden_2, [n_hidden_2, n_hidden_3], [n_hidden_3])
  41. with tf.variable_scope("output"):
  42. output = layer(hidden_3, [n_hidden_3, 5], [5])
  43.  
  44. return output
  45.  
  46. def loss(output, y):
  47. xentropy = tf.nn.softmax_cross_entropy_with_logits(output, y)
  48. loss = tf.reduce_mean(xentropy)
  49. return loss
  50.  
  51. def training(cost, global_step):
  52. tf.scalar_summary("cost", cost)
  53. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  54. train_op = optimizer.minimize(cost, global_step=global_step)
  55. return train_op
  56.  
  57.  
  58. def evaluate(output, y):
  59. correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y, 1))
  60. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  61. tf.scalar_summary("validation error", (1.0 - accuracy))
  62. return accuracy
  63.  
  64. if __name__ == '__main__':
  65.  
  66. with tf.Graph().as_default():
  67.  
  68. with tf.variable_scope("mlp_model"):
  69.  
  70. x = tf.placeholder("float", [None, X_train.shape[1]])
  71. y = tf.placeholder("float", [None, y_train.shape[1]])
  72.  
  73. output = inference(x)
  74.  
  75. cost = loss(output, y)
  76.  
  77. global_step = tf.Variable(0, name='global_step', trainable=False)
  78.  
  79. train_op = training(cost, global_step)
  80.  
  81. eval_op = evaluate(output, y)
  82.  
  83. summary_op = tf.merge_all_summaries()
  84.  
  85. sess = tf.Session()
  86.  
  87. # saver = tf.train.import_meta_graph('mlp_logs/model-checkpoint-383800.meta')
  88. saver.restore(sess, tf.train.latest_checkpoint('mlp_logs/'))
  89.  
  90. # saver = tf.train.Saver()
  91.  
  92. summary_writer = tf.train.SummaryWriter("mlp_logs/",
  93. graph_def=sess.graph_def)
  94.  
  95. init_op = tf.initialize_all_variables()
  96.  
  97. sess.run(init_op)
  98.  
  99.  
  100. # Training cycle
  101. for epoch in range(training_epochs):
  102.  
  103. avg_cost = 0.
  104. total_batch = int(X_train.shape[0]/batch_size)
  105. # Loop over all batches
  106. for i in range(total_batch):
  107. idx = i * batch_size
  108. minibatch_x, minibatch_y = X_train[idx:idx+batch_size], y_train[idx:idx+batch_size]
  109. # Fit training using batch data
  110. sess.run(train_op, feed_dict={x: minibatch_x, y: minibatch_y})
  111. # Compute average loss
  112. avg_cost += sess.run(cost, feed_dict={x: minibatch_x, y: minibatch_y})/total_batch
  113. # Display logs per epoch step
  114. if epoch % display_step == 0:
  115. print("Epoch:", '%04d' % (epoch+1), "cost =", "{:.9f}".format(avg_cost))
  116.  
  117. accuracy = sess.run(eval_op, feed_dict={x: X_test, y: y_test})
  118.  
  119. print("Validation Error:", (1 - accuracy))
  120.  
  121. summary_str = sess.run(summary_op, feed_dict={x: minibatch_x, y: minibatch_y})
  122. summary_writer.add_summary(summary_str, sess.run(global_step))
  123.  
  124. saver.save(sess, "mlp_logs/model-checkpoint", global_step=global_step)
  125.  
  126.  
  127. print("Optimization Finished!")
  128.  
  129.  
  130. accuracy = sess.run(eval_op, feed_dict={x: X_test, y: y_test})
  131.  
  132. print("Test Accuracy:", accuracy)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement