Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- train_dataset, val_dataset = train_utils.get_datasets()
- types = train_dataset.output_types
- shapes = train_dataset.output_shapes
- # LOOK HERE #
- train_iterator = train_dataset.make_one_shot_iterator()
- val_iterator = val_dataset.make_one_shot_iterator()
- handle = tf.placeholder(tf.string, shape=[], name='dataset_handle')
- iterator = tf.data.Iterator.from_string_handle(handle, types, shapes)
- inputs, outputs = iterator.get_next()
- model = MyModel(inputs)
- loss = MyLoss(model.outputs, outputs)
- train_step = ...
- val_step = ...
- with tf.Session() as sess:
- train_handle = sess.run(train_iterator.string_handle())
- val_handle = sess.run(val_iterator.string_handle())
- sess.graph.finalize()
- while True:
- sess.run(train_step, {handle: train_handle})
- sess.run(val_step, {handle: val_handle})
- train_dataset, val_dataset = train_utils.get_datasets()
- types = train_dataset.output_types
- shapes = train_dataset.output_shapes
- # LOOK HERE #
- iterator = tf.data.Iterator.from_structure(types, shapes)
- train_init_op = iterator.make_initializer(train_dataset)
- val_init_op = iterator.make_initializer(val_dataset)
- inputs, outputs = iterator.get_next()
- model = MyModel(inputs)
- loss = MyLoss(model.outputs, outputs)
- train_step = ...
- val_step = ...
- with tf.Session() as sess:
- sess.graph.finalize()
- while True:
- sess.run(train_init_op)
- sess.run(train_step)
- sess.run(val_init_op)
- sess.run(val_step)
- my_np_array = np.array(...)
- def gen():
- while True: yield my_np_array
- data_1 = tf.data.Dataset.from_generator(gen)
- data_2 = tf.data.Dataset.from_csv(...)
- data = tf.data.Dataset.zip((data_1, data_2))
Add Comment
Please, Sign In to add comment