Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/ema.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/ema.py
6
7
import random
8
9
import torch
10
11
12
class Ema(object):
13
def __init__(self, source, target, decay=0.9999, start_iter=0):
14
self.source = source
15
self.target = target
16
self.decay = decay
17
self.start_iter = start_iter
18
self.source_dict = self.source.state_dict()
19
self.target_dict = self.target.state_dict()
20
print("Initialize the copied generator's parameters to be source parameters.")
21
with torch.no_grad():
22
for p_ema, p in zip(self.target.parameters(), self.source.parameters()):
23
p_ema.copy_(p)
24
for b_ema, b in zip(self.target.buffers(), self.source.buffers()):
25
b_ema.copy_(b)
26
27
def update(self, iter=None):
28
if iter >= 0 and iter < self.start_iter:
29
decay = 0.0
30
else:
31
decay = self.decay
32
33
with torch.no_grad():
34
for p_ema, p in zip(self.target.parameters(), self.source.parameters()):
35
p_ema.copy_(p.lerp(p_ema, decay))
36
for (b_ema_name, b_ema), (b_name, b) in zip(self.target.named_buffers(), self.source.named_buffers()):
37
if "num_batches_tracked" in b_ema_name:
38
b_ema.copy_(b)
39
else:
40
b_ema.copy_(b.lerp(b_ema, decay))
41
42
43
class EmaStylegan2(object):
44
def __init__(self, source, target, ema_kimg, ema_rampup, effective_batch_size):
45
self.source = source
46
self.target = target
47
self.ema_nimg = ema_kimg * 1000
48
self.ema_rampup = ema_rampup
49
self.batch_size = effective_batch_size
50
self.source_dict = self.source.state_dict()
51
self.target_dict = self.target.state_dict()
52
print("Initialize the copied generator's parameters to be source parameters.")
53
with torch.no_grad():
54
for p_ema, p in zip(self.target.parameters(), self.source.parameters()):
55
p_ema.copy_(p)
56
for b_ema, b in zip(self.target.buffers(), self.source.buffers()):
57
b_ema.copy_(b)
58
59
def update(self, iter=None):
60
ema_nimg = self.ema_nimg
61
if self.ema_rampup != "N/A":
62
cur_nimg = self.batch_size * iter
63
ema_nimg = min(self.ema_nimg, cur_nimg * self.ema_rampup)
64
ema_beta = 0.5 ** (self.batch_size / max(ema_nimg, 1e-8))
65
with torch.no_grad():
66
for p_ema, p in zip(self.target.parameters(), self.source.parameters()):
67
p_ema.copy_(p.lerp(p_ema, ema_beta))
68
for b_ema, b in zip(self.target.buffers(), self.source.buffers()):
69
b_ema.copy_(b)
70
71