Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_ema.py
1422 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import tempfile
17
import unittest
18
19
import torch
20
21
from diffusers import UNet2DConditionModel
22
from diffusers.training_utils import EMAModel
23
from diffusers.utils.testing_utils import skip_mps, torch_device
24
25
26
class EMAModelTests(unittest.TestCase):
27
model_id = "hf-internal-testing/tiny-stable-diffusion-pipe"
28
batch_size = 1
29
prompt_length = 77
30
text_encoder_hidden_dim = 32
31
num_in_channels = 4
32
latent_height = latent_width = 64
33
generator = torch.manual_seed(0)
34
35
def get_models(self, decay=0.9999):
36
unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet")
37
unet = unet.to(torch_device)
38
ema_unet = EMAModel(unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config)
39
return unet, ema_unet
40
41
def get_dummy_inputs(self):
42
noisy_latents = torch.randn(
43
self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator
44
).to(torch_device)
45
timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device)
46
encoder_hidden_states = torch.randn(
47
self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator
48
).to(torch_device)
49
return noisy_latents, timesteps, encoder_hidden_states
50
51
def simulate_backprop(self, unet):
52
updated_state_dict = {}
53
for k, param in unet.state_dict().items():
54
updated_param = torch.randn_like(param) + (param * torch.randn_like(param))
55
updated_state_dict.update({k: updated_param})
56
unet.load_state_dict(updated_state_dict)
57
return unet
58
59
def test_optimization_steps_updated(self):
60
unet, ema_unet = self.get_models()
61
# Take the first (hypothetical) EMA step.
62
ema_unet.step(unet.parameters())
63
assert ema_unet.optimization_step == 1
64
65
# Take two more.
66
for _ in range(2):
67
ema_unet.step(unet.parameters())
68
assert ema_unet.optimization_step == 3
69
70
def test_shadow_params_not_updated(self):
71
unet, ema_unet = self.get_models()
72
# Since the `unet` is not being updated (i.e., backprop'd)
73
# there won't be any difference between the `params` of `unet`
74
# and `ema_unet` even if we call `ema_unet.step(unet.parameters())`.
75
ema_unet.step(unet.parameters())
76
orig_params = list(unet.parameters())
77
for s_param, param in zip(ema_unet.shadow_params, orig_params):
78
assert torch.allclose(s_param, param)
79
80
# The above holds true even if we call `ema.step()` multiple times since
81
# `unet` params are still not being updated.
82
for _ in range(4):
83
ema_unet.step(unet.parameters())
84
for s_param, param in zip(ema_unet.shadow_params, orig_params):
85
assert torch.allclose(s_param, param)
86
87
def test_shadow_params_updated(self):
88
unet, ema_unet = self.get_models()
89
# Here we simulate the parameter updates for `unet`. Since there might
90
# be some parameters which are initialized to zero we take extra care to
91
# initialize their values to something non-zero before the multiplication.
92
unet_pseudo_updated_step_one = self.simulate_backprop(unet)
93
94
# Take the EMA step.
95
ema_unet.step(unet_pseudo_updated_step_one.parameters())
96
97
# Now the EMA'd parameters won't be equal to the original model parameters.
98
orig_params = list(unet_pseudo_updated_step_one.parameters())
99
for s_param, param in zip(ema_unet.shadow_params, orig_params):
100
assert ~torch.allclose(s_param, param)
101
102
# Ensure this is the case when we take multiple EMA steps.
103
for _ in range(4):
104
ema_unet.step(unet.parameters())
105
for s_param, param in zip(ema_unet.shadow_params, orig_params):
106
assert ~torch.allclose(s_param, param)
107
108
def test_consecutive_shadow_params_updated(self):
109
# If we call EMA step after a backpropagation consecutively for two times,
110
# the shadow params from those two steps should be different.
111
unet, ema_unet = self.get_models()
112
113
# First backprop + EMA
114
unet_step_one = self.simulate_backprop(unet)
115
ema_unet.step(unet_step_one.parameters())
116
step_one_shadow_params = ema_unet.shadow_params
117
118
# Second backprop + EMA
119
unet_step_two = self.simulate_backprop(unet_step_one)
120
ema_unet.step(unet_step_two.parameters())
121
step_two_shadow_params = ema_unet.shadow_params
122
123
for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params):
124
assert ~torch.allclose(step_one, step_two)
125
126
def test_zero_decay(self):
127
# If there's no decay even if there are backprops, EMA steps
128
# won't take any effect i.e., the shadow params would remain the
129
# same.
130
unet, ema_unet = self.get_models(decay=0.0)
131
unet_step_one = self.simulate_backprop(unet)
132
ema_unet.step(unet_step_one.parameters())
133
step_one_shadow_params = ema_unet.shadow_params
134
135
unet_step_two = self.simulate_backprop(unet_step_one)
136
ema_unet.step(unet_step_two.parameters())
137
step_two_shadow_params = ema_unet.shadow_params
138
139
for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params):
140
assert torch.allclose(step_one, step_two)
141
142
@skip_mps
143
def test_serialization(self):
144
unet, ema_unet = self.get_models()
145
noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs()
146
147
with tempfile.TemporaryDirectory() as tmpdir:
148
ema_unet.save_pretrained(tmpdir)
149
loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel)
150
loaded_unet = loaded_unet.to(unet.device)
151
152
# Since no EMA step has been performed the outputs should match.
153
output = unet(noisy_latents, timesteps, encoder_hidden_states).sample
154
output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
155
156
assert torch.allclose(output, output_loaded, atol=1e-4)
157
158