# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# src/utils/apa_aug.py56import torch789def apply_apa_aug(real_images, fake_images, apa_p, local_rank):10# Apply Adaptive Pseudo Augmentation (APA)11# https://github.com/EndlessSora/DeceiveD/blob/main/training/loss.py12batch_size = real_images.shape[0]13pseudo_flag = torch.ones([batch_size, 1, 1, 1], device=local_rank)14pseudo_flag = torch.where(torch.rand([batch_size, 1, 1, 1], device=local_rank) < apa_p,15pseudo_flag, torch.zeros_like(pseudo_flag))16if torch.allclose(pseudo_flag, torch.zeros_like(pseudo_flag)):17return real_images18else:19assert fake_images is not None20return fake_images * pseudo_flag + real_images * (1 - pseudo_flag)212223