Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def input_fn(filenames: [],
- labels: [],
- num_classes: int,
- batch_size: int,
- epochs: int,
- is_training: bool):
- num_entries = len(filenames)
- assert num_entries == len(labels), 'Length of labels is not equal to image list'
- load_fn = lambda f, l: load_img(f, l, input_size)
- one_hot_fn = lambda f, l: one_hot_encode(f, l, num_classes)
- if is_training:
- dataset = (tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
- .shuffle(num_entries)
- .repeat(epochs)
- .map(load_fn, num_parallel_calls=4)
- .map(one_hot_fn, num_parallel_calls=4)
- .batch(batch_size)
- .prefetch(2)
- )
- else:
- dataset = (tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
- .map(load_fn, num_parallel_calls=4)
- .map(one_hot_fn, num_parallel_calls=4)
- .batch(batch_size)
- .prefetch(1)
- )
- # Create re-initializable iterator from dataset
- iterator = dataset.make_initializable_iterator()
- iterator_init_op = iterator.initializer
- tf.keras.backend.get_session().run(iterator_init_op)
- images, labels = iterator.get_next()
- return images, labels, iterator_init_op
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement