Guest User

Untitled

a guest
Oct 20th, 2017
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.44 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import csv
  4.  
  5. def gen_dataset():
  6. up = [i for i in range(10)]
  7. down = [9-i for i in range(10)]
  8.  
  9. with open('./est_dataset.csv', 'w') as f:
  10. writer = csv.writer(f, delimiter=',')
  11. for i in range(1000):
  12. writer.writerow([1] + up)
  13. writer.writerow([0] + down)
  14.  
  15. def input_fn():
  16. dataset = tf.contrib.data.TextLineDataset('./est_dataset.csv')
  17. dataset = dataset.shuffle(7777).batch(10)
  18.  
  19. itr = dataset.make_one_shot_iterator()
  20.  
  21. batch = itr.get_next()
  22.  
  23. batch = tf.decode_csv(batch, [[0]]*11)
  24.  
  25. train = tf.cast(tf.stack(batch[1:], axis=1), dtype=tf.float32)
  26. label = tf.cast(batch[0], dtype=tf.float32)
  27. return train, label
  28.  
  29. # train, label = input_fn()
  30. # with tf.Session() as sess:
  31. # _train, _label = sess.run([train, label])
  32. # for t, l in zip(_train, _label):
  33. # print(t, l)
  34.  
  35. def model_fn(features, labels, mode):
  36. layer1 = tf.layers.dense(features, 10)
  37. layer2 = tf.layers.dense(layer1, 10)
  38. out = tf.layers.dense(layer2, 1)
  39.  
  40. out = tf.reshape(out, [-1])
  41.  
  42. global_step = tf.train.get_global_step()
  43. loss = tf.losses.sigmoid_cross_entropy(labels, out)
  44. train_op = tf.train.GradientDescentOptimizer(1e-2).minimize(loss, global_step)
  45. return tf.estimator.EstimatorSpec(mode=mode, train_op=train_op, loss=loss)
  46.  
  47. est = tf.estimator.Estimator(model_fn, model_dir='./est_logs/')
  48.  
  49. for i in range(100):
  50. est.train(input_fn)
Add Comment
Please, Sign In to add comment