Path: blob/main/tests/repo_utils/test_check_copies.py
1441 views
# Copyright 2023 The HuggingFace Team. All rights reserved.1#2# Licensed under the Apache License, Version 2.0 (the "License");3# you may not use this file except in compliance with the License.4# You may obtain a copy of the License at5#6# http://www.apache.org/licenses/LICENSE-2.07#8# Unless required by applicable law or agreed to in writing, software9# distributed under the License is distributed on an "AS IS" BASIS,10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11# See the License for the specific language governing permissions and12# limitations under the License.1314import os15import re16import shutil17import sys18import tempfile19import unittest2021import black222324git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))25sys.path.append(os.path.join(git_repo_path, "utils"))2627import check_copies # noqa: E402282930# This is the reference code that will be used in the tests.31# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated.32REFERENCE_CODE = """ \"""33Output class for the scheduler's step function output.3435Args:36prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):37Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the38denoising loop.39pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):40The predicted denoised sample (x_{0}) based on the model output from the current timestep.41`pred_original_sample` can be used to preview progress or for guidance.42\"""4344prev_sample: torch.FloatTensor45pred_original_sample: Optional[torch.FloatTensor] = None46"""474849class CopyCheckTester(unittest.TestCase):50def setUp(self):51self.diffusers_dir = tempfile.mkdtemp()52os.makedirs(os.path.join(self.diffusers_dir, "schedulers/"))53check_copies.DIFFUSERS_PATH = self.diffusers_dir54shutil.copy(55os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"),56os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"),57)5859def tearDown(self):60check_copies.DIFFUSERS_PATH = "src/diffusers"61shutil.rmtree(self.diffusers_dir)6263def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):64code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code65if overwrite_result is not None:66expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result67mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)68code = black.format_str(code, mode=mode)69fname = os.path.join(self.diffusers_dir, "new_code.py")70with open(fname, "w", newline="\n") as f:71f.write(code)72if overwrite_result is None:73self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)74else:75check_copies.is_copy_consistent(f.name, overwrite=True)76with open(fname, "r") as f:77self.assertTrue(f.read(), expected)7879def test_find_code_in_diffusers(self):80code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput")81self.assertEqual(code, REFERENCE_CODE)8283def test_is_copy_consistent(self):84# Base copy consistency85self.check_copy_consistency(86"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",87"DDPMSchedulerOutput",88REFERENCE_CODE + "\n",89)9091# With no empty line at the end92self.check_copy_consistency(93"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",94"DDPMSchedulerOutput",95REFERENCE_CODE,96)9798# Copy consistency with rename99self.check_copy_consistency(100"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",101"TestSchedulerOutput",102re.sub("DDPM", "Test", REFERENCE_CODE),103)104105# Copy consistency with a really long name106long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"107self.check_copy_consistency(108f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}",109f"{long_class_name}SchedulerOutput",110re.sub("Bert", long_class_name, REFERENCE_CODE),111)112113# Copy consistency with overwrite114self.check_copy_consistency(115"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",116"TestSchedulerOutput",117REFERENCE_CODE,118overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE),119)120121122