Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/optimizers/adamweightdecay.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""AdamW for training self-attention."""
16
17
18
import re
19
20
import tensorflow as tf
21
22
23
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
24
"""Applys a warmup schedule on a given learning rate decay schedule."""
25
26
def __init__(
27
self,
28
initial_learning_rate,
29
decay_schedule_fn,
30
warmup_steps,
31
power=1.0,
32
name=None,
33
):
34
super(WarmUp, self).__init__()
35
self.initial_learning_rate = initial_learning_rate
36
self.warmup_steps = warmup_steps
37
self.power = power
38
self.decay_schedule_fn = decay_schedule_fn
39
self.name = name
40
41
def __call__(self, step):
42
with tf.name_scope(self.name or "WarmUp") as name:
43
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
44
# learning rate will be `global_step/num_warmup_steps * init_lr`.
45
global_step_float = tf.cast(step, tf.float32)
46
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
47
warmup_percent_done = global_step_float / warmup_steps_float
48
warmup_learning_rate = self.initial_learning_rate * tf.math.pow(
49
warmup_percent_done, self.power
50
)
51
return tf.cond(
52
global_step_float < warmup_steps_float,
53
lambda: warmup_learning_rate,
54
lambda: self.decay_schedule_fn(step),
55
name=name,
56
)
57
58
def get_config(self):
59
return {
60
"initial_learning_rate": self.initial_learning_rate,
61
"decay_schedule_fn": self.decay_schedule_fn,
62
"warmup_steps": self.warmup_steps,
63
"power": self.power,
64
"name": self.name,
65
}
66
67
68
class AdamWeightDecay(tf.keras.optimizers.Adam):
69
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
70
Just adding the square of the weights to the loss function is *not* the
71
correct way of using L2 regularization/weight decay with Adam, since that will
72
interact with the m and v parameters in strange ways.
73
74
Instead we want ot decay the weights in a manner that doesn't interact with
75
the m/v parameters. This is equivalent to adding the square of the weights to
76
the loss with plain (non-momentum) SGD.
77
"""
78
79
def __init__(
80
self,
81
learning_rate=0.001,
82
beta_1=0.9,
83
beta_2=0.999,
84
epsilon=1e-7,
85
amsgrad=False,
86
weight_decay_rate=0.0,
87
include_in_weight_decay=None,
88
exclude_from_weight_decay=None,
89
name="AdamWeightDecay",
90
**kwargs
91
):
92
super(AdamWeightDecay, self).__init__(
93
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs
94
)
95
self.weight_decay_rate = weight_decay_rate
96
self._include_in_weight_decay = include_in_weight_decay
97
self._exclude_from_weight_decay = exclude_from_weight_decay
98
99
@classmethod
100
def from_config(cls, config):
101
"""Creates an optimizer from its config with WarmUp custom object."""
102
custom_objects = {"WarmUp": WarmUp}
103
return super(AdamWeightDecay, cls).from_config(
104
config, custom_objects=custom_objects
105
)
106
107
def _prepare_local(self, var_device, var_dtype, apply_state):
108
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
109
apply_state["weight_decay_rate"] = tf.constant(
110
self.weight_decay_rate, name="adam_weight_decay_rate"
111
)
112
113
def _decay_weights_op(self, var, learning_rate, apply_state):
114
do_decay = self._do_use_weight_decay(var.name)
115
if do_decay:
116
return var.assign_sub(
117
learning_rate * var * apply_state["weight_decay_rate"],
118
use_locking=self._use_locking,
119
)
120
return tf.no_op()
121
122
def apply_gradients(self, grads_and_vars, clip_norm=0.5, **kwargs):
123
grads, tvars = list(zip(*grads_and_vars))
124
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
125
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), **kwargs)
126
127
def _get_lr(self, var_device, var_dtype, apply_state):
128
"""Retrieves the learning rate with the given state."""
129
if apply_state is None:
130
return self._decayed_lr_t[var_dtype], {}
131
132
apply_state = apply_state or {}
133
coefficients = apply_state.get((var_device, var_dtype))
134
if coefficients is None:
135
coefficients = self._fallback_apply_state(var_device, var_dtype)
136
apply_state[(var_device, var_dtype)] = coefficients
137
138
return coefficients["lr_t"], dict(apply_state=apply_state)
139
140
def _resource_apply_dense(self, grad, var, apply_state=None):
141
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
142
decay = self._decay_weights_op(var, lr_t, apply_state)
143
with tf.control_dependencies([decay]):
144
return super(AdamWeightDecay, self)._resource_apply_dense(
145
grad, var, **kwargs
146
)
147
148
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
149
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
150
decay = self._decay_weights_op(var, lr_t, apply_state)
151
with tf.control_dependencies([decay]):
152
return super(AdamWeightDecay, self)._resource_apply_sparse(
153
grad, var, indices, **kwargs
154
)
155
156
def get_config(self):
157
config = super(AdamWeightDecay, self).get_config()
158
config.update(
159
{"weight_decay_rate": self.weight_decay_rate,}
160
)
161
return config
162
163
def _do_use_weight_decay(self, param_name):
164
"""Whether to use L2 weight decay for `param_name`."""
165
if self.weight_decay_rate == 0:
166
return False
167
168
if self._include_in_weight_decay:
169
for r in self._include_in_weight_decay:
170
if re.search(r, param_name) is not None:
171
return True
172
173
if self._exclude_from_weight_decay:
174
for r in self._exclude_from_weight_decay:
175
if re.search(r, param_name) is not None:
176
return False
177
return True
178
179