Path: blob/master/tensorflow_tts/optimizers/adamweightdecay.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""AdamW for training self-attention."""151617import re1819import tensorflow as tf202122class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):23"""Applys a warmup schedule on a given learning rate decay schedule."""2425def __init__(26self,27initial_learning_rate,28decay_schedule_fn,29warmup_steps,30power=1.0,31name=None,32):33super(WarmUp, self).__init__()34self.initial_learning_rate = initial_learning_rate35self.warmup_steps = warmup_steps36self.power = power37self.decay_schedule_fn = decay_schedule_fn38self.name = name3940def __call__(self, step):41with tf.name_scope(self.name or "WarmUp") as name:42# Implements polynomial warmup. i.e., if global_step < warmup_steps, the43# learning rate will be `global_step/num_warmup_steps * init_lr`.44global_step_float = tf.cast(step, tf.float32)45warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)46warmup_percent_done = global_step_float / warmup_steps_float47warmup_learning_rate = self.initial_learning_rate * tf.math.pow(48warmup_percent_done, self.power49)50return tf.cond(51global_step_float < warmup_steps_float,52lambda: warmup_learning_rate,53lambda: self.decay_schedule_fn(step),54name=name,55)5657def get_config(self):58return {59"initial_learning_rate": self.initial_learning_rate,60"decay_schedule_fn": self.decay_schedule_fn,61"warmup_steps": self.warmup_steps,62"power": self.power,63"name": self.name,64}656667class AdamWeightDecay(tf.keras.optimizers.Adam):68"""Adam enables L2 weight decay and clip_by_global_norm on gradients.69Just adding the square of the weights to the loss function is *not* the70correct way of using L2 regularization/weight decay with Adam, since that will71interact with the m and v parameters in strange ways.7273Instead we want ot decay the weights in a manner that doesn't interact with74the m/v parameters. This is equivalent to adding the square of the weights to75the loss with plain (non-momentum) SGD.76"""7778def __init__(79self,80learning_rate=0.001,81beta_1=0.9,82beta_2=0.999,83epsilon=1e-7,84amsgrad=False,85weight_decay_rate=0.0,86include_in_weight_decay=None,87exclude_from_weight_decay=None,88name="AdamWeightDecay",89**kwargs90):91super(AdamWeightDecay, self).__init__(92learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs93)94self.weight_decay_rate = weight_decay_rate95self._include_in_weight_decay = include_in_weight_decay96self._exclude_from_weight_decay = exclude_from_weight_decay9798@classmethod99def from_config(cls, config):100"""Creates an optimizer from its config with WarmUp custom object."""101custom_objects = {"WarmUp": WarmUp}102return super(AdamWeightDecay, cls).from_config(103config, custom_objects=custom_objects104)105106def _prepare_local(self, var_device, var_dtype, apply_state):107super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)108apply_state["weight_decay_rate"] = tf.constant(109self.weight_decay_rate, name="adam_weight_decay_rate"110)111112def _decay_weights_op(self, var, learning_rate, apply_state):113do_decay = self._do_use_weight_decay(var.name)114if do_decay:115return var.assign_sub(116learning_rate * var * apply_state["weight_decay_rate"],117use_locking=self._use_locking,118)119return tf.no_op()120121def apply_gradients(self, grads_and_vars, clip_norm=0.5, **kwargs):122grads, tvars = list(zip(*grads_and_vars))123(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)124return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), **kwargs)125126def _get_lr(self, var_device, var_dtype, apply_state):127"""Retrieves the learning rate with the given state."""128if apply_state is None:129return self._decayed_lr_t[var_dtype], {}130131apply_state = apply_state or {}132coefficients = apply_state.get((var_device, var_dtype))133if coefficients is None:134coefficients = self._fallback_apply_state(var_device, var_dtype)135apply_state[(var_device, var_dtype)] = coefficients136137return coefficients["lr_t"], dict(apply_state=apply_state)138139def _resource_apply_dense(self, grad, var, apply_state=None):140lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)141decay = self._decay_weights_op(var, lr_t, apply_state)142with tf.control_dependencies([decay]):143return super(AdamWeightDecay, self)._resource_apply_dense(144grad, var, **kwargs145)146147def _resource_apply_sparse(self, grad, var, indices, apply_state=None):148lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)149decay = self._decay_weights_op(var, lr_t, apply_state)150with tf.control_dependencies([decay]):151return super(AdamWeightDecay, self)._resource_apply_sparse(152grad, var, indices, **kwargs153)154155def get_config(self):156config = super(AdamWeightDecay, self).get_config()157config.update(158{"weight_decay_rate": self.weight_decay_rate,}159)160return config161162def _do_use_weight_decay(self, param_name):163"""Whether to use L2 weight decay for `param_name`."""164if self.weight_decay_rate == 0:165return False166167if self._include_in_weight_decay:168for r in self._include_in_weight_decay:169if re.search(r, param_name) is not None:170return True171172if self._exclude_from_weight_decay:173for r in self._exclude_from_weight_decay:174if re.search(r, param_name) is not None:175return False176return True177178179