Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import collections
- import logging
- import pickle
- from threading import Thread
- import cv2
- import numpy as np
- import torch
- # torch.multiprocessing.set_start_method('spawn', force="True")
- from tqdm import tqdm
- from t.model_create import get_model, get_loader, cuda
- class PredictThread(Thread):
- def __init__(self, device_ids, num_classes, queue, batch_size, workers, mask_size_threshold=1000, timeout=15,
- **kwargs):
- Thread.__init__(self, **kwargs)
- self.num_classes = num_classes
- self.device_ids = device_ids
- self.queue = queue
- self.batch_size = batch_size
- self.workers = workers
- self.mask_size_threshold = mask_size_threshold
- self.timeout = timeout
- self.model = get_model(num_classes, device_ids)
- self.interrupt = False
- def run(self):
- with torch.no_grad():
- while not self.interrupt:
- result = []
- try:
- result = [self.queue.get()]
- while len(result) < self.batch_size:
- result.append(self.queue.get(timeout=self.timeout))
- except Exception as error:
- logging.info("Writer: Timeout occurred {}".format(str(error)))
- self.interrupt = True
- ids, imgs, paths = self.save_imgs(result)
- for batch_num, (inputs, file_id, paths) in enumerate(
- tqdm(get_loader(ids, imgs, paths, self.batch_size, self.workers), desc='Predict')):
- inputs = cuda(inputs)
- r = self.model(inputs).cpu().numpy()
- self.process_n_save_masks(r, paths, ids, result)
- def save_imgs(self, r):
- paths = []
- ids = []
- imgs = []
- for i in r:
- imgs.append(i[2])
- paths.append(i[1])
- ids.append(i[0])
- return ids, imgs, paths
- def process_n_save_masks(self, batch, paths, ids, images):
- for i in range(len(batch)):
- as_mask = do_threshold(batch[i], threshold=0.5)
- l = as_mask[1, :, :]
- r = as_mask[2, :, :]
- w = as_mask[3, :, :]
- w[w > 0] = 2
- if np.count_nonzero(l) > self.mask_size_threshold:
- l_pad_only = get_n_rotate_rect(l)
- with open(paths[i] + 'sizes/' + str(ids[i]) + '.pickle', 'wb') as f:
- pickle.dump((pads_len(l_pad_only), 'l', str(ids[i])), f)
- save(ids[i], paths[i], l, w, images[i][2])
- if np.count_nonzero(r) > self.mask_size_threshold:
- r_pad_only = get_n_rotate_rect(r)
- with open(paths[i] + 'sizes/' + str(ids[i]) + '.pickle', 'wb') as f:
- pickle.dump((pads_len(r_pad_only), 'r', str(ids[i])), f)
- save(ids[i], paths[i], r, w, images[i][2])
- def save(id_, path, s, w, image):
- mask = np.stack((s, w), 2)
- mask = np.max(mask, 2)
- cv2.imwrite(path + 'masks/' + str(id_) + '.png', mask)
- cv2.imwrite(path + 'images/' + str(id_) + '.png', image)
- def pads_len(img):
- w = img.shape[1]
- h = img.shape[0]
- p = int(h / 1)
- x = collections.deque(p * [0], p)
- hq0, h4_q_0, h4_q_1 = [], int((h / 4) - (h / 50)), int((h / 4) + (h / 50))
- hq1, h34_q_0, h34_q_1 = [], int(((3 * h) / 4) - (h / 50)), int(((3 * h) / 4) + (h / 50))
- for i in range(h):
- sum = 0
- for j in range(w):
- v = img[i][j]
- if v > 0:
- sum += v
- x.append(sum)
- if h4_q_0 <= i <= h4_q_1:
- hq0.append(sum)
- if h34_q_0 <= i <= h34_q_1:
- hq1.append(sum)
- v = np.asarray([np.mean(x), h, np.mean(hq0),
- np.mean(hq1)]) # mean w, height, width first quarter +/- 2%, width third quarter +/- 2%
- # print(f'aver_len={v[0]}, h={v[1]}, hq0={v[2]}, hq1={v[3]}')
- return v
- def do_threshold(v, threshold=0.5):
- return (v >= threshold).astype(np.uint8)
- def get_n_rotate_rect(image):
- contours, hierarchy = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
- for contour in contours:
- rect = cv2.minAreaRect(contour)
- img_cropped = crop_min_area_rect(image, rect)
- return img_cropped
- def crop_min_area_rect(img, rect):
- angle = rect[2]
- rows, cols = img.shape[0], img.shape[1]
- M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
- img_rot = cv2.warpAffine(img, M, (cols, rows))
- box = cv2.boxPoints(rect)
- pts = np.int0(cv2.transform(np.array([box]), M))[0]
- pts[pts < 0] = 0
- img_crop = img_rot[pts[1][1]:pts[0][1],
- pts[1][0]:pts[2][0]]
- if img_crop.shape[0] < img_crop.shape[1]:
- img_crop = np.rot90(img_crop)
- return img_crop
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement