Path: blob/master/tensorflow_tts/trainers/base_trainer.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Based Trainer."""1516import abc17import logging18import os1920import tensorflow as tf21from tqdm import tqdm2223from tensorflow_tts.optimizers import GradientAccumulator24from tensorflow_tts.utils import utils252627class BasedTrainer(metaclass=abc.ABCMeta):28"""Customized trainer module for all models."""2930def __init__(self, steps, epochs, config):31self.steps = steps32self.epochs = epochs33self.config = config34self.finish_train = False35self.writer = tf.summary.create_file_writer(config["outdir"])36self.train_data_loader = None37self.eval_data_loader = None38self.train_metrics = None39self.eval_metrics = None40self.list_metrics_name = None4142def init_train_eval_metrics(self, list_metrics_name):43"""Init train and eval metrics to save it to tensorboard."""44self.train_metrics = {}45self.eval_metrics = {}46for name in list_metrics_name:47self.train_metrics.update(48{name: tf.keras.metrics.Mean(name="train_" + name, dtype=tf.float32)}49)50self.eval_metrics.update(51{name: tf.keras.metrics.Mean(name="eval_" + name, dtype=tf.float32)}52)5354def reset_states_train(self):55"""Reset train metrics after save it to tensorboard."""56for metric in self.train_metrics.keys():57self.train_metrics[metric].reset_states()5859def reset_states_eval(self):60"""Reset eval metrics after save it to tensorboard."""61for metric in self.eval_metrics.keys():62self.eval_metrics[metric].reset_states()6364def update_train_metrics(self, dict_metrics_losses):65for name, value in dict_metrics_losses.items():66self.train_metrics[name].update_state(value)6768def update_eval_metrics(self, dict_metrics_losses):69for name, value in dict_metrics_losses.items():70self.eval_metrics[name].update_state(value)7172def set_train_data_loader(self, train_dataset):73"""Set train data loader (MUST)."""74self.train_data_loader = train_dataset7576def get_train_data_loader(self):77"""Get train data loader."""78return self.train_data_loader7980def set_eval_data_loader(self, eval_dataset):81"""Set eval data loader (MUST)."""82self.eval_data_loader = eval_dataset8384def get_eval_data_loader(self):85"""Get eval data loader."""86return self.eval_data_loader8788@abc.abstractmethod89def compile(self):90pass9192@abc.abstractmethod93def create_checkpoint_manager(self, saved_path=None, max_to_keep=10):94"""Create checkpoint management."""95pass9697def run(self):98"""Run training."""99self.tqdm = tqdm(100initial=self.steps, total=self.config["train_max_steps"], desc="[train]"101)102while True:103self._train_epoch()104105if self.finish_train:106break107108self.tqdm.close()109logging.info("Finish training.")110111@abc.abstractmethod112def save_checkpoint(self):113"""Save checkpoint."""114pass115116@abc.abstractmethod117def load_checkpoint(self, pretrained_path):118"""Load checkpoint."""119pass120121def _train_epoch(self):122"""Train model one epoch."""123for train_steps_per_epoch, batch in enumerate(self.train_data_loader, 1):124# one step training125self._train_step(batch)126127# check interval128self._check_log_interval()129self._check_eval_interval()130self._check_save_interval()131132# check wheter training is finished133if self.finish_train:134return135136# update137self.epochs += 1138self.train_steps_per_epoch = train_steps_per_epoch139logging.info(140f"(Steps: {self.steps}) Finished {self.epochs} epoch training "141f"({self.train_steps_per_epoch} steps per epoch)."142)143144@abc.abstractmethod145def _eval_epoch(self):146"""One epoch evaluation."""147pass148149@abc.abstractmethod150def _train_step(self, batch):151"""One step training."""152pass153154@abc.abstractmethod155def _check_log_interval(self):156"""Save log interval."""157pass158159@abc.abstractmethod160def fit(self):161pass162163def _check_eval_interval(self):164"""Evaluation interval step."""165if self.steps % self.config["eval_interval_steps"] == 0:166self._eval_epoch()167168def _check_save_interval(self):169"""Save interval checkpoint."""170if self.steps % self.config["save_interval_steps"] == 0:171self.save_checkpoint()172logging.info(f"Successfully saved checkpoint @ {self.steps} steps.")173174def generate_and_save_intermediate_result(self, batch):175"""Generate and save intermediate result."""176pass177178def _write_to_tensorboard(self, list_metrics, stage="train"):179"""Write variables to tensorboard."""180with self.writer.as_default():181for key, value in list_metrics.items():182tf.summary.scalar(stage + "/" + key, value.result(), step=self.steps)183self.writer.flush()184185186class GanBasedTrainer(BasedTrainer):187"""Customized trainer module for GAN TTS training (MelGAN, GAN-TTS, ParallelWaveGAN)."""188189def __init__(190self,191steps,192epochs,193config,194strategy,195is_generator_mixed_precision=False,196is_discriminator_mixed_precision=False,197):198"""Initialize trainer.199200Args:201steps (int): Initial global steps.202epochs (int): Initial global epochs.203config (dict): Config dict loaded from yaml format configuration file.204205"""206super().__init__(steps, epochs, config)207self._is_generator_mixed_precision = is_generator_mixed_precision208self._is_discriminator_mixed_precision = is_discriminator_mixed_precision209self._strategy = strategy210self._already_apply_input_signature = False211self._generator_gradient_accumulator = GradientAccumulator()212self._discriminator_gradient_accumulator = GradientAccumulator()213self._generator_gradient_accumulator.reset()214self._discriminator_gradient_accumulator.reset()215216def init_train_eval_metrics(self, list_metrics_name):217with self._strategy.scope():218super().init_train_eval_metrics(list_metrics_name)219220def get_n_gpus(self):221return self._strategy.num_replicas_in_sync222223def _get_train_element_signature(self):224return self.train_data_loader.element_spec225226def _get_eval_element_signature(self):227return self.eval_data_loader.element_spec228229def set_gen_model(self, generator_model):230"""Set generator class model (MUST)."""231self._generator = generator_model232233def get_gen_model(self):234"""Get generator model."""235return self._generator236237def set_dis_model(self, discriminator_model):238"""Set discriminator class model (MUST)."""239self._discriminator = discriminator_model240241def get_dis_model(self):242"""Get discriminator model."""243return self._discriminator244245def set_gen_optimizer(self, generator_optimizer):246"""Set generator optimizer (MUST)."""247self._gen_optimizer = generator_optimizer248if self._is_generator_mixed_precision:249self._gen_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(250self._gen_optimizer, "dynamic"251)252253def get_gen_optimizer(self):254"""Get generator optimizer."""255return self._gen_optimizer256257def set_dis_optimizer(self, discriminator_optimizer):258"""Set discriminator optimizer (MUST)."""259self._dis_optimizer = discriminator_optimizer260if self._is_discriminator_mixed_precision:261self._dis_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(262self._dis_optimizer, "dynamic"263)264265def get_dis_optimizer(self):266"""Get discriminator optimizer."""267return self._dis_optimizer268269def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):270self.set_gen_model(gen_model)271self.set_dis_model(dis_model)272self.set_gen_optimizer(gen_optimizer)273self.set_dis_optimizer(dis_optimizer)274275def _train_step(self, batch):276if self._already_apply_input_signature is False:277train_element_signature = self._get_train_element_signature()278eval_element_signature = self._get_eval_element_signature()279self.one_step_forward = tf.function(280self._one_step_forward, input_signature=[train_element_signature]281)282self.one_step_evaluate = tf.function(283self._one_step_evaluate, input_signature=[eval_element_signature]284)285self.one_step_predict = tf.function(286self._one_step_predict, input_signature=[eval_element_signature]287)288self._already_apply_input_signature = True289290# run one_step_forward291self.one_step_forward(batch)292293# update counts294self.steps += 1295self.tqdm.update(1)296self._check_train_finish()297298def _one_step_forward(self, batch):299per_replica_losses = self._strategy.run(300self._one_step_forward_per_replica, args=(batch,)301)302return self._strategy.reduce(303tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None304)305306@abc.abstractmethod307def compute_per_example_generator_losses(self, batch, outputs):308"""Compute per example generator losses and return dict_metrics_losses309Note that all element of the loss MUST has a shape [batch_size] and310the keys of dict_metrics_losses MUST be in self.list_metrics_name.311312Args:313batch: dictionary batch input return from dataloader314outputs: outputs of the model315316Returns:317per_example_losses: per example losses for each GPU, shape [B]318dict_metrics_losses: dictionary loss.319"""320per_example_losses = 0.0321dict_metrics_losses = {}322return per_example_losses, dict_metrics_losses323324@abc.abstractmethod325def compute_per_example_discriminator_losses(self, batch, gen_outputs):326"""Compute per example discriminator losses and return dict_metrics_losses327Note that all element of the loss MUST has a shape [batch_size] and328the keys of dict_metrics_losses MUST be in self.list_metrics_name.329330Args:331batch: dictionary batch input return from dataloader332outputs: outputs of the model333334Returns:335per_example_losses: per example losses for each GPU, shape [B]336dict_metrics_losses: dictionary loss.337"""338per_example_losses = 0.0339dict_metrics_losses = {}340return per_example_losses, dict_metrics_losses341342def _calculate_generator_gradient_per_batch(self, batch):343outputs = self._generator(**batch, training=True)344(345per_example_losses,346dict_metrics_losses,347) = self.compute_per_example_generator_losses(batch, outputs)348per_replica_gen_losses = tf.nn.compute_average_loss(349per_example_losses,350global_batch_size=self.config["batch_size"]351* self.get_n_gpus()352* self.config["gradient_accumulation_steps"],353)354355if self._is_generator_mixed_precision:356scaled_per_replica_gen_losses = self._gen_optimizer.get_scaled_loss(357per_replica_gen_losses358)359360if self._is_generator_mixed_precision:361scaled_gradients = tf.gradients(362scaled_per_replica_gen_losses, self._generator.trainable_variables363)364gradients = self._gen_optimizer.get_unscaled_gradients(scaled_gradients)365else:366gradients = tf.gradients(367per_replica_gen_losses, self._generator.trainable_variables368)369370# gradient accumulate for generator here371if self.config["gradient_accumulation_steps"] > 1:372self._generator_gradient_accumulator(gradients)373374# accumulate loss into metrics375self.update_train_metrics(dict_metrics_losses)376377if self.config["gradient_accumulation_steps"] == 1:378return gradients, per_replica_gen_losses379else:380return per_replica_gen_losses381382def _calculate_discriminator_gradient_per_batch(self, batch):383(384per_example_losses,385dict_metrics_losses,386) = self.compute_per_example_discriminator_losses(387batch, self._generator(**batch, training=True)388)389390per_replica_dis_losses = tf.nn.compute_average_loss(391per_example_losses,392global_batch_size=self.config["batch_size"]393* self.get_n_gpus()394* self.config["gradient_accumulation_steps"],395)396397if self._is_discriminator_mixed_precision:398scaled_per_replica_dis_losses = self._dis_optimizer.get_scaled_loss(399per_replica_dis_losses400)401402if self._is_discriminator_mixed_precision:403scaled_gradients = tf.gradients(404scaled_per_replica_dis_losses,405self._discriminator.trainable_variables,406)407gradients = self._dis_optimizer.get_unscaled_gradients(scaled_gradients)408else:409gradients = tf.gradients(410per_replica_dis_losses, self._discriminator.trainable_variables411)412413# accumulate loss into metrics414self.update_train_metrics(dict_metrics_losses)415416# gradient accumulate for discriminator here417if self.config["gradient_accumulation_steps"] > 1:418self._discriminator_gradient_accumulator(gradients)419420if self.config["gradient_accumulation_steps"] == 1:421return gradients, per_replica_dis_losses422else:423return per_replica_dis_losses424425426def _one_step_forward_per_replica(self, batch):427per_replica_gen_losses = 0.0428per_replica_dis_losses = 0.0429430if self.config["gradient_accumulation_steps"] == 1:431(432gradients,433per_replica_gen_losses,434) = self._calculate_generator_gradient_per_batch(batch)435self._gen_optimizer.apply_gradients(436zip(gradients, self._generator.trainable_variables)437)438else:439# gradient acummulation here.440for i in tf.range(self.config["gradient_accumulation_steps"]):441reduced_batch = {442k: v[443i444* self.config["batch_size"] : (i + 1)445* self.config["batch_size"]446]447for k, v in batch.items()448}449450# run 1 step accumulate451reduced_batch_losses = self._calculate_generator_gradient_per_batch(452reduced_batch453)454455# sum per_replica_losses456per_replica_gen_losses += reduced_batch_losses457458gradients = self._generator_gradient_accumulator.gradients459self._gen_optimizer.apply_gradients(460zip(gradients, self._generator.trainable_variables)461)462self._generator_gradient_accumulator.reset()463464# one step discriminator465# recompute y_hat after 1 step generator for discriminator training.466if self.steps >= self.config["discriminator_train_start_steps"]:467if self.config["gradient_accumulation_steps"] == 1:468(469gradients,470per_replica_dis_losses,471) = self._calculate_discriminator_gradient_per_batch(batch)472self._dis_optimizer.apply_gradients(473zip(gradients, self._discriminator.trainable_variables)474)475else:476# gradient acummulation here.477for i in tf.range(self.config["gradient_accumulation_steps"]):478reduced_batch = {479k: v[480i481* self.config["batch_size"] : (i + 1)482* self.config["batch_size"]483]484for k, v in batch.items()485}486487# run 1 step accumulate488reduced_batch_losses = (489self._calculate_discriminator_gradient_per_batch(reduced_batch)490)491492# sum per_replica_losses493per_replica_dis_losses += reduced_batch_losses494495gradients = self._discriminator_gradient_accumulator.gradients496self._dis_optimizer.apply_gradients(497zip(gradients, self._discriminator.trainable_variables)498)499self._discriminator_gradient_accumulator.reset()500501return per_replica_gen_losses + per_replica_dis_losses502503def _eval_epoch(self):504"""Evaluate model one epoch."""505logging.info(f"(Steps: {self.steps}) Start evaluation.")506507# calculate loss for each batch508for eval_steps_per_epoch, batch in enumerate(509tqdm(self.eval_data_loader, desc="[eval]"), 1510):511# eval one step512self.one_step_evaluate(batch)513514if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]:515# save intermedia516self.generate_and_save_intermediate_result(batch)517518logging.info(519f"(Steps: {self.steps}) Finished evaluation "520f"({eval_steps_per_epoch} steps per epoch)."521)522523# average loss524for key in self.eval_metrics.keys():525logging.info(526f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}."527)528529# record530self._write_to_tensorboard(self.eval_metrics, stage="eval")531532# reset533self.reset_states_eval()534535def _one_step_evaluate_per_replica(self, batch):536################################################537# one step generator.538outputs = self._generator(**batch, training=False)539_, dict_metrics_losses = self.compute_per_example_generator_losses(540batch, outputs541)542543# accumulate loss into metrics544self.update_eval_metrics(dict_metrics_losses)545546################################################547# one step discriminator548if self.steps >= self.config["discriminator_train_start_steps"]:549_, dict_metrics_losses = self.compute_per_example_discriminator_losses(550batch, outputs551)552553# accumulate loss into metrics554self.update_eval_metrics(dict_metrics_losses)555556################################################557558def _one_step_evaluate(self, batch):559self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,))560561def _one_step_predict_per_replica(self, batch):562outputs = self._generator(**batch, training=False)563return outputs564565def _one_step_predict(self, batch):566outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,))567return outputs568569@abc.abstractmethod570def generate_and_save_intermediate_result(self, batch):571return572573def create_checkpoint_manager(self, saved_path=None, max_to_keep=10):574"""Create checkpoint management."""575if saved_path is None:576saved_path = self.config["outdir"] + "/checkpoints/"577578os.makedirs(saved_path, exist_ok=True)579580self.saved_path = saved_path581self.ckpt = tf.train.Checkpoint(582steps=tf.Variable(1),583epochs=tf.Variable(1),584gen_optimizer=self.get_gen_optimizer(),585dis_optimizer=self.get_dis_optimizer(),586)587self.ckp_manager = tf.train.CheckpointManager(588self.ckpt, saved_path, max_to_keep=max_to_keep589)590591def save_checkpoint(self):592"""Save checkpoint."""593self.ckpt.steps.assign(self.steps)594self.ckpt.epochs.assign(self.epochs)595self.ckp_manager.save(checkpoint_number=self.steps)596utils.save_weights(597self._generator,598self.saved_path + "generator-{}.h5".format(self.steps)599)600utils.save_weights(601self._discriminator,602self.saved_path + "discriminator-{}.h5".format(self.steps)603)604605def load_checkpoint(self, pretrained_path):606"""Load checkpoint."""607self.ckpt.restore(pretrained_path)608self.steps = self.ckpt.steps.numpy()609self.epochs = self.ckpt.epochs.numpy()610self._gen_optimizer = self.ckpt.gen_optimizer611# re-assign iterations (global steps) for gen_optimizer.612self._gen_optimizer.iterations.assign(tf.cast(self.steps, tf.int64))613# re-assign iterations (global steps) for dis_optimizer.614try:615discriminator_train_start_steps = self.config[616"discriminator_train_start_steps"617]618discriminator_train_start_steps = tf.math.maximum(6190, self.steps - discriminator_train_start_steps620)621except Exception:622discriminator_train_start_steps = self.steps623self._dis_optimizer = self.ckpt.dis_optimizer624self._dis_optimizer.iterations.assign(625tf.cast(discriminator_train_start_steps, tf.int64)626)627628# load weights.629utils.load_weights(630self._generator,631self.saved_path + "generator-{}.h5".format(self.steps)632)633utils.load_weights(634self._discriminator,635self.saved_path + "discriminator-{}.h5".format(self.steps)636)637638def _check_train_finish(self):639"""Check training finished."""640if self.steps >= self.config["train_max_steps"]:641self.finish_train = True642643if (644self.steps != 0645and self.steps == self.config["discriminator_train_start_steps"]646):647self.finish_train = True648logging.info(649f"Finished training only generator at {self.steps}steps, pls resume and continue training."650)651652def _check_log_interval(self):653"""Log to tensorboard."""654if self.steps % self.config["log_interval_steps"] == 0:655for metric_name in self.list_metrics_name:656logging.info(657f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}."658)659self._write_to_tensorboard(self.train_metrics, stage="train")660661# reset662self.reset_states_train()663664def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None):665self.set_train_data_loader(train_data_loader)666self.set_eval_data_loader(valid_data_loader)667self.train_data_loader = self._strategy.experimental_distribute_dataset(668self.train_data_loader669)670self.eval_data_loader = self._strategy.experimental_distribute_dataset(671self.eval_data_loader672)673with self._strategy.scope():674self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000)675if len(resume) > 1:676self.load_checkpoint(resume)677logging.info(f"Successfully resumed from {resume}.")678self.run()679680681class Seq2SeqBasedTrainer(BasedTrainer, metaclass=abc.ABCMeta):682"""Customized trainer module for Seq2Seq TTS training (Tacotron, FastSpeech)."""683684def __init__(685self, steps, epochs, config, strategy, is_mixed_precision=False,686):687"""Initialize trainer.688689Args:690steps (int): Initial global steps.691epochs (int): Initial global epochs.692config (dict): Config dict loaded from yaml format configuration file.693strategy (tf.distribute): Strategy for distributed training.694is_mixed_precision (bool): Use mixed_precision training or not.695696"""697super().__init__(steps, epochs, config)698self._is_mixed_precision = is_mixed_precision699self._strategy = strategy700self._model = None701self._optimizer = None702self._trainable_variables = None703704# check if we already apply input_signature for train_step.705self._already_apply_input_signature = False706707# create gradient accumulator708self._gradient_accumulator = GradientAccumulator()709self._gradient_accumulator.reset()710711def init_train_eval_metrics(self, list_metrics_name):712with self._strategy.scope():713super().init_train_eval_metrics(list_metrics_name)714715def set_model(self, model):716"""Set generator class model (MUST)."""717self._model = model718719def get_model(self):720"""Get generator model."""721return self._model722723def set_optimizer(self, optimizer):724"""Set optimizer (MUST)."""725self._optimizer = optimizer726if self._is_mixed_precision:727self._optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(728self._optimizer, "dynamic"729)730731def get_optimizer(self):732"""Get optimizer."""733return self._optimizer734735def get_n_gpus(self):736return self._strategy.num_replicas_in_sync737738def compile(self, model, optimizer):739self.set_model(model)740self.set_optimizer(optimizer)741self._trainable_variables = self._train_vars()742743def _train_vars(self):744if self.config["var_train_expr"]:745list_train_var = self.config["var_train_expr"].split("|")746return [747v748for v in self._model.trainable_variables749if self._check_string_exist(list_train_var, v.name)750]751return self._model.trainable_variables752753def _check_string_exist(self, list_string, inp_string):754for string in list_string:755if string in inp_string:756return True757return False758759def _get_train_element_signature(self):760return self.train_data_loader.element_spec761762def _get_eval_element_signature(self):763return self.eval_data_loader.element_spec764765def _train_step(self, batch):766if self._already_apply_input_signature is False:767train_element_signature = self._get_train_element_signature()768eval_element_signature = self._get_eval_element_signature()769self.one_step_forward = tf.function(770self._one_step_forward, input_signature=[train_element_signature]771)772self.one_step_evaluate = tf.function(773self._one_step_evaluate, input_signature=[eval_element_signature]774)775self.one_step_predict = tf.function(776self._one_step_predict, input_signature=[eval_element_signature]777)778self._already_apply_input_signature = True779780# run one_step_forward781self.one_step_forward(batch)782783# update counts784self.steps += 1785self.tqdm.update(1)786self._check_train_finish()787788def _one_step_forward(self, batch):789per_replica_losses = self._strategy.run(790self._one_step_forward_per_replica, args=(batch,)791)792return self._strategy.reduce(793tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None794)795796def _calculate_gradient_per_batch(self, batch):797outputs = self._model(**batch, training=True)798per_example_losses, dict_metrics_losses = self.compute_per_example_losses(799batch, outputs800)801per_replica_losses = tf.nn.compute_average_loss(802per_example_losses,803global_batch_size=self.config["batch_size"]804* self.get_n_gpus()805* self.config["gradient_accumulation_steps"],806)807808if self._is_mixed_precision:809scaled_per_replica_losses = self._optimizer.get_scaled_loss(810per_replica_losses811)812813if self._is_mixed_precision:814scaled_gradients = tf.gradients(815scaled_per_replica_losses, self._trainable_variables816)817gradients = self._optimizer.get_unscaled_gradients(scaled_gradients)818else:819gradients = tf.gradients(per_replica_losses, self._trainable_variables)820821# gradient accumulate here822if self.config["gradient_accumulation_steps"] > 1:823self._gradient_accumulator(gradients)824825# accumulate loss into metrics826self.update_train_metrics(dict_metrics_losses)827828if self.config["gradient_accumulation_steps"] == 1:829return gradients, per_replica_losses830else:831return per_replica_losses832833def _one_step_forward_per_replica(self, batch):834if self.config["gradient_accumulation_steps"] == 1:835gradients, per_replica_losses = self._calculate_gradient_per_batch(batch)836self._optimizer.apply_gradients(837zip(gradients, self._trainable_variables), 1.0838)839else:840# gradient acummulation here.841per_replica_losses = 0.0842for i in tf.range(self.config["gradient_accumulation_steps"]):843reduced_batch = {844k: v[845i846* self.config["batch_size"] : (i + 1)847* self.config["batch_size"]848]849for k, v in batch.items()850}851852# run 1 step accumulate853reduced_batch_losses = self._calculate_gradient_per_batch(reduced_batch)854855# sum per_replica_losses856per_replica_losses += reduced_batch_losses857858gradients = self._gradient_accumulator.gradients859self._optimizer.apply_gradients(860zip(gradients, self._trainable_variables), 1.0861)862self._gradient_accumulator.reset()863864return per_replica_losses865866867@abc.abstractmethod868def compute_per_example_losses(self, batch, outputs):869"""Compute per example losses and return dict_metrics_losses870Note that all element of the loss MUST has a shape [batch_size] and871the keys of dict_metrics_losses MUST be in self.list_metrics_name.872873Args:874batch: dictionary batch input return from dataloader875outputs: outputs of the model876877Returns:878per_example_losses: per example losses for each GPU, shape [B]879dict_metrics_losses: dictionary loss.880"""881per_example_losses = 0.0882dict_metrics_losses = {}883return per_example_losses, dict_metrics_losses884885def _eval_epoch(self):886"""Evaluate model one epoch."""887logging.info(f"(Steps: {self.steps}) Start evaluation.")888889# calculate loss for each batch890for eval_steps_per_epoch, batch in enumerate(891tqdm(self.eval_data_loader, desc="[eval]"), 1892):893# eval one step894self.one_step_evaluate(batch)895896if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]:897# save intermedia898self.generate_and_save_intermediate_result(batch)899900logging.info(901f"(Steps: {self.steps}) Finished evaluation "902f"({eval_steps_per_epoch} steps per epoch)."903)904905# average loss906for key in self.eval_metrics.keys():907logging.info(908f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}."909)910911# record912self._write_to_tensorboard(self.eval_metrics, stage="eval")913914# reset915self.reset_states_eval()916917def _one_step_evaluate_per_replica(self, batch):918outputs = self._model(**batch, training=False)919_, dict_metrics_losses = self.compute_per_example_losses(batch, outputs)920921self.update_eval_metrics(dict_metrics_losses)922923def _one_step_evaluate(self, batch):924self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,))925926def _one_step_predict_per_replica(self, batch):927outputs = self._model(**batch, training=False)928return outputs929930def _one_step_predict(self, batch):931outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,))932return outputs933934@abc.abstractmethod935def generate_and_save_intermediate_result(self, batch):936return937938def create_checkpoint_manager(self, saved_path=None, max_to_keep=10):939"""Create checkpoint management."""940if saved_path is None:941saved_path = self.config["outdir"] + "/checkpoints/"942943os.makedirs(saved_path, exist_ok=True)944945self.saved_path = saved_path946self.ckpt = tf.train.Checkpoint(947steps=tf.Variable(1), epochs=tf.Variable(1), optimizer=self.get_optimizer()948)949self.ckp_manager = tf.train.CheckpointManager(950self.ckpt, saved_path, max_to_keep=max_to_keep951)952953def save_checkpoint(self):954"""Save checkpoint."""955self.ckpt.steps.assign(self.steps)956self.ckpt.epochs.assign(self.epochs)957self.ckp_manager.save(checkpoint_number=self.steps)958utils.save_weights(959self._model,960self.saved_path + "model-{}.h5".format(self.steps)961)962963def load_checkpoint(self, pretrained_path):964"""Load checkpoint."""965self.ckpt.restore(pretrained_path)966self.steps = self.ckpt.steps.numpy()967self.epochs = self.ckpt.epochs.numpy()968self._optimizer = self.ckpt.optimizer969# re-assign iterations (global steps) for optimizer.970self._optimizer.iterations.assign(tf.cast(self.steps, tf.int64))971972# load weights.973utils.load_weights(974self._model,975self.saved_path + "model-{}.h5".format(self.steps)976)977978def _check_train_finish(self):979"""Check training finished."""980if self.steps >= self.config["train_max_steps"]:981self.finish_train = True982983def _check_log_interval(self):984"""Log to tensorboard."""985if self.steps % self.config["log_interval_steps"] == 0:986for metric_name in self.list_metrics_name:987logging.info(988f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}."989)990self._write_to_tensorboard(self.train_metrics, stage="train")991992# reset993self.reset_states_train()994995def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None):996self.set_train_data_loader(train_data_loader)997self.set_eval_data_loader(valid_data_loader)998self.train_data_loader = self._strategy.experimental_distribute_dataset(999self.train_data_loader1000)1001self.eval_data_loader = self._strategy.experimental_distribute_dataset(1002self.eval_data_loader1003)1004with self._strategy.scope():1005self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000)1006if len(resume) > 1:1007self.load_checkpoint(resume)1008logging.info(f"Successfully resumed from {resume}.")1009self.run()101010111012