Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
mikf
GitHub Repository: mikf/gallery-dl
Path: blob/master/test/test_extractor.py
8771 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
# Copyright 2018-2025 Mike Fährmann
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License version 2 as
8
# published by the Free Software Foundation.
9
10
import os
11
import sys
12
import unittest
13
from unittest.mock import patch
14
15
import time
16
import string
17
18
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
from gallery_dl import extractor, util, dt, config # noqa E402
20
from gallery_dl.extractor import mastodon # noqa E402
21
from gallery_dl.extractor.common import Extractor, Message # noqa E402
22
from gallery_dl.extractor.directlink import DirectlinkExtractor # noqa E402
23
24
_list_classes = extractor._list_classes
25
26
try:
27
RESULTS = os.environ.get("GDL_TEST_RESULTS")
28
if RESULTS:
29
results = util.import_file(RESULTS)
30
else:
31
from test import results
32
except ImportError:
33
results = None
34
35
36
class FakeExtractor(Extractor):
37
category = "fake"
38
subcategory = "test"
39
pattern = "fake:"
40
41
def items(self):
42
yield Message.Noop
43
yield Message.Url, "text:foobar", {}
44
45
46
class TestExtractorModule(unittest.TestCase):
47
VALID_URIS = (
48
"https://example.org/file.jpg",
49
"tumblr:foobar",
50
"oauth:flickr",
51
"generic:https://example.org/",
52
"recursive:https://example.org/document.html",
53
)
54
55
def setUp(self):
56
extractor._cache.clear()
57
extractor._module_iter = extractor._modules_internal()
58
extractor._list_classes = _list_classes
59
60
def test_find(self):
61
for uri in self.VALID_URIS:
62
result = extractor.find(uri)
63
self.assertIsInstance(result, Extractor, uri)
64
65
for not_found in ("", "/tmp/file.ext"):
66
self.assertIsNone(extractor.find(not_found))
67
68
for invalid in (None, [], {}, 123, b"test:"):
69
with self.assertRaises(TypeError):
70
extractor.find(invalid)
71
72
def test_add(self):
73
uri = "fake:foobar"
74
self.assertIsNone(extractor.find(uri))
75
76
extractor.add(FakeExtractor)
77
self.assertIsInstance(extractor.find(uri), FakeExtractor)
78
79
def test_add_module(self):
80
uri = "fake:foobar"
81
self.assertIsNone(extractor.find(uri))
82
83
classes = extractor.add_module(sys.modules[__name__])
84
self.assertEqual(len(classes), 1)
85
self.assertEqual(classes[0].pattern, FakeExtractor.pattern)
86
self.assertEqual(classes[0], FakeExtractor)
87
self.assertIsInstance(extractor.find(uri), FakeExtractor)
88
89
def test_from_url(self):
90
for uri in self.VALID_URIS:
91
cls = extractor.find(uri).__class__
92
extr = cls.from_url(uri)
93
self.assertIs(type(extr), cls)
94
self.assertIsInstance(extr, Extractor)
95
96
for not_found in ("", "/tmp/file.ext"):
97
self.assertIsNone(FakeExtractor.from_url(not_found))
98
99
for invalid in (None, [], {}, 123, b"test:"):
100
with self.assertRaises(TypeError):
101
FakeExtractor.from_url(invalid)
102
103
@unittest.skipIf(not results, "no test data")
104
def test_categories(self):
105
for result in results.all():
106
if result.get("#fail"):
107
try:
108
self.assertCategories(result)
109
except AssertionError:
110
pass
111
else:
112
self.fail(f"{result['#url']}: Test did not fail")
113
else:
114
self.assertCategories(result)
115
116
def assertCategories(self, result):
117
url = result["#url"]
118
cls = result["#class"]
119
120
try:
121
extr = cls.from_url(url)
122
except ImportError as exc:
123
if exc.name in ("youtube_dl", "yt_dlp"):
124
return sys.stdout.write(
125
f"Skipping '{cls.category}' category checks\n")
126
raise
127
self.assertTrue(extr, url)
128
129
categories = result.get("#category")
130
if categories:
131
base, cat, sub = categories
132
else:
133
cat = cls.category
134
sub = cls.subcategory
135
base = cls.basecategory
136
self.assertEqual(extr.category, cat, url)
137
self.assertEqual(extr.subcategory, sub, url)
138
self.assertEqual(extr.basecategory, base, url)
139
140
if base not in ("reactor", "wikimedia"):
141
self.assertEqual(extr._cfgpath, ("extractor", cat, sub), url)
142
143
def test_init(self):
144
"""Test for exceptions in Extractor.initialize() and .finalize()"""
145
def fail_request(*args, **kwargs):
146
self.fail("called 'request() during initialization")
147
148
for cls in extractor.extractors():
149
if cls.category == "ytdl":
150
continue
151
extr = cls.from_url(cls.example)
152
if not extr:
153
if cls.basecategory and not cls.instances:
154
continue
155
self.fail(f"{cls.__name__} pattern does not match "
156
f"example URL '{cls.example}'")
157
158
self.assertEqual(cls, extr.__class__)
159
self.assertEqual(cls, extractor.find(cls.example).__class__)
160
161
extr.request = fail_request
162
extr.initialize()
163
if extr.finalize is not None:
164
extr.finalize(0)
165
166
def test_init_ytdl(self):
167
try:
168
extr = extractor.find("ytdl:")
169
extr.initialize()
170
if extr.finalize is not None:
171
extr.finalize(0)
172
except ImportError as exc:
173
if exc.name in ("youtube_dl", "yt_dlp"):
174
raise unittest.SkipTest(f"cannot import module '{exc.name}'")
175
raise
176
177
def test_docstrings(self):
178
"""Ensure docstring uniqueness"""
179
for extr1 in extractor.extractors():
180
for extr2 in extractor.extractors():
181
if extr1 != extr2 and extr1.__doc__ and extr2.__doc__:
182
self.assertNotEqual(
183
extr1.__doc__,
184
extr2.__doc__,
185
f"{extr1} <-> {extr2}",
186
)
187
188
def test_names(self):
189
"""Ensure extractor classes are named CategorySubcategoryExtractor"""
190
def capitalize(c):
191
if "-" in c:
192
return string.capwords(c.replace("-", " ")).replace(" ", "")
193
return c.capitalize()
194
195
for extr in extractor.extractors():
196
if extr.category not in ("", "oauth", "ytdl"):
197
expected = (f"{capitalize(extr.category)}"
198
f"{capitalize(extr.subcategory)}Extractor")
199
if expected[0].isdigit():
200
expected = f"_{expected}"
201
self.assertEqual(expected, extr.__name__)
202
203
204
class TestExtractorWait(unittest.TestCase):
205
206
def test_wait_seconds(self):
207
extr = extractor.find("generic:https://example.org/")
208
seconds = 5
209
until = time.time() + seconds
210
211
with patch("time.sleep") as sleep, patch.object(extr, "log") as log:
212
extr.wait(seconds=seconds)
213
214
sleep.assert_called_once_with(6.0)
215
216
calls = log.info.mock_calls
217
self.assertEqual(len(calls), 1)
218
self._assert_isotime(calls[0][1][1], until)
219
220
def test_wait_until(self):
221
extr = extractor.find("generic:https://example.org/")
222
until = time.time() + 5
223
224
with patch("time.sleep") as sleep, patch.object(extr, "log") as log:
225
extr.wait(until=until)
226
227
calls = sleep.mock_calls
228
self.assertEqual(len(calls), 1)
229
self.assertAlmostEqual(calls[0][1][0], 6.0, places=0)
230
231
calls = log.info.mock_calls
232
self.assertEqual(len(calls), 1)
233
self._assert_isotime(calls[0][1][1], until)
234
235
def test_wait_until_datetime(self):
236
extr = extractor.find("generic:https://example.org/")
237
until = dt.now() + dt.timedelta(seconds=5)
238
until_local = dt.datetime.now() + dt.timedelta(seconds=5)
239
240
if not until.microsecond:
241
until = until.replace(microsecond=until_local.microsecond)
242
243
with patch("time.sleep") as sleep, patch.object(extr, "log") as log:
244
extr.wait(until=until)
245
246
calls = sleep.mock_calls
247
self.assertEqual(len(calls), 1)
248
self.assertAlmostEqual(calls[0][1][0], 6.0, places=1)
249
250
calls = log.info.mock_calls
251
self.assertEqual(len(calls), 1)
252
self._assert_isotime(calls[0][1][1], until_local)
253
254
def _assert_isotime(self, output, until):
255
if not isinstance(until, dt.datetime):
256
until = dt.datetime.fromtimestamp(until)
257
o = self._isotime_to_seconds(output)
258
u = self._isotime_to_seconds(until.time().isoformat()[:8])
259
self.assertLessEqual(o-u, 1.0)
260
261
def _isotime_to_seconds(self, isotime):
262
parts = isotime.split(":")
263
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
264
265
266
class TextExtractorCommonDateminmax(unittest.TestCase):
267
268
def setUp(self):
269
config.clear()
270
271
tearDown = setUp
272
273
def test_date_min_max_default(self):
274
extr = extractor.find("generic:https://example.org/")
275
276
dmin, dmax = extr._get_date_min_max()
277
self.assertEqual(dmin, None)
278
self.assertEqual(dmax, None)
279
280
dmin, dmax = extr._get_date_min_max(..., -1)
281
self.assertEqual(dmin, ...)
282
self.assertEqual(dmax, -1)
283
284
def test_date_min_max_timestamp(self):
285
extr = extractor.find("generic:https://example.org/")
286
config.set((), "date-min", 1262304000)
287
config.set((), "date-max", 1262304000.123)
288
289
dmin, dmax = extr._get_date_min_max()
290
self.assertEqual(dmin, 1262304000)
291
self.assertEqual(dmax, 1262304000.123)
292
293
def test_date_min_max_iso(self):
294
extr = extractor.find("generic:https://example.org/")
295
config.set((), "date-min", "2010-01-01")
296
config.set((), "date-max", "2010-01-01T00:01:03")
297
298
dmin, dmax = extr._get_date_min_max()
299
self.assertEqual(dmin, 1262304000)
300
self.assertEqual(dmax, 1262304063)
301
302
def test_date_min_max_iso_invalid(self):
303
extr = extractor.find("generic:https://example.org/")
304
config.set((), "date-min", "2010-01-01")
305
config.set((), "date-max", "2010-01")
306
307
with self.assertLogs() as log_info:
308
dmin, dmax = extr._get_date_min_max()
309
self.assertEqual(dmin, 1262304000)
310
self.assertEqual(dmax, None)
311
312
self.assertEqual(len(log_info.output), 1)
313
self.assertEqual(
314
log_info.output[0],
315
"WARNING:generic:Unable to parse 'date-max': "
316
"Invalid isoformat string '2010-01'")
317
318
def test_date_min_max_fmt(self):
319
extr = extractor.find("generic:https://example.org/")
320
config.set((), "date-format", "%B %d %Y")
321
config.set((), "date-min", "January 01 2010")
322
config.set((), "date-max", "August 18 2022")
323
324
dmin, dmax = extr._get_date_min_max()
325
self.assertEqual(dmin, 1262304000)
326
self.assertEqual(dmax, 1660780800)
327
328
def test_date_min_max_mix(self):
329
extr = extractor.find("generic:https://example.org/")
330
config.set((), "date-format", "%B %d %Y")
331
config.set((), "date-min", "January 01 2010")
332
config.set((), "date-max", 1262304061)
333
334
dmin, dmax = extr._get_date_min_max()
335
self.assertEqual(dmin, 1262304000)
336
self.assertEqual(dmax, 1262304061)
337
338
339
class TextExtractorOAuth(unittest.TestCase):
340
341
def test_oauth1(self):
342
for category in ("flickr", "smugmug", "tumblr"):
343
extr = extractor.find(f"oauth:{category}")
344
345
with patch.object(extr, "_oauth1_authorization_flow") as m:
346
for msg in extr:
347
pass
348
self.assertEqual(len(m.mock_calls), 1)
349
350
def test_oauth2(self):
351
for category in ("deviantart", "reddit"):
352
extr = extractor.find(f"oauth:{category}")
353
354
with patch.object(extr, "_oauth2_authorization_code_grant") as m:
355
for msg in extr:
356
pass
357
self.assertEqual(len(m.mock_calls), 1)
358
359
def test_oauth2_mastodon(self):
360
extr = extractor.find("oauth:mastodon:pawoo.net")
361
362
with patch.object(extr, "_oauth2_authorization_code_grant") as m, \
363
patch.object(extr, "_register") as r:
364
for msg in extr:
365
pass
366
self.assertEqual(len(r.mock_calls), 0)
367
self.assertEqual(len(m.mock_calls), 1)
368
369
def test_oauth2_mastodon_unknown(self):
370
extr = extractor.find("oauth:mastodon:example.com")
371
372
with patch.object(extr, "_oauth2_authorization_code_grant") as m, \
373
patch.object(extr, "_register") as r:
374
r.return_value = {
375
"client-id" : "foo",
376
"client-secret": "bar",
377
}
378
379
for msg in extr:
380
pass
381
382
self.assertEqual(len(r.mock_calls), 1)
383
self.assertEqual(len(m.mock_calls), 1)
384
385
386
if __name__ == "__main__":
387
unittest.main()
388
389