SHARE
TWEET

DataGenerator

a guest Oct 18th, 2019 94 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class DataGenerator(tf.keras.utils.Sequence):
  2.     #'Generates data for Keras'
  3.    
  4.     def __init__(self,
  5.                  x = None,
  6.                  y = None,
  7.                  batch_size = 20,
  8.                  target_shape = (100, 100, 3),
  9.                  sample_std = False,
  10.                  feature_std = False,
  11.                  proj_parameters = None,
  12.                  blur_parameters = None,
  13.                  nois_parameters = None,
  14.                  flip_parameters = None,
  15.                  gamm_parameters = None):
  16.        
  17.         #'Initialization'
  18.         self.x = np.array(x)
  19.         self.y = np.array(y)
  20.         self.batch_size = batch_size
  21.         self.target_shape = target_shape
  22.         self.sample_std = sample_std
  23.         self.feature_std = feature_std
  24.         self.proj_parameters = proj_parameters
  25.         self.blur_parameters = blur_parameters
  26.         self.nois_parameters = nois_parameters
  27.         self.flip_parameters = flip_parameters
  28.         self.gamm_parameters = gamm_parameters
  29.  
  30.     def __len__(self):
  31.         #'Denotes the number of batches per epoch'
  32.         return len(self.x) // self.batch_size
  33.  
  34.     def __getitem__(self, index):
  35.         #'Generate one batch of data'
  36.         # Generate indexes of the batch
  37.         indexes = np.random.randint(len(self.x), size = self.batch_size)
  38.  
  39.         # Generate data
  40.         x, y = self.__data_generation(self.x[indexes], self.y[indexes])
  41.  
  42.         return x, y
  43.  
  44.     def on_epoch_end(self):
  45.         pass
  46.  
  47.     def __data_generation(self, imgs, key_points):
  48.         x = []
  49.         y = []
  50.  
  51.         # Generate data
  52.         for img, points in zip(imgs, key_points):
  53.             img, points = self.__preprocess_input(np.copy(img), np.copy(points))
  54.             x += [img]
  55.             y += [points]            
  56.            
  57.         x = np.array(x)
  58.         if self.sample_std:
  59.             batch_x = self.smap_standardize(x)
  60.                
  61.         y = np.array(y)
  62.         return x, y
  63.            
  64.     def __preprocess_input(self, image, key_points):
  65.        
  66.         if self.proj_parameters is not None:
  67.             image, key_points = self.__projection(image, key_points, proj_parameters)#0.05, 0.35, 25, 0.1)
  68.        
  69.         if image.shape != self.target_shape:
  70.             image, key_points = self.__resize(image, key_points, self.target_shape)
  71.        
  72.         if self.flip_parameters is not None:
  73.             image, key_points = self.__flip(image, key_points, self.flip_parameters)#0.5)
  74.         if self.gamm_parameters is not None:
  75.             image = self.__gamma(image, self.gamm_parameters)#1/2, 1.7)
  76.         if self.blur_parameters is not None:
  77.             image = self.__gaussian_blur(image, self.blur_parameters)#1.8)
  78.         if self.nois_parameters is not None:
  79.             image = self.__gaussian_noise(image, self.nois_parameters)#0, 0.06)
  80.         if self.feature_std:
  81.             image = self.__feat_standardize(image)
  82.         return image, key_points
  83.    
  84.     def __gaussian_noise(self, image, params):
  85.         mean, sigma = params
  86.         image = image + np.random.normal(mean * random(), sigma * random(), image.shape)
  87.         image = np.clip(image, 0.0, 1.0)
  88.         return image
  89.    
  90.     def __gaussian_blur(self, image, params):
  91.         sigma = params
  92.         image = skimage.filters.gaussian(image, sigma = sigma * random(), multichannel = True)
  93.         return image
  94.        
  95.     def __gamma(self, image, params):
  96.         min_gamma, max_gamma = params
  97.         d = max_gamma - min_gamma
  98.         image = skimage.exposure.adjust_gamma(image, min_gamma + d * random())
  99.         return image
  100.    
  101.     def __flip(self, image, key_points, params):
  102.         probability = params
  103.         if probability < random():
  104.             image, key_points = mirr(image, key_points)
  105.         return image, key_points
  106.    
  107.     def __projection(self, image, key_points, params):
  108.         low, high, max_angle, alpha = params
  109.         #print(key_points)
  110.         try_count = 200
  111.         for i in range(try_count):
  112.             height, width = image.shape[0], image.shape[1]
  113.             corners = np.array([[0, 0],
  114.                                 [height, 0],
  115.                                 [height, width],
  116.                                 [0, width]])
  117.             d_h = low * height + high * height
  118.             d_w = low * width + high * width
  119.             s_h = -high * height
  120.             s_w = -high * width
  121.             edges =   np.array([[         s_h + d_h * random(),         s_w + d_w * random()],
  122.                                 [height - s_h - d_h * random(),         s_w + d_w * random()],
  123.                                 [height - s_h - d_h * random(), width - s_w - d_w * random()],
  124.                                 [         s_h + d_h * random(), width - s_w - d_w * random()]])
  125.             theta = np.radians(max_angle * random())
  126.             if i == try_count - 1:
  127.                 edges = np.copy(corners)
  128.                 theta = 0
  129.                
  130.             halflen = corners.shape[0]
  131.        
  132.             ax = np.array([-corners[..., 0], -corners[..., 1], -np.ones(halflen),
  133.                            np.zeros(halflen), np.zeros(halflen), np.zeros(halflen),
  134.                            edges[..., 0] * corners[..., 0], edges[..., 0] * corners[..., 1], edges[..., 0]]).T
  135.  
  136.             ay = np.array([np.zeros(halflen), np.zeros(halflen), np.zeros(halflen),
  137.                            -corners[..., 0], -corners[..., 1], -np.ones(halflen),
  138.                            edges[..., 1] * corners[..., 0], edges[..., 1] * corners[..., 1], edges[..., 1]]).T
  139.             #print(ax.shape)
  140.             A = np.zeros((halflen * 2, 9))
  141.             A[::2, ...] = ax
  142.             A[1::2, ...] = ay
  143.  
  144.             u, s, v = np.linalg.svd(A)
  145.  
  146.             H = v[-1, ...].reshape((3,3))[(1, 0, 2), :][:, (1, 0, 2)]
  147.             #print(H)
  148.            
  149.             c, s = np.cos(theta), np.sin(theta)
  150.             R = np.array(((c,-s,0), (s, c,0), (0,0,1)))
  151.            
  152.             #print(H)
  153.             H = H @ R
  154.             #print(H)
  155.             if np.linalg.det(H) == 0:
  156.                 continue
  157.            
  158.             points = np.zeros((key_points.shape[0] // 2, 2))
  159.        
  160.             points[..., 0] = key_points[::2]
  161.             points[..., 1] = key_points[1::2]
  162.             test_points = np.zeros(key_points.shape)
  163.             #print(key_points)
  164.             points = skimage.transform.ProjectiveTransform(H)(points)
  165.             test_points[::2]  = points[..., 0]
  166.             test_points[1::2] = points[..., 1]
  167.             #print(key_points)
  168.             low_keys = np.ones(test_points.shape)
  169.             high_keys = np.ones(test_points.shape)
  170.             low_keys[::2] = alpha * height * low_keys[::2]
  171.             low_keys[1::2] = alpha * width * low_keys[1::2]
  172.             high_keys[::2] = (1 - alpha) * height * high_keys[::2]
  173.             high_keys[1::2] = (1 - alpha) * width * high_keys[1::2]
  174.             #print(low,high,test_points, np.greater(key_points, low).all(), np.less(key_points, high).all())
  175.             if np.greater(test_points, low_keys).all() and np.less(test_points, high_keys).all():
  176.                 break
  177.        
  178.         key_points = test_points
  179.         inv = np.linalg.inv(H)
  180.         transf = skimage.transform.ProjectiveTransform(inv)
  181.         image = skimage.transform.warp(image, transf, output_shape = image.shape, mode='edge')
  182.         return image, key_points
  183.    
  184.     def __resize(self, image, key_points, shape):
  185.         h, w = image.shape[1], image.shape[0]
  186.         image = skimage.transform.resize(image, shape)
  187.         points = np.zeros(key_points.shape, dtype=np.int64)
  188.         points[ ::2] = (key_points[ ::2] / h * 100).round()
  189.         points[1::2] = (key_points[1::2] / w * 100).round()
  190.         return image, points
  191.    
  192.     def fit(self, imgs, passes):
  193.         imgs = np.array(imgs)
  194.         #i = np.copy(imgs)
  195.        
  196.         for i in range(passes):
  197.             for j, img in enumerate(imgs):      
  198.                 imgs[j], _ = self.preprocess_input(img, np.zeros(28))
  199.                
  200.             imgs = np.array(imgs)
  201.             self.std = np.std(imgs, axis = 0)
  202.             self.mean = np.mean(imgs, axis = 0)
  203.            
  204.             for j, img in enumerate(imgs):
  205.                 imgs[j] -= (self.mean)
  206.                 imgs[j] /= (self.std + 1e-6)
  207.                 low = np.min(imgs[j])
  208.                 high = np.max(imgs[j] - low)
  209.                 imgs[j] = (imgs[j] - low) / high
  210.                 #imgs[j] = np.clip(imgs[j], 0.0, 1.0)
  211.            
  212.         self.fited = True
  213.        
  214.     def feat_standardize(self, image):
  215.         if self.fited:
  216.             #print("lel")
  217.             image -= self.mean
  218.             image /= (self.std + 1e-6)
  219.             low = np.min(image)
  220.             high = np.max(image - low)
  221.             image = (image - low) / high #np.clip(image, 0.0, 1.0)
  222.            
  223.         return image
  224.    
  225.     def smap_standardize(self, images):
  226.         images -= np.mean(images, keepdims=True)
  227.         images /= (np.std(images, keepdims=True) + 1e-6)
  228.         return images
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top