Path: blob/master/tensorflow_tts/utils/weight_norm.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Weight Norm Modules."""1516import warnings1718import tensorflow as tf192021class WeightNormalization(tf.keras.layers.Wrapper):22"""Layer wrapper to decouple magnitude and direction of the layer's weights.23This wrapper reparameterizes a layer by decoupling the weight's24magnitude and direction. This speeds up convergence by improving the25conditioning of the optimization problem. It has an optional data-dependent26initialization scheme, in which initial values of weights are set as functions27of the first minibatch of data. Both the weight normalization and data-28dependent initialization are described in [Salimans and Kingma (2016)][1].29#### Example30```python31net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),32input_shape=(32, 32, 3), data_init=True)(x)33net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'),34data_init=True)35net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),36data_init=True)(net)37net = WeightNorm(tf.keras.layers.Dense(num_classes),38data_init=True)(net)39```40#### References41[1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple42Reparameterization to Accelerate Training of Deep Neural Networks. In43_30th Conference on Neural Information Processing Systems_, 2016.44https://arxiv.org/abs/1602.0786845"""4647def __init__(self, layer, data_init=True, **kwargs):48"""Initialize WeightNorm wrapper.49Args:50layer: A `tf.keras.layers.Layer` instance. Supported layer types are51`Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs52are not supported.53data_init: `bool`, if `True` use data dependent variable initialization.54**kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`.55Raises:56ValueError: If `layer` is not a `tf.keras.layers.Layer` instance.57"""58if not isinstance(layer, tf.keras.layers.Layer):59raise ValueError(60"Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` "61"instance. You passed: {input}".format(input=layer)62)6364layer_type = type(layer).__name__65if layer_type not in [66"Dense",67"Conv2D",68"Conv2DTranspose",69"Conv1D",70"GroupConv1D",71]:72warnings.warn(73"`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, "74"`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format(75layer_type76)77)7879super().__init__(layer, **kwargs)8081self.data_init = data_init82self._track_trackable(layer, name="layer")83self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -18485def _compute_weights(self):86"""Generate weights with normalization."""87# Determine the axis along which to expand `g` so that `g` broadcasts to88# the shape of `v`.89new_axis = -self.filter_axis - 39091self.layer.kernel = tf.nn.l2_normalize(92self.v, axis=self.kernel_norm_axes93) * tf.expand_dims(self.g, new_axis)9495def _init_norm(self):96"""Set the norm of the weight vector."""97kernel_norm = tf.sqrt(98tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes)99)100self.g.assign(kernel_norm)101102def _data_dep_init(self, inputs):103"""Data dependent initialization."""104# Normalize kernel first so that calling the layer calculates105# `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).106self._compute_weights()107108activation = self.layer.activation109self.layer.activation = None110111use_bias = self.layer.bias is not None112if use_bias:113bias = self.layer.bias114self.layer.bias = tf.zeros_like(bias)115116# Since the bias is initialized as zero, setting the activation to zero and117# calling the initialized layer (with normalized kernel) yields the correct118# computation ((5) in Salimans and Kingma (2016))119x_init = self.layer(inputs)120norm_axes_out = list(range(x_init.shape.rank - 1))121m_init, v_init = tf.nn.moments(x_init, norm_axes_out)122scale_init = 1.0 / tf.sqrt(v_init + 1e-10)123124self.g.assign(self.g * scale_init)125if use_bias:126self.layer.bias = bias127self.layer.bias.assign(-m_init * scale_init)128self.layer.activation = activation129130def build(self, input_shape=None):131"""Build `Layer`.132Args:133input_shape: The shape of the input to `self.layer`.134Raises:135ValueError: If `Layer` does not contain a `kernel` of weights136"""137if not self.layer.built:138self.layer.build(input_shape)139140if not hasattr(self.layer, "kernel"):141raise ValueError(142"`WeightNorm` must wrap a layer that"143" contains a `kernel` for weights"144)145146self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))147self.kernel_norm_axes.pop(self.filter_axis)148149self.v = self.layer.kernel150151# to avoid a duplicate `kernel` variable after `build` is called152self.layer.kernel = None153self.g = self.add_weight(154name="g",155shape=(int(self.v.shape[self.filter_axis]),),156initializer="ones",157dtype=self.v.dtype,158trainable=True,159)160self.initialized = self.add_weight(161name="initialized", dtype=tf.bool, trainable=False162)163self.initialized.assign(False)164165super().build()166167def call(self, inputs):168"""Call `Layer`."""169if not self.initialized:170if self.data_init:171self._data_dep_init(inputs)172else:173# initialize `g` as the norm of the initialized kernel174self._init_norm()175176self.initialized.assign(True)177178self._compute_weights()179output = self.layer(inputs)180return output181182def compute_output_shape(self, input_shape):183return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())184185186