Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
mikf
GitHub Repository: mikf/gallery-dl
Path: blob/master/test/test_downloader.py
8858 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
# Copyright 2018-2025 Mike Fährmann
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License version 2 as
8
# published by the Free Software Foundation.
9
10
import os
11
import sys
12
import unittest
13
from unittest.mock import Mock, MagicMock, patch
14
15
import re
16
import logging
17
import os.path
18
import binascii
19
import tempfile
20
import threading
21
import http.server
22
23
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
from gallery_dl import downloader, extractor, output, config, path # noqa E402
25
from gallery_dl.downloader.http import MIME_TYPES, SIGNATURE_CHECKS # noqa E402
26
27
28
class MockDownloaderModule(Mock):
29
__downloader__ = "mock"
30
31
32
class FakeJob():
33
34
def __init__(self):
35
self.extractor = extractor.find("generic:https://example.org/")
36
self.extractor.initialize()
37
self.pathfmt = path.PathFormat(self.extractor)
38
self.out = output.NullOutput()
39
self.get_logger = logging.getLogger
40
41
42
class TestDownloaderModule(unittest.TestCase):
43
44
@classmethod
45
def setUpClass(cls):
46
# allow import of ytdl downloader module without youtube_dl installed
47
cls._orig_ytdl = sys.modules.get("youtube_dl")
48
sys.modules["youtube_dl"] = MagicMock()
49
50
@classmethod
51
def tearDownClass(cls):
52
if cls._orig_ytdl:
53
sys.modules["youtube_dl"] = cls._orig_ytdl
54
else:
55
del sys.modules["youtube_dl"]
56
57
def setUp(self):
58
downloader._cache.clear()
59
60
def tearDown(self):
61
downloader._cache.clear()
62
63
def test_find(self):
64
cls = downloader.find("http")
65
self.assertEqual(cls.__name__, "HttpDownloader")
66
self.assertEqual(cls.scheme , "http")
67
68
cls = downloader.find("https")
69
self.assertEqual(cls.__name__, "HttpDownloader")
70
self.assertEqual(cls.scheme , "http")
71
72
cls = downloader.find("text")
73
self.assertEqual(cls.__name__, "TextDownloader")
74
self.assertEqual(cls.scheme , "text")
75
76
cls = downloader.find("ytdl")
77
self.assertEqual(cls.__name__, "YoutubeDLDownloader")
78
self.assertEqual(cls.scheme , "ytdl")
79
80
self.assertEqual(downloader.find("ftp"), None)
81
self.assertEqual(downloader.find("foo"), None)
82
self.assertEqual(downloader.find(1234) , None)
83
self.assertEqual(downloader.find(None) , None)
84
85
@patch("builtins.__import__")
86
def test_cache(self, import_module):
87
import_module.return_value = MockDownloaderModule()
88
downloader.find("http")
89
downloader.find("text")
90
downloader.find("ytdl")
91
self.assertEqual(import_module.call_count, 3)
92
downloader.find("http")
93
downloader.find("text")
94
downloader.find("ytdl")
95
self.assertEqual(import_module.call_count, 3)
96
97
@patch("builtins.__import__")
98
def test_cache_http(self, import_module):
99
import_module.return_value = MockDownloaderModule()
100
downloader.find("http")
101
downloader.find("https")
102
self.assertEqual(import_module.call_count, 1)
103
104
@patch("builtins.__import__")
105
def test_cache_https(self, import_module):
106
import_module.return_value = MockDownloaderModule()
107
downloader.find("https")
108
downloader.find("http")
109
self.assertEqual(import_module.call_count, 1)
110
111
112
class TestDownloaderConfig(unittest.TestCase):
113
114
def setUp(self):
115
config.clear()
116
117
def tearDown(self):
118
config.clear()
119
120
def test_default_http(self):
121
job = FakeJob()
122
extr = job.extractor
123
dl = downloader.find("http")(job)
124
125
self.assertEqual(dl.adjust_extension, True)
126
self.assertEqual(dl.chunk_size, 32768)
127
self.assertEqual(dl.metadata, None)
128
self.assertEqual(dl.progress, 3.0)
129
self.assertEqual(dl.validate, True)
130
self.assertEqual(dl.headers, None)
131
self.assertEqual(dl.minsize, None)
132
self.assertEqual(dl.maxsize, None)
133
self.assertEqual(dl.mtime, True)
134
self.assertEqual(dl.rate, None)
135
self.assertEqual(dl.part, True)
136
self.assertEqual(dl.partdir, None)
137
138
self.assertIs(dl.interval_429, extr._interval_429)
139
self.assertIs(dl.retry_codes, extr._retry_codes)
140
self.assertIs(dl.retries, extr._retries)
141
self.assertIs(dl.timeout, extr._timeout)
142
self.assertIs(dl.proxies, extr._proxies)
143
self.assertIs(dl.verify, extr._verify)
144
145
def test_config_http(self):
146
config.set((), "rate", 42)
147
config.set((), "mtime", False)
148
config.set((), "headers", {"foo": "bar"})
149
config.set(("downloader",), "retries", -1)
150
config.set(("downloader", "http"), "filesize-min", "10k")
151
config.set(("extractor", "generic"), "verify", False)
152
config.set(("extractor", "generic", "example.org"), "timeout", 10)
153
config.set(("extractor", "generic", "http"), "part", False)
154
config.set(
155
("extractor", "generic", "example.org", "http"), "headers", {})
156
157
job = FakeJob()
158
dl = downloader.find("http")(job)
159
160
self.assertEqual(dl.headers, {"foo": "bar"})
161
self.assertEqual(dl.minsize, 10240)
162
self.assertEqual(dl.retries, float("inf"))
163
self.assertEqual(dl.timeout, 10)
164
self.assertEqual(dl.verify, False)
165
self.assertEqual(dl.mtime, False)
166
self.assertEqual(dl.rate(), 42)
167
self.assertEqual(dl.part, False)
168
169
170
class TestDownloaderBase(unittest.TestCase):
171
172
@classmethod
173
def setUpClass(cls):
174
cls.dir = tempfile.TemporaryDirectory()
175
cls.fnum = 0
176
config.set((), "base-directory", cls.dir.name)
177
cls.job = FakeJob()
178
179
@classmethod
180
def tearDownClass(cls):
181
cls.dir.cleanup()
182
config.clear()
183
184
@classmethod
185
def _prepare_destination(cls, content=None, part=True, extension=None):
186
name = f"file-{cls.fnum}"
187
cls.fnum += 1
188
189
kwdict = {
190
"category" : "test",
191
"subcategory": "test",
192
"filename" : name,
193
"extension" : extension,
194
}
195
196
pathfmt = cls.job.pathfmt
197
pathfmt.set_directory(kwdict)
198
pathfmt.set_filename(kwdict)
199
pathfmt.build_path()
200
201
if content:
202
mode = "wb" if isinstance(content, bytes) else "w"
203
with pathfmt.open(mode) as fp:
204
fp.write(content)
205
206
return pathfmt
207
208
def _run_test(self, url, input, output,
209
extension, expected_extension=None):
210
pathfmt = self._prepare_destination(input, extension=extension)
211
success = self.downloader.download(url, pathfmt)
212
213
# test successful download
214
self.assertTrue(success, f"downloading '{url}' failed")
215
216
# test content
217
mode = "rb" if isinstance(output, bytes) else "r"
218
with pathfmt.open(mode) as fp:
219
content = fp.read()
220
self.assertEqual(content, output)
221
222
# test filename extension
223
self.assertEqual(
224
pathfmt.extension,
225
expected_extension,
226
content[0:16],
227
)
228
self.assertEqual(
229
os.path.splitext(pathfmt.realpath)[1][1:],
230
expected_extension,
231
)
232
233
234
class TestHTTPDownloader(TestDownloaderBase):
235
236
@classmethod
237
def setUpClass(cls):
238
TestDownloaderBase.setUpClass()
239
cls.downloader = downloader.find("http")(cls.job)
240
241
host = "127.0.0.1"
242
port = 0 # select random not-in-use port
243
244
try:
245
server = http.server.HTTPServer((host, port), HttpRequestHandler)
246
except OSError as exc:
247
raise unittest.SkipTest(
248
f"cannot spawn local HTTP server ({exc})")
249
250
host, port = server.server_address
251
cls.address = f"http://{host}:{port}"
252
threading.Thread(target=server.serve_forever, daemon=True).start()
253
254
def _run_test(self, ext, input, output,
255
extension, expected_extension=None):
256
TestDownloaderBase._run_test(
257
self, f"{self.address}/{ext}", input, output,
258
extension, expected_extension)
259
260
def tearDown(self):
261
self.downloader.minsize = self.downloader.maxsize = None
262
263
def test_http_download(self):
264
self._run_test("jpg", None, DATA["jpg"], "jpg", "jpg")
265
self._run_test("png", None, DATA["png"], "png", "png")
266
self._run_test("gif", None, DATA["gif"], "gif", "gif")
267
268
def test_http_offset(self):
269
self._run_test("jpg", DATA["jpg"][:123], DATA["jpg"], "jpg", "jpg")
270
self._run_test("png", DATA["png"][:12] , DATA["png"], "png", "png")
271
self._run_test("gif", DATA["gif"][:1] , DATA["gif"], "gif", "gif")
272
273
def test_http_extension(self):
274
self._run_test("jpg", None, DATA["jpg"], None, "jpg")
275
self._run_test("png", None, DATA["png"], None, "png")
276
self._run_test("gif", None, DATA["gif"], None, "gif")
277
278
def test_http_adjust_extension(self):
279
self._run_test("jpg", None, DATA["jpg"], "png", "jpg")
280
self._run_test("png", None, DATA["png"], "gif", "png")
281
self._run_test("gif", None, DATA["gif"], "jpg", "gif")
282
283
def test_http_filesize_min(self):
284
url = f"{self.address}/gif"
285
pathfmt = self._prepare_destination(None, extension=None)
286
self.downloader.minsize = 100
287
with self.assertLogs(self.downloader.log, "WARNING"):
288
success = self.downloader.download(url, pathfmt)
289
self.assertTrue(success)
290
self.assertEqual(pathfmt.temppath, "")
291
292
def test_http_filesize_max(self):
293
url = f"{self.address}/jpg"
294
pathfmt = self._prepare_destination(None, extension=None)
295
self.downloader.maxsize = 100
296
with self.assertLogs(self.downloader.log, "WARNING"):
297
success = self.downloader.download(url, pathfmt)
298
self.assertTrue(success)
299
self.assertEqual(pathfmt.temppath, "")
300
301
def test_http_empty(self):
302
url = f"{self.address}/~NUL"
303
pathfmt = self._prepare_destination(None, extension=None)
304
with self.assertLogs(self.downloader.log, "WARNING") as log_info:
305
success = self.downloader.download(url, pathfmt)
306
self.assertFalse(success)
307
self.assertEqual(log_info.output[0],
308
"WARNING:downloader.http:Empty file")
309
310
311
class TestTextDownloader(TestDownloaderBase):
312
313
@classmethod
314
def setUpClass(cls):
315
TestDownloaderBase.setUpClass()
316
cls.downloader = downloader.find("text")(cls.job)
317
318
def test_text_download(self):
319
self._run_test("text:foobar", None, "foobar", "txt", "txt")
320
321
def test_text_offset(self):
322
self._run_test("text:foobar", "foo", "foobar", "txt", "txt")
323
324
def test_text_empty(self):
325
self._run_test("text:", None, "", "txt", "txt")
326
327
328
class HttpRequestHandler(http.server.BaseHTTPRequestHandler):
329
330
def do_GET(self):
331
try:
332
output = DATA[self.path[1:]]
333
except KeyError:
334
self.send_response(404)
335
self.wfile.write(self.path.encode())
336
return
337
338
headers = {"Content-Length": len(output)}
339
340
if "Range" in self.headers:
341
status = 206
342
343
match = re.match(r"bytes=(\d+)-", self.headers["Range"])
344
start = int(match[1])
345
346
headers["Content-Range"] = \
347
f"bytes {start}-{len(output) - 1}/{len(output)}"
348
output = output[start:]
349
else:
350
status = 200
351
352
self.send_response(status)
353
for key, value in headers.items():
354
self.send_header(key, value)
355
self.end_headers()
356
self.wfile.write(output)
357
358
359
SAMPLES = {
360
("jpg" , binascii.a2b_base64(
361
"/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB"
362
"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/2wBDAQEB"
363
"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB"
364
"AQEBAQEBAQEBAQEBAQH/wAARCAABAAEDAREAAhEBAxEB/8QAFAABAAAAAAAAAAAA"
365
"AAAAAAAACv/EABQQAQAAAAAAAAAAAAAAAAAAAAD/xAAUAQEAAAAAAAAAAAAAAAAA"
366
"AAAA/8QAFBEBAAAAAAAAAAAAAAAAAAAAAP/aAAwDAQACEQMRAD8AfwD/2Q==")),
367
("png" , binascii.a2b_base64(
368
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAAAAAA6fptVAAAACklEQVQIHWP4DwAB"
369
"AQEANl9ngAAAAABJRU5ErkJggg==")),
370
("gif" , binascii.a2b_base64(
371
"R0lGODdhAQABAIAAAP///////ywAAAAAAQABAAACAkQBADs=")),
372
("bmp" , b"BM"),
373
("webp", b"RIFF????WEBP"),
374
("avif", b"????ftypavif"),
375
("avif", b"????ftypavis"),
376
("heic", b"????ftypheic"),
377
("heic", b"????ftypheim"),
378
("heic", b"????ftypheis"),
379
("heic", b"????ftypheix"),
380
("svg" , b"<?xml"),
381
("html", b"<!DOCTYPE html><html>...</html>"),
382
("html", b" \n \n\r\t\n <!DOCTYPE html><html>...</html>"),
383
("ico" , b"\x00\x00\x01\x00"),
384
("cur" , b"\x00\x00\x02\x00"),
385
("psd" , b"8BPS"),
386
("mp4" , b"????ftypmp4"),
387
("mp4" , b"????ftypavc1"),
388
("mp4" , b"????ftypiso3"),
389
("m4v" , b"????ftypM4V"),
390
("mov" , b"????ftypqt "),
391
("webm", b"\x1A\x45\xDF\xA3"),
392
("ogg" , b"OggS"),
393
("wav" , b"RIFF????WAVE"),
394
("mp3" , b"ID3"),
395
("mp3" , b"\xFF\xFB"),
396
("mp3" , b"\xFF\xF3"),
397
("mp3" , b"\xFF\xF2"),
398
("aac" , b"\xFF\xF9"),
399
("aac" , b"\xFF\xF1"),
400
("m3u8", b"#EXTM3U\n#EXT-X-STREAM-INF:PROGRAM-ID=1, BANDWIDTH=200000"),
401
("mpd" , b'<MPD xmlns="urn:mpeg:dash:schema:mpd:2011"'),
402
("zip" , b"PK\x03\x04"),
403
("zip" , b"PK\x05\x06"),
404
("zip" , b"PK\x07\x08"),
405
("rar" , b"Rar!\x1A\x07"),
406
("rar" , b"\x52\x61\x72\x21\x1A\x07"),
407
("7z" , b"\x37\x7A\xBC\xAF\x27\x1C"),
408
("pdf" , b"%PDF-"),
409
("swf" , b"FWS"),
410
("swf" , b"CWS"),
411
("blend", b"BLENDER-v303RENDH"),
412
("obj" , b"# Blender v3.2.0 OBJ File: 'foo.blend'"),
413
("clip", b"CSFCHUNK\x00\x00\x00\x00"),
414
("~NUL", b""),
415
}
416
417
418
DATA = {}
419
420
for ext, content in SAMPLES:
421
if ext not in DATA:
422
DATA[ext] = content
423
424
for idx, (_, content) in enumerate(SAMPLES):
425
DATA[f"S{idx:>02}"] = content
426
427
428
# reverse mime types mapping
429
MIME_TYPES = {
430
ext: mtype
431
for mtype, ext in MIME_TYPES.items()
432
}
433
434
435
def generate_tests():
436
def generate_test(idx, ext, content):
437
def test(self):
438
self._run_test(f"S{idx:>02}", None, content, "bin", ext)
439
test.__name__ = f"test_http_ext_{idx:>02}_{ext}"
440
return test
441
442
for idx, (ext, content) in enumerate(SAMPLES):
443
if ext[0].isalnum():
444
test = generate_test(idx, ext, content)
445
setattr(TestHTTPDownloader, test.__name__, test)
446
447
448
generate_tests()
449
if __name__ == "__main__":
450
unittest.main()
451
452