Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/apa_aug.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/apa_aug.py
6
7
import torch
8
9
10
def apply_apa_aug(real_images, fake_images, apa_p, local_rank):
11
# Apply Adaptive Pseudo Augmentation (APA)
12
# https://github.com/EndlessSora/DeceiveD/blob/main/training/loss.py
13
batch_size = real_images.shape[0]
14
pseudo_flag = torch.ones([batch_size, 1, 1, 1], device=local_rank)
15
pseudo_flag = torch.where(torch.rand([batch_size, 1, 1, 1], device=local_rank) < apa_p,
16
pseudo_flag, torch.zeros_like(pseudo_flag))
17
if torch.allclose(pseudo_flag, torch.zeros_like(pseudo_flag)):
18
return real_images
19
else:
20
assert fake_images is not None
21
return fake_images * pseudo_flag + real_images * (1 - pseudo_flag)
22
23