Guest User

Untitled

a guest
Sep 18th, 2018
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.62 KB | None | 0 0
  1. train_dataset, val_dataset = train_utils.get_datasets()
  2. types = train_dataset.output_types
  3. shapes = train_dataset.output_shapes
  4.  
  5. # LOOK HERE #
  6. train_iterator = train_dataset.make_one_shot_iterator()
  7. val_iterator = val_dataset.make_one_shot_iterator()
  8. handle = tf.placeholder(tf.string, shape=[], name='dataset_handle')
  9. iterator = tf.data.Iterator.from_string_handle(handle, types, shapes)
  10.  
  11. inputs, outputs = iterator.get_next()
  12.  
  13. model = MyModel(inputs)
  14. loss = MyLoss(model.outputs, outputs)
  15.  
  16. train_step = ...
  17. val_step = ...
  18.  
  19. with tf.Session() as sess:
  20. train_handle = sess.run(train_iterator.string_handle())
  21. val_handle = sess.run(val_iterator.string_handle())
  22.  
  23. sess.graph.finalize()
  24. while True:
  25. sess.run(train_step, {handle: train_handle})
  26. sess.run(val_step, {handle: val_handle})
  27.  
  28. train_dataset, val_dataset = train_utils.get_datasets()
  29. types = train_dataset.output_types
  30. shapes = train_dataset.output_shapes
  31.  
  32. # LOOK HERE #
  33. iterator = tf.data.Iterator.from_structure(types, shapes)
  34. train_init_op = iterator.make_initializer(train_dataset)
  35. val_init_op = iterator.make_initializer(val_dataset)
  36.  
  37. inputs, outputs = iterator.get_next()
  38.  
  39. model = MyModel(inputs)
  40. loss = MyLoss(model.outputs, outputs)
  41.  
  42. train_step = ...
  43. val_step = ...
  44.  
  45. with tf.Session() as sess:
  46. sess.graph.finalize()
  47. while True:
  48. sess.run(train_init_op)
  49. sess.run(train_step)
  50.  
  51. sess.run(val_init_op)
  52. sess.run(val_step)
  53.  
  54. my_np_array = np.array(...)
  55. def gen():
  56. while True: yield my_np_array
  57.  
  58. data_1 = tf.data.Dataset.from_generator(gen)
  59. data_2 = tf.data.Dataset.from_csv(...)
  60. data = tf.data.Dataset.zip((data_1, data_2))
Add Comment
Please, Sign In to add comment