Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """
- read and preprocess cifar100 data
- """
- import tensorflow as tf
- import functools
- import pickle
- import numpy as np
- import os
- def data_augment(images_dataset, label_dataset):
- """Data argument for dataset, including: random
- ration, random crop, random flip, mean substraction
- and std diviation, transpose of matrix from D/W/H
- to W/H/D
- """
- #reshape to 3, 32, 32
- reshape = functools.partial(tf.reshape, shape=[3, 32, 32])
- images_dataset = images_dataset.map(reshape)
- #transpose to W/H/D format
- transpose = functools.partial(tf.transpose, perm=(1, 2, 0))
- images_dataset = images_dataset.map(transpose)
- #random crop
- pad = functools.partial(tf.pad, paddings=tf.constant([[0, 0], [4, 4], [4, 4]]))
- images_dataset = images_dataset.map(pad)
- crop = functools.partial(tf.random_crop, size=[32, 32, 3])
- images_dataset = images_dataset.map(crop)
- #random flip
- images_dataset = images_dataset.map(tf.image.random_flip_left_right)
- #random rotation
- rotation = functools.partial(tf.contrib.image.rotate, angles=10)
- images_dataset = images_dataset.map(rotation)
- #standard
- images_dataset = images_dataset.map(tf.image.per_image_standardization)
- return images_dataset, label_dataset
- def cifar100_train(data_dir, batch_size):
- """Read and return cifar100 training dataset
- Args:
- data_dir: cifar100 dataset path to cifar100
- batch_size: batch_size for cifar100 training
- dataset
- Returns: an image dataset and a label dataset
- """
- data_dir = os.path.join(data_dir, 'train')
- with open(data_dir, 'rb') as cifar100_train:
- cifar100 = pickle.load(cifar100_train, encoding='bytes')
- images = tf.convert_to_tensor(cifar100['data'.encode()], dtype=tf.float32)
- labels = tf.convert_to_tensor(cifar100['fine_labels'.encode()], dtype=tf.int64)
- images_dataset = tf.data.Dataset.from_tensor_slices(images)
- labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
- images_dataset, labels_dataset = data_augment(images_dataset, labels_dataset)
- #one-hot
- one_hot = functools.partial(tf.one_hot, depth=100)
- labels_dataset = labels_dataset.map(one_hot)
- images_dataset = images_dataset.repeat().batch(batch_size)
- labels_dataset = labels_dataset.repeat().batch(batch_size)
- return images_dataset, labels_dataset
Add Comment
Please, Sign In to add comment