Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from typing import List
- from . import options
- import tensorflow as tf
- import numpy as np
- import math
- def _relu(x: tf.Tensor) -> tf.Tensor:
- return tf.nn.leaky_relu(x, alpha=0.2)
- # def make_init(inp):
- # def init_function(shape, dtype=None):
- # print("=== Conv2d init ===")
- # print("Init Shape", shape)
- # print("Numpy Shape", inp.shape)
- # kernel = np.zeros(shape)
- # # print(dir(weights))
- # return kernel
- # return init_function
- def make_init(idx, fusion_layer):
- def init_function(shape, dtype=None):
- print("=== Conv2d init idx (%d) ===" % (idx))
- print("Init shape", shape)
- fuseLayer = fusion_layer.variables[idx]
- pretrained_shape = fuseLayer.shape
- print("Pretrained shape", pretrained_shape)
- # conv kernel
- if len(shape) == 4:
- kernel = fuseLayer
- # bias
- else:
- kernel = fuseLayer[:shape[0]]
- return kernel
- return init_function
- _NUMBER_OF_COLOR_CHANNELS = 3
- class Fusion(tf.keras.layers.Layer):
- """The decoder."""
- def __init__(self, name: str, config: options.Options, fusion_layer):
- super().__init__(name=name)
- # kernels, biases = self.getModelFusionData(model)
- # Each item 'convs[i]' will contain the list of convolutions to be applied
- # for pyramid level 'i'.
- self.convs: List[List[tf.keras.layers.Layer]] = []
- # Store the levels, so we can verify right number of levels in call().
- self.levels = config.fusion_pyramid_levels
- # Create the convolutions. Roughly following the feature extractor, we
- # double the number of filters when the resolution halves, but only up to
- # the specialized_levels, after which we use the same number of filters on
- # all levels.
- #
- # We create the convs in fine-to-coarse order, so that the array index
- # for the convs will correspond to our normal indexing (0=finest level).
- idx = 0
- for i in range(config.fusion_pyramid_levels - 1):
- m = config.specialized_levels
- k = config.filters
- num_filters = (k << i) if i < m else (k << m)
- convs: List[tf.keras.layers.Layer] = []
- convs.append(
- tf.keras.layers.Conv2D(
- filters=num_filters,
- kernel_size=[2, 2],
- padding='same',
- kernel_initializer = make_init(idx, fusion_layer),
- bias_initializer = make_init(idx+1, fusion_layer)))
- idx += 2
- convs.append(
- tf.keras.layers.Conv2D(
- filters=num_filters,
- kernel_size=[3, 3],
- padding='same',
- activation=_relu,
- kernel_initializer = make_init(idx, fusion_layer),
- bias_initializer = make_init(idx+1, fusion_layer)))
- idx += 2
- convs.append(
- tf.keras.layers.Conv2D(
- filters=num_filters,
- kernel_size=[3, 3],
- padding='same',
- activation=_relu,
- kernel_initializer = make_init(idx, fusion_layer),
- bias_initializer = make_init(idx+1, fusion_layer)))
- idx += 2
- self.convs.append(convs)
- # The final convolution that outputs RGB:
- self.output_conv = tf.keras.layers.Conv2D(
- filters=_NUMBER_OF_COLOR_CHANNELS, kernel_size=1,
- kernel_initializer = make_init(idx, fusion_layer),
- bias_initializer = make_init(idx+1, fusion_layer)
- )
- def call(self, pyramid: List[tf.Tensor]) -> tf.Tensor:
- """Runs the fusion module.
- Args:
- pyramid: The input feature pyramid as list of tensors. Each tensor being
- in (B x H x W x C) format, with finest level tensor first.
- Returns:
- A batch of RGB images.
- Raises:
- ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
- the constructor.
- """
- if len(pyramid) != self.levels:
- raise ValueError(
- 'Fusion called with different number of pyramid levels '
- f'{len(pyramid)} than it was configured for, {self.levels}.')
- # As a slight difference to a conventional decoder (e.g. U-net), we don't
- # apply any extra convolutions to the coarsest level, but just pass it
- # to finer levels for concatenation. This choice has not been thoroughly
- # evaluated, but is motivated by the educated guess that the fusion part
- # probably does not need large spatial context, because at this point the
- # features are spatially aligned by the preceding warp.
- print("== [FUSE CALL]== ")
- net = pyramid[-1]
- # Loop starting from the 2nd coarsest level:
- for i in reversed(range(0, self.levels - 1)):
- # Resize the tensor from coarser level to match for concatenation.
- level_size = tf.shape(pyramid[i])[1:3]
- net = tf.image.resize(net, level_size,
- tf.image.ResizeMethod.NEAREST_NEIGHBOR)
- # print("Level", i, "Net Shape", net.shape)
- net = self.convs[i][0](net)
- net = tf.concat([pyramid[i], net], axis=-1)
- net = self.convs[i][1](net)
- net = self.convs[i][2](net)
- net = self.output_conv(net)
- print("== [FUSE DONE] ==")
- return net
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement