Guest User

Untitled

a guest
Apr 19th, 2018
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.09 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. from __future__ import print_function
  5.  
  6. import shutil
  7.  
  8. import numpy as np
  9. import tensorflow as tf
  10.  
  11.  
  12. INPUT_SIZE = 50
  13. N_TRAIN_SAMPLES = 102400
  14. # just random guess of 16 lstm cells
  15. HIDDEN_UNITS = 16
  16. BATCH_SIZE = 256
  17. LEARNING_RATE = 0.005
  18.  
  19. TRAINING_STEPS = 1000
  20. DISPLAY_STEP = TRAINING_STEPS / 100
  21.  
  22. TENSORBOARD_DIR = "./tensorboard"
  23.  
  24.  
  25. def generate_dataset(n=N_TRAIN_SAMPLES):
  26. max_50_bit_number = 2**50 - 1
  27. random_numbers = np.random.randint(0, max_50_bit_number, n)
  28. X_strings = list(map(lambda x: "{:050b}".format(x), random_numbers))
  29. # print(X_strings)
  30. y = []
  31. for x_s in X_strings:
  32. if x_s.count('1') % 2 == 0:
  33. y.append([1,0])
  34. else:
  35. y.append([0,1])
  36.  
  37. X = [[int(i) for i in s] for s in X_strings]
  38. return (X, y)
  39.  
  40.  
  41. def build_and_train_model(dataset):
  42. print("[..] Building network model")
  43.  
  44. # input tensor X is #{BATCH_SIZE} of 50bits lists
  45. X = tf.placeholder(tf.float32, (BATCH_SIZE, INPUT_SIZE), name="X")
  46. # the answer is either even [1, 0] or odd [0, 1]
  47. y = tf.placeholder(tf.float32, (BATCH_SIZE, 2), name="y")
  48.  
  49. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=HIDDEN_UNITS)
  50.  
  51. # cast inputs to the shape `static_rnn` wants: list of 50 elements every of size [BATCH_SIZE, 1]
  52. inputs = tf.reshape(X, [BATCH_SIZE, INPUT_SIZE, 1])
  53. inputs = tf.unstack(inputs, INPUT_SIZE, axis=1)
  54. assert(len(inputs) == INPUT_SIZE)
  55. assert(inputs[0].shape == (BATCH_SIZE, 1) )
  56. outputs, final_state = tf.nn.static_rnn(lstm_cell, inputs, dtype=tf.float32)
  57.  
  58. with tf.name_scope("Logits"):
  59. Why = tf.get_variable("Why", shape=[HIDDEN_UNITS, 2],
  60. initializer=tf.contrib.layers.xavier_initializer())
  61. b = tf.get_variable("b", shape=[2,],
  62. initializer=tf.contrib.layers.xavier_initializer())
  63. logits = tf.matmul(outputs[-1], Why) + b
  64. tf.summary.histogram("Why", Why)
  65.  
  66. with tf.name_scope("CE"):
  67. loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits))
  68. tf.summary.scalar("CE_loss", loss_op)
  69.  
  70. with tf.name_scope("accuracy"):
  71. predictions = tf.nn.softmax(logits)
  72. assert(tf.argmax(predictions, axis=1).shape == tf.argmax(y, axis=1).shape)
  73. correct_predictions = tf.equal(tf.argmax(predictions, axis=1), tf.argmax(y, axis=1))
  74. accuracy_op = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
  75. tf.summary.scalar("accuracy", accuracy_op)
  76.  
  77. with tf.name_scope("train"):
  78. train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss_op)
  79.  
  80. summ = tf.summary.merge_all()
  81.  
  82. print("[..] Training the model")
  83. X_train, y_train = np.array(dataset[0]), np.array(dataset[1])
  84.  
  85. init = tf.global_variables_initializer()
  86. with tf.Session() as sess:
  87. sess.run(init)
  88.  
  89. writer = tf.summary.FileWriter(TENSORBOARD_DIR)
  90. writer.add_graph(sess.graph)
  91.  
  92. for step in range(TRAINING_STEPS):
  93. i = 0
  94. for batch_start in range(0, N_TRAIN_SAMPLES, BATCH_SIZE):
  95. batch_end = batch_start + BATCH_SIZE
  96. _, loss_val, acc_val = sess.run([train_op, loss_op, accuracy_op],
  97. feed_dict= {X: X_train[batch_start : batch_end],
  98. y: y_train[batch_start : batch_end]})
  99.  
  100. i += 1
  101. if i % 5 == 0:
  102. s = sess.run(summ, feed_dict={X: X_train[batch_start : batch_end],
  103. y: y_train[batch_start : batch_end]})
  104. writer.add_summary(s, step)
  105.  
  106. if (step+1) % DISPLAY_STEP == 0:
  107. print("Step {0}, loss: {1:.4}, accuracy: {2:.4}".format(step, loss_val, acc_val))
  108.  
  109.  
  110. def clean_logs_from_previous_run():
  111. print("[..] Removing \"{}\" dir".format(TENSORBOARD_DIR))
  112. shutil.rmtree(TENSORBOARD_DIR, ignore_errors=True)
  113.  
  114.  
  115. def main():
  116. clean_logs_from_previous_run()
  117. dataset = generate_dataset(N_TRAIN_SAMPLES)
  118. build_and_train_model(dataset)
  119.  
  120.  
  121. if __name__ == '__main__':
  122. main()
Add Comment
Please, Sign In to add comment