Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
mikf
GitHub Repository: mikf/gallery-dl
Path: blob/master/scripts/export_tests.py
5457 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
# Copyright 2023 Mike Fährmann
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License version 2 as
8
# published by the Free Software Foundation.
9
10
import os
11
import re
12
import sys
13
import itertools
14
import collections
15
16
import util
17
from pyprint import pyprint
18
from gallery_dl import extractor
19
20
21
FORMAT = '''\
22
# -*- coding: utf-8 -*-
23
24
# This program is free software; you can redistribute it and/or modify
25
# it under the terms of the GNU General Public License version 2 as
26
# published by the Free Software Foundation.
27
28
{imports}
29
30
31
__tests__ = (
32
{tests}\
33
)
34
'''
35
36
37
def extract_tests_from_source(lines):
38
tests = {}
39
40
match_url = re.compile(
41
r''' (?:test = | )?\(\(?"([^"]+)"(.*)''').match
42
match_end = re.compile(
43
r" (\}\)| \}\),)\n$").match
44
first = 0
45
url = ""
46
47
for index, line in enumerate(lines):
48
if first and match_end(line):
49
tests[url] = lines[first-1:index+1]
50
first = 0
51
52
elif (m := match_url(line)):
53
offset = index
54
while not m[2]:
55
offset += 1
56
next = lines[offset]
57
line = line[:-2] + next[next.index('"')+1:]
58
m = match_url(line)
59
url = m[1]
60
if m[2] in (",)", "),"):
61
tests[url] = lines[index-1:index+1]
62
first = 0
63
else:
64
first = index
65
66
return tests
67
68
69
def get_test_source(extr, *, cache={}):
70
try:
71
tests = cache[extr.__module__]
72
except KeyError:
73
path = sys.modules[extr.__module__].__file__
74
with util.open(path) as fp:
75
lines = fp.readlines()
76
tests = cache[extr.__module__] = extract_tests_from_source(lines)
77
return tests.get(extr.url) or ("",)
78
return tests[extr.url]
79
80
81
def comment_from_source(source):
82
match = re.match(r"\s+#\s*(.+)", source[0])
83
return match[1] if match else ""
84
85
86
def build_test(extr, data):
87
source = get_test_source(extr)
88
comment = comment_from_source(source)
89
90
head = {
91
"#url" : extr.url,
92
"#comment" : comment.replace('"', "'"),
93
"#category": (extr.basecategory,
94
extr.category,
95
extr.subcategory),
96
"#class" : extr.__class__,
97
}
98
99
if not comment:
100
del head["#comment"]
101
102
instr = {}
103
104
if not data:
105
data = {}
106
if (options := data.pop("options", None)):
107
instr["#options"] = {
108
name: value
109
for name, value in options
110
}
111
if (pattern := data.pop("pattern", None)):
112
if pattern in PATTERNS:
113
cls = PATTERNS[pattern]
114
pattern = f"lit:{pyprint(cls)}.pattern"
115
instr["#pattern"] = pattern
116
if (exception := data.pop("exception", None)):
117
instr["#exception"] = exception
118
if (range := data.pop("range", None)):
119
instr["#range"] = range
120
if (count := data.pop("count", None)) is not None:
121
instr["#count"] = count
122
if (archive := data.pop("archive", None)) is not None:
123
instr["#archive"] = archive
124
if (extractor := data.pop("extractor", None)) is not None:
125
instr["#extractor"] = extractor
126
if (url := data.pop("url", None)):
127
instr["#sha1_url"] = url
128
if (metadata := data.pop("keyword", None)):
129
if isinstance(metadata, str) and len(metadata) == 40:
130
instr["#sha1_metadata"] = metadata
131
metadata = {}
132
if (content := data.pop("content", None)):
133
if isinstance(content, tuple):
134
content = list(content)
135
instr["#sha1_content"] = content
136
137
if data:
138
print(extr)
139
for k in data:
140
print(k)
141
exit()
142
143
return head, instr, metadata
144
145
146
def collect_patterns():
147
return {
148
cls.pattern.pattern: cls
149
for cls in extractor._list_classes()
150
}
151
152
153
def collect_tests(whitelist=None):
154
tests = collections.defaultdict(list)
155
156
for cls in extractor._list_classes():
157
for url, data in cls._get_tests():
158
159
extr = cls.from_url(url)
160
if whitelist and extr.category not in whitelist:
161
continue
162
test = build_test(extr, data)
163
tests[extr.category].append(test)
164
165
return tests
166
167
168
def export_tests(data):
169
imports = {}
170
tests = []
171
172
for head, instr, metadata in data:
173
174
for v in itertools.chain(
175
head.values(),
176
instr.values() if instr else (),
177
metadata.values() if metadata else (),
178
):
179
if not isinstance(v, type) or v.__module__ == "builtins":
180
continue
181
182
module, _, name = v.__module__.rpartition(".")
183
if name[0].isdecimal():
184
stmt = f'''\
185
{module.partition(".")[0]} = __import__("{v.__module__}")
186
_{name} = getattr({module}, "{name}")'''
187
elif module:
188
stmt = f"from {module} import {name}"
189
else:
190
stmt = f"import {name}"
191
imports[v.__module__] = stmt
192
193
test = pyprint(head)
194
if instr:
195
test = f"{test[:-2]}{pyprint(instr)[1:]}"
196
if metadata:
197
for k, v in metadata.items():
198
if v == "type:datetime":
199
imports["datetime"] = "import datetime"
200
metadata[k] = "lit:datetime.datetime"
201
test = f"{test[:-1]}{pyprint(metadata, lmin=0)[1:]}"
202
203
tests.append(f"{test},\n\n")
204
205
return FORMAT.format(
206
imports="\n".join(imports.values()),
207
tests="".join(tests),
208
)
209
210
211
PATTERNS = None
212
DIRECTORY = "/tmp/_/results"
213
214
215
def main():
216
import argparse
217
218
parser = argparse.ArgumentParser()
219
parser.add_argument(
220
"-t", "--target",
221
help="target directory",
222
)
223
parser.add_argument(
224
"-c", "--category", action="append",
225
help="extractor categories to export",
226
)
227
228
args = parser.parse_args()
229
230
if not args.target:
231
args.target = os.path.join(
232
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
233
"test", "results",
234
)
235
236
global PATTERNS
237
PATTERNS = collect_patterns()
238
239
os.makedirs(args.target, exist_ok=True)
240
for name, tests in collect_tests(args.category).items():
241
name = name.replace(".", "")
242
with util.lazy(f"{args.target}/{name}.py") as fp:
243
fp.write(export_tests(tests))
244
245
246
if __name__ == "__main__":
247
main()
248
249