Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- batch_size = 100
- handle_mix = tf.placeholder(tf.float64, shape=[])
- handle_src0 = tf.placeholder(tf.float64, shape=[])
- handle_src1 = tf.placeholder(tf.float64, shape=[])
- handle_src2 = tf.placeholder(tf.float64, shape=[])
- handle_src3 = tf.placeholder(tf.float64, shape=[])
- dataset = tf.data.Dataset.from_tensor_slices(
- {"x_mixed":padded_lbl, "y_src0": padded_src[0], "y_src1": padded_src[1],"y_src2": padded_src[1], "y_src3": padded_src[1]})
- dataset = dataset.shuffle(1000).repeat().batch(batch_size)
- iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
- next_element = iterator.get_next()
- training_init_op = iterator.make_initializer(dataset)
- for _ in range(20):
- # Initialize an iterator over the training dataset.
- sess.run(training_init_op)
- for _ in range(100):
- sess.run(next_element)
- l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_mix, handle_src0: batch_src0,
- handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})
Add Comment
Please, Sign In to add comment