Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/repaint/test_repaint.py
1450 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import numpy as np
20
import torch
21
22
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
23
from diffusers.utils.testing_utils import load_image, load_numpy, nightly, require_torch_gpu, skip_mps, torch_device
24
25
from ...pipeline_params import IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_INPAINTING_PARAMS
26
from ...test_pipelines_common import PipelineTesterMixin
27
28
29
torch.backends.cuda.matmul.allow_tf32 = False
30
31
32
class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33
pipeline_class = RePaintPipeline
34
params = IMAGE_INPAINTING_PARAMS - {"width", "height", "guidance_scale"}
35
required_optional_params = PipelineTesterMixin.required_optional_params - {
36
"latents",
37
"num_images_per_prompt",
38
"callback",
39
"callback_steps",
40
}
41
batch_params = IMAGE_INPAINTING_BATCH_PARAMS
42
test_cpu_offload = False
43
44
def get_dummy_components(self):
45
torch.manual_seed(0)
46
torch.manual_seed(0)
47
unet = UNet2DModel(
48
block_out_channels=(32, 64),
49
layers_per_block=2,
50
sample_size=32,
51
in_channels=3,
52
out_channels=3,
53
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
54
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
55
)
56
scheduler = RePaintScheduler()
57
components = {"unet": unet, "scheduler": scheduler}
58
return components
59
60
def get_dummy_inputs(self, device, seed=0):
61
if str(device).startswith("mps"):
62
generator = torch.manual_seed(seed)
63
else:
64
generator = torch.Generator(device=device).manual_seed(seed)
65
image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32))
66
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
67
mask = (image > 0).to(device=device, dtype=torch.float32)
68
inputs = {
69
"image": image,
70
"mask_image": mask,
71
"generator": generator,
72
"num_inference_steps": 5,
73
"eta": 0.0,
74
"jump_length": 2,
75
"jump_n_sample": 2,
76
"output_type": "numpy",
77
}
78
return inputs
79
80
def test_repaint(self):
81
device = "cpu" # ensure determinism for the device-dependent torch.Generator
82
components = self.get_dummy_components()
83
sd_pipe = RePaintPipeline(**components)
84
sd_pipe = sd_pipe.to(device)
85
sd_pipe.set_progress_bar_config(disable=None)
86
87
inputs = self.get_dummy_inputs(device)
88
image = sd_pipe(**inputs).images
89
image_slice = image[0, -3:, -3:, -1]
90
91
assert image.shape == (1, 32, 32, 3)
92
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
93
94
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
95
96
@skip_mps
97
def test_save_load_local(self):
98
return super().test_save_load_local()
99
100
# RePaint can hardly be made deterministic since the scheduler is currently always
101
# nondeterministic
102
@unittest.skip("non-deterministic pipeline")
103
def test_inference_batch_single_identical(self):
104
return super().test_inference_batch_single_identical()
105
106
@skip_mps
107
def test_dict_tuple_outputs_equivalent(self):
108
return super().test_dict_tuple_outputs_equivalent()
109
110
@skip_mps
111
def test_save_load_optional_components(self):
112
return super().test_save_load_optional_components()
113
114
@skip_mps
115
def test_attention_slicing_forward_pass(self):
116
return super().test_attention_slicing_forward_pass()
117
118
119
@nightly
120
@require_torch_gpu
121
class RepaintPipelineNightlyTests(unittest.TestCase):
122
def tearDown(self):
123
super().tearDown()
124
gc.collect()
125
torch.cuda.empty_cache()
126
127
def test_celebahq(self):
128
original_image = load_image(
129
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
130
"repaint/celeba_hq_256.png"
131
)
132
mask_image = load_image(
133
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
134
)
135
expected_image = load_numpy(
136
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
137
"repaint/celeba_hq_256_result.npy"
138
)
139
140
model_id = "google/ddpm-ema-celebahq-256"
141
unet = UNet2DModel.from_pretrained(model_id)
142
scheduler = RePaintScheduler.from_pretrained(model_id)
143
144
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
145
repaint.set_progress_bar_config(disable=None)
146
repaint.enable_attention_slicing()
147
148
generator = torch.manual_seed(0)
149
output = repaint(
150
original_image,
151
mask_image,
152
num_inference_steps=250,
153
eta=0.0,
154
jump_length=10,
155
jump_n_sample=10,
156
generator=generator,
157
output_type="np",
158
)
159
image = output.images[0]
160
161
assert image.shape == (256, 256, 3)
162
assert np.abs(expected_image - image).mean() < 1e-2
163
164