Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import os
- def parse_fn(example):
- example_fmt = {
- 'image/encoded': tf.FixedLenFeature((), tf.string, ""),
- 'image/format': tf.FixedLenFeature((), tf.string, ""),
- 'image/class/label': tf.FixedLenFeature((), tf.int64, -1),
- 'image/height': tf.FixedLenFeature((), tf.int64, -1),
- 'image/width': tf.FixedLenFeature((), tf.int64, -1)
- }
- parsed = tf.parse_single_example(example, example_fmt)
- image = tf.image.decode_image(parsed['image/encoded'])
- label = parsed['image/class/label']
- return image, label
- files = tf.data.Dataset.list_files(os.path.join(".\\datasets\\cifar-10-2\\train.tfrecords"))
- dataset = files.interleave(tf.data.TFRecordDataset, 1) # cycle_length = num files
- dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=parse_fn, batch_size=128, drop_remainder=True))
- dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=128, count=None))
- dataset = dataset.prefetch(buffer_size=1)
- itr = dataset.make_one_shot_iterator().get_next()
- print(itr)
- with tf.Session('') as s:
- for el in itr:
- print(s.run(el))
- break
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement