Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_outputs.py
1440 views
1
import unittest
2
from dataclasses import dataclass
3
from typing import List, Union
4
5
import numpy as np
6
import PIL.Image
7
8
from diffusers.utils.outputs import BaseOutput
9
10
11
@dataclass
12
class CustomOutput(BaseOutput):
13
images: Union[List[PIL.Image.Image], np.ndarray]
14
15
16
class ConfigTester(unittest.TestCase):
17
def test_outputs_single_attribute(self):
18
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
19
20
# check every way of getting the attribute
21
assert isinstance(outputs.images, np.ndarray)
22
assert outputs.images.shape == (1, 3, 4, 4)
23
assert isinstance(outputs["images"], np.ndarray)
24
assert outputs["images"].shape == (1, 3, 4, 4)
25
assert isinstance(outputs[0], np.ndarray)
26
assert outputs[0].shape == (1, 3, 4, 4)
27
28
# test with a non-tensor attribute
29
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
30
31
# check every way of getting the attribute
32
assert isinstance(outputs.images, list)
33
assert isinstance(outputs.images[0], PIL.Image.Image)
34
assert isinstance(outputs["images"], list)
35
assert isinstance(outputs["images"][0], PIL.Image.Image)
36
assert isinstance(outputs[0], list)
37
assert isinstance(outputs[0][0], PIL.Image.Image)
38
39
def test_outputs_dict_init(self):
40
# test output reinitialization with a `dict` for compatibility with `accelerate`
41
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
42
43
# check every way of getting the attribute
44
assert isinstance(outputs.images, np.ndarray)
45
assert outputs.images.shape == (1, 3, 4, 4)
46
assert isinstance(outputs["images"], np.ndarray)
47
assert outputs["images"].shape == (1, 3, 4, 4)
48
assert isinstance(outputs[0], np.ndarray)
49
assert outputs[0].shape == (1, 3, 4, 4)
50
51
# test with a non-tensor attribute
52
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
53
54
# check every way of getting the attribute
55
assert isinstance(outputs.images, list)
56
assert isinstance(outputs.images[0], PIL.Image.Image)
57
assert isinstance(outputs["images"], list)
58
assert isinstance(outputs["images"][0], PIL.Image.Image)
59
assert isinstance(outputs[0], list)
60
assert isinstance(outputs[0][0], PIL.Image.Image)
61
62