Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from keras import backend as K
- import numpy as np
- from skimage import morphology as morph
- from tqdm import tqdm
- from utils import custom_bce, watersplit
- from time import time
- def __loss(thres=0.5, fp_scale=1.):
- def loss(y_true, y_pred):
- #import ipdb; ipdb.set_trace()
- points = K.squeeze(K.squeeze(y_true, axis=0), axis=2) #remove batch and channel dimension from y_true
- predictions = K.sigmoid(K.squeeze(K.squeeze(y_pred, axis=0), axis=2)) #remove batch and channel and then sigmoid it to get predictions
- #assert points.shape == predictions.shape
- pred_mask = K.eval(predictions > thres).astype(int) #create the output mask with 1's and 0's
- blobs = morph.label(pred_mask) #convert to blobs with unique IDs
- points_np = K.eval(points) #convert targets to numpy array
- predictions_np = K.eval(predictions) #convert predictions to numpy array
- #blob_uniques will now contain points that are intersecting with points in the target annotations
- #blob_counts will now contain number of occurances of every unique point in blob_counts
- #NOTE:- blob_uniques WILL HAVE UNIQUE LABELLED POINTS ONLY (after morph.label). DO NOT THINK IT ONLY CONTAINS 0,1's LIKE A MORON!
- blob_uniques, blob_counts = np.unique(blobs * (points_np), return_counts=True)
- #uniques will now contain points that are NOT intersecting with points in the target annotations (False positives)
- uniques = np.delete(np.unique(blobs), blob_uniques)
- #-----------------------------------------IMAGE LEVEL LOSS-----------------------------------------
- #This loss makes sure the model realizes that atleast one rebar *exists*
- image_level_loss = custom_bce(K.max(points), K.max(predictions))
- #-----------------------------------------POINT LEVEL LOSS-----------------------------------------
- #This loss makes sure the model predicts the rebar's approx location in the image
- point_level_loss = K.mean(custom_bce(points, predictions))
- #-----------------------------------------FALSE POSITIVE LOSS--------------------------------------
- #This loss penalizes the model for predicting false positives
- false_positive_mask = np.zeros(predictions_np.shape) #initialize a mask to filter only false positives from predictions
- for u in uniques: #iterate over false positive blob IDs
- if u==0: #ignore if background
- continue
- false_positive_mask += blobs == u #iteratively get the locations of false_positives
- assert (false_positive_mask <= 1).all() #make sure the blobs haven't ever intersect. (It shouldn't but still...¯\_(ツ)_/¯)
- false_positive_target = K.variable(1 - false_positive_mask) #The target to train against for false positives (Covert those locations to 0)
- #Find the loss ignoring locations where target = 1
- false_positive_loss = fp_scale * K.mean(custom_bce(false_positive_target, predictions, ignore_label=1))
- #-----------------------------------------SPLIT LOSS-----------------------------------------------
- #This loss penalizes the model for predicting blobs with more than 1 point annotation to force the model to split predicted blobs
- T = np.zeros(predictions_np.shape)
- scale_multi = 0. #Count of blobs having count > 2 will be added iteratively
- for i in range(len(blob_uniques)): #iterate over correctly predicted blobs
- if blob_counts[i] < 2 or blob_uniques[i] == 0: #ignore if blobs count is 1 (blob has only one point)
- continue
- blob_ind = blobs == blob_uniques[i] #Working with a particular blob
- T += watersplit(predictions_np, points_np*blob_ind)*blob_ind #Find locations of boundaries inside that blob and add it to T (To get boundaries in the entire image)
- scale_multi += float(blob_counts[i]+1) #Add blob_counts for overall scale
- assert (T <= 1).all() #make sure no intersecting boundaries
- multi_blob_target = K.variable(1 - T) #1-T to convert the boundaries to background(0)
- #Find the loss ignoring locations where target = 1
- split_loss = scale_multi * K.mean(custom_bce(multi_blob_target, predictions, ignore_label=1))
- #-----------------------------------------GLOBAL SPLIT LOSS----------------------------------------
- #This loss forces the model to make make as many split as can be.
- T = 1 - watersplit(predictions_np, points_np) #Find boundaries and set them to 0(background)
- scale = float(points_np.sum())
- global_split_loss = scale * K.mean(custom_bce(T, predictions,ignore_label=1))
- return image_level_loss+point_level_loss+false_positive_loss+split_loss+global_split_loss
- return loss
Add Comment
Please, Sign In to add comment