Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import numpy as np
20
import torch
21
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
22
23
from diffusers import (
24
AutoencoderKL,
25
DDIMScheduler,
26
StableDiffusionSAGPipeline,
27
UNet2DConditionModel,
28
)
29
from diffusers.utils import slow, torch_device
30
from diffusers.utils.testing_utils import require_torch_gpu
31
32
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
33
from ...test_pipelines_common import PipelineTesterMixin
34
35
36
torch.backends.cuda.matmul.allow_tf32 = False
37
38
39
class StableDiffusionSAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
40
pipeline_class = StableDiffusionSAGPipeline
41
params = TEXT_TO_IMAGE_PARAMS
42
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
43
test_cpu_offload = False
44
45
def get_dummy_components(self):
46
torch.manual_seed(0)
47
unet = UNet2DConditionModel(
48
block_out_channels=(32, 64),
49
layers_per_block=2,
50
sample_size=32,
51
in_channels=4,
52
out_channels=4,
53
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
54
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
55
cross_attention_dim=32,
56
)
57
scheduler = DDIMScheduler(
58
beta_start=0.00085,
59
beta_end=0.012,
60
beta_schedule="scaled_linear",
61
clip_sample=False,
62
set_alpha_to_one=False,
63
)
64
torch.manual_seed(0)
65
vae = AutoencoderKL(
66
block_out_channels=[32, 64],
67
in_channels=3,
68
out_channels=3,
69
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
70
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
71
latent_channels=4,
72
)
73
torch.manual_seed(0)
74
text_encoder_config = CLIPTextConfig(
75
bos_token_id=0,
76
eos_token_id=2,
77
hidden_size=32,
78
intermediate_size=37,
79
layer_norm_eps=1e-05,
80
num_attention_heads=4,
81
num_hidden_layers=5,
82
pad_token_id=1,
83
vocab_size=1000,
84
)
85
text_encoder = CLIPTextModel(text_encoder_config)
86
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
87
88
components = {
89
"unet": unet,
90
"scheduler": scheduler,
91
"vae": vae,
92
"text_encoder": text_encoder,
93
"tokenizer": tokenizer,
94
"safety_checker": None,
95
"feature_extractor": None,
96
}
97
return components
98
99
def get_dummy_inputs(self, device, seed=0):
100
if str(device).startswith("mps"):
101
generator = torch.manual_seed(seed)
102
else:
103
generator = torch.Generator(device=device).manual_seed(seed)
104
inputs = {
105
"prompt": ".",
106
"generator": generator,
107
"num_inference_steps": 2,
108
"guidance_scale": 1.0,
109
"sag_scale": 1.0,
110
"output_type": "numpy",
111
}
112
return inputs
113
114
115
@slow
116
@require_torch_gpu
117
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
118
def tearDown(self):
119
# clean up the VRAM after each test
120
super().tearDown()
121
gc.collect()
122
torch.cuda.empty_cache()
123
124
def test_stable_diffusion_1(self):
125
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
126
sag_pipe = sag_pipe.to(torch_device)
127
sag_pipe.set_progress_bar_config(disable=None)
128
129
prompt = "."
130
generator = torch.manual_seed(0)
131
output = sag_pipe(
132
[prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
133
)
134
135
image = output.images
136
137
image_slice = image[0, -3:, -3:, -1]
138
139
assert image.shape == (1, 512, 512, 3)
140
expected_slice = np.array([0.1568, 0.1738, 0.1695, 0.1693, 0.1507, 0.1705, 0.1547, 0.1751, 0.1949])
141
142
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
143
144
def test_stable_diffusion_2(self):
145
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
146
sag_pipe = sag_pipe.to(torch_device)
147
sag_pipe.set_progress_bar_config(disable=None)
148
149
prompt = "."
150
generator = torch.manual_seed(0)
151
output = sag_pipe(
152
[prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
153
)
154
155
image = output.images
156
157
image_slice = image[0, -3:, -3:, -1]
158
159
assert image.shape == (1, 512, 512, 3)
160
expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
161
162
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
163
164
def test_stable_diffusion_2_non_square(self):
165
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
166
sag_pipe = sag_pipe.to(torch_device)
167
sag_pipe.set_progress_bar_config(disable=None)
168
169
prompt = "."
170
generator = torch.manual_seed(0)
171
output = sag_pipe(
172
[prompt],
173
width=768,
174
height=512,
175
generator=generator,
176
guidance_scale=7.5,
177
sag_scale=1.0,
178
num_inference_steps=20,
179
output_type="np",
180
)
181
182
image = output.images
183
184
assert image.shape == (1, 512, 768, 3)
185
186