Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/paint_by_example/test_paint_by_example.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import random
18
import unittest
19
20
import numpy as np
21
import torch
22
from PIL import Image
23
from transformers import CLIPImageProcessor, CLIPVisionConfig
24
25
from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
26
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
27
from diffusers.utils import floats_tensor, load_image, slow, torch_device
28
from diffusers.utils.testing_utils import require_torch_gpu
29
30
from ...pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
31
from ...test_pipelines_common import PipelineTesterMixin
32
33
34
torch.backends.cuda.matmul.allow_tf32 = False
35
36
37
class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
38
pipeline_class = PaintByExamplePipeline
39
params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
40
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
41
42
def get_dummy_components(self):
43
torch.manual_seed(0)
44
unet = UNet2DConditionModel(
45
block_out_channels=(32, 64),
46
layers_per_block=2,
47
sample_size=32,
48
in_channels=9,
49
out_channels=4,
50
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
51
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
52
cross_attention_dim=32,
53
)
54
scheduler = PNDMScheduler(skip_prk_steps=True)
55
torch.manual_seed(0)
56
vae = AutoencoderKL(
57
block_out_channels=[32, 64],
58
in_channels=3,
59
out_channels=3,
60
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
61
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
62
latent_channels=4,
63
)
64
torch.manual_seed(0)
65
config = CLIPVisionConfig(
66
hidden_size=32,
67
projection_dim=32,
68
intermediate_size=37,
69
layer_norm_eps=1e-05,
70
num_attention_heads=4,
71
num_hidden_layers=5,
72
image_size=32,
73
patch_size=4,
74
)
75
image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
76
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
77
78
components = {
79
"unet": unet,
80
"scheduler": scheduler,
81
"vae": vae,
82
"image_encoder": image_encoder,
83
"safety_checker": None,
84
"feature_extractor": feature_extractor,
85
}
86
return components
87
88
def convert_to_pt(self, image):
89
image = np.array(image.convert("RGB"))
90
image = image[None].transpose(0, 3, 1, 2)
91
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
92
return image
93
94
def get_dummy_inputs(self, device="cpu", seed=0):
95
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
96
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
97
image = image.cpu().permute(0, 2, 3, 1)[0]
98
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
99
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
100
example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
101
102
if str(device).startswith("mps"):
103
generator = torch.manual_seed(seed)
104
else:
105
generator = torch.Generator(device=device).manual_seed(seed)
106
inputs = {
107
"example_image": example_image,
108
"image": init_image,
109
"mask_image": mask_image,
110
"generator": generator,
111
"num_inference_steps": 2,
112
"guidance_scale": 6.0,
113
"output_type": "numpy",
114
}
115
return inputs
116
117
def test_paint_by_example_inpaint(self):
118
components = self.get_dummy_components()
119
120
# make sure here that pndm scheduler skips prk
121
pipe = PaintByExamplePipeline(**components)
122
pipe = pipe.to("cpu")
123
pipe.set_progress_bar_config(disable=None)
124
125
inputs = self.get_dummy_inputs()
126
output = pipe(**inputs)
127
image = output.images
128
129
image_slice = image[0, -3:, -3:, -1]
130
131
assert image.shape == (1, 64, 64, 3)
132
expected_slice = np.array([0.4701, 0.5555, 0.3994, 0.5107, 0.5691, 0.4517, 0.5125, 0.4769, 0.4539])
133
134
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
135
136
def test_paint_by_example_image_tensor(self):
137
device = "cpu"
138
inputs = self.get_dummy_inputs()
139
inputs.pop("mask_image")
140
image = self.convert_to_pt(inputs.pop("image"))
141
mask_image = image.clamp(0, 1) / 2
142
143
# make sure here that pndm scheduler skips prk
144
pipe = PaintByExamplePipeline(**self.get_dummy_components())
145
pipe = pipe.to(device)
146
pipe.set_progress_bar_config(disable=None)
147
148
output = pipe(image=image, mask_image=mask_image[:, 0], **inputs)
149
out_1 = output.images
150
151
image = image.cpu().permute(0, 2, 3, 1)[0]
152
mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0]
153
154
image = Image.fromarray(np.uint8(image)).convert("RGB")
155
mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB")
156
157
output = pipe(**self.get_dummy_inputs())
158
out_2 = output.images
159
160
assert out_1.shape == (1, 64, 64, 3)
161
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
162
163
164
@slow
165
@require_torch_gpu
166
class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
167
def tearDown(self):
168
# clean up the VRAM after each test
169
super().tearDown()
170
gc.collect()
171
torch.cuda.empty_cache()
172
173
def test_paint_by_example(self):
174
# make sure here that pndm scheduler skips prk
175
init_image = load_image(
176
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
177
"/paint_by_example/dog_in_bucket.png"
178
)
179
mask_image = load_image(
180
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
181
"/paint_by_example/mask.png"
182
)
183
example_image = load_image(
184
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
185
"/paint_by_example/panda.jpg"
186
)
187
188
pipe = PaintByExamplePipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
189
pipe = pipe.to(torch_device)
190
pipe.set_progress_bar_config(disable=None)
191
192
generator = torch.manual_seed(321)
193
output = pipe(
194
image=init_image,
195
mask_image=mask_image,
196
example_image=example_image,
197
generator=generator,
198
guidance_scale=5.0,
199
num_inference_steps=50,
200
output_type="np",
201
)
202
203
image = output.images
204
205
image_slice = image[0, -3:, -3:, -1]
206
207
assert image.shape == (1, 512, 512, 3)
208
expected_slice = np.array([0.4834, 0.4811, 0.4874, 0.5122, 0.5081, 0.5144, 0.5291, 0.5290, 0.5374])
209
210
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
211
212