Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/repo_utils/test_check_copies.py
1441 views
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import os
16
import re
17
import shutil
18
import sys
19
import tempfile
20
import unittest
21
22
import black
23
24
25
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
26
sys.path.append(os.path.join(git_repo_path, "utils"))
27
28
import check_copies # noqa: E402
29
30
31
# This is the reference code that will be used in the tests.
32
# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated.
33
REFERENCE_CODE = """ \"""
34
Output class for the scheduler's step function output.
35
36
Args:
37
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39
denoising loop.
40
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42
`pred_original_sample` can be used to preview progress or for guidance.
43
\"""
44
45
prev_sample: torch.FloatTensor
46
pred_original_sample: Optional[torch.FloatTensor] = None
47
"""
48
49
50
class CopyCheckTester(unittest.TestCase):
51
def setUp(self):
52
self.diffusers_dir = tempfile.mkdtemp()
53
os.makedirs(os.path.join(self.diffusers_dir, "schedulers/"))
54
check_copies.DIFFUSERS_PATH = self.diffusers_dir
55
shutil.copy(
56
os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"),
57
os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"),
58
)
59
60
def tearDown(self):
61
check_copies.DIFFUSERS_PATH = "src/diffusers"
62
shutil.rmtree(self.diffusers_dir)
63
64
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
65
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
66
if overwrite_result is not None:
67
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
68
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
69
code = black.format_str(code, mode=mode)
70
fname = os.path.join(self.diffusers_dir, "new_code.py")
71
with open(fname, "w", newline="\n") as f:
72
f.write(code)
73
if overwrite_result is None:
74
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
75
else:
76
check_copies.is_copy_consistent(f.name, overwrite=True)
77
with open(fname, "r") as f:
78
self.assertTrue(f.read(), expected)
79
80
def test_find_code_in_diffusers(self):
81
code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput")
82
self.assertEqual(code, REFERENCE_CODE)
83
84
def test_is_copy_consistent(self):
85
# Base copy consistency
86
self.check_copy_consistency(
87
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",
88
"DDPMSchedulerOutput",
89
REFERENCE_CODE + "\n",
90
)
91
92
# With no empty line at the end
93
self.check_copy_consistency(
94
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",
95
"DDPMSchedulerOutput",
96
REFERENCE_CODE,
97
)
98
99
# Copy consistency with rename
100
self.check_copy_consistency(
101
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",
102
"TestSchedulerOutput",
103
re.sub("DDPM", "Test", REFERENCE_CODE),
104
)
105
106
# Copy consistency with a really long name
107
long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
108
self.check_copy_consistency(
109
f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}",
110
f"{long_class_name}SchedulerOutput",
111
re.sub("Bert", long_class_name, REFERENCE_CODE),
112
)
113
114
# Copy consistency with overwrite
115
self.check_copy_consistency(
116
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",
117
"TestSchedulerOutput",
118
REFERENCE_CODE,
119
overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE),
120
)
121
122