Guest User

Untitled

a guest
Mar 24th, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.60 KB | None | 0 0
  1. from __future__ import print_function
  2. import tensorflow as tf
  3. import numpy as np
  4. from PIL import Image
  5. import os
  6. import glob
  7. import csv
  8. import time
  9.  
  10. class dataset_api(object):
  11. def __init__(self, n_examples, im_height, im_width, channels, n_episodes, mean_sub):
  12. self.n_examples = n_examples
  13. self.im_height = im_height
  14. self.im_width = im_width
  15. self.channels = channels
  16. self.n_episodes = n_episodes
  17. self.root_dir = '../data/miniImagenet'
  18.  
  19. self.mean_sub = mean_sub # if use mean substraction for preprocessing
  20.  
  21. self.mean_value = [120.15719937, 114.71930599, 102.78186757]
  22.  
  23. def subtract_mean(self, X):
  24. # per image mean subtraction
  25. # X: N_c x N x H x W x C
  26. '''
  27. N_c,N,H,W,C = np.shape(X)
  28. Xf = np.reshape(X, [-1,C])
  29. means = np.mean(Xf, axis=0, dtype=np.float64) # float64 is necessray
  30. print(means)
  31. '''
  32.  
  33. return X - self.mean_value
  34.  
  35.  
  36. def load_data(self, n_way, n_shot, n_query, stage='train', n_epochs=1000):
  37. """"
  38. Load data main func: use stage params to decide train, val or test
  39. """
  40. ## main dataset code from here
  41. classes = os.listdir(os.path.join(self.root_dir,'data', stage))
  42. n_classes = len(classes)
  43. print('Load {}, {} classes'.format(stage, n_classes))
  44.  
  45. all_eps = np.zeros((n_epochs*100, n_way), dtype=np.int32)
  46. for itr in range(n_epochs*100):
  47. epi_classes = np.random.permutation(n_classes)[:n_way]
  48. all_eps[itr] = epi_classes
  49.  
  50. dataset = tf.data.Dataset.from_tensor_slices(all_eps)
  51.  
  52. def _parse_py_function(ep_class):
  53. """ pyfunc to deal with dirs and img names """
  54. selected = [[]]*len(ep_class)
  55. for i in range(len(ep_class)):
  56. cls_name = classes[ep_class[i]]
  57. cls_pattern = self.root_dir + '/data/'+stage+'/' + cls_name +'/*jpg'
  58. imgs = glob.glob(cls_pattern)
  59.  
  60. np.random.shuffle(imgs) # in-place operation
  61. selected[i] = imgs[0:n_shot+n_query]
  62.  
  63. selected = np.array(selected)
  64. return selected
  65.  
  66. def precess_imgs(selected):
  67. """ tf operation to load and preprocess images """
  68. selected = tf.reshape(selected, [n_way,n_shot+n_query])
  69. out_imgs = []
  70.  
  71. for i in range(n_way):
  72. out_imgs.append([])
  73. for j in range(n_shot+n_query):
  74. fname = selected[i][j]
  75. image_string = tf.read_file(fname)
  76. image_decoded = tf.image.decode_jpeg(image_string)
  77. image_resized = tf.image.resize_images(image_decoded, [self.im_height, self.im_width],
  78. method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  79. if self.mean_sub:
  80. image_resized = self.subtract_mean(image_resized)
  81. out_imgs[i].append(image_resized)
  82.  
  83. labels = tf.tile(tf.expand_dims(tf.range(n_way),-1), (1, n_query))
  84. out_imgs = tf.convert_to_tensor(out_imgs)
  85.  
  86. return out_imgs, labels
  87.  
  88.  
  89. dataset = dataset.map(
  90. lambda e_class: tuple(tf.py_func(_parse_py_function, [e_class], [tf.string])))
  91.  
  92. dataset = dataset.map(precess_imgs)
  93.  
  94. iterator = dataset.make_one_shot_iterator()
  95. out_imgs, labels = iterator.get_next()
  96. support, query = tf.split(out_imgs, [n_shot,n_query], 1)
  97.  
  98. return support, query, labels
Add Comment
Please, Sign In to add comment