Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import tempfile
18
import unittest
19
20
import numpy as np
21
import torch
22
23
from diffusers import VersatileDiffusionPipeline
24
from diffusers.utils.testing_utils import load_image, nightly, require_torch_gpu, torch_device
25
26
27
torch.backends.cuda.matmul.allow_tf32 = False
28
29
30
class VersatileDiffusionMegaPipelineFastTests(unittest.TestCase):
31
pass
32
33
34
@nightly
35
@require_torch_gpu
36
class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
37
def tearDown(self):
38
# clean up the VRAM after each test
39
super().tearDown()
40
gc.collect()
41
torch.cuda.empty_cache()
42
43
def test_from_save_pretrained(self):
44
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
45
pipe.to(torch_device)
46
pipe.set_progress_bar_config(disable=None)
47
48
prompt_image = load_image(
49
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"
50
)
51
52
generator = torch.manual_seed(0)
53
image = pipe.dual_guided(
54
prompt="first prompt",
55
image=prompt_image,
56
text_to_image_strength=0.75,
57
generator=generator,
58
guidance_scale=7.5,
59
num_inference_steps=2,
60
output_type="numpy",
61
).images
62
63
with tempfile.TemporaryDirectory() as tmpdirname:
64
pipe.save_pretrained(tmpdirname)
65
pipe = VersatileDiffusionPipeline.from_pretrained(tmpdirname, torch_dtype=torch.float16)
66
pipe.to(torch_device)
67
pipe.set_progress_bar_config(disable=None)
68
69
generator = generator.manual_seed(0)
70
new_image = pipe.dual_guided(
71
prompt="first prompt",
72
image=prompt_image,
73
text_to_image_strength=0.75,
74
generator=generator,
75
guidance_scale=7.5,
76
num_inference_steps=2,
77
output_type="numpy",
78
).images
79
80
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
81
82
def test_inference_dual_guided_then_text_to_image(self):
83
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
84
pipe.to(torch_device)
85
pipe.set_progress_bar_config(disable=None)
86
87
prompt = "cyberpunk 2077"
88
init_image = load_image(
89
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"
90
)
91
generator = torch.manual_seed(0)
92
image = pipe.dual_guided(
93
prompt=prompt,
94
image=init_image,
95
text_to_image_strength=0.75,
96
generator=generator,
97
guidance_scale=7.5,
98
num_inference_steps=50,
99
output_type="numpy",
100
).images
101
102
image_slice = image[0, 253:256, 253:256, -1]
103
104
assert image.shape == (1, 512, 512, 3)
105
expected_slice = np.array([0.1448, 0.1619, 0.1741, 0.1086, 0.1147, 0.1128, 0.1199, 0.1165, 0.1001])
106
107
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
108
109
prompt = "A painting of a squirrel eating a burger "
110
generator = torch.manual_seed(0)
111
image = pipe.text_to_image(
112
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
113
).images
114
115
image_slice = image[0, 253:256, 253:256, -1]
116
117
assert image.shape == (1, 512, 512, 3)
118
expected_slice = np.array([0.3367, 0.3169, 0.2656, 0.3870, 0.4790, 0.3796, 0.4009, 0.4878, 0.4778])
119
120
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
121
122
image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images
123
124
image_slice = image[0, 253:256, 253:256, -1]
125
126
assert image.shape == (1, 512, 512, 3)
127
expected_slice = np.array([0.3076, 0.3123, 0.3284, 0.3782, 0.3770, 0.3894, 0.4297, 0.4331, 0.4456])
128
129
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
130
131