Guest User

custom_loss_function

a guest
Apr 7th, 2019
320
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.48 KB | None | 0 0
  1. from keras import backend as K
  2. import numpy as np
  3. from skimage import morphology as morph
  4. from tqdm import tqdm
  5. from utils import custom_bce, watersplit
  6. from time import time
  7.  
  8.  
  9. def __loss(thres=0.5, fp_scale=1.):
  10.     def loss(y_true, y_pred):
  11.  
  12.         #import ipdb; ipdb.set_trace()
  13.  
  14.         points = K.squeeze(K.squeeze(y_true, axis=0), axis=2) #remove batch and channel dimension from y_true
  15.         predictions = K.sigmoid(K.squeeze(K.squeeze(y_pred, axis=0), axis=2)) #remove batch and channel and then sigmoid it to get predictions
  16.  
  17.         #assert points.shape == predictions.shape
  18.  
  19.         pred_mask = K.eval(predictions > thres).astype(int) #create the output mask with 1's and 0's
  20.         blobs = morph.label(pred_mask) #convert to blobs with unique IDs
  21.  
  22.         points_np = K.eval(points) #convert targets to numpy array
  23.         predictions_np = K.eval(predictions) #convert predictions to numpy array
  24.  
  25.         #blob_uniques will now contain points that are intersecting with points in the target annotations
  26.         #blob_counts will now contain number of occurances of every unique point in blob_counts
  27.         #NOTE:- blob_uniques WILL HAVE UNIQUE LABELLED POINTS ONLY (after morph.label). DO NOT THINK IT ONLY CONTAINS 0,1's LIKE A MORON!
  28.         blob_uniques, blob_counts = np.unique(blobs * (points_np), return_counts=True)
  29.  
  30.         #uniques will now contain points that are NOT intersecting with points in the target annotations (False positives)
  31.         uniques = np.delete(np.unique(blobs), blob_uniques)
  32.  
  33.         #-----------------------------------------IMAGE LEVEL LOSS-----------------------------------------
  34.         #This loss makes sure the model realizes that atleast one rebar *exists*
  35.  
  36.         image_level_loss = custom_bce(K.max(points), K.max(predictions))
  37.  
  38.         #-----------------------------------------POINT LEVEL LOSS-----------------------------------------
  39.         #This loss makes sure the model predicts the rebar's approx location in the image
  40.  
  41.         point_level_loss = K.mean(custom_bce(points, predictions))
  42.  
  43.         #-----------------------------------------FALSE POSITIVE LOSS--------------------------------------
  44.         #This loss penalizes the model for predicting false positives
  45.  
  46.         false_positive_mask = np.zeros(predictions_np.shape) #initialize a mask to filter only false positives from predictions
  47.  
  48.         for u in uniques: #iterate over false positive blob IDs
  49.             if u==0: #ignore if background
  50.                 continue
  51.             false_positive_mask += blobs == u #iteratively get the locations of false_positives
  52.  
  53.         assert (false_positive_mask <= 1).all() #make sure the blobs haven't ever intersect. (It shouldn't but still...¯\_(ツ)_/¯)
  54.  
  55.         false_positive_target = K.variable(1 - false_positive_mask) #The target to train against for false positives (Covert those locations to 0)
  56.         #Find the loss ignoring locations where target = 1
  57.         false_positive_loss = fp_scale * K.mean(custom_bce(false_positive_target, predictions, ignore_label=1))
  58.  
  59.         #-----------------------------------------SPLIT LOSS-----------------------------------------------
  60.         #This loss penalizes the model for predicting blobs with more than 1 point annotation to force the model to split predicted blobs
  61.  
  62.         T = np.zeros(predictions_np.shape)
  63.         scale_multi = 0. #Count of blobs having count > 2 will be added iteratively
  64.  
  65.         for i in range(len(blob_uniques)): #iterate over correctly predicted blobs
  66.             if blob_counts[i] < 2 or blob_uniques[i] == 0: #ignore if blobs count is 1 (blob has only one point)
  67.                 continue
  68.  
  69.             blob_ind = blobs == blob_uniques[i] #Working with a particular blob
  70.  
  71.             T += watersplit(predictions_np, points_np*blob_ind)*blob_ind #Find locations of boundaries inside that blob and add it to T (To get boundaries in the entire image)
  72.  
  73.             scale_multi += float(blob_counts[i]+1) #Add blob_counts for overall scale
  74.  
  75.         assert (T <= 1).all() #make sure no intersecting boundaries
  76.  
  77.         multi_blob_target = K.variable(1 - T) #1-T to convert the boundaries to background(0)
  78.         #Find the loss ignoring locations where target = 1
  79.         split_loss = scale_multi * K.mean(custom_bce(multi_blob_target, predictions, ignore_label=1))
  80.  
  81.         #-----------------------------------------GLOBAL SPLIT LOSS----------------------------------------
  82.         #This loss forces the model to make make as many split as can be.
  83.  
  84.         T = 1 - watersplit(predictions_np, points_np) #Find boundaries and set them to 0(background)
  85.         scale = float(points_np.sum())
  86.  
  87.         global_split_loss = scale * K.mean(custom_bce(T, predictions,ignore_label=1))
  88.  
  89.  
  90.         return image_level_loss+point_level_loss+false_positive_loss+split_loss+global_split_loss
  91.  
  92.     return loss
Add Comment
Please, Sign In to add comment