Guest User

Untitled

a guest
Jul 19th, 2018
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.33 KB | None | 0 0
  1. """
  2. read and preprocess cifar100 data
  3. """
  4.  
  5. import tensorflow as tf
  6. import functools
  7. import pickle
  8. import numpy as np
  9. import os
  10.  
  11.  
  12. def data_augment(images_dataset, label_dataset):
  13. """Data argument for dataset, including: random
  14. ration, random crop, random flip, mean substraction
  15. and std diviation, transpose of matrix from D/W/H
  16. to W/H/D
  17. """
  18.  
  19. #reshape to 3, 32, 32
  20. reshape = functools.partial(tf.reshape, shape=[3, 32, 32])
  21. images_dataset = images_dataset.map(reshape)
  22.  
  23. #transpose to W/H/D format
  24. transpose = functools.partial(tf.transpose, perm=(1, 2, 0))
  25. images_dataset = images_dataset.map(transpose)
  26.  
  27. #random crop
  28. pad = functools.partial(tf.pad, paddings=tf.constant([[0, 0], [4, 4], [4, 4]]))
  29. images_dataset = images_dataset.map(pad)
  30. crop = functools.partial(tf.random_crop, size=[32, 32, 3])
  31. images_dataset = images_dataset.map(crop)
  32.  
  33. #random flip
  34. images_dataset = images_dataset.map(tf.image.random_flip_left_right)
  35.  
  36. #random rotation
  37. rotation = functools.partial(tf.contrib.image.rotate, angles=10)
  38. images_dataset = images_dataset.map(rotation)
  39.  
  40. #standard
  41. images_dataset = images_dataset.map(tf.image.per_image_standardization)
  42.  
  43. return images_dataset, label_dataset
  44.  
  45.  
  46. def cifar100_train(data_dir, batch_size):
  47. """Read and return cifar100 training dataset
  48.  
  49. Args:
  50. data_dir: cifar100 dataset path to cifar100
  51. batch_size: batch_size for cifar100 training
  52. dataset
  53.  
  54. Returns: an image dataset and a label dataset
  55. """
  56.  
  57. data_dir = os.path.join(data_dir, 'train')
  58. with open(data_dir, 'rb') as cifar100_train:
  59. cifar100 = pickle.load(cifar100_train, encoding='bytes')
  60.  
  61. images = tf.convert_to_tensor(cifar100['data'.encode()], dtype=tf.float32)
  62. labels = tf.convert_to_tensor(cifar100['fine_labels'.encode()], dtype=tf.int64)
  63.  
  64. images_dataset = tf.data.Dataset.from_tensor_slices(images)
  65. labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
  66.  
  67. images_dataset, labels_dataset = data_augment(images_dataset, labels_dataset)
  68.  
  69. #one-hot
  70. one_hot = functools.partial(tf.one_hot, depth=100)
  71. labels_dataset = labels_dataset.map(one_hot)
  72.  
  73. images_dataset = images_dataset.repeat().batch(batch_size)
  74. labels_dataset = labels_dataset.repeat().batch(batch_size)
  75.  
  76. return images_dataset, labels_dataset
Add Comment
Please, Sign In to add comment