Path: blob/master/tensorflow_tts/utils/group_conv.py
1558 views
# -*- coding: utf-8 -*-1# This code is copy from https://github.com/tensorflow/tensorflow/pull/36773.2"""Group Convolution Modules."""34from tensorflow.python.framework import tensor_shape5from tensorflow.python.keras import activations, constraints, initializers, regularizers6from tensorflow.python.keras.engine.base_layer import Layer7from tensorflow.python.keras.engine.input_spec import InputSpec8from tensorflow.python.keras.layers import Conv1D, SeparableConv1D9from tensorflow.python.keras.utils import conv_utils10from tensorflow.python.ops import array_ops, nn, nn_ops111213class Convolution(object):14"""Helper class for convolution.15Note that this class assumes that shapes of input and filter passed to16__call__ are compatible with input_shape and filter_shape passed to the17constructor.18Arguments19input_shape: static shape of input. i.e. input.get_shape().20filter_shape: static shape of the filter. i.e. filter.get_shape().21padding: see convolution.22strides: see convolution.23dilation_rate: see convolution.24name: see convolution.25data_format: see convolution.26"""2728def __init__(29self,30input_shape,31filter_shape,32padding,33strides=None,34dilation_rate=None,35name=None,36data_format=None,37):38"""Helper function for convolution."""39num_total_dims = filter_shape.ndims40if num_total_dims is None:41num_total_dims = input_shape.ndims42if num_total_dims is None:43raise ValueError("rank of input or filter must be known")4445num_spatial_dims = num_total_dims - 24647try:48input_shape.with_rank(num_spatial_dims + 2)49except ValueError:50raise ValueError("input tensor must have rank %d" % (num_spatial_dims + 2))5152try:53filter_shape.with_rank(num_spatial_dims + 2)54except ValueError:55raise ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))5657if data_format is None or not data_format.startswith("NC"):58input_channels_dim = tensor_shape.dimension_at_index(59input_shape, num_spatial_dims + 160)61spatial_dims = range(1, num_spatial_dims + 1)62else:63input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1)64spatial_dims = range(2, num_spatial_dims + 2)6566filter_dim = tensor_shape.dimension_at_index(filter_shape, num_spatial_dims)67if not (input_channels_dim % filter_dim).is_compatible_with(0):68raise ValueError(69"number of input channels is not divisible by corresponding "70"dimension of filter, {} % {} != 0".format(71input_channels_dim, filter_dim72)73)7475strides, dilation_rate = nn_ops._get_strides_and_dilation_rate(76num_spatial_dims, strides, dilation_rate77)7879self.input_shape = input_shape80self.filter_shape = filter_shape81self.data_format = data_format82self.strides = strides83self.padding = padding84self.name = name85self.dilation_rate = dilation_rate86self.conv_op = nn_ops._WithSpaceToBatch(87input_shape,88dilation_rate=dilation_rate,89padding=padding,90build_op=self._build_op,91filter_shape=filter_shape,92spatial_dims=spatial_dims,93data_format=data_format,94)9596def _build_op(self, _, padding):97return nn_ops._NonAtrousConvolution(98self.input_shape,99filter_shape=self.filter_shape,100padding=padding,101data_format=self.data_format,102strides=self.strides,103name=self.name,104)105106def __call__(self, inp, filter):107return self.conv_op(inp, filter)108109110class Conv(Layer):111"""Abstract N-D convolution layer (private, used as implementation base).112This layer creates a convolution kernel that is convolved113(actually cross-correlated) with the layer input to produce a tensor of114outputs. If `use_bias` is True (and a `bias_initializer` is provided),115a bias vector is created and added to the outputs. Finally, if116`activation` is not `None`, it is applied to the outputs as well.117Note: layer attributes cannot be modified after the layer has been called118once (except the `trainable` attribute).119Arguments:120rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.121filters: Integer, the dimensionality of the output space (i.e. the number122of filters in the convolution).123kernel_size: An integer or tuple/list of n integers, specifying the124length of the convolution window.125strides: An integer or tuple/list of n integers,126specifying the stride length of the convolution.127Specifying any stride value != 1 is incompatible with specifying128any `dilation_rate` value != 1.129padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive).130data_format: A string, one of `channels_last` (default) or `channels_first`.131The ordering of the dimensions in the inputs.132`channels_last` corresponds to inputs with shape133`(batch_size, ..., channels)` while `channels_first` corresponds to134inputs with shape `(batch_size, channels, ...)`.135dilation_rate: An integer or tuple/list of n integers, specifying136the dilation rate to use for dilated convolution.137Currently, specifying any `dilation_rate` value != 1 is138incompatible with specifying any `strides` value != 1.139groups: Integer, the number of channel groups controlling the connections140between inputs and outputs. Input channels and `filters` must both be141divisible by `groups`. For example,142- At `groups=1`, all inputs are convolved to all outputs.143- At `groups=2`, the operation becomes equivalent to having two144convolutional layers side by side, each seeing half the input145channels, and producing half the output channels, and both146subsequently concatenated.147- At `groups=input_channels`, each input channel is convolved with its148own set of filters, of size `input_channels / filters`149activation: Activation function to use.150If you don't specify anything, no activation is applied.151use_bias: Boolean, whether the layer uses a bias.152kernel_initializer: An initializer for the convolution kernel.153bias_initializer: An initializer for the bias vector. If None, the default154initializer will be used.155kernel_regularizer: Optional regularizer for the convolution kernel.156bias_regularizer: Optional regularizer for the bias vector.157activity_regularizer: Optional regularizer function for the output.158kernel_constraint: Optional projection function to be applied to the159kernel after being updated by an `Optimizer` (e.g. used to implement160norm constraints or value constraints for layer weights). The function161must take as input the unprojected variable and must return the162projected variable (which must have the same shape). Constraints are163not safe to use when doing asynchronous distributed training.164bias_constraint: Optional projection function to be applied to the165bias after being updated by an `Optimizer`.166trainable: Boolean, if `True` the weights of this layer will be marked as167trainable (and listed in `layer.trainable_weights`).168name: A string, the name of the layer.169"""170171def __init__(172self,173rank,174filters,175kernel_size,176strides=1,177padding="valid",178data_format=None,179dilation_rate=1,180groups=1,181activation=None,182use_bias=True,183kernel_initializer="glorot_uniform",184bias_initializer="zeros",185kernel_regularizer=None,186bias_regularizer=None,187activity_regularizer=None,188kernel_constraint=None,189bias_constraint=None,190trainable=True,191name=None,192**kwargs193):194super(Conv, self).__init__(195trainable=trainable,196name=name,197activity_regularizer=regularizers.get(activity_regularizer),198**kwargs199)200self.rank = rank201if filters is not None and not isinstance(filters, int):202filters = int(filters)203self.filters = filters204self.groups = groups or 1205if filters is not None and filters % self.groups != 0:206raise ValueError(207"The number of filters must be evenly divisible by the number of "208"groups. Received: groups={}, filters={}".format(groups, filters)209)210self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size")211if not all(self.kernel_size):212raise ValueError(213"The argument `kernel_size` cannot contain 0(s). "214"Received: %s" % (kernel_size,)215)216self.strides = conv_utils.normalize_tuple(strides, rank, "strides")217self.padding = conv_utils.normalize_padding(padding)218if self.padding == "causal" and not isinstance(self, (Conv1D, SeparableConv1D)):219raise ValueError(220"Causal padding is only supported for `Conv1D`"221"and ``SeparableConv1D`."222)223self.data_format = conv_utils.normalize_data_format(data_format)224self.dilation_rate = conv_utils.normalize_tuple(225dilation_rate, rank, "dilation_rate"226)227self.activation = activations.get(activation)228self.use_bias = use_bias229self.kernel_initializer = initializers.get(kernel_initializer)230self.bias_initializer = initializers.get(bias_initializer)231self.kernel_regularizer = regularizers.get(kernel_regularizer)232self.bias_regularizer = regularizers.get(bias_regularizer)233self.kernel_constraint = constraints.get(kernel_constraint)234self.bias_constraint = constraints.get(bias_constraint)235self.input_spec = InputSpec(ndim=self.rank + 2)236237def build(self, input_shape):238input_shape = tensor_shape.TensorShape(input_shape)239input_channel = self._get_input_channel(input_shape)240if input_channel % self.groups != 0:241raise ValueError(242"The number of input channels must be evenly divisible by the number "243"of groups. Received groups={}, but the input has {} channels "244"(full input shape is {}).".format(245self.groups, input_channel, input_shape246)247)248kernel_shape = self.kernel_size + (input_channel // self.groups, self.filters)249250self.kernel = self.add_weight(251name="kernel",252shape=kernel_shape,253initializer=self.kernel_initializer,254regularizer=self.kernel_regularizer,255constraint=self.kernel_constraint,256trainable=True,257dtype=self.dtype,258)259if self.use_bias:260self.bias = self.add_weight(261name="bias",262shape=(self.filters,),263initializer=self.bias_initializer,264regularizer=self.bias_regularizer,265constraint=self.bias_constraint,266trainable=True,267dtype=self.dtype,268)269else:270self.bias = None271channel_axis = self._get_channel_axis()272self.input_spec = InputSpec(273ndim=self.rank + 2, axes={channel_axis: input_channel}274)275276self._build_conv_op_input_shape = input_shape277self._build_input_channel = input_channel278self._padding_op = self._get_padding_op()279self._conv_op_data_format = conv_utils.convert_data_format(280self.data_format, self.rank + 2281)282self._convolution_op = Convolution(283input_shape,284filter_shape=self.kernel.shape,285dilation_rate=self.dilation_rate,286strides=self.strides,287padding=self._padding_op,288data_format=self._conv_op_data_format,289)290self.built = True291292def call(self, inputs):293if self._recreate_conv_op(inputs):294self._convolution_op = Convolution(295inputs.get_shape(),296filter_shape=self.kernel.shape,297dilation_rate=self.dilation_rate,298strides=self.strides,299padding=self._padding_op,300data_format=self._conv_op_data_format,301)302self._build_conv_op_input_shape = inputs.get_shape()303304# Apply causal padding to inputs for Conv1D.305if self.padding == "causal" and self.__class__.__name__ == "Conv1D":306inputs = array_ops.pad(inputs, self._compute_causal_padding())307308outputs = self._convolution_op(inputs, self.kernel)309310if self.use_bias:311if self.data_format == "channels_first":312if self.rank == 1:313# nn.bias_add does not accept a 1D input tensor.314bias = array_ops.reshape(self.bias, (1, self.filters, 1))315outputs += bias316else:317outputs = nn.bias_add(outputs, self.bias, data_format="NCHW")318else:319outputs = nn.bias_add(outputs, self.bias, data_format="NHWC")320321if self.activation is not None:322return self.activation(outputs)323return outputs324325def compute_output_shape(self, input_shape):326input_shape = tensor_shape.TensorShape(input_shape).as_list()327if self.data_format == "channels_last":328space = input_shape[1:-1]329new_space = []330for i in range(len(space)):331new_dim = conv_utils.conv_output_length(332space[i],333self.kernel_size[i],334padding=self.padding,335stride=self.strides[i],336dilation=self.dilation_rate[i],337)338new_space.append(new_dim)339return tensor_shape.TensorShape(340[input_shape[0]] + new_space + [self.filters]341)342else:343space = input_shape[2:]344new_space = []345for i in range(len(space)):346new_dim = conv_utils.conv_output_length(347space[i],348self.kernel_size[i],349padding=self.padding,350stride=self.strides[i],351dilation=self.dilation_rate[i],352)353new_space.append(new_dim)354return tensor_shape.TensorShape([input_shape[0], self.filters] + new_space)355356def get_config(self):357config = {358"filters": self.filters,359"kernel_size": self.kernel_size,360"strides": self.strides,361"padding": self.padding,362"data_format": self.data_format,363"dilation_rate": self.dilation_rate,364"groups": self.groups,365"activation": activations.serialize(self.activation),366"use_bias": self.use_bias,367"kernel_initializer": initializers.serialize(self.kernel_initializer),368"bias_initializer": initializers.serialize(self.bias_initializer),369"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),370"bias_regularizer": regularizers.serialize(self.bias_regularizer),371"activity_regularizer": regularizers.serialize(self.activity_regularizer),372"kernel_constraint": constraints.serialize(self.kernel_constraint),373"bias_constraint": constraints.serialize(self.bias_constraint),374}375base_config = super(Conv, self).get_config()376return dict(list(base_config.items()) + list(config.items()))377378def _compute_causal_padding(self):379"""Calculates padding for 'causal' option for 1-d conv layers."""380left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)381if self.data_format == "channels_last":382causal_padding = [[0, 0], [left_pad, 0], [0, 0]]383else:384causal_padding = [[0, 0], [0, 0], [left_pad, 0]]385return causal_padding386387def _get_channel_axis(self):388if self.data_format == "channels_first":389return 1390else:391return -1392393def _get_input_channel(self, input_shape):394channel_axis = self._get_channel_axis()395if input_shape.dims[channel_axis].value is None:396raise ValueError(397"The channel dimension of the inputs "398"should be defined. Found `None`."399)400return int(input_shape[channel_axis])401402def _get_padding_op(self):403if self.padding == "causal":404op_padding = "valid"405else:406op_padding = self.padding407if not isinstance(op_padding, (list, tuple)):408op_padding = op_padding.upper()409return op_padding410411def _recreate_conv_op(self, inputs):412"""Recreate conv_op if necessary.413Check if the input_shape in call() is different from that in build().414For the values that are not None, if they are different, recreate415the _convolution_op to avoid the stateful behavior.416Args:417inputs: The input data to call() method.418Returns:419`True` or `False` to indicate whether to recreate the conv_op.420"""421call_input_shape = inputs.get_shape()422for axis in range(1, len(call_input_shape)):423if (424call_input_shape[axis] is not None425and self._build_conv_op_input_shape[axis] is not None426and call_input_shape[axis] != self._build_conv_op_input_shape[axis]427):428return True429return False430431432class GroupConv1D(Conv):433"""1D convolution layer (e.g. temporal convolution).434This layer creates a convolution kernel that is convolved435with the layer input over a single spatial (or temporal) dimension436to produce a tensor of outputs.437If `use_bias` is True, a bias vector is created and added to the outputs.438Finally, if `activation` is not `None`,439it is applied to the outputs as well.440When using this layer as the first layer in a model,441provide an `input_shape` argument442(tuple of integers or `None`, e.g.443`(10, 128)` for sequences of 10 vectors of 128-dimensional vectors,444or `(None, 128)` for variable-length sequences of 128-dimensional vectors.445Examples:446>>> # The inputs are 128-length vectors with 10 timesteps, and the batch size447>>> # is 4.448>>> input_shape = (4, 10, 128)449>>> x = tf.random.normal(input_shape)450>>> y = tf.keras.layers.Conv1D(451... 32, 3, activation='relu',input_shape=input_shape)(x)452>>> print(y.shape)453(4, 8, 32)454Arguments:455filters: Integer, the dimensionality of the output space456(i.e. the number of output filters in the convolution).457kernel_size: An integer or tuple/list of a single integer,458specifying the length of the 1D convolution window.459strides: An integer or tuple/list of a single integer,460specifying the stride length of the convolution.461Specifying any stride value != 1 is incompatible with specifying462any `dilation_rate` value != 1.463padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).464`"causal"` results in causal (dilated) convolutions, e.g. `output[t]`465does not depend on `input[t+1:]`. Useful when modeling temporal data466where the model should not violate the temporal order.467See [WaveNet: A Generative Model for Raw Audio, section4682.1](https://arxiv.org/abs/1609.03499).469data_format: A string,470one of `channels_last` (default) or `channels_first`.471groups: Integer, the number of channel groups controlling the connections472between inputs and outputs. Input channels and `filters` must both be473divisible by `groups`. For example,474- At `groups=1`, all inputs are convolved to all outputs.475- At `groups=2`, the operation becomes equivalent to having two476convolutional layers side by side, each seeing half the input477channels, and producing half the output channels, and both478subsequently concatenated.479- At `groups=input_channels`, each input channel is convolved with its480own set of filters, of size `input_channels / filters`481dilation_rate: an integer or tuple/list of a single integer, specifying482the dilation rate to use for dilated convolution.483Currently, specifying any `dilation_rate` value != 1 is484incompatible with specifying any `strides` value != 1.485activation: Activation function to use.486If you don't specify anything, no activation is applied (487see `keras.activations`).488use_bias: Boolean, whether the layer uses a bias vector.489kernel_initializer: Initializer for the `kernel` weights matrix (490see `keras.initializers`).491bias_initializer: Initializer for the bias vector (492see `keras.initializers`).493kernel_regularizer: Regularizer function applied to494the `kernel` weights matrix (see `keras.regularizers`).495bias_regularizer: Regularizer function applied to the bias vector (496see `keras.regularizers`).497activity_regularizer: Regularizer function applied to498the output of the layer (its "activation") (499see `keras.regularizers`).500kernel_constraint: Constraint function applied to the kernel matrix (501see `keras.constraints`).502bias_constraint: Constraint function applied to the bias vector (503see `keras.constraints`).504Input shape:5053D tensor with shape: `(batch_size, steps, input_dim)`506Output shape:5073D tensor with shape: `(batch_size, new_steps, filters)`508`steps` value might have changed due to padding or strides.509Returns:510A tensor of rank 3 representing511`activation(conv1d(inputs, kernel) + bias)`.512Raises:513ValueError: when both `strides` > 1 and `dilation_rate` > 1.514"""515516def __init__(517self,518filters,519kernel_size,520strides=1,521padding="valid",522data_format="channels_last",523dilation_rate=1,524groups=1,525activation=None,526use_bias=True,527kernel_initializer="glorot_uniform",528bias_initializer="zeros",529kernel_regularizer=None,530bias_regularizer=None,531activity_regularizer=None,532kernel_constraint=None,533bias_constraint=None,534**kwargs535):536super().__init__(537rank=1,538filters=filters,539kernel_size=kernel_size,540strides=strides,541padding=padding,542data_format=data_format,543dilation_rate=dilation_rate,544groups=groups,545activation=activations.get(activation),546use_bias=use_bias,547kernel_initializer=initializers.get(kernel_initializer),548bias_initializer=initializers.get(bias_initializer),549kernel_regularizer=regularizers.get(kernel_regularizer),550bias_regularizer=regularizers.get(bias_regularizer),551activity_regularizer=regularizers.get(activity_regularizer),552kernel_constraint=constraints.get(kernel_constraint),553bias_constraint=constraints.get(bias_constraint),554**kwargs555)556557558