Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import unittest
17
18
import numpy as np
19
import torch
20
from transformers import CLIPTextConfig, CLIPTextModel
21
22
from diffusers import DDIMScheduler, LDMPipeline, UNet2DModel, VQModel
23
from diffusers.utils.testing_utils import require_torch, slow, torch_device
24
25
26
torch.backends.cuda.matmul.allow_tf32 = False
27
28
29
class LDMPipelineFastTests(unittest.TestCase):
30
@property
31
def dummy_uncond_unet(self):
32
torch.manual_seed(0)
33
model = UNet2DModel(
34
block_out_channels=(32, 64),
35
layers_per_block=2,
36
sample_size=32,
37
in_channels=3,
38
out_channels=3,
39
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
40
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
41
)
42
return model
43
44
@property
45
def dummy_vq_model(self):
46
torch.manual_seed(0)
47
model = VQModel(
48
block_out_channels=[32, 64],
49
in_channels=3,
50
out_channels=3,
51
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
52
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
53
latent_channels=3,
54
)
55
return model
56
57
@property
58
def dummy_text_encoder(self):
59
torch.manual_seed(0)
60
config = CLIPTextConfig(
61
bos_token_id=0,
62
eos_token_id=2,
63
hidden_size=32,
64
intermediate_size=37,
65
layer_norm_eps=1e-05,
66
num_attention_heads=4,
67
num_hidden_layers=5,
68
pad_token_id=1,
69
vocab_size=1000,
70
)
71
return CLIPTextModel(config)
72
73
def test_inference_uncond(self):
74
unet = self.dummy_uncond_unet
75
scheduler = DDIMScheduler()
76
vae = self.dummy_vq_model
77
78
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
79
ldm.to(torch_device)
80
ldm.set_progress_bar_config(disable=None)
81
82
generator = torch.manual_seed(0)
83
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
84
85
generator = torch.manual_seed(0)
86
image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
87
88
image_slice = image[0, -3:, -3:, -1]
89
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
90
91
assert image.shape == (1, 64, 64, 3)
92
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
93
tolerance = 1e-2 if torch_device != "mps" else 3e-2
94
95
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
96
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
97
98
99
@slow
100
@require_torch
101
class LDMPipelineIntegrationTests(unittest.TestCase):
102
def test_inference_uncond(self):
103
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
104
ldm.to(torch_device)
105
ldm.set_progress_bar_config(disable=None)
106
107
generator = torch.manual_seed(0)
108
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
109
110
image_slice = image[0, -3:, -3:, -1]
111
112
assert image.shape == (1, 256, 256, 3)
113
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
114
tolerance = 1e-2 if torch_device != "mps" else 3e-2
115
116
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
117
118