Path: blob/main/tests/schedulers/test_scheduler_flax.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14import inspect15import tempfile16import unittest17from typing import Dict, List, Tuple1819from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler20from diffusers.utils import is_flax_available21from diffusers.utils.testing_utils import require_flax222324if is_flax_available():25import jax26import jax.numpy as jnp27from jax import random2829jax_device = jax.default_backend()303132@require_flax33class FlaxSchedulerCommonTest(unittest.TestCase):34scheduler_classes = ()35forward_default_kwargs = ()3637@property38def dummy_sample(self):39batch_size = 440num_channels = 341height = 842width = 84344key1, key2 = random.split(random.PRNGKey(0))45sample = random.uniform(key1, (batch_size, num_channels, height, width))4647return sample, key24849@property50def dummy_sample_deter(self):51batch_size = 452num_channels = 353height = 854width = 85556num_elems = batch_size * num_channels * height * width57sample = jnp.arange(num_elems)58sample = sample.reshape(num_channels, height, width, batch_size)59sample = sample / num_elems60return jnp.transpose(sample, (3, 0, 1, 2))6162def get_scheduler_config(self):63raise NotImplementedError6465def dummy_model(self):66def model(sample, t, *args):67return sample * t / (t + 1)6869return model7071def check_over_configs(self, time_step=0, **config):72kwargs = dict(self.forward_default_kwargs)7374num_inference_steps = kwargs.pop("num_inference_steps", None)7576for scheduler_class in self.scheduler_classes:77sample, key = self.dummy_sample78residual = 0.1 * sample7980scheduler_config = self.get_scheduler_config(**config)81scheduler = scheduler_class(**scheduler_config)82state = scheduler.create_state()8384with tempfile.TemporaryDirectory() as tmpdirname:85scheduler.save_config(tmpdirname)86new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)8788if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):89state = scheduler.set_timesteps(state, num_inference_steps)90new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)91elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):92kwargs["num_inference_steps"] = num_inference_steps9394output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample95new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample9697assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"9899def check_over_forward(self, time_step=0, **forward_kwargs):100kwargs = dict(self.forward_default_kwargs)101kwargs.update(forward_kwargs)102103num_inference_steps = kwargs.pop("num_inference_steps", None)104105for scheduler_class in self.scheduler_classes:106sample, key = self.dummy_sample107residual = 0.1 * sample108109scheduler_config = self.get_scheduler_config()110scheduler = scheduler_class(**scheduler_config)111state = scheduler.create_state()112113with tempfile.TemporaryDirectory() as tmpdirname:114scheduler.save_config(tmpdirname)115new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)116117if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):118state = scheduler.set_timesteps(state, num_inference_steps)119new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)120elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):121kwargs["num_inference_steps"] = num_inference_steps122123output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample124new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample125126assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"127128def test_from_save_pretrained(self):129kwargs = dict(self.forward_default_kwargs)130131num_inference_steps = kwargs.pop("num_inference_steps", None)132133for scheduler_class in self.scheduler_classes:134sample, key = self.dummy_sample135residual = 0.1 * sample136137scheduler_config = self.get_scheduler_config()138scheduler = scheduler_class(**scheduler_config)139state = scheduler.create_state()140141with tempfile.TemporaryDirectory() as tmpdirname:142scheduler.save_config(tmpdirname)143new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)144145if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):146state = scheduler.set_timesteps(state, num_inference_steps)147new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)148elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):149kwargs["num_inference_steps"] = num_inference_steps150151output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample152new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample153154assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"155156def test_step_shape(self):157kwargs = dict(self.forward_default_kwargs)158159num_inference_steps = kwargs.pop("num_inference_steps", None)160161for scheduler_class in self.scheduler_classes:162scheduler_config = self.get_scheduler_config()163scheduler = scheduler_class(**scheduler_config)164state = scheduler.create_state()165166sample, key = self.dummy_sample167residual = 0.1 * sample168169if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):170state = scheduler.set_timesteps(state, num_inference_steps)171elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):172kwargs["num_inference_steps"] = num_inference_steps173174output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample175output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample176177self.assertEqual(output_0.shape, sample.shape)178self.assertEqual(output_0.shape, output_1.shape)179180def test_scheduler_outputs_equivalence(self):181def set_nan_tensor_to_zero(t):182return t.at[t != t].set(0)183184def recursive_check(tuple_object, dict_object):185if isinstance(tuple_object, (List, Tuple)):186for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):187recursive_check(tuple_iterable_value, dict_iterable_value)188elif isinstance(tuple_object, Dict):189for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):190recursive_check(tuple_iterable_value, dict_iterable_value)191elif tuple_object is None:192return193else:194self.assertTrue(195jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),196msg=(197"Tuple and dict output are not equal. Difference:"198f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"199f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"200f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."201),202)203204kwargs = dict(self.forward_default_kwargs)205num_inference_steps = kwargs.pop("num_inference_steps", None)206207for scheduler_class in self.scheduler_classes:208scheduler_config = self.get_scheduler_config()209scheduler = scheduler_class(**scheduler_config)210state = scheduler.create_state()211212sample, key = self.dummy_sample213residual = 0.1 * sample214215if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):216state = scheduler.set_timesteps(state, num_inference_steps)217elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):218kwargs["num_inference_steps"] = num_inference_steps219220outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)221222if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):223state = scheduler.set_timesteps(state, num_inference_steps)224elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):225kwargs["num_inference_steps"] = num_inference_steps226227outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)228229recursive_check(outputs_tuple[0], outputs_dict.prev_sample)230231def test_deprecated_kwargs(self):232for scheduler_class in self.scheduler_classes:233has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters234has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0235236if has_kwarg_in_model_class and not has_deprecated_kwarg:237raise ValueError(238f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"239" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"240" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="241" [<deprecated_argument>]`"242)243244if not has_kwarg_in_model_class and has_deprecated_kwarg:245raise ValueError(246f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"247" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"248f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"249" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"250)251252253@require_flax254class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):255scheduler_classes = (FlaxDDPMScheduler,)256257def get_scheduler_config(self, **kwargs):258config = {259"num_train_timesteps": 1000,260"beta_start": 0.0001,261"beta_end": 0.02,262"beta_schedule": "linear",263"variance_type": "fixed_small",264"clip_sample": True,265}266267config.update(**kwargs)268return config269270def test_timesteps(self):271for timesteps in [1, 5, 100, 1000]:272self.check_over_configs(num_train_timesteps=timesteps)273274def test_betas(self):275for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):276self.check_over_configs(beta_start=beta_start, beta_end=beta_end)277278def test_schedules(self):279for schedule in ["linear", "squaredcos_cap_v2"]:280self.check_over_configs(beta_schedule=schedule)281282def test_variance_type(self):283for variance in ["fixed_small", "fixed_large", "other"]:284self.check_over_configs(variance_type=variance)285286def test_clip_sample(self):287for clip_sample in [True, False]:288self.check_over_configs(clip_sample=clip_sample)289290def test_time_indices(self):291for t in [0, 500, 999]:292self.check_over_forward(time_step=t)293294def test_variance(self):295scheduler_class = self.scheduler_classes[0]296scheduler_config = self.get_scheduler_config()297scheduler = scheduler_class(**scheduler_config)298state = scheduler.create_state()299300assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5301assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5302assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5303304def test_full_loop_no_noise(self):305scheduler_class = self.scheduler_classes[0]306scheduler_config = self.get_scheduler_config()307scheduler = scheduler_class(**scheduler_config)308state = scheduler.create_state()309310num_trained_timesteps = len(scheduler)311312model = self.dummy_model()313sample = self.dummy_sample_deter314key1, key2 = random.split(random.PRNGKey(0))315316for t in reversed(range(num_trained_timesteps)):317# 1. predict noise residual318residual = model(sample, t)319320# 2. predict previous mean of sample x_t-1321output = scheduler.step(state, residual, t, sample, key1)322pred_prev_sample = output.prev_sample323state = output.state324key1, key2 = random.split(key2)325326# if t > 0:327# noise = self.dummy_sample_deter328# variance = scheduler.get_variance(t) ** (0.5) * noise329#330# sample = pred_prev_sample + variance331sample = pred_prev_sample332333result_sum = jnp.sum(jnp.abs(sample))334result_mean = jnp.mean(jnp.abs(sample))335336if jax_device == "tpu":337assert abs(result_sum - 255.0714) < 1e-2338assert abs(result_mean - 0.332124) < 1e-3339else:340assert abs(result_sum - 255.1113) < 1e-2341assert abs(result_mean - 0.332176) < 1e-3342343344@require_flax345class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):346scheduler_classes = (FlaxDDIMScheduler,)347forward_default_kwargs = (("num_inference_steps", 50),)348349def get_scheduler_config(self, **kwargs):350config = {351"num_train_timesteps": 1000,352"beta_start": 0.0001,353"beta_end": 0.02,354"beta_schedule": "linear",355}356357config.update(**kwargs)358return config359360def full_loop(self, **config):361scheduler_class = self.scheduler_classes[0]362scheduler_config = self.get_scheduler_config(**config)363scheduler = scheduler_class(**scheduler_config)364state = scheduler.create_state()365key1, key2 = random.split(random.PRNGKey(0))366367num_inference_steps = 10368369model = self.dummy_model()370sample = self.dummy_sample_deter371372state = scheduler.set_timesteps(state, num_inference_steps)373374for t in state.timesteps:375residual = model(sample, t)376output = scheduler.step(state, residual, t, sample)377sample = output.prev_sample378state = output.state379key1, key2 = random.split(key2)380381return sample382383def check_over_configs(self, time_step=0, **config):384kwargs = dict(self.forward_default_kwargs)385386num_inference_steps = kwargs.pop("num_inference_steps", None)387388for scheduler_class in self.scheduler_classes:389sample, _ = self.dummy_sample390residual = 0.1 * sample391392scheduler_config = self.get_scheduler_config(**config)393scheduler = scheduler_class(**scheduler_config)394state = scheduler.create_state()395396with tempfile.TemporaryDirectory() as tmpdirname:397scheduler.save_config(tmpdirname)398new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)399400if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):401state = scheduler.set_timesteps(state, num_inference_steps)402new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)403elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):404kwargs["num_inference_steps"] = num_inference_steps405406output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample407new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample408409assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"410411def test_from_save_pretrained(self):412kwargs = dict(self.forward_default_kwargs)413414num_inference_steps = kwargs.pop("num_inference_steps", None)415416for scheduler_class in self.scheduler_classes:417sample, _ = self.dummy_sample418residual = 0.1 * sample419420scheduler_config = self.get_scheduler_config()421scheduler = scheduler_class(**scheduler_config)422state = scheduler.create_state()423424with tempfile.TemporaryDirectory() as tmpdirname:425scheduler.save_config(tmpdirname)426new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)427428if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):429state = scheduler.set_timesteps(state, num_inference_steps)430new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)431elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):432kwargs["num_inference_steps"] = num_inference_steps433434output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample435new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample436437assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"438439def check_over_forward(self, time_step=0, **forward_kwargs):440kwargs = dict(self.forward_default_kwargs)441kwargs.update(forward_kwargs)442443num_inference_steps = kwargs.pop("num_inference_steps", None)444445for scheduler_class in self.scheduler_classes:446sample, _ = self.dummy_sample447residual = 0.1 * sample448449scheduler_config = self.get_scheduler_config()450scheduler = scheduler_class(**scheduler_config)451state = scheduler.create_state()452453with tempfile.TemporaryDirectory() as tmpdirname:454scheduler.save_config(tmpdirname)455new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)456457if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):458state = scheduler.set_timesteps(state, num_inference_steps)459new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)460elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):461kwargs["num_inference_steps"] = num_inference_steps462463output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample464new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample465466assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"467468def test_scheduler_outputs_equivalence(self):469def set_nan_tensor_to_zero(t):470return t.at[t != t].set(0)471472def recursive_check(tuple_object, dict_object):473if isinstance(tuple_object, (List, Tuple)):474for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):475recursive_check(tuple_iterable_value, dict_iterable_value)476elif isinstance(tuple_object, Dict):477for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):478recursive_check(tuple_iterable_value, dict_iterable_value)479elif tuple_object is None:480return481else:482self.assertTrue(483jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),484msg=(485"Tuple and dict output are not equal. Difference:"486f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"487f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"488f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."489),490)491492kwargs = dict(self.forward_default_kwargs)493num_inference_steps = kwargs.pop("num_inference_steps", None)494495for scheduler_class in self.scheduler_classes:496scheduler_config = self.get_scheduler_config()497scheduler = scheduler_class(**scheduler_config)498state = scheduler.create_state()499500sample, _ = self.dummy_sample501residual = 0.1 * sample502503if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):504state = scheduler.set_timesteps(state, num_inference_steps)505elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):506kwargs["num_inference_steps"] = num_inference_steps507508outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)509510if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):511state = scheduler.set_timesteps(state, num_inference_steps)512elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):513kwargs["num_inference_steps"] = num_inference_steps514515outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)516517recursive_check(outputs_tuple[0], outputs_dict.prev_sample)518519def test_step_shape(self):520kwargs = dict(self.forward_default_kwargs)521522num_inference_steps = kwargs.pop("num_inference_steps", None)523524for scheduler_class in self.scheduler_classes:525scheduler_config = self.get_scheduler_config()526scheduler = scheduler_class(**scheduler_config)527state = scheduler.create_state()528529sample, _ = self.dummy_sample530residual = 0.1 * sample531532if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):533state = scheduler.set_timesteps(state, num_inference_steps)534elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):535kwargs["num_inference_steps"] = num_inference_steps536537output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample538output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample539540self.assertEqual(output_0.shape, sample.shape)541self.assertEqual(output_0.shape, output_1.shape)542543def test_timesteps(self):544for timesteps in [100, 500, 1000]:545self.check_over_configs(num_train_timesteps=timesteps)546547def test_steps_offset(self):548for steps_offset in [0, 1]:549self.check_over_configs(steps_offset=steps_offset)550551scheduler_class = self.scheduler_classes[0]552scheduler_config = self.get_scheduler_config(steps_offset=1)553scheduler = scheduler_class(**scheduler_config)554state = scheduler.create_state()555state = scheduler.set_timesteps(state, 5)556assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()557558def test_betas(self):559for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):560self.check_over_configs(beta_start=beta_start, beta_end=beta_end)561562def test_schedules(self):563for schedule in ["linear", "squaredcos_cap_v2"]:564self.check_over_configs(beta_schedule=schedule)565566def test_time_indices(self):567for t in [1, 10, 49]:568self.check_over_forward(time_step=t)569570def test_inference_steps(self):571for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):572self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)573574def test_variance(self):575scheduler_class = self.scheduler_classes[0]576scheduler_config = self.get_scheduler_config()577scheduler = scheduler_class(**scheduler_config)578state = scheduler.create_state()579580assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5581assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5582assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5583assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5584assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5585assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5586587def test_full_loop_no_noise(self):588sample = self.full_loop()589590result_sum = jnp.sum(jnp.abs(sample))591result_mean = jnp.mean(jnp.abs(sample))592593assert abs(result_sum - 172.0067) < 1e-2594assert abs(result_mean - 0.223967) < 1e-3595596def test_full_loop_with_set_alpha_to_one(self):597# We specify different beta, so that the first alpha is 0.99598sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)599result_sum = jnp.sum(jnp.abs(sample))600result_mean = jnp.mean(jnp.abs(sample))601602if jax_device == "tpu":603assert abs(result_sum - 149.8409) < 1e-2604assert abs(result_mean - 0.1951) < 1e-3605else:606assert abs(result_sum - 149.8295) < 1e-2607assert abs(result_mean - 0.1951) < 1e-3608609def test_full_loop_with_no_set_alpha_to_one(self):610# We specify different beta, so that the first alpha is 0.99611sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)612result_sum = jnp.sum(jnp.abs(sample))613result_mean = jnp.mean(jnp.abs(sample))614615if jax_device == "tpu":616pass617# FIXME: both result_sum and result_mean are nan on TPU618# assert jnp.isnan(result_sum)619# assert jnp.isnan(result_mean)620else:621assert abs(result_sum - 149.0784) < 1e-2622assert abs(result_mean - 0.1941) < 1e-3623624def test_prediction_type(self):625for prediction_type in ["epsilon", "sample", "v_prediction"]:626self.check_over_configs(prediction_type=prediction_type)627628629@require_flax630class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):631scheduler_classes = (FlaxPNDMScheduler,)632forward_default_kwargs = (("num_inference_steps", 50),)633634def get_scheduler_config(self, **kwargs):635config = {636"num_train_timesteps": 1000,637"beta_start": 0.0001,638"beta_end": 0.02,639"beta_schedule": "linear",640}641642config.update(**kwargs)643return config644645def check_over_configs(self, time_step=0, **config):646kwargs = dict(self.forward_default_kwargs)647num_inference_steps = kwargs.pop("num_inference_steps", None)648sample, _ = self.dummy_sample649residual = 0.1 * sample650dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])651652for scheduler_class in self.scheduler_classes:653scheduler_config = self.get_scheduler_config(**config)654scheduler = scheduler_class(**scheduler_config)655state = scheduler.create_state()656state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)657# copy over dummy past residuals658state = state.replace(ets=dummy_past_residuals[:])659660with tempfile.TemporaryDirectory() as tmpdirname:661scheduler.save_config(tmpdirname)662new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)663new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)664# copy over dummy past residuals665new_state = new_state.replace(ets=dummy_past_residuals[:])666667(prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)668(new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)669670assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"671672output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)673new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)674675assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"676677def test_from_save_pretrained(self):678pass679680def test_scheduler_outputs_equivalence(self):681def set_nan_tensor_to_zero(t):682return t.at[t != t].set(0)683684def recursive_check(tuple_object, dict_object):685if isinstance(tuple_object, (List, Tuple)):686for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):687recursive_check(tuple_iterable_value, dict_iterable_value)688elif isinstance(tuple_object, Dict):689for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):690recursive_check(tuple_iterable_value, dict_iterable_value)691elif tuple_object is None:692return693else:694self.assertTrue(695jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),696msg=(697"Tuple and dict output are not equal. Difference:"698f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"699f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"700f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."701),702)703704kwargs = dict(self.forward_default_kwargs)705num_inference_steps = kwargs.pop("num_inference_steps", None)706707for scheduler_class in self.scheduler_classes:708scheduler_config = self.get_scheduler_config()709scheduler = scheduler_class(**scheduler_config)710state = scheduler.create_state()711712sample, _ = self.dummy_sample713residual = 0.1 * sample714715if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):716state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)717elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):718kwargs["num_inference_steps"] = num_inference_steps719720outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)721722if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):723state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)724elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):725kwargs["num_inference_steps"] = num_inference_steps726727outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)728729recursive_check(outputs_tuple[0], outputs_dict.prev_sample)730731def check_over_forward(self, time_step=0, **forward_kwargs):732kwargs = dict(self.forward_default_kwargs)733num_inference_steps = kwargs.pop("num_inference_steps", None)734sample, _ = self.dummy_sample735residual = 0.1 * sample736dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])737738for scheduler_class in self.scheduler_classes:739scheduler_config = self.get_scheduler_config()740scheduler = scheduler_class(**scheduler_config)741state = scheduler.create_state()742state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)743744# copy over dummy past residuals (must be after setting timesteps)745scheduler.ets = dummy_past_residuals[:]746747with tempfile.TemporaryDirectory() as tmpdirname:748scheduler.save_config(tmpdirname)749new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)750# copy over dummy past residuals751new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)752753# copy over dummy past residual (must be after setting timesteps)754new_state.replace(ets=dummy_past_residuals[:])755756output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)757new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)758759assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"760761output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)762new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)763764assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"765766def full_loop(self, **config):767scheduler_class = self.scheduler_classes[0]768scheduler_config = self.get_scheduler_config(**config)769scheduler = scheduler_class(**scheduler_config)770state = scheduler.create_state()771772num_inference_steps = 10773model = self.dummy_model()774sample = self.dummy_sample_deter775state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)776777for i, t in enumerate(state.prk_timesteps):778residual = model(sample, t)779sample, state = scheduler.step_prk(state, residual, t, sample)780781for i, t in enumerate(state.plms_timesteps):782residual = model(sample, t)783sample, state = scheduler.step_plms(state, residual, t, sample)784785return sample786787def test_step_shape(self):788kwargs = dict(self.forward_default_kwargs)789790num_inference_steps = kwargs.pop("num_inference_steps", None)791792for scheduler_class in self.scheduler_classes:793scheduler_config = self.get_scheduler_config()794scheduler = scheduler_class(**scheduler_config)795state = scheduler.create_state()796797sample, _ = self.dummy_sample798residual = 0.1 * sample799800if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):801state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)802elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):803kwargs["num_inference_steps"] = num_inference_steps804805# copy over dummy past residuals (must be done after set_timesteps)806dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])807state = state.replace(ets=dummy_past_residuals[:])808809output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)810output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)811812self.assertEqual(output_0.shape, sample.shape)813self.assertEqual(output_0.shape, output_1.shape)814815output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)816output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)817818self.assertEqual(output_0.shape, sample.shape)819self.assertEqual(output_0.shape, output_1.shape)820821def test_timesteps(self):822for timesteps in [100, 1000]:823self.check_over_configs(num_train_timesteps=timesteps)824825def test_steps_offset(self):826for steps_offset in [0, 1]:827self.check_over_configs(steps_offset=steps_offset)828829scheduler_class = self.scheduler_classes[0]830scheduler_config = self.get_scheduler_config(steps_offset=1)831scheduler = scheduler_class(**scheduler_config)832state = scheduler.create_state()833state = scheduler.set_timesteps(state, 10, shape=())834assert jnp.equal(835state.timesteps,836jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),837).all()838839def test_betas(self):840for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):841self.check_over_configs(beta_start=beta_start, beta_end=beta_end)842843def test_schedules(self):844for schedule in ["linear", "squaredcos_cap_v2"]:845self.check_over_configs(beta_schedule=schedule)846847def test_time_indices(self):848for t in [1, 5, 10]:849self.check_over_forward(time_step=t)850851def test_inference_steps(self):852for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):853self.check_over_forward(num_inference_steps=num_inference_steps)854855def test_pow_of_3_inference_steps(self):856# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3857num_inference_steps = 27858859for scheduler_class in self.scheduler_classes:860sample, _ = self.dummy_sample861residual = 0.1 * sample862863scheduler_config = self.get_scheduler_config()864scheduler = scheduler_class(**scheduler_config)865state = scheduler.create_state()866867state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)868869# before power of 3 fix, would error on first step, so we only need to do two870for i, t in enumerate(state.prk_timesteps[:2]):871sample, state = scheduler.step_prk(state, residual, t, sample)872873def test_inference_plms_no_past_residuals(self):874with self.assertRaises(ValueError):875scheduler_class = self.scheduler_classes[0]876scheduler_config = self.get_scheduler_config()877scheduler = scheduler_class(**scheduler_config)878state = scheduler.create_state()879880scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample881882def test_full_loop_no_noise(self):883sample = self.full_loop()884result_sum = jnp.sum(jnp.abs(sample))885result_mean = jnp.mean(jnp.abs(sample))886887if jax_device == "tpu":888assert abs(result_sum - 198.1275) < 1e-2889assert abs(result_mean - 0.2580) < 1e-3890else:891assert abs(result_sum - 198.1318) < 1e-2892assert abs(result_mean - 0.2580) < 1e-3893894def test_full_loop_with_set_alpha_to_one(self):895# We specify different beta, so that the first alpha is 0.99896sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)897result_sum = jnp.sum(jnp.abs(sample))898result_mean = jnp.mean(jnp.abs(sample))899900if jax_device == "tpu":901assert abs(result_sum - 186.83226) < 1e-2902assert abs(result_mean - 0.24327) < 1e-3903else:904assert abs(result_sum - 186.9466) < 1e-2905assert abs(result_mean - 0.24342) < 1e-3906907def test_full_loop_with_no_set_alpha_to_one(self):908# We specify different beta, so that the first alpha is 0.99909sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)910result_sum = jnp.sum(jnp.abs(sample))911result_mean = jnp.mean(jnp.abs(sample))912913if jax_device == "tpu":914assert abs(result_sum - 186.83226) < 1e-2915assert abs(result_mean - 0.24327) < 1e-3916else:917assert abs(result_sum - 186.9482) < 1e-2918assert abs(result_mean - 0.2434) < 1e-3919920921