Advertisement
Guest User

Untitled

a guest
Nov 14th, 2018
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.56 KB | None | 0 0
  1. from __future__ import absolute_import
  2.  
  3. from keras import backend as K
  4. from keras.engine import Layer
  5. from keras.utils.generic_utils import get_custom_objects
  6. from keras.utils.conv_utils import normalize_data_format
  7.  
  8. if K.backend() == 'theano':
  9.     import nets.theano_backend as K_BACKEND
  10. else:
  11.     import nets.tensorflow_backend as K_BACKEND
  12.  
  13. class SubPixelUpscaling(Layer):
  14.     """ Sub-pixel convolutional upscaling layer based on the paper "Real-Time Single Image
  15.    and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network"
  16.    (https://arxiv.org/abs/1609.05158).
  17.    This layer requires a Convolution2D prior to it, having output filters computed according to
  18.    the formula :
  19.        filters = k * (scale_factor * scale_factor)
  20.        where k = a user defined number of filters (generally larger than 32)
  21.              scale_factor = the upscaling factor (generally 2)
  22.    This layer performs the depth to space operation on the convolution filters, and returns a
  23.    tensor with the size as defined below.
  24.    # Example :
  25.    ```python
  26.        # A standard subpixel upscaling block
  27.        x = Convolution2D(256, 3, 3, padding='same', activation='relu')(...)
  28.        u = SubPixelUpscaling(scale_factor=2)(x)
  29.        [Optional]
  30.        x = Convolution2D(256, 3, 3, padding='same', activation='relu')(u)
  31.    ```
  32.        In practice, it is useful to have a second convolution layer after the
  33.        SubPixelUpscaling layer to speed up the learning process.
  34.        However, if you are stacking multiple SubPixelUpscaling blocks, it may increase
  35.        the number of parameters greatly, so the Convolution layer after SubPixelUpscaling
  36.        layer can be removed.
  37.    # Arguments
  38.        scale_factor: Upscaling factor.
  39.        data_format: Can be None, 'channels_first' or 'channels_last'.
  40.    # Input shape
  41.        4D tensor with shape:
  42.        `(samples, k * (scale_factor * scale_factor) channels, rows, cols)` if data_format='channels_first'
  43.        or 4D tensor with shape:
  44.        `(samples, rows, cols, k * (scale_factor * scale_factor) channels)` if data_format='channels_last'.
  45.    # Output shape
  46.        4D tensor with shape:
  47.        `(samples, k channels, rows * scale_factor, cols * scale_factor))` if data_format='channels_first'
  48.        or 4D tensor with shape:
  49.        `(samples, rows * scale_factor, cols * scale_factor, k channels)` if data_format='channels_last'.
  50.    """
  51.  
  52.     def __init__(self, scale_factor=2, data_format=None, **kwargs):
  53.         super(SubPixelUpscaling, self).__init__(**kwargs)
  54.  
  55.         self.scale_factor = scale_factor
  56.         self.data_format = normalize_data_format(data_format)
  57.  
  58.     def build(self, input_shape):
  59.         pass
  60.  
  61.     def call(self, x, mask=None):
  62.         y = K_BACKEND.depth_to_space(x, self.scale_factor, self.data_format)
  63.         return y
  64.  
  65.     def compute_output_shape(self, input_shape):
  66.         if self.data_format == 'channels_first':
  67.             b, k, r, c = input_shape
  68.             return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor)
  69.         else:
  70.             b, r, c, k = input_shape
  71.             return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2))
  72.  
  73.     def get_config(self):
  74.         config = {'scale_factor': self.scale_factor,
  75.                   'data_format': self.data_format}
  76.         base_config = super(SubPixelUpscaling, self).get_config()
  77.         return dict(list(base_config.items()) + list(config.items()))
  78.  
  79.  
  80. get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling})
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement