Path: blob/main/tests/test_unet_blocks_common.py
1440 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14import unittest15from typing import Tuple1617import torch1819from diffusers.utils import floats_tensor, randn_tensor, torch_all_close, torch_device20from diffusers.utils.testing_utils import require_torch212223@require_torch24class UNetBlockTesterMixin:25@property26def dummy_input(self):27return self.get_dummy_input()2829@property30def output_shape(self):31if self.block_type == "down":32return (4, 32, 16, 16)33elif self.block_type == "mid":34return (4, 32, 32, 32)35elif self.block_type == "up":36return (4, 32, 64, 64)3738raise ValueError(f"'{self.block_type}' is not a supported block_type. Set it to 'up', 'mid', or 'down'.")3940def get_dummy_input(41self,42include_temb=True,43include_res_hidden_states_tuple=False,44include_encoder_hidden_states=False,45include_skip_sample=False,46):47batch_size = 448num_channels = 3249sizes = (32, 32)5051generator = torch.manual_seed(0)52device = torch.device(torch_device)53shape = (batch_size, num_channels) + sizes54hidden_states = randn_tensor(shape, generator=generator, device=device)55dummy_input = {"hidden_states": hidden_states}5657if include_temb:58temb_channels = 12859dummy_input["temb"] = randn_tensor((batch_size, temb_channels), generator=generator, device=device)6061if include_res_hidden_states_tuple:62generator_1 = torch.manual_seed(1)63dummy_input["res_hidden_states_tuple"] = (randn_tensor(shape, generator=generator_1, device=device),)6465if include_encoder_hidden_states:66dummy_input["encoder_hidden_states"] = floats_tensor((batch_size, 32, 32)).to(torch_device)6768if include_skip_sample:69dummy_input["skip_sample"] = randn_tensor(((batch_size, 3) + sizes), generator=generator, device=device)7071return dummy_input7273def prepare_init_args_and_inputs_for_common(self):74init_dict = {75"in_channels": 32,76"out_channels": 32,77"temb_channels": 128,78}79if self.block_type == "up":80init_dict["prev_output_channel"] = 328182if self.block_type == "mid":83init_dict.pop("out_channels")8485inputs_dict = self.dummy_input86return init_dict, inputs_dict8788def test_output(self, expected_slice):89init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()90unet_block = self.block_class(**init_dict)91unet_block.to(torch_device)92unet_block.eval()9394with torch.no_grad():95output = unet_block(**inputs_dict)9697if isinstance(output, Tuple):98output = output[0]99100self.assertEqual(output.shape, self.output_shape)101102output_slice = output[0, -1, -3:, -3:]103expected_slice = torch.tensor(expected_slice).to(torch_device)104assert torch_all_close(output_slice.flatten(), expected_slice, atol=5e-3)105106@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")107def test_training(self):108init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()109model = self.block_class(**init_dict)110model.to(torch_device)111model.train()112output = model(**inputs_dict)113114if isinstance(output, Tuple):115output = output[0]116117device = torch.device(torch_device)118noise = randn_tensor(output.shape, device=device)119loss = torch.nn.functional.mse_loss(output, noise)120loss.backward()121122123