Advertisement
Guest User

CNN visualization TF2

a guest
Dec 13th, 2019
1,254
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.09 KB | None | 0 0
  1. """
  2. #Visualization of the filters of VGG16, via gradient ascent in input space.
  3.  
  4. This script can run on CPU in a few minutes.
  5.  
  6. Results example: ![Visualization](http://i.imgur.com/4nj4KjN.jpg)
  7. """
  8. from __future__ import print_function
  9.  
  10. import time
  11. import numpy as np
  12. from PIL import Image as pil_image
  13. from tensorflow.keras.preprocessing.image import save_img
  14. from tensorflow.keras import layers, models
  15. from tensorflow.keras.applications import vgg16
  16. from tensorflow.keras import backend as K
  17. import tensorflow as tf
  18.  
  19.  
  20. def normalize(x):
  21.     """utility function to normalize a tensor.
  22.  
  23.    # Arguments
  24.        x: An input tensor.
  25.  
  26.    # Returns
  27.        The normalized input tensor.
  28.    """
  29.     return x / (K.sqrt(K.mean(K.square(x))) + K.epsilon())
  30.  
  31.  
  32. def deprocess_image(x):
  33.     """utility function to convert a float array into a valid uint8 image.
  34.  
  35.    # Arguments
  36.        x: A numpy-array representing the generated image.
  37.  
  38.    # Returns
  39.        A processed numpy-array, which could be used in e.g. imshow.
  40.    """
  41.     # normalize tensor: center on 0., ensure std is 0.25
  42.     x -= x.mean()
  43.     x /= (x.std() + K.epsilon())
  44.     x *= 0.25
  45.  
  46.     # clip to [0, 1]
  47.     x += 0.5
  48.     x = np.clip(x, 0, 1)
  49.  
  50.     # convert to RGB array
  51.     x *= 255
  52.     if K.image_data_format() == 'channels_first':
  53.         x = x.transpose((1, 2, 0))
  54.     x = np.clip(x, 0, 255).astype('uint8')
  55.     return x
  56.  
  57.  
  58. def process_image(x, former):
  59.     """utility function to convert a valid uint8 image back into a float array.
  60.       Reverses `deprocess_image`.
  61.  
  62.    # Arguments
  63.        x: A numpy-array, which could be used in e.g. imshow.
  64.        former: The former numpy-array.
  65.                Need to determine the former mean and variance.
  66.  
  67.    # Returns
  68.        A processed numpy-array representing the generated image.
  69.    """
  70.     if K.image_data_format() == 'channels_first':
  71.         x = x.transpose((2, 0, 1))
  72.     return (x / 255 - 0.5) * 4 * former.std() + former.mean()
  73.  
  74.  
  75. def visualize_layer(model,
  76.                     layer_name,
  77.                     step=1.,
  78.                     epochs=15,
  79.                     upscaling_steps=9,
  80.                     upscaling_factor=1.2,
  81.                     output_dim=(412, 412),
  82.                     filter_range=(0, None)):
  83.     """Visualizes the most relevant filters of one conv-layer in a certain model.
  84.  
  85.    # Arguments
  86.        model: The model containing layer_name.
  87.        layer_name: The name of the layer to be visualized.
  88.                    Has to be a part of model.
  89.        step: step size for gradient ascent.
  90.        epochs: Number of iterations for gradient ascent.
  91.        upscaling_steps: Number of upscaling steps.
  92.                         Starting image is in this case (80, 80).
  93.        upscaling_factor: Factor to which to slowly upgrade
  94.                          the image towards output_dim.
  95.        output_dim: [img_width, img_height] The output image dimensions.
  96.        filter_range: Tupel[lower, upper]
  97.                      Determines the to be computed filter numbers.
  98.                      If the second value is `None`,
  99.                      the last filter will be inferred as the upper boundary.
  100.    """
  101.  
  102.     def _generate_filter_image(input_img,
  103.                                layer_output,
  104.                                filter_index):
  105.         """Generates image for one particular filter.
  106.  
  107.        # Arguments
  108.            input_img: The input-image Tensor.
  109.            layer_output: The output-image Tensor.
  110.            filter_index: The to be processed filter number.
  111.                          Assumed to be valid.
  112.  
  113.        #Returns
  114.            Either None if no image could be generated.
  115.            or a tuple of the image (array) itself and the last loss.
  116.        """
  117.         s_time = time.time()
  118. #        input_img1 = np.random.random((1, 224, 224, 3))
  119. #        input_img1 = (input_img1 - 0.5) * 20 + 128.
  120. #        input_img = tf.Variable(tf.cast(input_img1, tf.float32))
  121.         # we build a loss function that maximizes the activation
  122.         # of the nth filter of the layer considered
  123.         with tf.GradientTape() as tape:
  124.             tape.watch(input_img)
  125.             outputs = submodel(input_img)
  126.             loss_value = tf.reduce_mean(outputs[:, :, :, filter_index])
  127.         grads = tape.gradient(loss_value, input_img)
  128. #        print("grads",grads)
  129. #        print("loss",loss_value)
  130. #        print("input", input_img1)
  131.         # normalization trick: we normalize the gradient
  132.         grads = normalize(grads)
  133.  
  134.         # this function returns the loss and grads given the input picture
  135.         iterate = K.function([input_img], [loss_value, grads])
  136.  
  137.         # we start from a gray image with some random noise
  138.         intermediate_dim = tuple(
  139.             int(x / (upscaling_factor ** upscaling_steps)) for x in output_dim)
  140.         if K.image_data_format() == 'channels_first':
  141.             input_img_data = np.random.random(
  142.                 (1, 3, intermediate_dim[0], intermediate_dim[1]))
  143.         else:
  144.             input_img_data = np.random.random(
  145.                 (1, intermediate_dim[0], intermediate_dim[1], 3))
  146.         input_img_data = (input_img_data - 0.5) * 20 + 128
  147.  
  148.         # Slowly upscaling towards the original size prevents
  149.         # a dominating high-frequency of the to visualized structure
  150.         # as it would occur if we directly compute the 412d-image.
  151.         # Behaves as a better starting point for each following dimension
  152.         # and therefore avoids poor local minima
  153.         for up in reversed(range(upscaling_steps)):
  154.             # we run gradient ascent for e.g. 20 steps
  155.             for _ in range(epochs):
  156.                 loss_value, grads_value = iterate([input_img_data])
  157.                 input_img_data += grads_value * step
  158.  
  159.                 # some filters get stuck to 0, we can skip them
  160.                 if loss_value <= K.epsilon():
  161.                     return None
  162.  
  163.             # Calculate upscaled dimension
  164.             intermediate_dim = tuple(
  165.                 int(x / (upscaling_factor ** up)) for x in output_dim)
  166.             # Upscale
  167.             img = deprocess_image(input_img_data[0])
  168.             img = np.array(pil_image.fromarray(img).resize(intermediate_dim,
  169.                                                            pil_image.BICUBIC))
  170.             input_img_data = np.expand_dims(
  171.                 process_image(img, input_img_data[0]), 0)
  172.  
  173.         # decode the resulting input image
  174.         img = deprocess_image(input_img_data[0])
  175.         e_time = time.time()
  176.         print('Costs of filter {:3}: {:5.0f} ( {:4.2f}s )'.format(filter_index,
  177.                                                                   loss_value,
  178.                                                                   e_time - s_time))
  179.         return img, loss_value
  180.  
  181.     def _draw_filters(filters, n=None):
  182.         """Draw the best filters in a nxn grid.
  183.  
  184.        # Arguments
  185.            filters: A List of generated images and their corresponding losses
  186.                     for each processed filter.
  187.            n: dimension of the grid.
  188.               If none, the largest possible square will be used
  189.        """
  190.         if n is None:
  191.             n = int(np.floor(np.sqrt(len(filters))))
  192.  
  193.         # the filters that have the highest loss are assumed to be better-looking.
  194.         # we will only keep the top n*n filters.
  195.         filters.sort(key=lambda x: x[1], reverse=True)
  196.         filters = filters[:n * n]
  197.  
  198.         # build a black picture with enough space for
  199.         # e.g. our 8 x 8 filters of size 412 x 412, with a 5px margin in between
  200.         MARGIN = 5
  201.         width = n * output_dim[0] + (n - 1) * MARGIN
  202.         height = n * output_dim[1] + (n - 1) * MARGIN
  203.         stitched_filters = np.zeros((width, height, 3), dtype='uint8')
  204.  
  205.         # fill the picture with our saved filters
  206.         for i in range(n):
  207.             for j in range(n):
  208.                 img, _ = filters[i * n + j]
  209.                 width_margin = (output_dim[0] + MARGIN) * i
  210.                 height_margin = (output_dim[1] + MARGIN) * j
  211.                 stitched_filters[
  212.                     width_margin: width_margin + output_dim[0],
  213.                     height_margin: height_margin + output_dim[1], :] = img
  214.  
  215.         # save the result to disk
  216.         save_img('vgg_{0:}_{1:}x{1:}.png'.format(layer_name, n), stitched_filters)
  217.  
  218.     # this is the placeholder for the input images
  219.     assert len(model.inputs) == 1
  220.     input_img = model.inputs[0]
  221.  
  222.     # get the symbolic outputs of each "key" layer (we gave them unique names).
  223.     layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])
  224.  
  225.     output_layer = layer_dict[layer_name]
  226.     assert isinstance(output_layer, layers.Conv2D)
  227.  
  228.     # Compute to be processed filter range
  229.     filter_lower = filter_range[0]
  230.     filter_upper = (filter_range[1]
  231.                     if filter_range[1] is not None
  232.                     else len(output_layer.get_weights()[1]))
  233.     assert(filter_lower >= 0
  234.            and filter_upper <= len(output_layer.get_weights()[1])
  235.            and filter_upper > filter_lower)
  236.     print('Compute filters {:} to {:}'.format(filter_lower, filter_upper))
  237.  
  238.     # iterate through each filter and generate its corresponding image
  239.     processed_filters = []
  240.     for f in range(filter_lower, filter_upper):
  241.         img_loss = _generate_filter_image(input_img, output_layer.output, f)
  242.  
  243.         if img_loss is not None:
  244.             processed_filters.append(img_loss)
  245.  
  246.     print('{} filter processed.'.format(len(processed_filters)))
  247.     # Finally draw and store the best filters to disk
  248.     _draw_filters(processed_filters)
  249.  
  250.  
  251. if __name__ == '__main__':
  252.     # the name of the layer we want to visualize
  253.     # (see model definition at keras/applications/vgg16.py)
  254.     LAYER_NAME = 'block5_conv1'
  255.  
  256.     # build the VGG16 network with ImageNet weights
  257.     vgg = vgg16.VGG16(weights='imagenet', include_top=False)
  258.     submodel = models.Model([vgg.inputs[0]], [vgg.get_layer(LAYER_NAME).output])
  259.     print('Model loaded.')
  260.     vgg.summary()
  261.  
  262.     # example function call
  263.     visualize_layer(vgg, LAYER_NAME, filter_range=(0,20))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement