Advertisement
Guest User

Untitled

a guest
Oct 18th, 2019
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.10 KB | None | 0 0
  1. def custom_generator(generator, train_gt, train_img_dir):
  2.     train_data = np.zeros((len(train_gt), 100, 100, 3))
  3.     train_labels = np.zeros((len(train_gt), 100, 100, 1)).astype(int)
  4.  
  5.     for idx, (img_name, img_points) in enumerate(train_gt.items()):
  6.         img = imread(join(train_img_dir, img_name), plugin="matplotlib")
  7.         points = img_points.reshape((14, 2))
  8.         new_img = resize(img, (100, 100, 3), anti_aliasing=False)
  9.         new_points = np.around(np.multiply(points, 100 / np.array([img.shape[1], img.shape[0]]))).astype(int)
  10.         new_mask = np.zeros((100, 100, 1)).astype(int)
  11.         for numer, point in enumerate(new_points[[1, 0], :]):
  12.             up_border = point[0] - 1 if point[0] - 1 >= 0 else 0
  13.             down_border = point[0] + 2 if point[0] + 2 <= new_mask.shape[0] else new_mask.shape[0]
  14.             left_border = point[1] - 1 if point[1] - 1 >= 0 else 0
  15.             right_border = point[1] + 2 if point[1] + 2 <= new_mask.shape[1] else new_mask.shape[1]
  16.             new_mask[up_border:down_border, left_border:right_border] = (numer + 1) * 10
  17.         plt.imshow(new_img + new_mask * 255)
  18.         plt.show()
  19.         break
  20.         train_data[idx, ...] = new_img
  21.         train_labels[idx, ...] = new_mask
  22.  
  23.     it = generator.flow(train_data, batch_size=64, seed=1337)
  24.     il = generator.flow(train_labels, batch_size=64, seed=1337)
  25.     while True:
  26.         batch = it.next()
  27.         masks = il.next()
  28.         ext_batch = []
  29.         ext_points = []
  30.         for idx, current in enumerate(masks):
  31.             points = np.zeros((14, 2))
  32.             try:
  33.                 for numer in range(14):
  34.                     indeces = np.argwhere(current == (numer + 1) * 10)
  35.                     if len(indeces) == 0:
  36.                         raise StopIteration
  37.                     points[numer] = np.around(np.array([np.mean(indeces[1, ...]), np.mean(indeces[0, ...])]))
  38.                 ext_batch.extend(batch[idx])
  39.                 ext_points.extend(points.reshape(-1))
  40.             except StopIteration:
  41.                 continue
  42.         yield np.array(ext_batch), np.array(ext_points)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement