Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/utils/strategy.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Strategy util functions"""
16
import tensorflow as tf
17
18
19
def return_strategy():
20
physical_devices = tf.config.list_physical_devices("GPU")
21
if len(physical_devices) == 0:
22
return tf.distribute.OneDeviceStrategy(device="/cpu:0")
23
elif len(physical_devices) == 1:
24
return tf.distribute.OneDeviceStrategy(device="/gpu:0")
25
else:
26
return tf.distribute.MirroredStrategy()
27
28
29
def calculate_3d_loss(y_gt, y_pred, loss_fn):
30
"""Calculate 3d loss, normally it's mel-spectrogram loss."""
31
y_gt_T = tf.shape(y_gt)[1]
32
y_pred_T = tf.shape(y_pred)[1]
33
34
# there is a mismath length when training multiple GPU.
35
# we need slice the longer tensor to make sure the loss
36
# calculated correctly.
37
if y_gt_T > y_pred_T:
38
y_gt = tf.slice(y_gt, [0, 0, 0], [-1, y_pred_T, -1])
39
elif y_pred_T > y_gt_T:
40
y_pred = tf.slice(y_pred, [0, 0, 0], [-1, y_gt_T, -1])
41
42
loss = loss_fn(y_gt, y_pred)
43
if isinstance(loss, tuple) is False:
44
loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
45
else:
46
loss = list(loss)
47
for i in range(len(loss)):
48
loss[i] = tf.reduce_mean(
49
loss[i], list(range(1, len(loss[i].shape)))
50
) # shape = [B]
51
return loss
52
53
54
def calculate_2d_loss(y_gt, y_pred, loss_fn):
55
"""Calculate 2d loss, normally it's durrations/f0s/energys loss."""
56
y_gt_T = tf.shape(y_gt)[1]
57
y_pred_T = tf.shape(y_pred)[1]
58
59
# there is a mismath length when training multiple GPU.
60
# we need slice the longer tensor to make sure the loss
61
# calculated correctly.
62
if y_gt_T > y_pred_T:
63
y_gt = tf.slice(y_gt, [0, 0], [-1, y_pred_T])
64
elif y_pred_T > y_gt_T:
65
y_pred = tf.slice(y_pred, [0, 0], [-1, y_gt_T])
66
67
loss = loss_fn(y_gt, y_pred)
68
if isinstance(loss, tuple) is False:
69
loss = tf.reduce_mean(loss, list(range(1, len(loss.shape)))) # shape = [B]
70
else:
71
loss = list(loss)
72
for i in range(len(loss)):
73
loss[i] = tf.reduce_mean(
74
loss[i], list(range(1, len(loss[i].shape)))
75
) # shape = [B]
76
77
return loss
78
79