Path: blob/main/tests/schedulers/test_schedulers.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14import inspect15import json16import os17import tempfile18import unittest19from typing import Dict, List, Tuple2021import numpy as np22import torch2324import diffusers25from diffusers import (26EulerAncestralDiscreteScheduler,27EulerDiscreteScheduler,28IPNDMScheduler,29LMSDiscreteScheduler,30VQDiffusionScheduler,31logging,32)33from diffusers.configuration_utils import ConfigMixin, register_to_config34from diffusers.schedulers.scheduling_utils import SchedulerMixin35from diffusers.utils import torch_device36from diffusers.utils.testing_utils import CaptureLogger373839torch.backends.cuda.matmul.allow_tf32 = False404142class SchedulerObject(SchedulerMixin, ConfigMixin):43config_name = "config.json"4445@register_to_config46def __init__(47self,48a=2,49b=5,50c=(2, 5),51d="for diffusion",52e=[1, 3],53):54pass555657class SchedulerObject2(SchedulerMixin, ConfigMixin):58config_name = "config.json"5960@register_to_config61def __init__(62self,63a=2,64b=5,65c=(2, 5),66d="for diffusion",67f=[1, 3],68):69pass707172class SchedulerObject3(SchedulerMixin, ConfigMixin):73config_name = "config.json"7475@register_to_config76def __init__(77self,78a=2,79b=5,80c=(2, 5),81d="for diffusion",82e=[1, 3],83f=[1, 3],84):85pass868788class SchedulerBaseTests(unittest.TestCase):89def test_save_load_from_different_config(self):90obj = SchedulerObject()9192# mock add obj class to `diffusers`93setattr(diffusers, "SchedulerObject", SchedulerObject)94logger = logging.get_logger("diffusers.configuration_utils")9596with tempfile.TemporaryDirectory() as tmpdirname:97obj.save_config(tmpdirname)98with CaptureLogger(logger) as cap_logger_1:99config = SchedulerObject2.load_config(tmpdirname)100new_obj_1 = SchedulerObject2.from_config(config)101102# now save a config parameter that is not expected103with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:104data = json.load(f)105data["unexpected"] = True106107with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:108json.dump(data, f)109110with CaptureLogger(logger) as cap_logger_2:111config = SchedulerObject.load_config(tmpdirname)112new_obj_2 = SchedulerObject.from_config(config)113114with CaptureLogger(logger) as cap_logger_3:115config = SchedulerObject2.load_config(tmpdirname)116new_obj_3 = SchedulerObject2.from_config(config)117118assert new_obj_1.__class__ == SchedulerObject2119assert new_obj_2.__class__ == SchedulerObject120assert new_obj_3.__class__ == SchedulerObject2121122assert cap_logger_1.out == ""123assert (124cap_logger_2.out125== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"126" will"127" be ignored. Please verify your config.json configuration file.\n"128)129assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out130131def test_save_load_compatible_schedulers(self):132SchedulerObject2._compatibles = ["SchedulerObject"]133SchedulerObject._compatibles = ["SchedulerObject2"]134135obj = SchedulerObject()136137# mock add obj class to `diffusers`138setattr(diffusers, "SchedulerObject", SchedulerObject)139setattr(diffusers, "SchedulerObject2", SchedulerObject2)140logger = logging.get_logger("diffusers.configuration_utils")141142with tempfile.TemporaryDirectory() as tmpdirname:143obj.save_config(tmpdirname)144145# now save a config parameter that is expected by another class, but not origin class146with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:147data = json.load(f)148data["f"] = [0, 0]149data["unexpected"] = True150151with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:152json.dump(data, f)153154with CaptureLogger(logger) as cap_logger:155config = SchedulerObject.load_config(tmpdirname)156new_obj = SchedulerObject.from_config(config)157158assert new_obj.__class__ == SchedulerObject159160assert (161cap_logger.out162== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"163" will"164" be ignored. Please verify your config.json configuration file.\n"165)166167def test_save_load_from_different_config_comp_schedulers(self):168SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]169SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]170SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]171172obj = SchedulerObject()173174# mock add obj class to `diffusers`175setattr(diffusers, "SchedulerObject", SchedulerObject)176setattr(diffusers, "SchedulerObject2", SchedulerObject2)177setattr(diffusers, "SchedulerObject3", SchedulerObject3)178logger = logging.get_logger("diffusers.configuration_utils")179logger.setLevel(diffusers.logging.INFO)180181with tempfile.TemporaryDirectory() as tmpdirname:182obj.save_config(tmpdirname)183184with CaptureLogger(logger) as cap_logger_1:185config = SchedulerObject.load_config(tmpdirname)186new_obj_1 = SchedulerObject.from_config(config)187188with CaptureLogger(logger) as cap_logger_2:189config = SchedulerObject2.load_config(tmpdirname)190new_obj_2 = SchedulerObject2.from_config(config)191192with CaptureLogger(logger) as cap_logger_3:193config = SchedulerObject3.load_config(tmpdirname)194new_obj_3 = SchedulerObject3.from_config(config)195196assert new_obj_1.__class__ == SchedulerObject197assert new_obj_2.__class__ == SchedulerObject2198assert new_obj_3.__class__ == SchedulerObject3199200assert cap_logger_1.out == ""201assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"202assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"203204205class SchedulerCommonTest(unittest.TestCase):206scheduler_classes = ()207forward_default_kwargs = ()208209@property210def dummy_sample(self):211batch_size = 4212num_channels = 3213height = 8214width = 8215216sample = torch.rand((batch_size, num_channels, height, width))217218return sample219220@property221def dummy_sample_deter(self):222batch_size = 4223num_channels = 3224height = 8225width = 8226227num_elems = batch_size * num_channels * height * width228sample = torch.arange(num_elems)229sample = sample.reshape(num_channels, height, width, batch_size)230sample = sample / num_elems231sample = sample.permute(3, 0, 1, 2)232233return sample234235def get_scheduler_config(self):236raise NotImplementedError237238def dummy_model(self):239def model(sample, t, *args):240return sample * t / (t + 1)241242return model243244def check_over_configs(self, time_step=0, **config):245kwargs = dict(self.forward_default_kwargs)246247num_inference_steps = kwargs.pop("num_inference_steps", None)248249for scheduler_class in self.scheduler_classes:250# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default251if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):252time_step = float(time_step)253254scheduler_config = self.get_scheduler_config(**config)255scheduler = scheduler_class(**scheduler_config)256257if scheduler_class == VQDiffusionScheduler:258num_vec_classes = scheduler_config["num_vec_classes"]259sample = self.dummy_sample(num_vec_classes)260model = self.dummy_model(num_vec_classes)261residual = model(sample, time_step)262else:263sample = self.dummy_sample264residual = 0.1 * sample265266with tempfile.TemporaryDirectory() as tmpdirname:267scheduler.save_config(tmpdirname)268new_scheduler = scheduler_class.from_pretrained(tmpdirname)269270if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):271scheduler.set_timesteps(num_inference_steps)272new_scheduler.set_timesteps(num_inference_steps)273elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):274kwargs["num_inference_steps"] = num_inference_steps275276# Make sure `scale_model_input` is invoked to prevent a warning277if scheduler_class != VQDiffusionScheduler:278_ = scheduler.scale_model_input(sample, 0)279_ = new_scheduler.scale_model_input(sample, 0)280281# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler282if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):283kwargs["generator"] = torch.manual_seed(0)284output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample285286if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):287kwargs["generator"] = torch.manual_seed(0)288new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample289290assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"291292def check_over_forward(self, time_step=0, **forward_kwargs):293kwargs = dict(self.forward_default_kwargs)294kwargs.update(forward_kwargs)295296num_inference_steps = kwargs.pop("num_inference_steps", None)297298for scheduler_class in self.scheduler_classes:299if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):300time_step = float(time_step)301302scheduler_config = self.get_scheduler_config()303scheduler = scheduler_class(**scheduler_config)304305if scheduler_class == VQDiffusionScheduler:306num_vec_classes = scheduler_config["num_vec_classes"]307sample = self.dummy_sample(num_vec_classes)308model = self.dummy_model(num_vec_classes)309residual = model(sample, time_step)310else:311sample = self.dummy_sample312residual = 0.1 * sample313314with tempfile.TemporaryDirectory() as tmpdirname:315scheduler.save_config(tmpdirname)316new_scheduler = scheduler_class.from_pretrained(tmpdirname)317318if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):319scheduler.set_timesteps(num_inference_steps)320new_scheduler.set_timesteps(num_inference_steps)321elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):322kwargs["num_inference_steps"] = num_inference_steps323324if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):325kwargs["generator"] = torch.manual_seed(0)326output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample327328if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):329kwargs["generator"] = torch.manual_seed(0)330new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample331332assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"333334def test_from_save_pretrained(self):335kwargs = dict(self.forward_default_kwargs)336337num_inference_steps = kwargs.pop("num_inference_steps", None)338339for scheduler_class in self.scheduler_classes:340timestep = 1341if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):342timestep = float(timestep)343344scheduler_config = self.get_scheduler_config()345scheduler = scheduler_class(**scheduler_config)346347if scheduler_class == VQDiffusionScheduler:348num_vec_classes = scheduler_config["num_vec_classes"]349sample = self.dummy_sample(num_vec_classes)350model = self.dummy_model(num_vec_classes)351residual = model(sample, timestep)352else:353sample = self.dummy_sample354residual = 0.1 * sample355356with tempfile.TemporaryDirectory() as tmpdirname:357scheduler.save_config(tmpdirname)358new_scheduler = scheduler_class.from_pretrained(tmpdirname)359360if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):361scheduler.set_timesteps(num_inference_steps)362new_scheduler.set_timesteps(num_inference_steps)363elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):364kwargs["num_inference_steps"] = num_inference_steps365366if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):367kwargs["generator"] = torch.manual_seed(0)368output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample369370if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):371kwargs["generator"] = torch.manual_seed(0)372new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample373374assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"375376def test_compatibles(self):377for scheduler_class in self.scheduler_classes:378scheduler_config = self.get_scheduler_config()379380scheduler = scheduler_class(**scheduler_config)381382assert all(c is not None for c in scheduler.compatibles)383384for comp_scheduler_cls in scheduler.compatibles:385comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)386assert comp_scheduler is not None387388new_scheduler = scheduler_class.from_config(comp_scheduler.config)389390new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}391scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}392393# make sure that configs are essentially identical394assert new_scheduler_config == dict(scheduler.config)395396# make sure that only differences are for configs that are not in init397init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()398assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()399400def test_from_pretrained(self):401for scheduler_class in self.scheduler_classes:402scheduler_config = self.get_scheduler_config()403404scheduler = scheduler_class(**scheduler_config)405406with tempfile.TemporaryDirectory() as tmpdirname:407scheduler.save_pretrained(tmpdirname)408new_scheduler = scheduler_class.from_pretrained(tmpdirname)409410assert scheduler.config == new_scheduler.config411412def test_step_shape(self):413kwargs = dict(self.forward_default_kwargs)414415num_inference_steps = kwargs.pop("num_inference_steps", None)416417timestep_0 = 0418timestep_1 = 1419420for scheduler_class in self.scheduler_classes:421if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):422timestep_0 = float(timestep_0)423timestep_1 = float(timestep_1)424425scheduler_config = self.get_scheduler_config()426scheduler = scheduler_class(**scheduler_config)427428if scheduler_class == VQDiffusionScheduler:429num_vec_classes = scheduler_config["num_vec_classes"]430sample = self.dummy_sample(num_vec_classes)431model = self.dummy_model(num_vec_classes)432residual = model(sample, timestep_0)433else:434sample = self.dummy_sample435residual = 0.1 * sample436437if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):438scheduler.set_timesteps(num_inference_steps)439elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):440kwargs["num_inference_steps"] = num_inference_steps441442output_0 = scheduler.step(residual, timestep_0, sample, **kwargs).prev_sample443output_1 = scheduler.step(residual, timestep_1, sample, **kwargs).prev_sample444445self.assertEqual(output_0.shape, sample.shape)446self.assertEqual(output_0.shape, output_1.shape)447448def test_scheduler_outputs_equivalence(self):449def set_nan_tensor_to_zero(t):450t[t != t] = 0451return t452453def recursive_check(tuple_object, dict_object):454if isinstance(tuple_object, (List, Tuple)):455for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):456recursive_check(tuple_iterable_value, dict_iterable_value)457elif isinstance(tuple_object, Dict):458for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):459recursive_check(tuple_iterable_value, dict_iterable_value)460elif tuple_object is None:461return462else:463self.assertTrue(464torch.allclose(465set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5466),467msg=(468"Tuple and dict output are not equal. Difference:"469f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"470f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"471f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."472),473)474475kwargs = dict(self.forward_default_kwargs)476num_inference_steps = kwargs.pop("num_inference_steps", 50)477478timestep = 0479if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:480timestep = 1481482for scheduler_class in self.scheduler_classes:483if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):484timestep = float(timestep)485486scheduler_config = self.get_scheduler_config()487scheduler = scheduler_class(**scheduler_config)488489if scheduler_class == VQDiffusionScheduler:490num_vec_classes = scheduler_config["num_vec_classes"]491sample = self.dummy_sample(num_vec_classes)492model = self.dummy_model(num_vec_classes)493residual = model(sample, timestep)494else:495sample = self.dummy_sample496residual = 0.1 * sample497498if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):499scheduler.set_timesteps(num_inference_steps)500elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):501kwargs["num_inference_steps"] = num_inference_steps502503# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler504if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):505kwargs["generator"] = torch.manual_seed(0)506outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)507508if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):509scheduler.set_timesteps(num_inference_steps)510elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):511kwargs["num_inference_steps"] = num_inference_steps512513# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler514if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):515kwargs["generator"] = torch.manual_seed(0)516outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)517518recursive_check(outputs_tuple, outputs_dict)519520def test_scheduler_public_api(self):521for scheduler_class in self.scheduler_classes:522scheduler_config = self.get_scheduler_config()523scheduler = scheduler_class(**scheduler_config)524525if scheduler_class != VQDiffusionScheduler:526self.assertTrue(527hasattr(scheduler, "init_noise_sigma"),528f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",529)530self.assertTrue(531hasattr(scheduler, "scale_model_input"),532(533f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"534" timestep)`"535),536)537self.assertTrue(538hasattr(scheduler, "step"),539f"{scheduler_class} does not implement a required class method `step(...)`",540)541542if scheduler_class != VQDiffusionScheduler:543sample = self.dummy_sample544scaled_sample = scheduler.scale_model_input(sample, 0.0)545self.assertEqual(sample.shape, scaled_sample.shape)546547def test_add_noise_device(self):548for scheduler_class in self.scheduler_classes:549if scheduler_class == IPNDMScheduler:550continue551scheduler_config = self.get_scheduler_config()552scheduler = scheduler_class(**scheduler_config)553scheduler.set_timesteps(100)554555sample = self.dummy_sample.to(torch_device)556scaled_sample = scheduler.scale_model_input(sample, 0.0)557self.assertEqual(sample.shape, scaled_sample.shape)558559noise = torch.randn_like(scaled_sample).to(torch_device)560t = scheduler.timesteps[5][None]561noised = scheduler.add_noise(scaled_sample, noise, t)562self.assertEqual(noised.shape, scaled_sample.shape)563564def test_deprecated_kwargs(self):565for scheduler_class in self.scheduler_classes:566has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters567has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0568569if has_kwarg_in_model_class and not has_deprecated_kwarg:570raise ValueError(571f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"572" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"573" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="574" [<deprecated_argument>]`"575)576577if not has_kwarg_in_model_class and has_deprecated_kwarg:578raise ValueError(579f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"580" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"581f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"582" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"583)584585def test_trained_betas(self):586for scheduler_class in self.scheduler_classes:587if scheduler_class == VQDiffusionScheduler:588continue589590scheduler_config = self.get_scheduler_config()591scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.1, 0.3]))592593with tempfile.TemporaryDirectory() as tmpdirname:594scheduler.save_pretrained(tmpdirname)595new_scheduler = scheduler_class.from_pretrained(tmpdirname)596597assert scheduler.betas.tolist() == new_scheduler.betas.tolist()598599600