Path: blob/main/tests/models/test_models_unet_2d.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import math17import unittest1819import torch2021from diffusers import UNet2DModel22from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device2324from ..test_modeling_common import ModelTesterMixin252627logger = logging.get_logger(__name__)28torch.backends.cuda.matmul.allow_tf32 = False293031class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):32model_class = UNet2DModel3334@property35def dummy_input(self):36batch_size = 437num_channels = 338sizes = (32, 32)3940noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)41time_step = torch.tensor([10]).to(torch_device)4243return {"sample": noise, "timestep": time_step}4445@property46def input_shape(self):47return (3, 32, 32)4849@property50def output_shape(self):51return (3, 32, 32)5253def prepare_init_args_and_inputs_for_common(self):54init_dict = {55"block_out_channels": (32, 64),56"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),57"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),58"attention_head_dim": None,59"out_channels": 3,60"in_channels": 3,61"layers_per_block": 2,62"sample_size": 32,63}64inputs_dict = self.dummy_input65return init_dict, inputs_dict666768class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):69model_class = UNet2DModel7071@property72def dummy_input(self):73batch_size = 474num_channels = 475sizes = (32, 32)7677noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)78time_step = torch.tensor([10]).to(torch_device)7980return {"sample": noise, "timestep": time_step}8182@property83def input_shape(self):84return (4, 32, 32)8586@property87def output_shape(self):88return (4, 32, 32)8990def prepare_init_args_and_inputs_for_common(self):91init_dict = {92"sample_size": 32,93"in_channels": 4,94"out_channels": 4,95"layers_per_block": 2,96"block_out_channels": (32, 64),97"attention_head_dim": 32,98"down_block_types": ("DownBlock2D", "DownBlock2D"),99"up_block_types": ("UpBlock2D", "UpBlock2D"),100}101inputs_dict = self.dummy_input102return init_dict, inputs_dict103104def test_from_pretrained_hub(self):105model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)106107self.assertIsNotNone(model)108self.assertEqual(len(loading_info["missing_keys"]), 0)109110model.to(torch_device)111image = model(**self.dummy_input).sample112113assert image is not None, "Make sure output is not None"114115@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")116def test_from_pretrained_accelerate(self):117model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)118model.to(torch_device)119image = model(**self.dummy_input).sample120121assert image is not None, "Make sure output is not None"122123@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")124def test_from_pretrained_accelerate_wont_change_results(self):125# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`126model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)127model_accelerate.to(torch_device)128model_accelerate.eval()129130noise = torch.randn(1311,132model_accelerate.config.in_channels,133model_accelerate.config.sample_size,134model_accelerate.config.sample_size,135generator=torch.manual_seed(0),136)137noise = noise.to(torch_device)138time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)139140arr_accelerate = model_accelerate(noise, time_step)["sample"]141142# two models don't need to stay in the device at the same time143del model_accelerate144torch.cuda.empty_cache()145gc.collect()146147model_normal_load, _ = UNet2DModel.from_pretrained(148"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False149)150model_normal_load.to(torch_device)151model_normal_load.eval()152arr_normal_load = model_normal_load(noise, time_step)["sample"]153154assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3)155156def test_output_pretrained(self):157model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")158model.eval()159model.to(torch_device)160161noise = torch.randn(1621,163model.config.in_channels,164model.config.sample_size,165model.config.sample_size,166generator=torch.manual_seed(0),167)168noise = noise.to(torch_device)169time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)170171with torch.no_grad():172output = model(noise, time_step).sample173174output_slice = output[0, -1, -3:, -3:].flatten().cpu()175# fmt: off176expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])177# fmt: on178179self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))180181182class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):183model_class = UNet2DModel184185@property186def dummy_input(self, sizes=(32, 32)):187batch_size = 4188num_channels = 3189190noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)191time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)192193return {"sample": noise, "timestep": time_step}194195@property196def input_shape(self):197return (3, 32, 32)198199@property200def output_shape(self):201return (3, 32, 32)202203def prepare_init_args_and_inputs_for_common(self):204init_dict = {205"block_out_channels": [32, 64, 64, 64],206"in_channels": 3,207"layers_per_block": 1,208"out_channels": 3,209"time_embedding_type": "fourier",210"norm_eps": 1e-6,211"mid_block_scale_factor": math.sqrt(2.0),212"norm_num_groups": None,213"down_block_types": [214"SkipDownBlock2D",215"AttnSkipDownBlock2D",216"SkipDownBlock2D",217"SkipDownBlock2D",218],219"up_block_types": [220"SkipUpBlock2D",221"SkipUpBlock2D",222"AttnSkipUpBlock2D",223"SkipUpBlock2D",224],225}226inputs_dict = self.dummy_input227return init_dict, inputs_dict228229@slow230def test_from_pretrained_hub(self):231model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)232self.assertIsNotNone(model)233self.assertEqual(len(loading_info["missing_keys"]), 0)234235model.to(torch_device)236inputs = self.dummy_input237noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)238inputs["sample"] = noise239image = model(**inputs)240241assert image is not None, "Make sure output is not None"242243@slow244def test_output_pretrained_ve_mid(self):245model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")246model.to(torch_device)247248torch.manual_seed(0)249if torch.cuda.is_available():250torch.cuda.manual_seed_all(0)251252batch_size = 4253num_channels = 3254sizes = (256, 256)255256noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)257time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)258259with torch.no_grad():260output = model(noise, time_step).sample261262output_slice = output[0, -3:, -3:, -1].flatten().cpu()263# fmt: off264expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])265# fmt: on266267self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))268269def test_output_pretrained_ve_large(self):270model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")271model.to(torch_device)272273torch.manual_seed(0)274if torch.cuda.is_available():275torch.cuda.manual_seed_all(0)276277batch_size = 4278num_channels = 3279sizes = (32, 32)280281noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)282time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)283284with torch.no_grad():285output = model(noise, time_step).sample286287output_slice = output[0, -3:, -3:, -1].flatten().cpu()288# fmt: off289expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])290# fmt: on291292self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))293294def test_forward_with_norm_groups(self):295# not required for this model296pass297298299