Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def reinitializable_input_fn(filenames, labels, train_val_ratio=0.8):
- num_files = len(filenames)
- num_train_files = int(num_files * train_val_ratio)
- train_filenames = filenames[:num_train_files]
- train_labels = labels[:num_train_files]
- val_filenames = filenames[num_train_files:]
- val_labels = labels[num_train_files:]
- train_data = tf.data.Dataset.from_tensor_slices(
- (train_filenames, train_labels))
- train_data = train_data.map(_parse_data).shuffle(1000).repeat().batch(4)
- val_data = tf.data.Dataset.from_tensor_slices(
- (val_filenames, val_labels))
- val_data = val_data.map(_parse_data).batch(1)
- iterator = tf.data.Iterator.from_structure(train_data.output_types,
- train_data.output_shapes)
- next_element = iterator.get_next()
- train_init_op = iterator.make_initializer(train_data)
- val_init_op = iterator.make_initializer(val_data)
- return next_element, train_init_op, val_init_op
Add Comment
Please, Sign In to add comment