Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/trainers/base_trainer.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Based Trainer."""
16
17
import abc
18
import logging
19
import os
20
21
import tensorflow as tf
22
from tqdm import tqdm
23
24
from tensorflow_tts.optimizers import GradientAccumulator
25
from tensorflow_tts.utils import utils
26
27
28
class BasedTrainer(metaclass=abc.ABCMeta):
29
"""Customized trainer module for all models."""
30
31
def __init__(self, steps, epochs, config):
32
self.steps = steps
33
self.epochs = epochs
34
self.config = config
35
self.finish_train = False
36
self.writer = tf.summary.create_file_writer(config["outdir"])
37
self.train_data_loader = None
38
self.eval_data_loader = None
39
self.train_metrics = None
40
self.eval_metrics = None
41
self.list_metrics_name = None
42
43
def init_train_eval_metrics(self, list_metrics_name):
44
"""Init train and eval metrics to save it to tensorboard."""
45
self.train_metrics = {}
46
self.eval_metrics = {}
47
for name in list_metrics_name:
48
self.train_metrics.update(
49
{name: tf.keras.metrics.Mean(name="train_" + name, dtype=tf.float32)}
50
)
51
self.eval_metrics.update(
52
{name: tf.keras.metrics.Mean(name="eval_" + name, dtype=tf.float32)}
53
)
54
55
def reset_states_train(self):
56
"""Reset train metrics after save it to tensorboard."""
57
for metric in self.train_metrics.keys():
58
self.train_metrics[metric].reset_states()
59
60
def reset_states_eval(self):
61
"""Reset eval metrics after save it to tensorboard."""
62
for metric in self.eval_metrics.keys():
63
self.eval_metrics[metric].reset_states()
64
65
def update_train_metrics(self, dict_metrics_losses):
66
for name, value in dict_metrics_losses.items():
67
self.train_metrics[name].update_state(value)
68
69
def update_eval_metrics(self, dict_metrics_losses):
70
for name, value in dict_metrics_losses.items():
71
self.eval_metrics[name].update_state(value)
72
73
def set_train_data_loader(self, train_dataset):
74
"""Set train data loader (MUST)."""
75
self.train_data_loader = train_dataset
76
77
def get_train_data_loader(self):
78
"""Get train data loader."""
79
return self.train_data_loader
80
81
def set_eval_data_loader(self, eval_dataset):
82
"""Set eval data loader (MUST)."""
83
self.eval_data_loader = eval_dataset
84
85
def get_eval_data_loader(self):
86
"""Get eval data loader."""
87
return self.eval_data_loader
88
89
@abc.abstractmethod
90
def compile(self):
91
pass
92
93
@abc.abstractmethod
94
def create_checkpoint_manager(self, saved_path=None, max_to_keep=10):
95
"""Create checkpoint management."""
96
pass
97
98
def run(self):
99
"""Run training."""
100
self.tqdm = tqdm(
101
initial=self.steps, total=self.config["train_max_steps"], desc="[train]"
102
)
103
while True:
104
self._train_epoch()
105
106
if self.finish_train:
107
break
108
109
self.tqdm.close()
110
logging.info("Finish training.")
111
112
@abc.abstractmethod
113
def save_checkpoint(self):
114
"""Save checkpoint."""
115
pass
116
117
@abc.abstractmethod
118
def load_checkpoint(self, pretrained_path):
119
"""Load checkpoint."""
120
pass
121
122
def _train_epoch(self):
123
"""Train model one epoch."""
124
for train_steps_per_epoch, batch in enumerate(self.train_data_loader, 1):
125
# one step training
126
self._train_step(batch)
127
128
# check interval
129
self._check_log_interval()
130
self._check_eval_interval()
131
self._check_save_interval()
132
133
# check wheter training is finished
134
if self.finish_train:
135
return
136
137
# update
138
self.epochs += 1
139
self.train_steps_per_epoch = train_steps_per_epoch
140
logging.info(
141
f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
142
f"({self.train_steps_per_epoch} steps per epoch)."
143
)
144
145
@abc.abstractmethod
146
def _eval_epoch(self):
147
"""One epoch evaluation."""
148
pass
149
150
@abc.abstractmethod
151
def _train_step(self, batch):
152
"""One step training."""
153
pass
154
155
@abc.abstractmethod
156
def _check_log_interval(self):
157
"""Save log interval."""
158
pass
159
160
@abc.abstractmethod
161
def fit(self):
162
pass
163
164
def _check_eval_interval(self):
165
"""Evaluation interval step."""
166
if self.steps % self.config["eval_interval_steps"] == 0:
167
self._eval_epoch()
168
169
def _check_save_interval(self):
170
"""Save interval checkpoint."""
171
if self.steps % self.config["save_interval_steps"] == 0:
172
self.save_checkpoint()
173
logging.info(f"Successfully saved checkpoint @ {self.steps} steps.")
174
175
def generate_and_save_intermediate_result(self, batch):
176
"""Generate and save intermediate result."""
177
pass
178
179
def _write_to_tensorboard(self, list_metrics, stage="train"):
180
"""Write variables to tensorboard."""
181
with self.writer.as_default():
182
for key, value in list_metrics.items():
183
tf.summary.scalar(stage + "/" + key, value.result(), step=self.steps)
184
self.writer.flush()
185
186
187
class GanBasedTrainer(BasedTrainer):
188
"""Customized trainer module for GAN TTS training (MelGAN, GAN-TTS, ParallelWaveGAN)."""
189
190
def __init__(
191
self,
192
steps,
193
epochs,
194
config,
195
strategy,
196
is_generator_mixed_precision=False,
197
is_discriminator_mixed_precision=False,
198
):
199
"""Initialize trainer.
200
201
Args:
202
steps (int): Initial global steps.
203
epochs (int): Initial global epochs.
204
config (dict): Config dict loaded from yaml format configuration file.
205
206
"""
207
super().__init__(steps, epochs, config)
208
self._is_generator_mixed_precision = is_generator_mixed_precision
209
self._is_discriminator_mixed_precision = is_discriminator_mixed_precision
210
self._strategy = strategy
211
self._already_apply_input_signature = False
212
self._generator_gradient_accumulator = GradientAccumulator()
213
self._discriminator_gradient_accumulator = GradientAccumulator()
214
self._generator_gradient_accumulator.reset()
215
self._discriminator_gradient_accumulator.reset()
216
217
def init_train_eval_metrics(self, list_metrics_name):
218
with self._strategy.scope():
219
super().init_train_eval_metrics(list_metrics_name)
220
221
def get_n_gpus(self):
222
return self._strategy.num_replicas_in_sync
223
224
def _get_train_element_signature(self):
225
return self.train_data_loader.element_spec
226
227
def _get_eval_element_signature(self):
228
return self.eval_data_loader.element_spec
229
230
def set_gen_model(self, generator_model):
231
"""Set generator class model (MUST)."""
232
self._generator = generator_model
233
234
def get_gen_model(self):
235
"""Get generator model."""
236
return self._generator
237
238
def set_dis_model(self, discriminator_model):
239
"""Set discriminator class model (MUST)."""
240
self._discriminator = discriminator_model
241
242
def get_dis_model(self):
243
"""Get discriminator model."""
244
return self._discriminator
245
246
def set_gen_optimizer(self, generator_optimizer):
247
"""Set generator optimizer (MUST)."""
248
self._gen_optimizer = generator_optimizer
249
if self._is_generator_mixed_precision:
250
self._gen_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
251
self._gen_optimizer, "dynamic"
252
)
253
254
def get_gen_optimizer(self):
255
"""Get generator optimizer."""
256
return self._gen_optimizer
257
258
def set_dis_optimizer(self, discriminator_optimizer):
259
"""Set discriminator optimizer (MUST)."""
260
self._dis_optimizer = discriminator_optimizer
261
if self._is_discriminator_mixed_precision:
262
self._dis_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
263
self._dis_optimizer, "dynamic"
264
)
265
266
def get_dis_optimizer(self):
267
"""Get discriminator optimizer."""
268
return self._dis_optimizer
269
270
def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):
271
self.set_gen_model(gen_model)
272
self.set_dis_model(dis_model)
273
self.set_gen_optimizer(gen_optimizer)
274
self.set_dis_optimizer(dis_optimizer)
275
276
def _train_step(self, batch):
277
if self._already_apply_input_signature is False:
278
train_element_signature = self._get_train_element_signature()
279
eval_element_signature = self._get_eval_element_signature()
280
self.one_step_forward = tf.function(
281
self._one_step_forward, input_signature=[train_element_signature]
282
)
283
self.one_step_evaluate = tf.function(
284
self._one_step_evaluate, input_signature=[eval_element_signature]
285
)
286
self.one_step_predict = tf.function(
287
self._one_step_predict, input_signature=[eval_element_signature]
288
)
289
self._already_apply_input_signature = True
290
291
# run one_step_forward
292
self.one_step_forward(batch)
293
294
# update counts
295
self.steps += 1
296
self.tqdm.update(1)
297
self._check_train_finish()
298
299
def _one_step_forward(self, batch):
300
per_replica_losses = self._strategy.run(
301
self._one_step_forward_per_replica, args=(batch,)
302
)
303
return self._strategy.reduce(
304
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None
305
)
306
307
@abc.abstractmethod
308
def compute_per_example_generator_losses(self, batch, outputs):
309
"""Compute per example generator losses and return dict_metrics_losses
310
Note that all element of the loss MUST has a shape [batch_size] and
311
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
312
313
Args:
314
batch: dictionary batch input return from dataloader
315
outputs: outputs of the model
316
317
Returns:
318
per_example_losses: per example losses for each GPU, shape [B]
319
dict_metrics_losses: dictionary loss.
320
"""
321
per_example_losses = 0.0
322
dict_metrics_losses = {}
323
return per_example_losses, dict_metrics_losses
324
325
@abc.abstractmethod
326
def compute_per_example_discriminator_losses(self, batch, gen_outputs):
327
"""Compute per example discriminator losses and return dict_metrics_losses
328
Note that all element of the loss MUST has a shape [batch_size] and
329
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
330
331
Args:
332
batch: dictionary batch input return from dataloader
333
outputs: outputs of the model
334
335
Returns:
336
per_example_losses: per example losses for each GPU, shape [B]
337
dict_metrics_losses: dictionary loss.
338
"""
339
per_example_losses = 0.0
340
dict_metrics_losses = {}
341
return per_example_losses, dict_metrics_losses
342
343
def _calculate_generator_gradient_per_batch(self, batch):
344
outputs = self._generator(**batch, training=True)
345
(
346
per_example_losses,
347
dict_metrics_losses,
348
) = self.compute_per_example_generator_losses(batch, outputs)
349
per_replica_gen_losses = tf.nn.compute_average_loss(
350
per_example_losses,
351
global_batch_size=self.config["batch_size"]
352
* self.get_n_gpus()
353
* self.config["gradient_accumulation_steps"],
354
)
355
356
if self._is_generator_mixed_precision:
357
scaled_per_replica_gen_losses = self._gen_optimizer.get_scaled_loss(
358
per_replica_gen_losses
359
)
360
361
if self._is_generator_mixed_precision:
362
scaled_gradients = tf.gradients(
363
scaled_per_replica_gen_losses, self._generator.trainable_variables
364
)
365
gradients = self._gen_optimizer.get_unscaled_gradients(scaled_gradients)
366
else:
367
gradients = tf.gradients(
368
per_replica_gen_losses, self._generator.trainable_variables
369
)
370
371
# gradient accumulate for generator here
372
if self.config["gradient_accumulation_steps"] > 1:
373
self._generator_gradient_accumulator(gradients)
374
375
# accumulate loss into metrics
376
self.update_train_metrics(dict_metrics_losses)
377
378
if self.config["gradient_accumulation_steps"] == 1:
379
return gradients, per_replica_gen_losses
380
else:
381
return per_replica_gen_losses
382
383
def _calculate_discriminator_gradient_per_batch(self, batch):
384
(
385
per_example_losses,
386
dict_metrics_losses,
387
) = self.compute_per_example_discriminator_losses(
388
batch, self._generator(**batch, training=True)
389
)
390
391
per_replica_dis_losses = tf.nn.compute_average_loss(
392
per_example_losses,
393
global_batch_size=self.config["batch_size"]
394
* self.get_n_gpus()
395
* self.config["gradient_accumulation_steps"],
396
)
397
398
if self._is_discriminator_mixed_precision:
399
scaled_per_replica_dis_losses = self._dis_optimizer.get_scaled_loss(
400
per_replica_dis_losses
401
)
402
403
if self._is_discriminator_mixed_precision:
404
scaled_gradients = tf.gradients(
405
scaled_per_replica_dis_losses,
406
self._discriminator.trainable_variables,
407
)
408
gradients = self._dis_optimizer.get_unscaled_gradients(scaled_gradients)
409
else:
410
gradients = tf.gradients(
411
per_replica_dis_losses, self._discriminator.trainable_variables
412
)
413
414
# accumulate loss into metrics
415
self.update_train_metrics(dict_metrics_losses)
416
417
# gradient accumulate for discriminator here
418
if self.config["gradient_accumulation_steps"] > 1:
419
self._discriminator_gradient_accumulator(gradients)
420
421
if self.config["gradient_accumulation_steps"] == 1:
422
return gradients, per_replica_dis_losses
423
else:
424
return per_replica_dis_losses
425
426
427
def _one_step_forward_per_replica(self, batch):
428
per_replica_gen_losses = 0.0
429
per_replica_dis_losses = 0.0
430
431
if self.config["gradient_accumulation_steps"] == 1:
432
(
433
gradients,
434
per_replica_gen_losses,
435
) = self._calculate_generator_gradient_per_batch(batch)
436
self._gen_optimizer.apply_gradients(
437
zip(gradients, self._generator.trainable_variables)
438
)
439
else:
440
# gradient acummulation here.
441
for i in tf.range(self.config["gradient_accumulation_steps"]):
442
reduced_batch = {
443
k: v[
444
i
445
* self.config["batch_size"] : (i + 1)
446
* self.config["batch_size"]
447
]
448
for k, v in batch.items()
449
}
450
451
# run 1 step accumulate
452
reduced_batch_losses = self._calculate_generator_gradient_per_batch(
453
reduced_batch
454
)
455
456
# sum per_replica_losses
457
per_replica_gen_losses += reduced_batch_losses
458
459
gradients = self._generator_gradient_accumulator.gradients
460
self._gen_optimizer.apply_gradients(
461
zip(gradients, self._generator.trainable_variables)
462
)
463
self._generator_gradient_accumulator.reset()
464
465
# one step discriminator
466
# recompute y_hat after 1 step generator for discriminator training.
467
if self.steps >= self.config["discriminator_train_start_steps"]:
468
if self.config["gradient_accumulation_steps"] == 1:
469
(
470
gradients,
471
per_replica_dis_losses,
472
) = self._calculate_discriminator_gradient_per_batch(batch)
473
self._dis_optimizer.apply_gradients(
474
zip(gradients, self._discriminator.trainable_variables)
475
)
476
else:
477
# gradient acummulation here.
478
for i in tf.range(self.config["gradient_accumulation_steps"]):
479
reduced_batch = {
480
k: v[
481
i
482
* self.config["batch_size"] : (i + 1)
483
* self.config["batch_size"]
484
]
485
for k, v in batch.items()
486
}
487
488
# run 1 step accumulate
489
reduced_batch_losses = (
490
self._calculate_discriminator_gradient_per_batch(reduced_batch)
491
)
492
493
# sum per_replica_losses
494
per_replica_dis_losses += reduced_batch_losses
495
496
gradients = self._discriminator_gradient_accumulator.gradients
497
self._dis_optimizer.apply_gradients(
498
zip(gradients, self._discriminator.trainable_variables)
499
)
500
self._discriminator_gradient_accumulator.reset()
501
502
return per_replica_gen_losses + per_replica_dis_losses
503
504
def _eval_epoch(self):
505
"""Evaluate model one epoch."""
506
logging.info(f"(Steps: {self.steps}) Start evaluation.")
507
508
# calculate loss for each batch
509
for eval_steps_per_epoch, batch in enumerate(
510
tqdm(self.eval_data_loader, desc="[eval]"), 1
511
):
512
# eval one step
513
self.one_step_evaluate(batch)
514
515
if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]:
516
# save intermedia
517
self.generate_and_save_intermediate_result(batch)
518
519
logging.info(
520
f"(Steps: {self.steps}) Finished evaluation "
521
f"({eval_steps_per_epoch} steps per epoch)."
522
)
523
524
# average loss
525
for key in self.eval_metrics.keys():
526
logging.info(
527
f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}."
528
)
529
530
# record
531
self._write_to_tensorboard(self.eval_metrics, stage="eval")
532
533
# reset
534
self.reset_states_eval()
535
536
def _one_step_evaluate_per_replica(self, batch):
537
################################################
538
# one step generator.
539
outputs = self._generator(**batch, training=False)
540
_, dict_metrics_losses = self.compute_per_example_generator_losses(
541
batch, outputs
542
)
543
544
# accumulate loss into metrics
545
self.update_eval_metrics(dict_metrics_losses)
546
547
################################################
548
# one step discriminator
549
if self.steps >= self.config["discriminator_train_start_steps"]:
550
_, dict_metrics_losses = self.compute_per_example_discriminator_losses(
551
batch, outputs
552
)
553
554
# accumulate loss into metrics
555
self.update_eval_metrics(dict_metrics_losses)
556
557
################################################
558
559
def _one_step_evaluate(self, batch):
560
self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,))
561
562
def _one_step_predict_per_replica(self, batch):
563
outputs = self._generator(**batch, training=False)
564
return outputs
565
566
def _one_step_predict(self, batch):
567
outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,))
568
return outputs
569
570
@abc.abstractmethod
571
def generate_and_save_intermediate_result(self, batch):
572
return
573
574
def create_checkpoint_manager(self, saved_path=None, max_to_keep=10):
575
"""Create checkpoint management."""
576
if saved_path is None:
577
saved_path = self.config["outdir"] + "/checkpoints/"
578
579
os.makedirs(saved_path, exist_ok=True)
580
581
self.saved_path = saved_path
582
self.ckpt = tf.train.Checkpoint(
583
steps=tf.Variable(1),
584
epochs=tf.Variable(1),
585
gen_optimizer=self.get_gen_optimizer(),
586
dis_optimizer=self.get_dis_optimizer(),
587
)
588
self.ckp_manager = tf.train.CheckpointManager(
589
self.ckpt, saved_path, max_to_keep=max_to_keep
590
)
591
592
def save_checkpoint(self):
593
"""Save checkpoint."""
594
self.ckpt.steps.assign(self.steps)
595
self.ckpt.epochs.assign(self.epochs)
596
self.ckp_manager.save(checkpoint_number=self.steps)
597
utils.save_weights(
598
self._generator,
599
self.saved_path + "generator-{}.h5".format(self.steps)
600
)
601
utils.save_weights(
602
self._discriminator,
603
self.saved_path + "discriminator-{}.h5".format(self.steps)
604
)
605
606
def load_checkpoint(self, pretrained_path):
607
"""Load checkpoint."""
608
self.ckpt.restore(pretrained_path)
609
self.steps = self.ckpt.steps.numpy()
610
self.epochs = self.ckpt.epochs.numpy()
611
self._gen_optimizer = self.ckpt.gen_optimizer
612
# re-assign iterations (global steps) for gen_optimizer.
613
self._gen_optimizer.iterations.assign(tf.cast(self.steps, tf.int64))
614
# re-assign iterations (global steps) for dis_optimizer.
615
try:
616
discriminator_train_start_steps = self.config[
617
"discriminator_train_start_steps"
618
]
619
discriminator_train_start_steps = tf.math.maximum(
620
0, self.steps - discriminator_train_start_steps
621
)
622
except Exception:
623
discriminator_train_start_steps = self.steps
624
self._dis_optimizer = self.ckpt.dis_optimizer
625
self._dis_optimizer.iterations.assign(
626
tf.cast(discriminator_train_start_steps, tf.int64)
627
)
628
629
# load weights.
630
utils.load_weights(
631
self._generator,
632
self.saved_path + "generator-{}.h5".format(self.steps)
633
)
634
utils.load_weights(
635
self._discriminator,
636
self.saved_path + "discriminator-{}.h5".format(self.steps)
637
)
638
639
def _check_train_finish(self):
640
"""Check training finished."""
641
if self.steps >= self.config["train_max_steps"]:
642
self.finish_train = True
643
644
if (
645
self.steps != 0
646
and self.steps == self.config["discriminator_train_start_steps"]
647
):
648
self.finish_train = True
649
logging.info(
650
f"Finished training only generator at {self.steps}steps, pls resume and continue training."
651
)
652
653
def _check_log_interval(self):
654
"""Log to tensorboard."""
655
if self.steps % self.config["log_interval_steps"] == 0:
656
for metric_name in self.list_metrics_name:
657
logging.info(
658
f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}."
659
)
660
self._write_to_tensorboard(self.train_metrics, stage="train")
661
662
# reset
663
self.reset_states_train()
664
665
def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None):
666
self.set_train_data_loader(train_data_loader)
667
self.set_eval_data_loader(valid_data_loader)
668
self.train_data_loader = self._strategy.experimental_distribute_dataset(
669
self.train_data_loader
670
)
671
self.eval_data_loader = self._strategy.experimental_distribute_dataset(
672
self.eval_data_loader
673
)
674
with self._strategy.scope():
675
self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000)
676
if len(resume) > 1:
677
self.load_checkpoint(resume)
678
logging.info(f"Successfully resumed from {resume}.")
679
self.run()
680
681
682
class Seq2SeqBasedTrainer(BasedTrainer, metaclass=abc.ABCMeta):
683
"""Customized trainer module for Seq2Seq TTS training (Tacotron, FastSpeech)."""
684
685
def __init__(
686
self, steps, epochs, config, strategy, is_mixed_precision=False,
687
):
688
"""Initialize trainer.
689
690
Args:
691
steps (int): Initial global steps.
692
epochs (int): Initial global epochs.
693
config (dict): Config dict loaded from yaml format configuration file.
694
strategy (tf.distribute): Strategy for distributed training.
695
is_mixed_precision (bool): Use mixed_precision training or not.
696
697
"""
698
super().__init__(steps, epochs, config)
699
self._is_mixed_precision = is_mixed_precision
700
self._strategy = strategy
701
self._model = None
702
self._optimizer = None
703
self._trainable_variables = None
704
705
# check if we already apply input_signature for train_step.
706
self._already_apply_input_signature = False
707
708
# create gradient accumulator
709
self._gradient_accumulator = GradientAccumulator()
710
self._gradient_accumulator.reset()
711
712
def init_train_eval_metrics(self, list_metrics_name):
713
with self._strategy.scope():
714
super().init_train_eval_metrics(list_metrics_name)
715
716
def set_model(self, model):
717
"""Set generator class model (MUST)."""
718
self._model = model
719
720
def get_model(self):
721
"""Get generator model."""
722
return self._model
723
724
def set_optimizer(self, optimizer):
725
"""Set optimizer (MUST)."""
726
self._optimizer = optimizer
727
if self._is_mixed_precision:
728
self._optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
729
self._optimizer, "dynamic"
730
)
731
732
def get_optimizer(self):
733
"""Get optimizer."""
734
return self._optimizer
735
736
def get_n_gpus(self):
737
return self._strategy.num_replicas_in_sync
738
739
def compile(self, model, optimizer):
740
self.set_model(model)
741
self.set_optimizer(optimizer)
742
self._trainable_variables = self._train_vars()
743
744
def _train_vars(self):
745
if self.config["var_train_expr"]:
746
list_train_var = self.config["var_train_expr"].split("|")
747
return [
748
v
749
for v in self._model.trainable_variables
750
if self._check_string_exist(list_train_var, v.name)
751
]
752
return self._model.trainable_variables
753
754
def _check_string_exist(self, list_string, inp_string):
755
for string in list_string:
756
if string in inp_string:
757
return True
758
return False
759
760
def _get_train_element_signature(self):
761
return self.train_data_loader.element_spec
762
763
def _get_eval_element_signature(self):
764
return self.eval_data_loader.element_spec
765
766
def _train_step(self, batch):
767
if self._already_apply_input_signature is False:
768
train_element_signature = self._get_train_element_signature()
769
eval_element_signature = self._get_eval_element_signature()
770
self.one_step_forward = tf.function(
771
self._one_step_forward, input_signature=[train_element_signature]
772
)
773
self.one_step_evaluate = tf.function(
774
self._one_step_evaluate, input_signature=[eval_element_signature]
775
)
776
self.one_step_predict = tf.function(
777
self._one_step_predict, input_signature=[eval_element_signature]
778
)
779
self._already_apply_input_signature = True
780
781
# run one_step_forward
782
self.one_step_forward(batch)
783
784
# update counts
785
self.steps += 1
786
self.tqdm.update(1)
787
self._check_train_finish()
788
789
def _one_step_forward(self, batch):
790
per_replica_losses = self._strategy.run(
791
self._one_step_forward_per_replica, args=(batch,)
792
)
793
return self._strategy.reduce(
794
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None
795
)
796
797
def _calculate_gradient_per_batch(self, batch):
798
outputs = self._model(**batch, training=True)
799
per_example_losses, dict_metrics_losses = self.compute_per_example_losses(
800
batch, outputs
801
)
802
per_replica_losses = tf.nn.compute_average_loss(
803
per_example_losses,
804
global_batch_size=self.config["batch_size"]
805
* self.get_n_gpus()
806
* self.config["gradient_accumulation_steps"],
807
)
808
809
if self._is_mixed_precision:
810
scaled_per_replica_losses = self._optimizer.get_scaled_loss(
811
per_replica_losses
812
)
813
814
if self._is_mixed_precision:
815
scaled_gradients = tf.gradients(
816
scaled_per_replica_losses, self._trainable_variables
817
)
818
gradients = self._optimizer.get_unscaled_gradients(scaled_gradients)
819
else:
820
gradients = tf.gradients(per_replica_losses, self._trainable_variables)
821
822
# gradient accumulate here
823
if self.config["gradient_accumulation_steps"] > 1:
824
self._gradient_accumulator(gradients)
825
826
# accumulate loss into metrics
827
self.update_train_metrics(dict_metrics_losses)
828
829
if self.config["gradient_accumulation_steps"] == 1:
830
return gradients, per_replica_losses
831
else:
832
return per_replica_losses
833
834
def _one_step_forward_per_replica(self, batch):
835
if self.config["gradient_accumulation_steps"] == 1:
836
gradients, per_replica_losses = self._calculate_gradient_per_batch(batch)
837
self._optimizer.apply_gradients(
838
zip(gradients, self._trainable_variables), 1.0
839
)
840
else:
841
# gradient acummulation here.
842
per_replica_losses = 0.0
843
for i in tf.range(self.config["gradient_accumulation_steps"]):
844
reduced_batch = {
845
k: v[
846
i
847
* self.config["batch_size"] : (i + 1)
848
* self.config["batch_size"]
849
]
850
for k, v in batch.items()
851
}
852
853
# run 1 step accumulate
854
reduced_batch_losses = self._calculate_gradient_per_batch(reduced_batch)
855
856
# sum per_replica_losses
857
per_replica_losses += reduced_batch_losses
858
859
gradients = self._gradient_accumulator.gradients
860
self._optimizer.apply_gradients(
861
zip(gradients, self._trainable_variables), 1.0
862
)
863
self._gradient_accumulator.reset()
864
865
return per_replica_losses
866
867
868
@abc.abstractmethod
869
def compute_per_example_losses(self, batch, outputs):
870
"""Compute per example losses and return dict_metrics_losses
871
Note that all element of the loss MUST has a shape [batch_size] and
872
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
873
874
Args:
875
batch: dictionary batch input return from dataloader
876
outputs: outputs of the model
877
878
Returns:
879
per_example_losses: per example losses for each GPU, shape [B]
880
dict_metrics_losses: dictionary loss.
881
"""
882
per_example_losses = 0.0
883
dict_metrics_losses = {}
884
return per_example_losses, dict_metrics_losses
885
886
def _eval_epoch(self):
887
"""Evaluate model one epoch."""
888
logging.info(f"(Steps: {self.steps}) Start evaluation.")
889
890
# calculate loss for each batch
891
for eval_steps_per_epoch, batch in enumerate(
892
tqdm(self.eval_data_loader, desc="[eval]"), 1
893
):
894
# eval one step
895
self.one_step_evaluate(batch)
896
897
if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]:
898
# save intermedia
899
self.generate_and_save_intermediate_result(batch)
900
901
logging.info(
902
f"(Steps: {self.steps}) Finished evaluation "
903
f"({eval_steps_per_epoch} steps per epoch)."
904
)
905
906
# average loss
907
for key in self.eval_metrics.keys():
908
logging.info(
909
f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}."
910
)
911
912
# record
913
self._write_to_tensorboard(self.eval_metrics, stage="eval")
914
915
# reset
916
self.reset_states_eval()
917
918
def _one_step_evaluate_per_replica(self, batch):
919
outputs = self._model(**batch, training=False)
920
_, dict_metrics_losses = self.compute_per_example_losses(batch, outputs)
921
922
self.update_eval_metrics(dict_metrics_losses)
923
924
def _one_step_evaluate(self, batch):
925
self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,))
926
927
def _one_step_predict_per_replica(self, batch):
928
outputs = self._model(**batch, training=False)
929
return outputs
930
931
def _one_step_predict(self, batch):
932
outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,))
933
return outputs
934
935
@abc.abstractmethod
936
def generate_and_save_intermediate_result(self, batch):
937
return
938
939
def create_checkpoint_manager(self, saved_path=None, max_to_keep=10):
940
"""Create checkpoint management."""
941
if saved_path is None:
942
saved_path = self.config["outdir"] + "/checkpoints/"
943
944
os.makedirs(saved_path, exist_ok=True)
945
946
self.saved_path = saved_path
947
self.ckpt = tf.train.Checkpoint(
948
steps=tf.Variable(1), epochs=tf.Variable(1), optimizer=self.get_optimizer()
949
)
950
self.ckp_manager = tf.train.CheckpointManager(
951
self.ckpt, saved_path, max_to_keep=max_to_keep
952
)
953
954
def save_checkpoint(self):
955
"""Save checkpoint."""
956
self.ckpt.steps.assign(self.steps)
957
self.ckpt.epochs.assign(self.epochs)
958
self.ckp_manager.save(checkpoint_number=self.steps)
959
utils.save_weights(
960
self._model,
961
self.saved_path + "model-{}.h5".format(self.steps)
962
)
963
964
def load_checkpoint(self, pretrained_path):
965
"""Load checkpoint."""
966
self.ckpt.restore(pretrained_path)
967
self.steps = self.ckpt.steps.numpy()
968
self.epochs = self.ckpt.epochs.numpy()
969
self._optimizer = self.ckpt.optimizer
970
# re-assign iterations (global steps) for optimizer.
971
self._optimizer.iterations.assign(tf.cast(self.steps, tf.int64))
972
973
# load weights.
974
utils.load_weights(
975
self._model,
976
self.saved_path + "model-{}.h5".format(self.steps)
977
)
978
979
def _check_train_finish(self):
980
"""Check training finished."""
981
if self.steps >= self.config["train_max_steps"]:
982
self.finish_train = True
983
984
def _check_log_interval(self):
985
"""Log to tensorboard."""
986
if self.steps % self.config["log_interval_steps"] == 0:
987
for metric_name in self.list_metrics_name:
988
logging.info(
989
f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}."
990
)
991
self._write_to_tensorboard(self.train_metrics, stage="train")
992
993
# reset
994
self.reset_states_train()
995
996
def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None):
997
self.set_train_data_loader(train_data_loader)
998
self.set_eval_data_loader(valid_data_loader)
999
self.train_data_loader = self._strategy.experimental_distribute_dataset(
1000
self.train_data_loader
1001
)
1002
self.eval_data_loader = self._strategy.experimental_distribute_dataset(
1003
self.eval_data_loader
1004
)
1005
with self._strategy.scope():
1006
self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000)
1007
if len(resume) > 1:
1008
self.load_checkpoint(resume)
1009
logging.info(f"Successfully resumed from {resume}.")
1010
self.run()
1011
1012