Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/optimizers/gradient_accumulate.py
1558 views
1
"""Gradient Accummlate for training TF2 custom training loop.
2
Copy from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py.
3
"""
4
5
6
import re
7
8
import tensorflow as tf
9
10
11
class GradientAccumulator(object):
12
"""Gradient accumulation utility.
13
When used with a distribution strategy, the accumulator should be called in a
14
replica context. Gradients will be accumulated locally on each replica and
15
without synchronization. Users should then call ``.gradients``, scale the
16
gradients if required, and pass the result to ``apply_gradients``.
17
"""
18
19
# We use the ON_READ synchronization policy so that no synchronization is
20
# performed on assignment. To get the value, we call .value() which returns the
21
# value on the current replica without synchronization.
22
23
def __init__(self):
24
"""Initializes the accumulator."""
25
self._gradients = []
26
self._accum_steps = None
27
28
@property
29
def step(self):
30
"""Number of accumulated steps."""
31
if self._accum_steps is None:
32
self._accum_steps = tf.Variable(
33
tf.constant(0, dtype=tf.int64),
34
trainable=False,
35
synchronization=tf.VariableSynchronization.ON_READ,
36
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
37
)
38
39
return self._accum_steps.value()
40
41
@property
42
def gradients(self):
43
"""The accumulated gradients on the current replica."""
44
if not self._gradients:
45
raise ValueError(
46
"The accumulator should be called first to initialize the gradients"
47
)
48
return list(
49
gradient.value() if gradient is not None else gradient
50
for gradient in self._gradients
51
)
52
53
def __call__(self, gradients):
54
"""Accumulates :obj:`gradients` on the current replica."""
55
if not self._gradients:
56
_ = self.step # Create the step variable.
57
self._gradients.extend(
58
[
59
tf.Variable(
60
tf.zeros_like(gradient),
61
trainable=False,
62
synchronization=tf.VariableSynchronization.ON_READ,
63
)
64
if gradient is not None
65
else gradient
66
for gradient in gradients
67
]
68
)
69
if len(gradients) != len(self._gradients):
70
raise ValueError(
71
"Expected %s gradients, but got %d"
72
% (len(self._gradients), len(gradients))
73
)
74
75
for accum_gradient, gradient in zip(self._gradients, gradients):
76
if accum_gradient is not None and gradient is not None:
77
accum_gradient.assign_add(gradient, read_value=False)
78
79
self._accum_steps.assign_add(1)
80
81
def reset(self):
82
"""Resets the accumulated gradients on the current replica."""
83
if not self._gradients:
84
return
85
self._accum_steps.assign(0)
86
for gradient in self._gradients:
87
if gradient is not None:
88
gradient.assign(tf.zeros_like(gradient), read_value=False)
89
90