rahools

GradCAM heatmap

May 28th, 2020
164
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.80 KB | None | 0 0
  1. def getHeatmapGradCAM(img, eps=1e-8, alpha = .3):
  2.     # load up graph and session
  3.     # global graph
  4.     # global sess
  5.  
  6.     # load image
  7.     originalImage = img.convert('RGB') # load up image
  8.     resizedImage = originalImage.resize((224, 224)) # resize image
  9.  
  10.     # preprocess and save rbg img arr
  11.     rgbArr = np.array(resizedImage) # arrayfy
  12.     rgbArr = rgbArr[np.newaxis, :, :, :] # batchify
  13.     rgbArrPreproc = tf.keras.applications.densenet.preprocess_input(rgbArr) # preproc densenet
  14.  
  15.     # save gray img arr
  16.     grayArr = np.array(resizedImage.convert('L')) # convert to gray arr
  17.  
  18.     # get final conv layer
  19.     finalLayerName = 'relu'
  20.  
  21.     with tf.device('cpu:0'):
  22.         with graph.as_default():
  23.             tf.compat.v1.keras.backend.set_session(sess)
  24.  
  25.             # get gradients
  26.             grads = tf.keras.backend.gradients(model.output[:, 0], model.get_layer(finalLayerName).output)[0]
  27.        
  28.             # make grad func
  29.             gradient_function = tf.keras.backend.function([model.input], [model.get_layer(finalLayerName).output, grads])
  30.  
  31.             # compute grads
  32.             output, grads_val = gradient_function([rgbArrPreproc])
  33.  
  34.     output, grads_val = output[0, :], grads_val[0, :, :, :]
  35.  
  36.     weights = np.mean(grads_val, axis=(0, 1))
  37.     cam = np.dot(output, weights)
  38.  
  39.     # resize and normalize
  40.     heatmap = cv2.resize(np.array(cam), (224, 224))
  41.     numer = heatmap - np.min(heatmap)
  42.     denom = (heatmap.max() - heatmap.min()) + eps
  43.     heatmap = numer / denom
  44.     heatmap = np.uint8(heatmap * 255)
  45.  
  46.     # visualize findings
  47.     # original image arr = grayArr
  48.     # heatmap image arr = heatmap
  49.     plt.figure(figsize = (8, 8))
  50.     plt.imshow(grayArr, cmap = 'gray')
  51.     plt.imshow(heatmap, cmap = 'jet', alpha = alpha)
  52.  
  53.     plt.savefig('heatmap.png')
Add Comment
Please, Sign In to add comment