Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import tempfile
18
import unittest
19
20
import numpy as np
21
import torch
22
23
from diffusers import VersatileDiffusionDualGuidedPipeline
24
from diffusers.utils.testing_utils import load_image, nightly, require_torch_gpu, torch_device
25
26
27
torch.backends.cuda.matmul.allow_tf32 = False
28
29
30
@nightly
31
@require_torch_gpu
32
class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
33
def tearDown(self):
34
# clean up the VRAM after each test
35
super().tearDown()
36
gc.collect()
37
torch.cuda.empty_cache()
38
39
def test_remove_unused_weights_save_load(self):
40
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")
41
# remove text_unet
42
pipe.remove_unused_weights()
43
pipe.to(torch_device)
44
pipe.set_progress_bar_config(disable=None)
45
46
second_prompt = load_image(
47
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"
48
)
49
50
generator = torch.manual_seed(0)
51
image = pipe(
52
prompt="first prompt",
53
image=second_prompt,
54
text_to_image_strength=0.75,
55
generator=generator,
56
guidance_scale=7.5,
57
num_inference_steps=2,
58
output_type="numpy",
59
).images
60
61
with tempfile.TemporaryDirectory() as tmpdirname:
62
pipe.save_pretrained(tmpdirname)
63
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname)
64
65
pipe.to(torch_device)
66
pipe.set_progress_bar_config(disable=None)
67
68
generator = generator.manual_seed(0)
69
new_image = pipe(
70
prompt="first prompt",
71
image=second_prompt,
72
text_to_image_strength=0.75,
73
generator=generator,
74
guidance_scale=7.5,
75
num_inference_steps=2,
76
output_type="numpy",
77
).images
78
79
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
80
81
def test_inference_dual_guided(self):
82
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")
83
pipe.remove_unused_weights()
84
pipe.to(torch_device)
85
pipe.set_progress_bar_config(disable=None)
86
87
first_prompt = "cyberpunk 2077"
88
second_prompt = load_image(
89
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"
90
)
91
generator = torch.manual_seed(0)
92
image = pipe(
93
prompt=first_prompt,
94
image=second_prompt,
95
text_to_image_strength=0.75,
96
generator=generator,
97
guidance_scale=7.5,
98
num_inference_steps=50,
99
output_type="numpy",
100
).images
101
102
image_slice = image[0, 253:256, 253:256, -1]
103
104
assert image.shape == (1, 512, 512, 3)
105
expected_slice = np.array([0.0787, 0.0849, 0.0826, 0.0812, 0.0807, 0.0795, 0.0818, 0.0798, 0.0779])
106
107
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
108
109