Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import random
18
import tempfile
19
import unittest
20
21
import numpy as np
22
import torch
23
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
24
25
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
26
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
27
from diffusers.utils import floats_tensor, nightly, torch_device
28
from diffusers.utils.testing_utils import require_torch_gpu
29
30
31
torch.backends.cuda.matmul.allow_tf32 = False
32
33
34
class SafeDiffusionPipelineFastTests(unittest.TestCase):
35
def tearDown(self):
36
# clean up the VRAM after each test
37
super().tearDown()
38
gc.collect()
39
torch.cuda.empty_cache()
40
41
@property
42
def dummy_image(self):
43
batch_size = 1
44
num_channels = 3
45
sizes = (32, 32)
46
47
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
48
return image
49
50
@property
51
def dummy_cond_unet(self):
52
torch.manual_seed(0)
53
model = UNet2DConditionModel(
54
block_out_channels=(32, 64),
55
layers_per_block=2,
56
sample_size=32,
57
in_channels=4,
58
out_channels=4,
59
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
60
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
61
cross_attention_dim=32,
62
)
63
return model
64
65
@property
66
def dummy_vae(self):
67
torch.manual_seed(0)
68
model = AutoencoderKL(
69
block_out_channels=[32, 64],
70
in_channels=3,
71
out_channels=3,
72
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
73
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
74
latent_channels=4,
75
)
76
return model
77
78
@property
79
def dummy_text_encoder(self):
80
torch.manual_seed(0)
81
config = CLIPTextConfig(
82
bos_token_id=0,
83
eos_token_id=2,
84
hidden_size=32,
85
intermediate_size=37,
86
layer_norm_eps=1e-05,
87
num_attention_heads=4,
88
num_hidden_layers=5,
89
pad_token_id=1,
90
vocab_size=1000,
91
)
92
return CLIPTextModel(config)
93
94
@property
95
def dummy_extractor(self):
96
def extract(*args, **kwargs):
97
class Out:
98
def __init__(self):
99
self.pixel_values = torch.ones([0])
100
101
def to(self, device):
102
self.pixel_values.to(device)
103
return self
104
105
return Out()
106
107
return extract
108
109
def test_safe_diffusion_ddim(self):
110
device = "cpu" # ensure determinism for the device-dependent torch.Generator
111
unet = self.dummy_cond_unet
112
scheduler = DDIMScheduler(
113
beta_start=0.00085,
114
beta_end=0.012,
115
beta_schedule="scaled_linear",
116
clip_sample=False,
117
set_alpha_to_one=False,
118
)
119
120
vae = self.dummy_vae
121
bert = self.dummy_text_encoder
122
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
123
124
# make sure here that pndm scheduler skips prk
125
sd_pipe = StableDiffusionPipeline(
126
unet=unet,
127
scheduler=scheduler,
128
vae=vae,
129
text_encoder=bert,
130
tokenizer=tokenizer,
131
safety_checker=None,
132
feature_extractor=self.dummy_extractor,
133
)
134
sd_pipe = sd_pipe.to(device)
135
sd_pipe.set_progress_bar_config(disable=None)
136
137
prompt = "A painting of a squirrel eating a burger"
138
139
generator = torch.Generator(device=device).manual_seed(0)
140
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
141
image = output.images
142
143
generator = torch.Generator(device=device).manual_seed(0)
144
image_from_tuple = sd_pipe(
145
[prompt],
146
generator=generator,
147
guidance_scale=6.0,
148
num_inference_steps=2,
149
output_type="np",
150
return_dict=False,
151
)[0]
152
153
image_slice = image[0, -3:, -3:, -1]
154
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
155
156
assert image.shape == (1, 64, 64, 3)
157
expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792])
158
159
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
160
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
161
162
def test_stable_diffusion_pndm(self):
163
device = "cpu" # ensure determinism for the device-dependent torch.Generator
164
unet = self.dummy_cond_unet
165
scheduler = PNDMScheduler(skip_prk_steps=True)
166
vae = self.dummy_vae
167
bert = self.dummy_text_encoder
168
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
169
170
# make sure here that pndm scheduler skips prk
171
sd_pipe = StableDiffusionPipeline(
172
unet=unet,
173
scheduler=scheduler,
174
vae=vae,
175
text_encoder=bert,
176
tokenizer=tokenizer,
177
safety_checker=None,
178
feature_extractor=self.dummy_extractor,
179
)
180
sd_pipe = sd_pipe.to(device)
181
sd_pipe.set_progress_bar_config(disable=None)
182
183
prompt = "A painting of a squirrel eating a burger"
184
generator = torch.Generator(device=device).manual_seed(0)
185
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
186
187
image = output.images
188
189
generator = torch.Generator(device=device).manual_seed(0)
190
image_from_tuple = sd_pipe(
191
[prompt],
192
generator=generator,
193
guidance_scale=6.0,
194
num_inference_steps=2,
195
output_type="np",
196
return_dict=False,
197
)[0]
198
199
image_slice = image[0, -3:, -3:, -1]
200
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
201
202
assert image.shape == (1, 64, 64, 3)
203
expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])
204
205
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
206
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
207
208
def test_stable_diffusion_no_safety_checker(self):
209
pipe = StableDiffusionPipeline.from_pretrained(
210
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
211
)
212
assert isinstance(pipe, StableDiffusionPipeline)
213
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
214
assert pipe.safety_checker is None
215
216
image = pipe("example prompt", num_inference_steps=2).images[0]
217
assert image is not None
218
219
# check that there's no error when saving a pipeline with one of the models being None
220
with tempfile.TemporaryDirectory() as tmpdirname:
221
pipe.save_pretrained(tmpdirname)
222
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
223
224
# sanity check that the pipeline still works
225
assert pipe.safety_checker is None
226
image = pipe("example prompt", num_inference_steps=2).images[0]
227
assert image is not None
228
229
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
230
def test_stable_diffusion_fp16(self):
231
"""Test that stable diffusion works with fp16"""
232
unet = self.dummy_cond_unet
233
scheduler = PNDMScheduler(skip_prk_steps=True)
234
vae = self.dummy_vae
235
bert = self.dummy_text_encoder
236
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
237
238
# put models in fp16
239
unet = unet.half()
240
vae = vae.half()
241
bert = bert.half()
242
243
# make sure here that pndm scheduler skips prk
244
sd_pipe = StableDiffusionPipeline(
245
unet=unet,
246
scheduler=scheduler,
247
vae=vae,
248
text_encoder=bert,
249
tokenizer=tokenizer,
250
safety_checker=None,
251
feature_extractor=self.dummy_extractor,
252
)
253
sd_pipe = sd_pipe.to(torch_device)
254
sd_pipe.set_progress_bar_config(disable=None)
255
256
prompt = "A painting of a squirrel eating a burger"
257
image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
258
259
assert image.shape == (1, 64, 64, 3)
260
261
262
@nightly
263
@require_torch_gpu
264
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
265
def tearDown(self):
266
# clean up the VRAM after each test
267
super().tearDown()
268
gc.collect()
269
torch.cuda.empty_cache()
270
271
def test_harm_safe_stable_diffusion(self):
272
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
273
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
274
sd_pipe = sd_pipe.to(torch_device)
275
sd_pipe.set_progress_bar_config(disable=None)
276
277
prompt = (
278
"portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"
279
" coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"
280
" anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"
281
" children from bahnhof zoo, detailed "
282
)
283
seed = 4003660346
284
guidance_scale = 7
285
286
# without safety guidance (sld_guidance_scale = 0)
287
generator = torch.manual_seed(seed)
288
output = sd_pipe(
289
[prompt],
290
generator=generator,
291
guidance_scale=guidance_scale,
292
num_inference_steps=50,
293
output_type="np",
294
width=512,
295
height=512,
296
sld_guidance_scale=0,
297
)
298
299
image = output.images
300
image_slice = image[0, -3:, -3:, -1]
301
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
302
303
assert image.shape == (1, 512, 512, 3)
304
305
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
306
307
# without safety guidance (strong configuration)
308
generator = torch.manual_seed(seed)
309
output = sd_pipe(
310
[prompt],
311
generator=generator,
312
guidance_scale=guidance_scale,
313
num_inference_steps=50,
314
output_type="np",
315
width=512,
316
height=512,
317
sld_guidance_scale=2000,
318
sld_warmup_steps=7,
319
sld_threshold=0.025,
320
sld_momentum_scale=0.5,
321
sld_mom_beta=0.7,
322
)
323
324
image = output.images
325
image_slice = image[0, -3:, -3:, -1]
326
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
327
328
assert image.shape == (1, 512, 512, 3)
329
330
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
331
332
def test_nudity_safe_stable_diffusion(self):
333
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
334
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
335
sd_pipe = sd_pipe.to(torch_device)
336
sd_pipe.set_progress_bar_config(disable=None)
337
338
prompt = "padme amidala taking a bath artwork, safe for work, no nudity"
339
seed = 2734971755
340
guidance_scale = 7
341
342
generator = torch.manual_seed(seed)
343
output = sd_pipe(
344
[prompt],
345
generator=generator,
346
guidance_scale=guidance_scale,
347
num_inference_steps=50,
348
output_type="np",
349
width=512,
350
height=512,
351
sld_guidance_scale=0,
352
)
353
354
image = output.images
355
image_slice = image[0, -3:, -3:, -1]
356
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
357
358
assert image.shape == (1, 512, 512, 3)
359
360
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
361
362
generator = torch.manual_seed(seed)
363
output = sd_pipe(
364
[prompt],
365
generator=generator,
366
guidance_scale=guidance_scale,
367
num_inference_steps=50,
368
output_type="np",
369
width=512,
370
height=512,
371
sld_guidance_scale=2000,
372
sld_warmup_steps=7,
373
sld_threshold=0.025,
374
sld_momentum_scale=0.5,
375
sld_mom_beta=0.7,
376
)
377
378
image = output.images
379
image_slice = image[0, -3:, -3:, -1]
380
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
381
382
assert image.shape == (1, 512, 512, 3)
383
384
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
385
386
def test_nudity_safetychecker_safe_stable_diffusion(self):
387
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
388
sd_pipe = sd_pipe.to(torch_device)
389
sd_pipe.set_progress_bar_config(disable=None)
390
391
prompt = (
392
"the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."
393
" leyendecker"
394
)
395
seed = 1044355234
396
guidance_scale = 12
397
398
generator = torch.manual_seed(seed)
399
output = sd_pipe(
400
[prompt],
401
generator=generator,
402
guidance_scale=guidance_scale,
403
num_inference_steps=50,
404
output_type="np",
405
width=512,
406
height=512,
407
sld_guidance_scale=0,
408
)
409
410
image = output.images
411
image_slice = image[0, -3:, -3:, -1]
412
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
413
414
assert image.shape == (1, 512, 512, 3)
415
416
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
417
418
generator = torch.manual_seed(seed)
419
output = sd_pipe(
420
[prompt],
421
generator=generator,
422
guidance_scale=guidance_scale,
423
num_inference_steps=50,
424
output_type="np",
425
width=512,
426
height=512,
427
sld_guidance_scale=2000,
428
sld_warmup_steps=7,
429
sld_threshold=0.025,
430
sld_momentum_scale=0.5,
431
sld_mom_beta=0.7,
432
)
433
434
image = output.images
435
image_slice = image[0, -3:, -3:, -1]
436
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
437
assert image.shape == (1, 512, 512, 3)
438
439
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
440
441