Source code for ashpy.layers.instance_normalization

# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Instance Noramalization Layer implementation."""

from typing import Tuple

import tensorflow as tf


[docs]class InstanceNormalization(tf.keras.layers.Layer): r""" Instance Normalization Layer (used by Pix2Pix [1]_ and Pix2PixHD [2]_ ). Basically it's a normalization done at instance level. The implementation follows the basic implementation of the Batch Normalization Layer. * Direct Usage: .. testcode:: x = tf.ones((1, 10, 10, 64)) # instantiate attention layer as model. normalization = InstanceNormalization() # evaluate passing x. output = normalization(x) # the output shape is. # the same as the input shape. print(output.shape) * Inside a Model: .. testcode:: def MyModel(): inputs = tf.keras.layers.Input(shape=[None, None, 64]) normalization = InstanceNormalization() return tf.keras.Model(inputs=inputs, outputs=normalization(inputs)) x = tf.ones((1, 10, 10, 64)) model = MyModel() output = model(x) print(output.shape) .. testoutput:: (1, 10, 10, 64) .. [1] Image-to-Image Translation with Conditional Adversarial Networks https://arxiv.org/abs/1611.07004 .. [2] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs https://arxiv.org/abs/1711.11585 """
[docs] def __init__( self, eps: float = 1e-6, beta_initializer: str = "zeros", gamma_initializer: str = "ones", ) -> None: r""" Initialize the layer. Args: eps (float): Variance_epsilon used by batch_norm layer. beta_initializer (str): Initializer for the beta variable. gamma_initializer (str): Initializer for the gamma variable. """ super().__init__() self._eps = eps self.gamma = None self.beta = None self.beta_initializer = tf.keras.initializers.get(beta_initializer) self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
[docs] def build(self, input_shape: Tuple[int]) -> None: """ Assemble the layer. Args: input_shape (:obj:`tuple` of (:obj:`int`)): Specifies the shape of the input accepted by the layer. """ shape = (1, 1, 1, input_shape[-1]) self.gamma = self.add_weight( shape=shape, initializer=self.gamma_initializer, name="gamma", trainable=True, ) self.beta = self.add_weight( shape=shape, initializer=self.beta_initializer, name="beta", trainable=True )
[docs] def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: """ Perform the computation. Args: inputs (:py:class:`tf.Tensor`): Inputs for the computation. training (bool): Controls for training or evaluation mode. Returns: :py:class:`tf.Tensor`: Output Tensor. """ mean, variance = tf.nn.moments(inputs, axes=[1, 2], keepdims=True) return tf.nn.batch_normalization( inputs, mean=mean, variance=variance, offset=self.beta, scale=self.gamma, variance_epsilon=self._eps, )