Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
1450 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import random
18
import unittest
19
20
import numpy as np
21
import torch
22
from PIL import Image
23
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
24
25
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
26
from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device
27
from diffusers.utils.testing_utils import require_torch_gpu, slow
28
29
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
30
from ...test_pipelines_common import PipelineTesterMixin
31
32
33
torch.backends.cuda.matmul.allow_tf32 = False
34
35
36
class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
37
pipeline_class = StableDiffusionInpaintPipeline
38
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
39
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
40
41
def get_dummy_components(self):
42
torch.manual_seed(0)
43
unet = UNet2DConditionModel(
44
block_out_channels=(32, 64),
45
layers_per_block=2,
46
sample_size=32,
47
in_channels=9,
48
out_channels=4,
49
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
50
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
51
cross_attention_dim=32,
52
# SD2-specific config below
53
attention_head_dim=(2, 4),
54
use_linear_projection=True,
55
)
56
scheduler = PNDMScheduler(skip_prk_steps=True)
57
torch.manual_seed(0)
58
vae = AutoencoderKL(
59
block_out_channels=[32, 64],
60
in_channels=3,
61
out_channels=3,
62
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
63
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
64
latent_channels=4,
65
sample_size=128,
66
)
67
torch.manual_seed(0)
68
text_encoder_config = CLIPTextConfig(
69
bos_token_id=0,
70
eos_token_id=2,
71
hidden_size=32,
72
intermediate_size=37,
73
layer_norm_eps=1e-05,
74
num_attention_heads=4,
75
num_hidden_layers=5,
76
pad_token_id=1,
77
vocab_size=1000,
78
# SD2-specific config below
79
hidden_act="gelu",
80
projection_dim=512,
81
)
82
text_encoder = CLIPTextModel(text_encoder_config)
83
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
84
85
components = {
86
"unet": unet,
87
"scheduler": scheduler,
88
"vae": vae,
89
"text_encoder": text_encoder,
90
"tokenizer": tokenizer,
91
"safety_checker": None,
92
"feature_extractor": None,
93
}
94
return components
95
96
def get_dummy_inputs(self, device, seed=0):
97
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
98
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
99
image = image.cpu().permute(0, 2, 3, 1)[0]
100
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
101
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
102
if str(device).startswith("mps"):
103
generator = torch.manual_seed(seed)
104
else:
105
generator = torch.Generator(device=device).manual_seed(seed)
106
inputs = {
107
"prompt": "A painting of a squirrel eating a burger",
108
"image": init_image,
109
"mask_image": mask_image,
110
"generator": generator,
111
"num_inference_steps": 2,
112
"guidance_scale": 6.0,
113
"output_type": "numpy",
114
}
115
return inputs
116
117
def test_stable_diffusion_inpaint(self):
118
device = "cpu" # ensure determinism for the device-dependent torch.Generator
119
components = self.get_dummy_components()
120
sd_pipe = StableDiffusionInpaintPipeline(**components)
121
sd_pipe = sd_pipe.to(device)
122
sd_pipe.set_progress_bar_config(disable=None)
123
124
inputs = self.get_dummy_inputs(device)
125
image = sd_pipe(**inputs).images
126
image_slice = image[0, -3:, -3:, -1]
127
128
assert image.shape == (1, 64, 64, 3)
129
expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476])
130
131
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
132
133
134
@slow
135
@require_torch_gpu
136
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
137
def tearDown(self):
138
# clean up the VRAM after each test
139
super().tearDown()
140
gc.collect()
141
torch.cuda.empty_cache()
142
143
def test_stable_diffusion_inpaint_pipeline(self):
144
init_image = load_image(
145
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
146
"/sd2-inpaint/init_image.png"
147
)
148
mask_image = load_image(
149
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
150
)
151
expected_image = load_numpy(
152
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
153
"/yellow_cat_sitting_on_a_park_bench.npy"
154
)
155
156
model_id = "stabilityai/stable-diffusion-2-inpainting"
157
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
158
pipe.to(torch_device)
159
pipe.set_progress_bar_config(disable=None)
160
pipe.enable_attention_slicing()
161
162
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
163
164
generator = torch.manual_seed(0)
165
output = pipe(
166
prompt=prompt,
167
image=init_image,
168
mask_image=mask_image,
169
generator=generator,
170
output_type="np",
171
)
172
image = output.images[0]
173
174
assert image.shape == (512, 512, 3)
175
assert np.abs(expected_image - image).max() < 1e-3
176
177
def test_stable_diffusion_inpaint_pipeline_fp16(self):
178
init_image = load_image(
179
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
180
"/sd2-inpaint/init_image.png"
181
)
182
mask_image = load_image(
183
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
184
)
185
expected_image = load_numpy(
186
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
187
"/yellow_cat_sitting_on_a_park_bench_fp16.npy"
188
)
189
190
model_id = "stabilityai/stable-diffusion-2-inpainting"
191
pipe = StableDiffusionInpaintPipeline.from_pretrained(
192
model_id,
193
torch_dtype=torch.float16,
194
safety_checker=None,
195
)
196
pipe.to(torch_device)
197
pipe.set_progress_bar_config(disable=None)
198
pipe.enable_attention_slicing()
199
200
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
201
202
generator = torch.manual_seed(0)
203
output = pipe(
204
prompt=prompt,
205
image=init_image,
206
mask_image=mask_image,
207
generator=generator,
208
output_type="np",
209
)
210
image = output.images[0]
211
212
assert image.shape == (512, 512, 3)
213
assert np.abs(expected_image - image).max() < 5e-1
214
215
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
216
torch.cuda.empty_cache()
217
torch.cuda.reset_max_memory_allocated()
218
torch.cuda.reset_peak_memory_stats()
219
220
init_image = load_image(
221
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
222
"/sd2-inpaint/init_image.png"
223
)
224
mask_image = load_image(
225
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
226
)
227
228
model_id = "stabilityai/stable-diffusion-2-inpainting"
229
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
230
pipe = StableDiffusionInpaintPipeline.from_pretrained(
231
model_id,
232
safety_checker=None,
233
scheduler=pndm,
234
torch_dtype=torch.float16,
235
)
236
pipe.to(torch_device)
237
pipe.set_progress_bar_config(disable=None)
238
pipe.enable_attention_slicing(1)
239
pipe.enable_sequential_cpu_offload()
240
241
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
242
243
generator = torch.manual_seed(0)
244
_ = pipe(
245
prompt=prompt,
246
image=init_image,
247
mask_image=mask_image,
248
generator=generator,
249
num_inference_steps=2,
250
output_type="np",
251
)
252
253
mem_bytes = torch.cuda.max_memory_allocated()
254
# make sure that less than 2.65 GB is allocated
255
assert mem_bytes < 2.65 * 10**9
256
257