Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/test_pipeline_utils.py
1441 views
1
import unittest
2
3
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
4
5
6
class IsSafetensorsCompatibleTests(unittest.TestCase):
7
def test_all_is_compatible(self):
8
filenames = [
9
"safety_checker/pytorch_model.bin",
10
"safety_checker/model.safetensors",
11
"vae/diffusion_pytorch_model.bin",
12
"vae/diffusion_pytorch_model.safetensors",
13
"text_encoder/pytorch_model.bin",
14
"text_encoder/model.safetensors",
15
"unet/diffusion_pytorch_model.bin",
16
"unet/diffusion_pytorch_model.safetensors",
17
]
18
self.assertTrue(is_safetensors_compatible(filenames))
19
20
def test_diffusers_model_is_compatible(self):
21
filenames = [
22
"unet/diffusion_pytorch_model.bin",
23
"unet/diffusion_pytorch_model.safetensors",
24
]
25
self.assertTrue(is_safetensors_compatible(filenames))
26
27
def test_diffusers_model_is_not_compatible(self):
28
filenames = [
29
"safety_checker/pytorch_model.bin",
30
"safety_checker/model.safetensors",
31
"vae/diffusion_pytorch_model.bin",
32
"vae/diffusion_pytorch_model.safetensors",
33
"text_encoder/pytorch_model.bin",
34
"text_encoder/model.safetensors",
35
"unet/diffusion_pytorch_model.bin",
36
# Removed: 'unet/diffusion_pytorch_model.safetensors',
37
]
38
self.assertFalse(is_safetensors_compatible(filenames))
39
40
def test_transformer_model_is_compatible(self):
41
filenames = [
42
"text_encoder/pytorch_model.bin",
43
"text_encoder/model.safetensors",
44
]
45
self.assertTrue(is_safetensors_compatible(filenames))
46
47
def test_transformer_model_is_not_compatible(self):
48
filenames = [
49
"safety_checker/pytorch_model.bin",
50
"safety_checker/model.safetensors",
51
"vae/diffusion_pytorch_model.bin",
52
"vae/diffusion_pytorch_model.safetensors",
53
"text_encoder/pytorch_model.bin",
54
# Removed: 'text_encoder/model.safetensors',
55
"unet/diffusion_pytorch_model.bin",
56
"unet/diffusion_pytorch_model.safetensors",
57
]
58
self.assertFalse(is_safetensors_compatible(filenames))
59
60
def test_all_is_compatible_variant(self):
61
filenames = [
62
"safety_checker/pytorch_model.fp16.bin",
63
"safety_checker/model.fp16.safetensors",
64
"vae/diffusion_pytorch_model.fp16.bin",
65
"vae/diffusion_pytorch_model.fp16.safetensors",
66
"text_encoder/pytorch_model.fp16.bin",
67
"text_encoder/model.fp16.safetensors",
68
"unet/diffusion_pytorch_model.fp16.bin",
69
"unet/diffusion_pytorch_model.fp16.safetensors",
70
]
71
variant = "fp16"
72
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
73
74
def test_diffusers_model_is_compatible_variant(self):
75
filenames = [
76
"unet/diffusion_pytorch_model.fp16.bin",
77
"unet/diffusion_pytorch_model.fp16.safetensors",
78
]
79
variant = "fp16"
80
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
81
82
def test_diffusers_model_is_compatible_variant_partial(self):
83
# pass variant but use the non-variant filenames
84
filenames = [
85
"unet/diffusion_pytorch_model.bin",
86
"unet/diffusion_pytorch_model.safetensors",
87
]
88
variant = "fp16"
89
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
90
91
def test_diffusers_model_is_not_compatible_variant(self):
92
filenames = [
93
"safety_checker/pytorch_model.fp16.bin",
94
"safety_checker/model.fp16.safetensors",
95
"vae/diffusion_pytorch_model.fp16.bin",
96
"vae/diffusion_pytorch_model.fp16.safetensors",
97
"text_encoder/pytorch_model.fp16.bin",
98
"text_encoder/model.fp16.safetensors",
99
"unet/diffusion_pytorch_model.fp16.bin",
100
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
101
]
102
variant = "fp16"
103
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
104
105
def test_transformer_model_is_compatible_variant(self):
106
filenames = [
107
"text_encoder/pytorch_model.fp16.bin",
108
"text_encoder/model.fp16.safetensors",
109
]
110
variant = "fp16"
111
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
112
113
def test_transformer_model_is_compatible_variant_partial(self):
114
# pass variant but use the non-variant filenames
115
filenames = [
116
"text_encoder/pytorch_model.bin",
117
"text_encoder/model.safetensors",
118
]
119
variant = "fp16"
120
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
121
122
def test_transformer_model_is_not_compatible_variant(self):
123
filenames = [
124
"safety_checker/pytorch_model.fp16.bin",
125
"safety_checker/model.fp16.safetensors",
126
"vae/diffusion_pytorch_model.fp16.bin",
127
"vae/diffusion_pytorch_model.fp16.safetensors",
128
"text_encoder/pytorch_model.fp16.bin",
129
# 'text_encoder/model.fp16.safetensors',
130
"unet/diffusion_pytorch_model.fp16.bin",
131
"unet/diffusion_pytorch_model.fp16.safetensors",
132
]
133
variant = "fp16"
134
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
135
136