daily pastebin goal
26%
SHARE
TWEET

Untitled

a guest Dec 18th, 2017 59 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy
  2.  
  3. import chainer
  4. from chainer import configuration
  5. from chainer import cuda
  6. from chainer import functions
  7. from chainer import initializers
  8. from chainer import link
  9. from chainer.utils import argument
  10. from chainer import variable
  11.  
  12.  
  13. class InstanceNormalization(link.Link):
  14.  
  15.     def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32,
  16.                  valid_test=False, use_gamma=True, use_beta=True,
  17.                  initial_gamma=None, initial_beta=None):
  18.         super(InstanceNormalization, self).__init__()
  19.         self.valid_test = valid_test
  20.         self.avg_mean = numpy.zeros(size, dtype=dtype)
  21.         self.avg_var = numpy.zeros(size, dtype=dtype)
  22.         self.N = 0
  23.         self.register_persistent('avg_mean')
  24.         self.register_persistent('avg_var')
  25.         self.register_persistent('N')
  26.         self.decay = decay
  27.         self.eps = eps
  28.  
  29.         with self.init_scope():
  30.             if use_gamma:
  31.                 if initial_gamma is None:
  32.                     initial_gamma = 1
  33.                 initial_gamma = initializers._get_initializer(initial_gamma)
  34.                 initial_gamma.dtype = dtype
  35.                 self.gamma = variable.Parameter(initial_gamma, size)
  36.             if use_beta:
  37.                 if initial_beta is None:
  38.                     initial_beta = 0
  39.                 initial_beta = initializers._get_initializer(initial_beta)
  40.                 initial_beta.dtype = dtype
  41.                 self.beta = variable.Parameter(initial_beta, size)
  42.  
  43.     def __call__(self, x, **kwargs):
  44.         """__call__(self, x, finetune=False)
  45.         Invokes the forward propagation of BatchNormalization.
  46.         In training mode, the BatchNormalization computes moving averages of
  47.         mean and variance for evaluation during training, and normalizes the
  48.         input using batch statistics.
  49.         .. warning::
  50.            ``test`` argument is not supported anymore since v2.
  51.            Instead, use ``chainer.using_config('train', False)``.
  52.            See :func:`chainer.using_config`.
  53.         Args:
  54.             x (Variable): Input variable.
  55.             finetune (bool): If it is in the training mode and ``finetune`` is
  56.                 ``True``, BatchNormalization runs in fine-tuning mode; it
  57.                 accumulates the input array to compute population statistics
  58.                 for normalization, and normalizes the input using batch
  59.                 statistics.
  60.         """
  61.         # check argument
  62.         argument.check_unexpected_kwargs(
  63.             kwargs, test='test argument is not supported anymore. '
  64.             'Use chainer.using_config')
  65.         finetune, = argument.parse_kwargs(kwargs, ('finetune', False))
  66.  
  67.         # reshape input x
  68.         original_shape = x.shape
  69.         batch_size, n_ch = original_shape[:2]
  70.         new_shape = (1, batch_size * n_ch) + original_shape[2:]
  71.         reshaped_x = functions.reshape(x, new_shape)
  72.  
  73.         if hasattr(self, 'gamma'):
  74.             gamma = self.gamma
  75.         else:
  76.             with cuda.get_device_from_id(self._device_id):
  77.                 gamma = variable.Variable(self.xp.ones(
  78.                     self.avg_mean.shape, dtype=x.dtype))
  79.         if hasattr(self, 'beta'):
  80.             beta = self.beta
  81.         else:
  82.             with cuda.get_device_from_id(self._device_id):
  83.                 beta = variable.Variable(self.xp.zeros(
  84.                     self.avg_mean.shape, dtype=x.dtype))
  85.  
  86.         mean = chainer.as_variable(self.xp.hstack([self.avg_mean] * batch_size))
  87.         var = chainer.as_variable(self.xp.hstack([self.avg_var] * batch_size))
  88.         gamma = chainer.as_variable(self.xp.hstack([gamma.array] * batch_size))
  89.         beta = chainer.as_variable(self.xp.hstack([beta.array] * batch_size))
  90.         if configuration.config.train:
  91.             if finetune:
  92.                 self.N += 1
  93.                 decay = 1. - 1. / self.N
  94.             else:
  95.                 decay = self.decay
  96.  
  97.             ret = functions.batch_normalization(
  98.                 reshaped_x, gamma, beta, eps=self.eps, running_mean=mean,
  99.                 running_var=var, decay=decay)
  100.         else:
  101.             # Use running average statistics or fine-tuned statistics.
  102.             ret = functions.fixed_batch_normalization(
  103.                 reshaped_x, gamma, beta, mean, var, self.eps)
  104.  
  105.         # ret is normalized input x
  106.         return functions.reshape(ret, original_shape)
  107.  
  108.  
  109. if __name__ == '__main__':
  110.     import numpy as np
  111.     base_shape = [10, 3]
  112.     with chainer.using_config('debug', True):
  113.         for i, n_element in enumerate([32, 32, 32]):
  114.             base_shape.append(n_element)
  115.             print('# {} th: input shape: {}'.format(i, base_shape))
  116.             x_array = np.random.normal(size=base_shape).astype(np.float32)
  117.             x = chainer.as_variable(x_array)
  118.             layer = InstanceNormalization(base_shape[1])
  119.             y = layer(x)
  120.             # calculate y_hat manually
  121.             axes = tuple(range(2, len(base_shape)))
  122.             x_mean = np.mean(x_array, axis=axes, keepdims=True)
  123.             x_var = np.var(x_array, axis=axes, keepdims=True) + 1e-5
  124.             x_std = np.sqrt(x_var)
  125.             y_hat = (x_array - x_mean) / x_std
  126.             diff = y.array - y_hat
  127.             print('*** diff ***')
  128.             print('\tmean: {:03f},\n\tstd: {:.03f}'.format(
  129.                 np.mean(diff), np.std(diff)))
  130.  
  131.         base_shape = [10, 3]
  132.         with chainer.using_config('train', False):
  133.             print('\n# test mode\n')
  134.             for i, n_element in enumerate([32, 32, 32]):
  135.                 base_shape.append(n_element)
  136.                 print('# {} th: input shape: {}'.format(i, base_shape))
  137.                 x_array = np.random.normal(size=base_shape).astype(np.float32)
  138.                 x = chainer.as_variable(x_array)
  139.                 layer = InstanceNormalization(base_shape[1])
  140.                 y = layer(x)
  141.                 axes = tuple(range(2, len(base_shape)))
  142.                 x_mean = np.mean(x_array, axis=axes, keepdims=True)
  143.                 x_var = np.var(x_array, axis=axes, keepdims=True) + 1e-5
  144.                 x_std = np.sqrt(x_var)
  145.                 y_hat = (x_array - x_mean) / x_std
  146.                 diff = y.array - y_hat
  147.                 print('*** diff ***')
  148.                 print('\tmean: {:03f},\n\tstd: {:.03f}'.format(
  149.                     np.mean(diff), np.std(diff)))
  150.  
  151.  
  152. """
  153. ○ → python instance_norm.py
  154. # 0 th: input shape: [10, 3, 32]
  155. *** diff ***
  156.         mean: -0.000000,
  157.         std: 0.000
  158. # 1 th: input shape: [10, 3, 32, 32]
  159. *** diff ***
  160.         mean: -0.000000,
  161.         std: 0.000
  162. # 2 th: input shape: [10, 3, 32, 32, 32]
  163. *** diff ***
  164.         mean: -0.000000,
  165.         std: 0.000
  166.        
  167. # test mode
  168.  
  169. # 0 th: input shape: [10, 3, 32]
  170. *** diff ***
  171.         mean: 14.126040,
  172.         std: 227.823
  173. # 1 th: input shape: [10, 3, 32, 32]
  174. *** diff ***
  175.         mean: -0.286635,
  176.         std: 221.926
  177. # 2 th: input shape: [10, 3, 32, 32, 32]
  178. *** diff ***
  179.         mean: -0.064297,
  180.         std: 222.492
  181. """
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top