Guest User

Untitled

a guest
Feb 20th, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.54 KB | None | 0 0
  1. import tensorflow as tf
  2.  
  3.  
  4. # Data source
  5. def data_generator(start, end):
  6. for x, y in zip(range(start, end), range(start, end)):
  7. print(x, y)
  8. yield x, y
  9.  
  10.  
  11. # TF dataset
  12. def input_fn(data_getter):
  13. dataset = (tf.data.Dataset.from_generator(
  14. generator=lambda: data_getter,
  15. output_types=(tf.float32),
  16. )
  17. .repeat()
  18. .make_one_shot_iterator().get_next()
  19. )
  20. return dataset[0], dataset[1]
  21.  
  22.  
  23. def model_fn(features, labels, mode):
  24. var = tf.Variable(0, dtype=tf.float32)
  25. prediction = features + var
  26. loss = prediction - labels
  27. loss.set_shape([])
  28.  
  29. return tf.estimator.EstimatorSpec(
  30. mode=mode,
  31. predictions=prediction,
  32. loss=loss,
  33. train_op=tf.contrib.layers.optimize_loss(
  34. loss=loss,
  35. global_step=tf.train.get_global_step(),
  36. optimizer=tf.train.AdamOptimizer,
  37. learning_rate=0.01,
  38. ),
  39. )
  40.  
  41.  
  42. def run():
  43. tf.logging.set_verbosity(tf.logging.DEBUG)
  44.  
  45. # NB! External data source as generator (this is what we should avoid!)
  46. train_data_gen = data_generator(start=0, end=5)
  47. eval_data_gen = data_generator(start=100, end=105)
  48.  
  49. estimator = tf.estimator.Estimator(model_fn=model_fn)
  50. train_spec = tf.estimator.TrainSpec(
  51. input_fn=lambda: input_fn(train_data_gen))
  52. eval_spec = tf.estimator.EvalSpec(
  53. input_fn=lambda: input_fn(eval_data_gen), start_delay_secs=0)
  54. tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  55.  
  56.  
  57. if __name__ == '__main__':
  58. run()
Add Comment
Please, Sign In to add comment