Advertisement
Guest User

Untitled

a guest
Sep 15th, 2019
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.39 KB | None | 0 0
  1. def input_fn(filenames: [],
  2. labels: [],
  3. num_classes: int,
  4. batch_size: int,
  5. epochs: int,
  6. is_training: bool):
  7.  
  8. num_entries = len(filenames)
  9. assert num_entries == len(labels), 'Length of labels is not equal to image list'
  10.  
  11. load_fn = lambda f, l: load_img(f, l, input_size)
  12. one_hot_fn = lambda f, l: one_hot_encode(f, l, num_classes)
  13.  
  14. if is_training:
  15. dataset = (tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
  16. .shuffle(num_entries)
  17. .repeat(epochs)
  18. .map(load_fn, num_parallel_calls=4)
  19. .map(one_hot_fn, num_parallel_calls=4)
  20. .batch(batch_size)
  21. .prefetch(2)
  22. )
  23. else:
  24. dataset = (tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
  25. .map(load_fn, num_parallel_calls=4)
  26. .map(one_hot_fn, num_parallel_calls=4)
  27. .batch(batch_size)
  28. .prefetch(1)
  29. )
  30.  
  31. # Create re-initializable iterator from dataset
  32. iterator = dataset.make_initializable_iterator()
  33. iterator_init_op = iterator.initializer
  34. tf.keras.backend.get_session().run(iterator_init_op)
  35. images, labels = iterator.get_next()
  36. return images, labels, iterator_init_op
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement