Path: blob/main/tests/models/test_models_unet_1d.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import unittest1617import torch1819from diffusers import UNet1DModel20from diffusers.utils import floats_tensor, slow, torch_device2122from ..test_modeling_common import ModelTesterMixin232425torch.backends.cuda.matmul.allow_tf32 = False262728class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):29model_class = UNet1DModel3031@property32def dummy_input(self):33batch_size = 434num_features = 1435seq_len = 163637noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)38time_step = torch.tensor([10] * batch_size).to(torch_device)3940return {"sample": noise, "timestep": time_step}4142@property43def input_shape(self):44return (4, 14, 16)4546@property47def output_shape(self):48return (4, 14, 16)4950def test_ema_training(self):51pass5253def test_training(self):54pass5556@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")57def test_determinism(self):58super().test_determinism()5960@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")61def test_outputs_equivalence(self):62super().test_outputs_equivalence()6364@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")65def test_from_save_pretrained(self):66super().test_from_save_pretrained()6768@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")69def test_from_save_pretrained_variant(self):70super().test_from_save_pretrained_variant()7172@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")73def test_model_from_pretrained(self):74super().test_model_from_pretrained()7576@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")77def test_output(self):78super().test_output()7980def prepare_init_args_and_inputs_for_common(self):81init_dict = {82"block_out_channels": (32, 64, 128, 256),83"in_channels": 14,84"out_channels": 14,85"time_embedding_type": "positional",86"use_timestep_embedding": True,87"flip_sin_to_cos": False,88"freq_shift": 1.0,89"out_block_type": "OutConv1DBlock",90"mid_block_type": "MidResTemporalBlock1D",91"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),92"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),93"act_fn": "mish",94}95inputs_dict = self.dummy_input96return init_dict, inputs_dict9798@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")99def test_from_pretrained_hub(self):100model, loading_info = UNet1DModel.from_pretrained(101"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"102)103self.assertIsNotNone(model)104self.assertEqual(len(loading_info["missing_keys"]), 0)105106model.to(torch_device)107image = model(**self.dummy_input)108109assert image is not None, "Make sure output is not None"110111@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")112def test_output_pretrained(self):113model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")114torch.manual_seed(0)115if torch.cuda.is_available():116torch.cuda.manual_seed_all(0)117118num_features = model.in_channels119seq_len = 16120noise = torch.randn((1, seq_len, num_features)).permute(1210, 2, 1122) # match original, we can update values and remove123time_step = torch.full((num_features,), 0)124125with torch.no_grad():126output = model(noise, time_step).sample.permute(0, 2, 1)127128output_slice = output[0, -3:, -3:].flatten()129# fmt: off130expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])131# fmt: on132self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))133134def test_forward_with_norm_groups(self):135# Not implemented yet for this UNet136pass137138@slow139def test_unet_1d_maestro(self):140model_id = "harmonai/maestro-150k"141model = UNet1DModel.from_pretrained(model_id, subfolder="unet")142model.to(torch_device)143144sample_size = 65536145noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device)146timestep = torch.tensor([1]).to(torch_device)147148with torch.no_grad():149output = model(noise, timestep).sample150151output_sum = output.abs().sum()152output_max = output.abs().max()153154assert (output_sum - 224.0896).abs() < 4e-2155assert (output_max - 0.0607).abs() < 4e-4156157158class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):159model_class = UNet1DModel160161@property162def dummy_input(self):163batch_size = 4164num_features = 14165seq_len = 16166167noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)168time_step = torch.tensor([10] * batch_size).to(torch_device)169170return {"sample": noise, "timestep": time_step}171172@property173def input_shape(self):174return (4, 14, 16)175176@property177def output_shape(self):178return (4, 14, 1)179180@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")181def test_determinism(self):182super().test_determinism()183184@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")185def test_outputs_equivalence(self):186super().test_outputs_equivalence()187188@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")189def test_from_save_pretrained(self):190super().test_from_save_pretrained()191192@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")193def test_from_save_pretrained_variant(self):194super().test_from_save_pretrained_variant()195196@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")197def test_model_from_pretrained(self):198super().test_model_from_pretrained()199200@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")201def test_output(self):202# UNetRL is a value-function is different output shape203init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()204model = self.model_class(**init_dict)205model.to(torch_device)206model.eval()207208with torch.no_grad():209output = model(**inputs_dict)210211if isinstance(output, dict):212output = output.sample213214self.assertIsNotNone(output)215expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))216self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")217218def test_ema_training(self):219pass220221def test_training(self):222pass223224def prepare_init_args_and_inputs_for_common(self):225init_dict = {226"in_channels": 14,227"out_channels": 14,228"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],229"up_block_types": [],230"out_block_type": "ValueFunction",231"mid_block_type": "ValueFunctionMidBlock1D",232"block_out_channels": [32, 64, 128, 256],233"layers_per_block": 1,234"downsample_each_block": True,235"use_timestep_embedding": True,236"freq_shift": 1.0,237"flip_sin_to_cos": False,238"time_embedding_type": "positional",239"act_fn": "mish",240}241inputs_dict = self.dummy_input242return init_dict, inputs_dict243244@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")245def test_from_pretrained_hub(self):246value_function, vf_loading_info = UNet1DModel.from_pretrained(247"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"248)249self.assertIsNotNone(value_function)250self.assertEqual(len(vf_loading_info["missing_keys"]), 0)251252value_function.to(torch_device)253image = value_function(**self.dummy_input)254255assert image is not None, "Make sure output is not None"256257@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")258def test_output_pretrained(self):259value_function, vf_loading_info = UNet1DModel.from_pretrained(260"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"261)262torch.manual_seed(0)263if torch.cuda.is_available():264torch.cuda.manual_seed_all(0)265266num_features = value_function.in_channels267seq_len = 14268noise = torch.randn((1, seq_len, num_features)).permute(2690, 2, 1270) # match original, we can update values and remove271time_step = torch.full((num_features,), 0)272273with torch.no_grad():274output = value_function(noise, time_step).sample275276# fmt: off277expected_output_slice = torch.tensor([165.25] * seq_len)278# fmt: on279self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))280281def test_forward_with_norm_groups(self):282# Not implemented yet for this UNet283pass284285286