Path: blob/master/tensorflow_tts/utils/strategy.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Strategy util functions"""15import tensorflow as tf161718def return_strategy():19physical_devices = tf.config.list_physical_devices("GPU")20if len(physical_devices) == 0:21return tf.distribute.OneDeviceStrategy(device="/cpu:0")22elif len(physical_devices) == 1:23return tf.distribute.OneDeviceStrategy(device="/gpu:0")24else:25return tf.distribute.MirroredStrategy()262728def calculate_3d_loss(y_gt, y_pred, loss_fn):29"""Calculate 3d loss, normally it's mel-spectrogram loss."""30y_gt_T = tf.shape(y_gt)[1]31y_pred_T = tf.shape(y_pred)[1]3233# there is a mismath length when training multiple GPU.34# we need slice the longer tensor to make sure the loss35# calculated correctly.36if y_gt_T > y_pred_T:37y_gt = tf.slice(y_gt, [0, 0, 0], [-1, y_pred_T, -1])38elif y_pred_T > y_gt_T:39y_pred = tf.slice(y_pred, [0, 0, 0], [-1, y_gt_T, -1])4041loss = loss_fn(y_gt, y_pred)42if isinstance(loss, tuple) is False:43loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]44else:45loss = list(loss)46for i in range(len(loss)):47loss[i] = tf.reduce_mean(48loss[i], list(range(1, len(loss[i].shape)))49) # shape = [B]50return loss515253def calculate_2d_loss(y_gt, y_pred, loss_fn):54"""Calculate 2d loss, normally it's durrations/f0s/energys loss."""55y_gt_T = tf.shape(y_gt)[1]56y_pred_T = tf.shape(y_pred)[1]5758# there is a mismath length when training multiple GPU.59# we need slice the longer tensor to make sure the loss60# calculated correctly.61if y_gt_T > y_pred_T:62y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T])63elif y_pred_T > y_gt_T:64y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T])6566loss = loss_fn(y_gt, y_pred)67if isinstance(loss, tuple) is False:68loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]69else:70loss = list(loss)71for i in range(len(loss)):72loss[i] = tf.reduce_mean(73loss[i], list(range(1, len(loss[i].shape)))74) # shape = [B]7576return loss777879