Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_data(self):
- def _parse_function(xfilename,yfilename):
- W=960
- H=720
- image_size=vgg.vgg_16.default_image_size
- image_string = tf.read_file(xfilename)
- image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1)
- image = tf.cast(image_decoded, tf.float32)
- x=tf.image.resize(image,(image_size,image_size))
- image_string = tf.read_file(yfilename)
- image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1)
- y = tf.cast(image_decoded, tf.int32)
- equality = tf.equal(tf.reshape(y,[W, H, 1, 3] ), tf.reshape(classes, [n_class,3]))
- equality = tf.cast(tf.reduce_all(equality, axis=-1), tf.int32)
- return x,equality
- dataset=tf.data.Dataset.from_tensor_slices((self.xfilenames,self.yfilenames))
- dataset=dataset.map(_parse_function,num_parallel_calls=8)
- dataset=dataset.batch(self.batch_size)
- self.dataset=dataset
- iterator = tf.data.Iterator.from_structure(self.dataset.output_types,
- self.dataset.output_shapes)
- self.img, self.label = iterator.get_next()
- print("get_data label",self.label.shape)
- self.train_init = iterator.make_initializer(self.dataset) # initializer for train_data
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement