Path: blob/main/tests/pipelines/test_pipeline_utils.py
1441 views
import unittest12from diffusers.pipelines.pipeline_utils import is_safetensors_compatible345class IsSafetensorsCompatibleTests(unittest.TestCase):6def test_all_is_compatible(self):7filenames = [8"safety_checker/pytorch_model.bin",9"safety_checker/model.safetensors",10"vae/diffusion_pytorch_model.bin",11"vae/diffusion_pytorch_model.safetensors",12"text_encoder/pytorch_model.bin",13"text_encoder/model.safetensors",14"unet/diffusion_pytorch_model.bin",15"unet/diffusion_pytorch_model.safetensors",16]17self.assertTrue(is_safetensors_compatible(filenames))1819def test_diffusers_model_is_compatible(self):20filenames = [21"unet/diffusion_pytorch_model.bin",22"unet/diffusion_pytorch_model.safetensors",23]24self.assertTrue(is_safetensors_compatible(filenames))2526def test_diffusers_model_is_not_compatible(self):27filenames = [28"safety_checker/pytorch_model.bin",29"safety_checker/model.safetensors",30"vae/diffusion_pytorch_model.bin",31"vae/diffusion_pytorch_model.safetensors",32"text_encoder/pytorch_model.bin",33"text_encoder/model.safetensors",34"unet/diffusion_pytorch_model.bin",35# Removed: 'unet/diffusion_pytorch_model.safetensors',36]37self.assertFalse(is_safetensors_compatible(filenames))3839def test_transformer_model_is_compatible(self):40filenames = [41"text_encoder/pytorch_model.bin",42"text_encoder/model.safetensors",43]44self.assertTrue(is_safetensors_compatible(filenames))4546def test_transformer_model_is_not_compatible(self):47filenames = [48"safety_checker/pytorch_model.bin",49"safety_checker/model.safetensors",50"vae/diffusion_pytorch_model.bin",51"vae/diffusion_pytorch_model.safetensors",52"text_encoder/pytorch_model.bin",53# Removed: 'text_encoder/model.safetensors',54"unet/diffusion_pytorch_model.bin",55"unet/diffusion_pytorch_model.safetensors",56]57self.assertFalse(is_safetensors_compatible(filenames))5859def test_all_is_compatible_variant(self):60filenames = [61"safety_checker/pytorch_model.fp16.bin",62"safety_checker/model.fp16.safetensors",63"vae/diffusion_pytorch_model.fp16.bin",64"vae/diffusion_pytorch_model.fp16.safetensors",65"text_encoder/pytorch_model.fp16.bin",66"text_encoder/model.fp16.safetensors",67"unet/diffusion_pytorch_model.fp16.bin",68"unet/diffusion_pytorch_model.fp16.safetensors",69]70variant = "fp16"71self.assertTrue(is_safetensors_compatible(filenames, variant=variant))7273def test_diffusers_model_is_compatible_variant(self):74filenames = [75"unet/diffusion_pytorch_model.fp16.bin",76"unet/diffusion_pytorch_model.fp16.safetensors",77]78variant = "fp16"79self.assertTrue(is_safetensors_compatible(filenames, variant=variant))8081def test_diffusers_model_is_compatible_variant_partial(self):82# pass variant but use the non-variant filenames83filenames = [84"unet/diffusion_pytorch_model.bin",85"unet/diffusion_pytorch_model.safetensors",86]87variant = "fp16"88self.assertTrue(is_safetensors_compatible(filenames, variant=variant))8990def test_diffusers_model_is_not_compatible_variant(self):91filenames = [92"safety_checker/pytorch_model.fp16.bin",93"safety_checker/model.fp16.safetensors",94"vae/diffusion_pytorch_model.fp16.bin",95"vae/diffusion_pytorch_model.fp16.safetensors",96"text_encoder/pytorch_model.fp16.bin",97"text_encoder/model.fp16.safetensors",98"unet/diffusion_pytorch_model.fp16.bin",99# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',100]101variant = "fp16"102self.assertFalse(is_safetensors_compatible(filenames, variant=variant))103104def test_transformer_model_is_compatible_variant(self):105filenames = [106"text_encoder/pytorch_model.fp16.bin",107"text_encoder/model.fp16.safetensors",108]109variant = "fp16"110self.assertTrue(is_safetensors_compatible(filenames, variant=variant))111112def test_transformer_model_is_compatible_variant_partial(self):113# pass variant but use the non-variant filenames114filenames = [115"text_encoder/pytorch_model.bin",116"text_encoder/model.safetensors",117]118variant = "fp16"119self.assertTrue(is_safetensors_compatible(filenames, variant=variant))120121def test_transformer_model_is_not_compatible_variant(self):122filenames = [123"safety_checker/pytorch_model.fp16.bin",124"safety_checker/model.fp16.safetensors",125"vae/diffusion_pytorch_model.fp16.bin",126"vae/diffusion_pytorch_model.fp16.safetensors",127"text_encoder/pytorch_model.fp16.bin",128# 'text_encoder/model.fp16.safetensors',129"unet/diffusion_pytorch_model.fp16.bin",130"unet/diffusion_pytorch_model.fp16.safetensors",131]132variant = "fp16"133self.assertFalse(is_safetensors_compatible(filenames, variant=variant))134135136