Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_image_processor.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import unittest
17
18
import numpy as np
19
import PIL
20
import torch
21
22
from diffusers.image_processor import VaeImageProcessor
23
24
25
class ImageProcessorTest(unittest.TestCase):
26
@property
27
def dummy_sample(self):
28
batch_size = 1
29
num_channels = 3
30
height = 8
31
width = 8
32
33
sample = torch.rand((batch_size, num_channels, height, width))
34
35
return sample
36
37
def to_np(self, image):
38
if isinstance(image[0], PIL.Image.Image):
39
return np.stack([np.array(i) for i in image], axis=0)
40
elif isinstance(image, torch.Tensor):
41
return image.cpu().numpy().transpose(0, 2, 3, 1)
42
return image
43
44
def test_vae_image_processor_pt(self):
45
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
46
47
input_pt = self.dummy_sample
48
input_np = self.to_np(input_pt)
49
50
for output_type in ["pt", "np", "pil"]:
51
out = image_processor.postprocess(
52
image_processor.preprocess(input_pt),
53
output_type=output_type,
54
)
55
out_np = self.to_np(out)
56
in_np = (input_np * 255).round() if output_type == "pil" else input_np
57
assert (
58
np.abs(in_np - out_np).max() < 1e-6
59
), f"decoded output does not match input for output_type {output_type}"
60
61
def test_vae_image_processor_np(self):
62
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
63
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
64
65
for output_type in ["pt", "np", "pil"]:
66
out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type)
67
68
out_np = self.to_np(out)
69
in_np = (input_np * 255).round() if output_type == "pil" else input_np
70
assert (
71
np.abs(in_np - out_np).max() < 1e-6
72
), f"decoded output does not match input for output_type {output_type}"
73
74
def test_vae_image_processor_pil(self):
75
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
76
77
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
78
input_pil = image_processor.numpy_to_pil(input_np)
79
80
for output_type in ["pt", "np", "pil"]:
81
out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type)
82
for i, o in zip(input_pil, out):
83
in_np = np.array(i)
84
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
85
assert (
86
np.abs(in_np - out_np).max() < 1e-6
87
), f"decoded output does not match input for output_type {output_type}"
88
89
def test_preprocess_input_3d(self):
90
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
91
92
input_pt_4d = self.dummy_sample
93
input_pt_3d = input_pt_4d.squeeze(0)
94
95
out_pt_4d = image_processor.postprocess(
96
image_processor.preprocess(input_pt_4d),
97
output_type="np",
98
)
99
out_pt_3d = image_processor.postprocess(
100
image_processor.preprocess(input_pt_3d),
101
output_type="np",
102
)
103
104
input_np_4d = self.to_np(self.dummy_sample)
105
input_np_3d = input_np_4d.squeeze(0)
106
107
out_np_4d = image_processor.postprocess(
108
image_processor.preprocess(input_np_4d),
109
output_type="np",
110
)
111
out_np_3d = image_processor.postprocess(
112
image_processor.preprocess(input_np_3d),
113
output_type="np",
114
)
115
116
assert np.abs(out_pt_4d - out_pt_3d).max() < 1e-6
117
assert np.abs(out_np_4d - out_np_3d).max() < 1e-6
118
119
def test_preprocess_input_list(self):
120
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
121
122
input_pt_4d = self.dummy_sample
123
input_pt_list = list(input_pt_4d)
124
125
out_pt_4d = image_processor.postprocess(
126
image_processor.preprocess(input_pt_4d),
127
output_type="np",
128
)
129
130
out_pt_list = image_processor.postprocess(
131
image_processor.preprocess(input_pt_list),
132
output_type="np",
133
)
134
135
input_np_4d = self.to_np(self.dummy_sample)
136
list(input_np_4d)
137
138
out_np_4d = image_processor.postprocess(
139
image_processor.preprocess(input_pt_4d),
140
output_type="np",
141
)
142
143
out_np_list = image_processor.postprocess(
144
image_processor.preprocess(input_pt_list),
145
output_type="np",
146
)
147
148
assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6
149
assert np.abs(out_np_4d - out_np_list).max() < 1e-6
150
151