Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/release.py
1440 views
1
# coding=utf-8
2
# Copyright 2021 The HuggingFace Team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import argparse
17
import os
18
import re
19
20
import packaging.version
21
22
23
PATH_TO_EXAMPLES = "examples/"
24
REPLACE_PATTERNS = {
25
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
26
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
27
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
28
"doc": (re.compile(r'^(\s*)release\s*=\s*"[^"]+"$', re.MULTILINE), 'release = "VERSION"\n'),
29
}
30
REPLACE_FILES = {
31
"init": "src/diffusers/__init__.py",
32
"setup": "setup.py",
33
}
34
README_FILE = "README.md"
35
36
37
def update_version_in_file(fname, version, pattern):
38
"""Update the version in one file using a specific pattern."""
39
with open(fname, "r", encoding="utf-8", newline="\n") as f:
40
code = f.read()
41
re_pattern, replace = REPLACE_PATTERNS[pattern]
42
replace = replace.replace("VERSION", version)
43
code = re_pattern.sub(replace, code)
44
with open(fname, "w", encoding="utf-8", newline="\n") as f:
45
f.write(code)
46
47
48
def update_version_in_examples(version):
49
"""Update the version in all examples files."""
50
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
51
# Removing some of the folders with non-actively maintained examples from the walk
52
if "research_projects" in directories:
53
directories.remove("research_projects")
54
if "legacy" in directories:
55
directories.remove("legacy")
56
for fname in fnames:
57
if fname.endswith(".py"):
58
update_version_in_file(os.path.join(folder, fname), version, pattern="examples")
59
60
61
def global_version_update(version, patch=False):
62
"""Update the version in all needed files."""
63
for pattern, fname in REPLACE_FILES.items():
64
update_version_in_file(fname, version, pattern)
65
if not patch:
66
update_version_in_examples(version)
67
68
69
def clean_main_ref_in_model_list():
70
"""Replace the links from main doc tp stable doc in the model list of the README."""
71
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
72
_start_prompt = "🤗 Transformers currently provides the following architectures"
73
_end_prompt = "1. Want to contribute a new model?"
74
with open(README_FILE, "r", encoding="utf-8", newline="\n") as f:
75
lines = f.readlines()
76
77
# Find the start of the list.
78
start_index = 0
79
while not lines[start_index].startswith(_start_prompt):
80
start_index += 1
81
start_index += 1
82
83
index = start_index
84
# Update the lines in the model list.
85
while not lines[index].startswith(_end_prompt):
86
if lines[index].startswith("1."):
87
lines[index] = lines[index].replace(
88
"https://huggingface.co/docs/diffusers/main/model_doc",
89
"https://huggingface.co/docs/diffusers/model_doc",
90
)
91
index += 1
92
93
with open(README_FILE, "w", encoding="utf-8", newline="\n") as f:
94
f.writelines(lines)
95
96
97
def get_version():
98
"""Reads the current version in the __init__."""
99
with open(REPLACE_FILES["init"], "r") as f:
100
code = f.read()
101
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
102
return packaging.version.parse(default_version)
103
104
105
def pre_release_work(patch=False):
106
"""Do all the necessary pre-release steps."""
107
# First let's get the default version: base version if we are in dev, bump minor otherwise.
108
default_version = get_version()
109
if patch and default_version.is_devrelease:
110
raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
111
if default_version.is_devrelease:
112
default_version = default_version.base_version
113
elif patch:
114
default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
115
else:
116
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
117
118
# Now let's ask nicely if that's the right one.
119
version = input(f"Which version are you releasing? [{default_version}]")
120
if len(version) == 0:
121
version = default_version
122
123
print(f"Updating version to {version}.")
124
global_version_update(version, patch=patch)
125
126
127
# if not patch:
128
# print("Cleaning main README, don't forget to run `make fix-copies`.")
129
# clean_main_ref_in_model_list()
130
131
132
def post_release_work():
133
"""Do all the necesarry post-release steps."""
134
# First let's get the current version
135
current_version = get_version()
136
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
137
current_version = current_version.base_version
138
139
# Check with the user we got that right.
140
version = input(f"Which version are we developing now? [{dev_version}]")
141
if len(version) == 0:
142
version = dev_version
143
144
print(f"Updating version to {version}.")
145
global_version_update(version)
146
147
148
# print("Cleaning main README, don't forget to run `make fix-copies`.")
149
# clean_main_ref_in_model_list()
150
151
152
if __name__ == "__main__":
153
parser = argparse.ArgumentParser()
154
parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
155
parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
156
args = parser.parse_args()
157
if not args.post_release:
158
pre_release_work(patch=args.patch)
159
elif args.patch:
160
print("Nothing to do after a patch :-)")
161
else:
162
post_release_work()
163
164