Advertisement
Guest User

Untitled

a guest
Apr 22nd, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.69 KB | None | 0 0
  1. import collections
  2. import logging
  3. import pickle
  4. from threading import Thread
  5.  
  6. import cv2
  7. import numpy as np
  8. import torch
  9. # torch.multiprocessing.set_start_method('spawn', force="True")
  10. from tqdm import tqdm
  11.  
  12. from t.model_create import get_model, get_loader, cuda
  13.  
  14.  
  15. class PredictThread(Thread):
  16. def __init__(self, device_ids, num_classes, queue, batch_size, workers, mask_size_threshold=1000, timeout=15,
  17. **kwargs):
  18. Thread.__init__(self, **kwargs)
  19.  
  20. self.num_classes = num_classes
  21. self.device_ids = device_ids
  22. self.queue = queue
  23. self.batch_size = batch_size
  24. self.workers = workers
  25. self.mask_size_threshold = mask_size_threshold
  26. self.timeout = timeout
  27.  
  28. self.model = get_model(num_classes, device_ids)
  29. self.interrupt = False
  30.  
  31. def run(self):
  32. with torch.no_grad():
  33. while not self.interrupt:
  34. result = []
  35. try:
  36. result = [self.queue.get()]
  37. while len(result) < self.batch_size:
  38. result.append(self.queue.get(timeout=self.timeout))
  39. except Exception as error:
  40. logging.info("Writer: Timeout occurred {}".format(str(error)))
  41. self.interrupt = True
  42.  
  43. ids, imgs, paths = self.save_imgs(result)
  44.  
  45. for batch_num, (inputs, file_id, paths) in enumerate(
  46. tqdm(get_loader(ids, imgs, paths, self.batch_size, self.workers), desc='Predict')):
  47. inputs = cuda(inputs)
  48. r = self.model(inputs).cpu().numpy()
  49. self.process_n_save_masks(r, paths, ids, result)
  50.  
  51. def save_imgs(self, r):
  52. paths = []
  53. ids = []
  54. imgs = []
  55. for i in r:
  56. imgs.append(i[2])
  57. paths.append(i[1])
  58. ids.append(i[0])
  59. return ids, imgs, paths
  60.  
  61. def process_n_save_masks(self, batch, paths, ids, images):
  62. for i in range(len(batch)):
  63. as_mask = do_threshold(batch[i], threshold=0.5)
  64.  
  65. l = as_mask[1, :, :]
  66. r = as_mask[2, :, :]
  67. w = as_mask[3, :, :]
  68. w[w > 0] = 2
  69.  
  70. if np.count_nonzero(l) > self.mask_size_threshold:
  71. l_pad_only = get_n_rotate_rect(l)
  72. with open(paths[i] + 'sizes/' + str(ids[i]) + '.pickle', 'wb') as f:
  73. pickle.dump((pads_len(l_pad_only), 'l', str(ids[i])), f)
  74. save(ids[i], paths[i], l, w, images[i][2])
  75.  
  76. if np.count_nonzero(r) > self.mask_size_threshold:
  77. r_pad_only = get_n_rotate_rect(r)
  78. with open(paths[i] + 'sizes/' + str(ids[i]) + '.pickle', 'wb') as f:
  79. pickle.dump((pads_len(r_pad_only), 'r', str(ids[i])), f)
  80. save(ids[i], paths[i], r, w, images[i][2])
  81.  
  82.  
  83. def save(id_, path, s, w, image):
  84. mask = np.stack((s, w), 2)
  85. mask = np.max(mask, 2)
  86. cv2.imwrite(path + 'masks/' + str(id_) + '.png', mask)
  87. cv2.imwrite(path + 'images/' + str(id_) + '.png', image)
  88.  
  89.  
  90. def pads_len(img):
  91. w = img.shape[1]
  92. h = img.shape[0]
  93.  
  94. p = int(h / 1)
  95. x = collections.deque(p * [0], p)
  96.  
  97. hq0, h4_q_0, h4_q_1 = [], int((h / 4) - (h / 50)), int((h / 4) + (h / 50))
  98. hq1, h34_q_0, h34_q_1 = [], int(((3 * h) / 4) - (h / 50)), int(((3 * h) / 4) + (h / 50))
  99.  
  100. for i in range(h):
  101. sum = 0
  102. for j in range(w):
  103. v = img[i][j]
  104. if v > 0:
  105. sum += v
  106.  
  107. x.append(sum)
  108. if h4_q_0 <= i <= h4_q_1:
  109. hq0.append(sum)
  110. if h34_q_0 <= i <= h34_q_1:
  111. hq1.append(sum)
  112.  
  113. v = np.asarray([np.mean(x), h, np.mean(hq0),
  114. np.mean(hq1)]) # mean w, height, width first quarter +/- 2%, width third quarter +/- 2%
  115. # print(f'aver_len={v[0]}, h={v[1]}, hq0={v[2]}, hq1={v[3]}')
  116. return v
  117.  
  118.  
  119. def do_threshold(v, threshold=0.5):
  120. return (v >= threshold).astype(np.uint8)
  121.  
  122.  
  123. def get_n_rotate_rect(image):
  124. contours, hierarchy = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  125. for contour in contours:
  126. rect = cv2.minAreaRect(contour)
  127. img_cropped = crop_min_area_rect(image, rect)
  128. return img_cropped
  129.  
  130.  
  131. def crop_min_area_rect(img, rect):
  132. angle = rect[2]
  133. rows, cols = img.shape[0], img.shape[1]
  134. M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
  135. img_rot = cv2.warpAffine(img, M, (cols, rows))
  136.  
  137. box = cv2.boxPoints(rect)
  138. pts = np.int0(cv2.transform(np.array([box]), M))[0]
  139. pts[pts < 0] = 0
  140.  
  141. img_crop = img_rot[pts[1][1]:pts[0][1],
  142. pts[1][0]:pts[2][0]]
  143.  
  144. if img_crop.shape[0] < img_crop.shape[1]:
  145. img_crop = np.rot90(img_crop)
  146.  
  147. return img_crop
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement