Path: blob/master/tensorflow_tts/optimizers/gradient_accumulate.py
1558 views
"""Gradient Accummlate for training TF2 custom training loop.1Copy from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py.2"""345import re67import tensorflow as tf8910class GradientAccumulator(object):11"""Gradient accumulation utility.12When used with a distribution strategy, the accumulator should be called in a13replica context. Gradients will be accumulated locally on each replica and14without synchronization. Users should then call ``.gradients``, scale the15gradients if required, and pass the result to ``apply_gradients``.16"""1718# We use the ON_READ synchronization policy so that no synchronization is19# performed on assignment. To get the value, we call .value() which returns the20# value on the current replica without synchronization.2122def __init__(self):23"""Initializes the accumulator."""24self._gradients = []25self._accum_steps = None2627@property28def step(self):29"""Number of accumulated steps."""30if self._accum_steps is None:31self._accum_steps = tf.Variable(32tf.constant(0, dtype=tf.int64),33trainable=False,34synchronization=tf.VariableSynchronization.ON_READ,35aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,36)3738return self._accum_steps.value()3940@property41def gradients(self):42"""The accumulated gradients on the current replica."""43if not self._gradients:44raise ValueError(45"The accumulator should be called first to initialize the gradients"46)47return list(48gradient.value() if gradient is not None else gradient49for gradient in self._gradients50)5152def __call__(self, gradients):53"""Accumulates :obj:`gradients` on the current replica."""54if not self._gradients:55_ = self.step # Create the step variable.56self._gradients.extend(57[58tf.Variable(59tf.zeros_like(gradient),60trainable=False,61synchronization=tf.VariableSynchronization.ON_READ,62)63if gradient is not None64else gradient65for gradient in gradients66]67)68if len(gradients) != len(self._gradients):69raise ValueError(70"Expected %s gradients, but got %d"71% (len(self._gradients), len(gradients))72)7374for accum_gradient, gradient in zip(self._gradients, gradients):75if accum_gradient is not None and gradient is not None:76accum_gradient.assign_add(gradient, read_value=False)7778self._accum_steps.assign_add(1)7980def reset(self):81"""Resets the accumulated gradients on the current replica."""82if not self._gradients:83return84self._accum_steps.assign(0)85for gradient in self._gradients:86if gradient is not None:87gradient.assign(tf.zeros_like(gradient), read_value=False)888990