Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
mikf
GitHub Repository: mikf/gallery-dl
Path: blob/master/test/test_extractor.py
5457 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
# Copyright 2018-2025 Mike Fährmann
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License version 2 as
8
# published by the Free Software Foundation.
9
10
import os
11
import sys
12
import unittest
13
from unittest.mock import patch
14
15
import time
16
import string
17
from datetime import datetime, timedelta
18
19
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
from gallery_dl import extractor, util # noqa E402
21
from gallery_dl.extractor import mastodon # noqa E402
22
from gallery_dl.extractor.common import Extractor, Message # noqa E402
23
from gallery_dl.extractor.directlink import DirectlinkExtractor # noqa E402
24
25
_list_classes = extractor._list_classes
26
27
try:
28
RESULTS = os.environ.get("GDL_TEST_RESULTS")
29
if RESULTS:
30
results = util.import_file(RESULTS)
31
else:
32
from test import results
33
except ImportError:
34
results = None
35
36
37
class FakeExtractor(Extractor):
38
category = "fake"
39
subcategory = "test"
40
pattern = "fake:"
41
42
def items(self):
43
yield Message.Version, 1
44
yield Message.Url, "text:foobar", {}
45
46
47
class TestExtractorModule(unittest.TestCase):
48
VALID_URIS = (
49
"https://example.org/file.jpg",
50
"tumblr:foobar",
51
"oauth:flickr",
52
"generic:https://example.org/",
53
"recursive:https://example.org/document.html",
54
)
55
56
def setUp(self):
57
extractor._cache.clear()
58
extractor._module_iter = extractor._modules_internal()
59
extractor._list_classes = _list_classes
60
61
def test_find(self):
62
for uri in self.VALID_URIS:
63
result = extractor.find(uri)
64
self.assertIsInstance(result, Extractor, uri)
65
66
for not_found in ("", "/tmp/file.ext"):
67
self.assertIsNone(extractor.find(not_found))
68
69
for invalid in (None, [], {}, 123, b"test:"):
70
with self.assertRaises(TypeError):
71
extractor.find(invalid)
72
73
def test_add(self):
74
uri = "fake:foobar"
75
self.assertIsNone(extractor.find(uri))
76
77
extractor.add(FakeExtractor)
78
self.assertIsInstance(extractor.find(uri), FakeExtractor)
79
80
def test_add_module(self):
81
uri = "fake:foobar"
82
self.assertIsNone(extractor.find(uri))
83
84
classes = extractor.add_module(sys.modules[__name__])
85
self.assertEqual(len(classes), 1)
86
self.assertEqual(classes[0].pattern, FakeExtractor.pattern)
87
self.assertEqual(classes[0], FakeExtractor)
88
self.assertIsInstance(extractor.find(uri), FakeExtractor)
89
90
def test_from_url(self):
91
for uri in self.VALID_URIS:
92
cls = extractor.find(uri).__class__
93
extr = cls.from_url(uri)
94
self.assertIs(type(extr), cls)
95
self.assertIsInstance(extr, Extractor)
96
97
for not_found in ("", "/tmp/file.ext"):
98
self.assertIsNone(FakeExtractor.from_url(not_found))
99
100
for invalid in (None, [], {}, 123, b"test:"):
101
with self.assertRaises(TypeError):
102
FakeExtractor.from_url(invalid)
103
104
@unittest.skipIf(not results, "no test data")
105
def test_categories(self):
106
for result in results.all():
107
if result.get("#fail"):
108
try:
109
self.assertCategories(result)
110
except AssertionError:
111
pass
112
else:
113
self.fail(f"{result['#url']}: Test did not fail")
114
else:
115
self.assertCategories(result)
116
117
def assertCategories(self, result):
118
url = result["#url"]
119
cls = result["#class"]
120
121
try:
122
extr = cls.from_url(url)
123
except ImportError as exc:
124
if exc.name in ("youtube_dl", "yt_dlp"):
125
return sys.stdout.write(
126
f"Skipping '{cls.category}' category checks\n")
127
raise
128
self.assertTrue(extr, url)
129
130
categories = result.get("#category")
131
if categories:
132
base, cat, sub = categories
133
else:
134
cat = cls.category
135
sub = cls.subcategory
136
base = cls.basecategory
137
self.assertEqual(extr.category, cat, url)
138
self.assertEqual(extr.subcategory, sub, url)
139
self.assertEqual(extr.basecategory, base, url)
140
141
if base not in ("reactor", "wikimedia"):
142
self.assertEqual(extr._cfgpath, ("extractor", cat, sub), url)
143
144
def test_init(self):
145
"""Test for exceptions in Extractor.initialize() and .finalize()"""
146
def fail_request(*args, **kwargs):
147
self.fail("called 'request() during initialization")
148
149
for cls in extractor.extractors():
150
if cls.category == "ytdl":
151
continue
152
extr = cls.from_url(cls.example)
153
if not extr:
154
if cls.basecategory and not cls.instances:
155
continue
156
self.fail(f"{cls.__name__} pattern does not match "
157
f"example URL '{cls.example}'")
158
159
extr.request = fail_request
160
extr.initialize()
161
extr.finalize()
162
163
def test_init_ytdl(self):
164
try:
165
extr = extractor.find("ytdl:")
166
extr.initialize()
167
extr.finalize()
168
except ImportError as exc:
169
if exc.name in ("youtube_dl", "yt_dlp"):
170
raise unittest.SkipTest(f"cannot import module '{exc.name}'")
171
raise
172
173
def test_docstrings(self):
174
"""Ensure docstring uniqueness"""
175
for extr1 in extractor.extractors():
176
for extr2 in extractor.extractors():
177
if extr1 != extr2 and extr1.__doc__ and extr2.__doc__:
178
self.assertNotEqual(
179
extr1.__doc__,
180
extr2.__doc__,
181
f"{extr1} <-> {extr2}",
182
)
183
184
def test_names(self):
185
"""Ensure extractor classes are named CategorySubcategoryExtractor"""
186
def capitalize(c):
187
if "-" in c:
188
return string.capwords(c.replace("-", " ")).replace(" ", "")
189
return c.capitalize()
190
191
for extr in extractor.extractors():
192
if extr.category not in ("", "oauth", "ytdl"):
193
expected = (f"{capitalize(extr.category)}"
194
f"{capitalize(extr.subcategory)}Extractor")
195
if expected[0].isdigit():
196
expected = f"_{expected}"
197
self.assertEqual(expected, extr.__name__)
198
199
200
class TestExtractorWait(unittest.TestCase):
201
202
def test_wait_seconds(self):
203
extr = extractor.find("generic:https://example.org/")
204
seconds = 5
205
until = time.time() + seconds
206
207
with patch("time.sleep") as sleep, patch.object(extr, "log") as log:
208
extr.wait(seconds=seconds)
209
210
sleep.assert_called_once_with(6.0)
211
212
calls = log.info.mock_calls
213
self.assertEqual(len(calls), 1)
214
self._assert_isotime(calls[0][1][1], until)
215
216
def test_wait_until(self):
217
extr = extractor.find("generic:https://example.org/")
218
until = time.time() + 5
219
220
with patch("time.sleep") as sleep, patch.object(extr, "log") as log:
221
extr.wait(until=until)
222
223
calls = sleep.mock_calls
224
self.assertEqual(len(calls), 1)
225
self.assertAlmostEqual(calls[0][1][0], 6.0, places=0)
226
227
calls = log.info.mock_calls
228
self.assertEqual(len(calls), 1)
229
self._assert_isotime(calls[0][1][1], until)
230
231
def test_wait_until_datetime(self):
232
extr = extractor.find("generic:https://example.org/")
233
until = util.datetime_utcnow() + timedelta(seconds=5)
234
until_local = datetime.now() + timedelta(seconds=5)
235
236
if not until.microsecond:
237
until = until.replace(microsecond=until_local.microsecond)
238
239
with patch("time.sleep") as sleep, patch.object(extr, "log") as log:
240
extr.wait(until=until)
241
242
calls = sleep.mock_calls
243
self.assertEqual(len(calls), 1)
244
self.assertAlmostEqual(calls[0][1][0], 6.0, places=1)
245
246
calls = log.info.mock_calls
247
self.assertEqual(len(calls), 1)
248
self._assert_isotime(calls[0][1][1], until_local)
249
250
def _assert_isotime(self, output, until):
251
if not isinstance(until, datetime):
252
until = datetime.fromtimestamp(until)
253
o = self._isotime_to_seconds(output)
254
u = self._isotime_to_seconds(until.time().isoformat()[:8])
255
self.assertLessEqual(o-u, 1.0)
256
257
def _isotime_to_seconds(self, isotime):
258
parts = isotime.split(":")
259
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
260
261
262
class TextExtractorOAuth(unittest.TestCase):
263
264
def test_oauth1(self):
265
for category in ("flickr", "smugmug", "tumblr"):
266
extr = extractor.find(f"oauth:{category}")
267
268
with patch.object(extr, "_oauth1_authorization_flow") as m:
269
for msg in extr:
270
pass
271
self.assertEqual(len(m.mock_calls), 1)
272
273
def test_oauth2(self):
274
for category in ("deviantart", "reddit"):
275
extr = extractor.find(f"oauth:{category}")
276
277
with patch.object(extr, "_oauth2_authorization_code_grant") as m:
278
for msg in extr:
279
pass
280
self.assertEqual(len(m.mock_calls), 1)
281
282
def test_oauth2_mastodon(self):
283
extr = extractor.find("oauth:mastodon:pawoo.net")
284
285
with patch.object(extr, "_oauth2_authorization_code_grant") as m, \
286
patch.object(extr, "_register") as r:
287
for msg in extr:
288
pass
289
self.assertEqual(len(r.mock_calls), 0)
290
self.assertEqual(len(m.mock_calls), 1)
291
292
def test_oauth2_mastodon_unknown(self):
293
extr = extractor.find("oauth:mastodon:example.com")
294
295
with patch.object(extr, "_oauth2_authorization_code_grant") as m, \
296
patch.object(extr, "_register") as r:
297
r.return_value = {
298
"client-id" : "foo",
299
"client-secret": "bar",
300
}
301
302
for msg in extr:
303
pass
304
305
self.assertEqual(len(r.mock_calls), 1)
306
self.assertEqual(len(m.mock_calls), 1)
307
308
309
if __name__ == "__main__":
310
unittest.main()
311
312