Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
StevenBlack
GitHub Repository: StevenBlack/hosts
Path: blob/master/testUpdateHostsFile.py
1181 views
1
#!/usr/bin/env python
2
3
# Script by gfyoung
4
# https://github.com/gfyoung
5
#
6
# Python script for testing updateHostFiles.py
7
8
import json
9
import locale
10
import os
11
import platform
12
import re
13
import shutil
14
import sys
15
import tempfile
16
import unittest
17
import unittest.mock as mock
18
from io import BytesIO, StringIO
19
20
import requests
21
22
import updateHostsFile
23
from updateHostsFile import (
24
Colors,
25
colorize,
26
display_exclusion_options,
27
domain_to_idna,
28
exclude_domain,
29
flush_dns_cache,
30
gather_custom_exclusions,
31
get_defaults,
32
get_file_by_url,
33
is_valid_user_provided_domain_format,
34
matches_exclusions,
35
move_hosts_file_into_place,
36
normalize_rule,
37
path_join_robust,
38
print_failure,
39
print_success,
40
prompt_for_exclusions,
41
prompt_for_flush_dns_cache,
42
prompt_for_move,
43
prompt_for_update,
44
query_yes_no,
45
recursive_glob,
46
remove_old_hosts_file,
47
sort_sources,
48
strip_rule,
49
supports_color,
50
update_all_sources,
51
update_readme_data,
52
update_sources_data,
53
write_data,
54
write_opening_header,
55
)
56
57
unicode = str
58
59
60
# Test Helper Objects
61
class Base(unittest.TestCase):
62
@staticmethod
63
def mock_property(name):
64
return mock.patch(name, new_callable=mock.PropertyMock)
65
66
@property
67
def sep(self):
68
if platform.system().lower() == "windows":
69
return "\\"
70
return os.sep
71
72
def assert_called_once(self, mock_method):
73
self.assertEqual(mock_method.call_count, 1)
74
75
76
class BaseStdout(Base):
77
def setUp(self):
78
sys.stdout = StringIO()
79
80
def tearDown(self):
81
sys.stdout.close()
82
sys.stdout = sys.__stdout__
83
84
85
class BaseMockDir(Base):
86
@property
87
def dir_count(self):
88
return len(os.listdir(self.test_dir))
89
90
def setUp(self):
91
self.test_dir = tempfile.mkdtemp()
92
93
def tearDown(self):
94
shutil.rmtree(self.test_dir)
95
96
97
# End Test Helper Objects
98
99
100
# Project Settings
101
class TestGetDefaults(Base):
102
def test_get_defaults(self):
103
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
104
updateHostsFile.BASEDIR_PATH = "foo"
105
actual = get_defaults()
106
expected = {
107
"numberofrules": 0,
108
"datapath": "foo" + self.sep + "data",
109
"freshen": True,
110
"replace": False,
111
"backup": False,
112
"skipstatichosts": False,
113
"keepdomaincomments": True,
114
"extensionspath": "foo" + self.sep + "extensions",
115
"extensions": [],
116
"nounifiedhosts": False,
117
"compress": False,
118
"minimise": False,
119
"outputsubfolder": "",
120
"hostfilename": "hosts",
121
"targetip": "0.0.0.0",
122
"sourcedatafilename": "update.json",
123
"sourcesdata": [],
124
"readmefilename": "readme.md",
125
"readmetemplate": ("foo" + self.sep + "readme_template.md"),
126
"readmedata": {},
127
"readmedatafilename": ("foo" + self.sep + "readmeData.json"),
128
"exclusionpattern": r"([a-zA-Z\d-]+\.){0,}",
129
"exclusionregexes": [],
130
"exclusions": [],
131
"commonexclusions": ["hulu.com"],
132
"blacklistfile": "foo" + self.sep + "blacklist",
133
"whitelistfile": "foo" + self.sep + "whitelist",
134
}
135
self.assertDictEqual(actual, expected)
136
137
138
# End Project Settings
139
140
141
class TestSortSources(Base):
142
def test_sort_sources_simple(self):
143
given = [
144
"sbc.io",
145
"example.com",
146
"github.com",
147
]
148
149
expected = ["example.com", "github.com", "sbc.io"]
150
151
actual = sort_sources(given)
152
153
self.assertEqual(actual, expected)
154
155
def test_live_data(self):
156
given = [
157
"data/KADhosts/update.json",
158
"data/someonewhocares.org/update.json",
159
"data/StevenBlack/update.json",
160
"data/adaway.org/update.json",
161
"data/URLHaus/update.json",
162
"data/UncheckyAds/update.json",
163
"data/add.2o7Net/update.json",
164
"data/mvps.org/update.json",
165
"data/add.Spam/update.json",
166
"data/add.Dead/update.json",
167
"data/malwaredomainlist.com/update.json",
168
"data/Badd-Boyz-Hosts/update.json",
169
"data/hostsVN/update.json",
170
"data/yoyo.org/update.json",
171
"data/add.Risk/update.json",
172
"data/tiuxo/update.json",
173
"extensions/gambling/update.json",
174
"extensions/porn/clefspeare13/update.json",
175
"extensions/porn/sinfonietta-snuff/update.json",
176
"extensions/porn/tiuxo/update.json",
177
"extensions/porn/sinfonietta/update.json",
178
"extensions/fakenews/update.json",
179
"extensions/social/tiuxo/update.json",
180
"extensions/social/sinfonietta/update.json",
181
]
182
183
expected = [
184
"data/StevenBlack/update.json",
185
"data/adaway.org/update.json",
186
"data/add.2o7Net/update.json",
187
"data/add.Dead/update.json",
188
"data/add.Risk/update.json",
189
"data/add.Spam/update.json",
190
"data/Badd-Boyz-Hosts/update.json",
191
"data/hostsVN/update.json",
192
"data/KADhosts/update.json",
193
"data/malwaredomainlist.com/update.json",
194
"data/mvps.org/update.json",
195
"data/someonewhocares.org/update.json",
196
"data/tiuxo/update.json",
197
"data/UncheckyAds/update.json",
198
"data/URLHaus/update.json",
199
"data/yoyo.org/update.json",
200
"extensions/fakenews/update.json",
201
"extensions/gambling/update.json",
202
"extensions/porn/clefspeare13/update.json",
203
"extensions/porn/sinfonietta/update.json",
204
"extensions/porn/sinfonietta-snuff/update.json",
205
"extensions/porn/tiuxo/update.json",
206
"extensions/social/sinfonietta/update.json",
207
"extensions/social/tiuxo/update.json",
208
]
209
210
actual = sort_sources(given)
211
212
self.assertEqual(actual, expected)
213
214
215
# Prompt the User
216
class TestPromptForUpdate(BaseStdout, BaseMockDir):
217
def setUp(self):
218
BaseStdout.setUp(self)
219
BaseMockDir.setUp(self)
220
221
def test_no_freshen_no_new_file(self):
222
hostsfile = os.path.join(self.test_dir, "hosts")
223
hosts_data = "This data should not be overwritten"
224
225
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
226
updateHostsFile.BASEDIR_PATH = self.test_dir
227
228
with open(hostsfile, "w") as f:
229
f.write(hosts_data)
230
231
for updateauto in (False, True):
232
dir_count = self.dir_count
233
prompt_for_update(freshen=False, updateauto=updateauto)
234
235
output = sys.stdout.getvalue()
236
self.assertEqual(output, "")
237
238
sys.stdout = StringIO()
239
240
self.assertEqual(self.dir_count, dir_count)
241
242
with open(hostsfile, "r") as f:
243
contents = f.read()
244
self.assertEqual(contents, hosts_data)
245
246
def test_no_freshen_new_file(self):
247
hostsfile = os.path.join(self.test_dir, "hosts")
248
249
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
250
updateHostsFile.BASEDIR_PATH = self.test_dir
251
252
dir_count = self.dir_count
253
prompt_for_update(freshen=False, updateauto=False)
254
255
output = sys.stdout.getvalue()
256
self.assertEqual(output, "")
257
258
sys.stdout = StringIO()
259
260
self.assertEqual(self.dir_count, dir_count + 1)
261
262
with open(hostsfile, "r") as f:
263
contents = f.read()
264
self.assertEqual(contents, "")
265
266
@mock.patch("builtins.open")
267
def test_no_freshen_fail_new_file(self, mock_open):
268
for exc in (IOError, OSError):
269
mock_open.side_effect = exc("failed open")
270
271
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
272
updateHostsFile.BASEDIR_PATH = self.test_dir
273
prompt_for_update(freshen=False, updateauto=False)
274
275
output = sys.stdout.getvalue()
276
expected = (
277
"ERROR: No 'hosts' file in the folder. "
278
"Try creating one manually."
279
)
280
self.assertIn(expected, output)
281
282
sys.stdout = StringIO()
283
284
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
285
def test_freshen_no_update(self, _):
286
hostsfile = os.path.join(self.test_dir, "hosts")
287
hosts_data = "This data should not be overwritten"
288
289
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
290
updateHostsFile.BASEDIR_PATH = self.test_dir
291
292
with open(hostsfile, "w") as f:
293
f.write(hosts_data)
294
295
dir_count = self.dir_count
296
297
updatesources = prompt_for_update(freshen=True, updateauto=False)
298
self.assertFalse(updatesources)
299
300
output = sys.stdout.getvalue()
301
expected = "OK, we'll stick with what we've got locally."
302
self.assertIn(expected, output)
303
304
sys.stdout = StringIO()
305
306
self.assertEqual(self.dir_count, dir_count)
307
308
with open(hostsfile, "r") as f:
309
contents = f.read()
310
self.assertEqual(contents, hosts_data)
311
312
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
313
def test_freshen_update(self, _):
314
hostsfile = os.path.join(self.test_dir, "hosts")
315
hosts_data = "This data should not be overwritten"
316
317
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
318
updateHostsFile.BASEDIR_PATH = self.test_dir
319
320
with open(hostsfile, "w") as f:
321
f.write(hosts_data)
322
323
dir_count = self.dir_count
324
325
for updateauto in (False, True):
326
updatesources = prompt_for_update(
327
freshen=True, updateauto=updateauto
328
)
329
self.assertTrue(updatesources)
330
331
output = sys.stdout.getvalue()
332
self.assertEqual(output, "")
333
334
sys.stdout = StringIO()
335
336
self.assertEqual(self.dir_count, dir_count)
337
338
with open(hostsfile, "r") as f:
339
contents = f.read()
340
self.assertEqual(contents, hosts_data)
341
342
def tearDown(self):
343
BaseStdout.tearDown(self)
344
345
346
class TestPromptForExclusions(BaseStdout):
347
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
348
def testSkipPrompt(self, mock_query):
349
gatherexclusions = prompt_for_exclusions(skipprompt=True)
350
self.assertFalse(gatherexclusions)
351
352
output = sys.stdout.getvalue()
353
self.assertEqual(output, "")
354
355
mock_query.assert_not_called()
356
357
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
358
def testNoSkipPromptNoDisplay(self, mock_query):
359
gatherexclusions = prompt_for_exclusions(skipprompt=False)
360
self.assertFalse(gatherexclusions)
361
362
output = sys.stdout.getvalue()
363
expected = "OK, we'll only exclude domains in the whitelist."
364
self.assertIn(expected, output)
365
366
self.assert_called_once(mock_query)
367
368
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
369
def testNoSkipPromptDisplay(self, mock_query):
370
gatherexclusions = prompt_for_exclusions(skipprompt=False)
371
self.assertTrue(gatherexclusions)
372
373
output = sys.stdout.getvalue()
374
self.assertEqual(output, "")
375
376
self.assert_called_once(mock_query)
377
378
379
class TestPromptForFlushDnsCache(Base):
380
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
381
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
382
def testFlushCache(self, mock_query, mock_flush):
383
for promptflush in (False, True):
384
prompt_for_flush_dns_cache(flushcache=True, promptflush=promptflush)
385
386
mock_query.assert_not_called()
387
self.assert_called_once(mock_flush)
388
389
mock_query.reset_mock()
390
mock_flush.reset_mock()
391
392
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
393
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
394
def testNoFlushCacheNoPrompt(self, mock_query, mock_flush):
395
prompt_for_flush_dns_cache(flushcache=False, promptflush=False)
396
397
mock_query.assert_not_called()
398
mock_flush.assert_not_called()
399
400
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
401
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
402
def testNoFlushCachePromptNoFlush(self, mock_query, mock_flush):
403
prompt_for_flush_dns_cache(flushcache=False, promptflush=True)
404
405
self.assert_called_once(mock_query)
406
mock_flush.assert_not_called()
407
408
@mock.patch("updateHostsFile.flush_dns_cache", return_value=0)
409
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
410
def testNoFlushCachePromptFlush(self, mock_query, mock_flush):
411
prompt_for_flush_dns_cache(flushcache=False, promptflush=True)
412
413
self.assert_called_once(mock_query)
414
self.assert_called_once(mock_flush)
415
416
417
class TestPromptForMove(Base):
418
def setUp(self):
419
Base.setUp(self)
420
self.final_file = "final.txt"
421
422
def prompt_for_move(self, **move_params):
423
return prompt_for_move(self.final_file, **move_params)
424
425
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
426
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
427
def testSkipStaticHosts(self, mock_query, mock_move):
428
for replace in (False, True):
429
for auto in (False, True):
430
move_file = self.prompt_for_move(
431
replace=replace, auto=auto, skipstatichosts=True
432
)
433
self.assertFalse(move_file)
434
435
mock_query.assert_not_called()
436
mock_move.assert_not_called()
437
438
mock_query.reset_mock()
439
mock_move.reset_mock()
440
441
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
442
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
443
def testReplaceNoSkipStaticHosts(self, mock_query, mock_move):
444
for auto in (False, True):
445
move_file = self.prompt_for_move(
446
replace=True, auto=auto, skipstatichosts=False
447
)
448
self.assertFalse(move_file)
449
450
mock_query.assert_not_called()
451
self.assert_called_once(mock_move)
452
453
mock_query.reset_mock()
454
mock_move.reset_mock()
455
456
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
457
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
458
def testAutoNoSkipStaticHosts(self, mock_query, mock_move):
459
for replace in (False, True):
460
move_file = self.prompt_for_move(
461
replace=replace, auto=True, skipstatichosts=True
462
)
463
self.assertFalse(move_file)
464
465
mock_query.assert_not_called()
466
mock_move.assert_not_called()
467
468
mock_query.reset_mock()
469
mock_move.reset_mock()
470
471
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
472
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
473
def testPromptNoMove(self, mock_query, mock_move):
474
move_file = self.prompt_for_move(
475
replace=False, auto=False, skipstatichosts=False
476
)
477
self.assertFalse(move_file)
478
479
self.assert_called_once(mock_query)
480
mock_move.assert_not_called()
481
482
@mock.patch("updateHostsFile.move_hosts_file_into_place", return_value=0)
483
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
484
def testPromptMove(self, mock_query, mock_move):
485
move_file = self.prompt_for_move(
486
replace=False, auto=False, skipstatichosts=False
487
)
488
self.assertFalse(move_file)
489
490
self.assert_called_once(mock_query)
491
self.assert_called_once(mock_move)
492
493
494
# End Prompt the User
495
496
497
# Exclusion Logic
498
class TestDisplayExclusionsOptions(Base):
499
@mock.patch("updateHostsFile.query_yes_no", return_value=0)
500
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
501
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
502
def test_no_exclusions(self, mock_gather, mock_exclude, _):
503
common_exclusions = []
504
display_exclusion_options(common_exclusions, "foo", [])
505
506
mock_gather.assert_not_called()
507
mock_exclude.assert_not_called()
508
509
@mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 1, 0])
510
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
511
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
512
def test_only_common_exclusions(self, mock_gather, mock_exclude, _):
513
common_exclusions = ["foo", "bar"]
514
display_exclusion_options(common_exclusions, "foo", [])
515
516
mock_gather.assert_not_called()
517
518
exclude_calls = [mock.call("foo", "foo", []), mock.call("bar", "foo", None)]
519
mock_exclude.assert_has_calls(exclude_calls)
520
521
@mock.patch("updateHostsFile.query_yes_no", side_effect=[0, 0, 1])
522
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
523
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
524
def test_gatherexclusions(self, mock_gather, mock_exclude, _):
525
common_exclusions = ["foo", "bar"]
526
display_exclusion_options(common_exclusions, "foo", [])
527
528
mock_exclude.assert_not_called()
529
self.assert_called_once(mock_gather)
530
531
@mock.patch("updateHostsFile.query_yes_no", side_effect=[1, 0, 1])
532
@mock.patch("updateHostsFile.exclude_domain", return_value=None)
533
@mock.patch("updateHostsFile.gather_custom_exclusions", return_value=None)
534
def test_mixture_gatherexclusions(self, mock_gather, mock_exclude, _):
535
common_exclusions = ["foo", "bar"]
536
display_exclusion_options(common_exclusions, "foo", [])
537
538
mock_exclude.assert_called_once_with("foo", "foo", [])
539
self.assert_called_once(mock_gather)
540
541
542
class TestGatherCustomExclusions(BaseStdout):
543
544
# Can only test in the invalid domain case
545
# because of the settings global variable.
546
@mock.patch("updateHostsFile.input", side_effect=["foo", "no"])
547
@mock.patch(
548
"updateHostsFile.is_valid_user_provided_domain_format", return_value=False
549
)
550
def test_basic(self, *_):
551
gather_custom_exclusions("foo", [])
552
553
expected = "Do you have more domains you want to enter? [Y/n]"
554
output = sys.stdout.getvalue()
555
self.assertIn(expected, output)
556
557
@mock.patch("updateHostsFile.input", side_effect=["foo", "yes", "bar", "no"])
558
@mock.patch(
559
"updateHostsFile.is_valid_user_provided_domain_format", return_value=False
560
)
561
def test_multiple(self, *_):
562
gather_custom_exclusions("foo", [])
563
564
expected = (
565
"Do you have more domains you want to enter? [Y/n] "
566
"Do you have more domains you want to enter? [Y/n]"
567
)
568
output = sys.stdout.getvalue()
569
self.assertIn(expected, output)
570
571
572
class TestExcludeDomain(Base):
573
def test_invalid_exclude_domain(self):
574
exclusion_regexes = []
575
exclusion_pattern = "*.com"
576
577
for domain in ["google.com", "hulu.com", "adaway.org"]:
578
self.assertRaises(
579
re.error, exclude_domain, domain, exclusion_pattern, exclusion_regexes
580
)
581
582
self.assertListEqual(exclusion_regexes, [])
583
584
def test_valid_exclude_domain(self):
585
exp_count = 0
586
expected_regexes = []
587
exclusion_regexes = []
588
exclusion_pattern = r"[a-z]\."
589
590
for domain in ["google.com", "hulu.com", "adaway.org"]:
591
self.assertEqual(len(exclusion_regexes), exp_count)
592
593
exclusion_regexes = exclude_domain(
594
domain, exclusion_pattern, exclusion_regexes
595
)
596
expected_regex = re.compile(exclusion_pattern + domain)
597
598
expected_regexes.append(expected_regex)
599
exp_count += 1
600
601
self.assertEqual(len(exclusion_regexes), exp_count)
602
self.assertListEqual(exclusion_regexes, expected_regexes)
603
604
605
class TestMatchesExclusions(Base):
606
def test_no_match_empty_list(self):
607
exclusion_regexes = []
608
609
for domain in [
610
"1.2.3.4 localhost",
611
"5.6.7.8 hulu.com",
612
"9.1.2.3 yahoo.com",
613
"4.5.6.7 cloudfront.net",
614
]:
615
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
616
617
def test_no_match_list(self):
618
exclusion_regexes = [r".*\.org", r".*\.edu"]
619
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
620
621
for domain in [
622
"1.2.3.4 localhost",
623
"5.6.7.8 hulu.com",
624
"9.1.2.3 yahoo.com",
625
"4.5.6.7 cloudfront.net",
626
]:
627
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
628
629
def test_match_list(self):
630
exclusion_regexes = [r".*\.com", r".*\.org", r".*\.edu"]
631
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
632
633
for domain in [
634
"5.6.7.8 hulu.com",
635
"9.1.2.3 yahoo.com",
636
"4.5.6.7 adaway.org",
637
"8.9.1.2 education.edu",
638
]:
639
self.assertTrue(matches_exclusions(domain, exclusion_regexes))
640
641
def test_match_raw_list(self):
642
exclusion_regexes = [r".*\.com", r".*\.org", r".*\.edu", r".*@.*"]
643
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
644
645
for domain in [
646
"hulu.com",
647
"yahoo.com",
648
"adaway.org",
649
"education.edu",
650
"[email protected]",
651
]:
652
self.assertTrue(matches_exclusions(domain, exclusion_regexes))
653
654
def test_no_match_raw_list(self):
655
exclusion_regexes = [r".*\.org", r".*\.edu"]
656
exclusion_regexes = [re.compile(regex) for regex in exclusion_regexes]
657
658
for domain in [
659
"localhost",
660
"hulu.com",
661
"yahoo.com",
662
"cloudfront.net",
663
]:
664
self.assertFalse(matches_exclusions(domain, exclusion_regexes))
665
666
667
# End Exclusion Logic
668
669
670
# Update Logic
671
class TestUpdateSourcesData(Base):
672
def setUp(self):
673
Base.setUp(self)
674
675
self.datapath = "data"
676
self.extensionspath = "extensions"
677
self.source_data_filename = "update.json"
678
679
self.update_kwargs = dict(
680
datapath=self.datapath,
681
extensionspath=self.extensionspath,
682
sourcedatafilename=self.source_data_filename,
683
nounifiedhosts=False,
684
)
685
686
def update_sources_data(self, sources_data, extensions):
687
return update_sources_data(
688
sources_data[:], extensions=extensions, **self.update_kwargs
689
)
690
691
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
692
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
693
@mock.patch("builtins.open", return_value=mock.Mock())
694
def test_no_update(self, mock_open, mock_join_robust, _):
695
extensions = []
696
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
697
698
new_sources_data = self.update_sources_data(sources_data, extensions)
699
self.assertEqual(new_sources_data, sources_data)
700
mock_join_robust.assert_not_called()
701
mock_open.assert_not_called()
702
703
extensions = [".json", ".txt"]
704
new_sources_data = self.update_sources_data(sources_data, extensions)
705
706
self.assertEqual(new_sources_data, sources_data)
707
join_calls = [
708
mock.call(self.extensionspath, ".json"),
709
mock.call(self.extensionspath, ".txt"),
710
]
711
mock_join_robust.assert_has_calls(join_calls)
712
mock_open.assert_not_called()
713
714
@mock.patch(
715
"updateHostsFile.recursive_glob",
716
side_effect=[[], ["update1.txt", "update2.txt"]],
717
)
718
@mock.patch("json.load", return_value={"mock_source": "mock_source.ext"})
719
@mock.patch("builtins.open", return_value=mock.Mock())
720
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
721
def test_update_only_extensions(self, mock_join_robust, *_):
722
extensions = [".json"]
723
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
724
new_sources_data = self.update_sources_data(sources_data, extensions)
725
726
expected = sources_data + [{"mock_source": "mock_source.ext"}] * 2
727
self.assertEqual(new_sources_data, expected)
728
self.assert_called_once(mock_join_robust)
729
730
@mock.patch(
731
"updateHostsFile.recursive_glob",
732
side_effect=[["update1.txt", "update2.txt"], ["update3.txt", "update4.txt"]],
733
)
734
@mock.patch(
735
"json.load",
736
side_effect=[
737
{"mock_source": "mock_source.txt"},
738
{"mock_source": "mock_source2.txt"},
739
{"mock_source": "mock_source3.txt"},
740
{"mock_source": "mock_source4.txt"},
741
],
742
)
743
@mock.patch("builtins.open", return_value=mock.Mock())
744
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
745
def test_update_both_pathways(self, mock_join_robust, *_):
746
extensions = [".json"]
747
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
748
new_sources_data = self.update_sources_data(sources_data, extensions)
749
750
expected = sources_data + [
751
{"mock_source": "mock_source.txt"},
752
{"mock_source": "mock_source2.txt"},
753
{"mock_source": "mock_source3.txt"},
754
{"mock_source": "mock_source4.txt"},
755
]
756
self.assertEqual(new_sources_data, expected)
757
self.assert_called_once(mock_join_robust)
758
759
760
class TestUpdateAllSources(BaseStdout):
761
def setUp(self):
762
BaseStdout.setUp(self)
763
764
self.source_data_filename = "data.json"
765
self.hostfilename = "hosts.txt"
766
767
@mock.patch("builtins.open")
768
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
769
def test_no_sources(self, _, mock_open):
770
update_all_sources(self.source_data_filename, self.hostfilename)
771
mock_open.assert_not_called()
772
773
@mock.patch("builtins.open", return_value=mock.Mock())
774
@mock.patch("json.load", return_value={"url": "example.com"})
775
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
776
@mock.patch("updateHostsFile.write_data", return_value=0)
777
@mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
778
def test_one_source(self, mock_get, mock_write, *_):
779
update_all_sources(self.source_data_filename, self.hostfilename)
780
self.assert_called_once(mock_write)
781
self.assert_called_once(mock_get)
782
783
output = sys.stdout.getvalue()
784
expected = "Updating source from example.com"
785
786
self.assertIn(expected, output)
787
788
@mock.patch("builtins.open", return_value=mock.Mock())
789
@mock.patch("json.load", return_value={"url": "example.com"})
790
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
791
@mock.patch("updateHostsFile.write_data", return_value=0)
792
@mock.patch("updateHostsFile.get_file_by_url", return_value=Exception("fail"))
793
def test_source_fail(self, mock_get, mock_write, *_):
794
update_all_sources(self.source_data_filename, self.hostfilename)
795
mock_write.assert_not_called()
796
self.assert_called_once(mock_get)
797
798
output = sys.stdout.getvalue()
799
expecteds = [
800
"Updating source from example.com",
801
"Error in updating source: example.com",
802
]
803
for expected in expecteds:
804
self.assertIn(expected, output)
805
806
@mock.patch("builtins.open", return_value=mock.Mock())
807
@mock.patch(
808
"json.load", side_effect=[{"url": "example.com"}, {"url": "example2.com"}]
809
)
810
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
811
@mock.patch("updateHostsFile.write_data", return_value=0)
812
@mock.patch(
813
"updateHostsFile.get_file_by_url", side_effect=[Exception("fail"), "file_data"]
814
)
815
def test_sources_fail_succeed(self, mock_get, mock_write, *_):
816
update_all_sources(self.source_data_filename, self.hostfilename)
817
self.assert_called_once(mock_write)
818
819
get_calls = [mock.call("example.com"), mock.call("example2.com")]
820
mock_get.assert_has_calls(get_calls)
821
822
output = sys.stdout.getvalue()
823
expecteds = [
824
"Updating source from example.com",
825
"Error in updating source: example.com",
826
"Updating source from example2.com",
827
]
828
for expected in expecteds:
829
self.assertIn(expected, output)
830
831
832
# End Update Logic
833
834
835
# File Logic
836
class TestNormalizeRule(BaseStdout):
837
def test_no_match(self):
838
kwargs = dict(targetip="0.0.0.0", keep_domain_comments=False)
839
840
# Note: "Bare"- Domains are accepted. IP are excluded.
841
for rule in [
842
"128.0.0.1",
843
"::1",
844
"0.0.0.0 128.0.0.2",
845
"0.1.2.3 foo/bar",
846
"0.3.4.5 example.org/hello/world",
847
"0.0.0.0 https",
848
"0.0.0.0 https..",
849
"0.0.0.0 foo.",
850
]:
851
self.assertEqual(normalize_rule(rule, **kwargs), (None, None))
852
853
output = sys.stdout.getvalue()
854
sys.stdout = StringIO()
855
856
expected = "==>" + rule + "<=="
857
self.assertIn(expected, output)
858
859
def test_mixed_cases(self):
860
for rule, expected_target in (
861
("tWiTTer.cOM", "twitter.com"),
862
("goOgLe.Com", "google.com"),
863
("FoO.bAR.edu", "foo.bar.edu"),
864
):
865
expected = (expected_target, "0.0.0.0 " + expected_target + "\n")
866
867
actual = normalize_rule(
868
rule, targetip="0.0.0.0", keep_domain_comments=False
869
)
870
self.assertEqual(actual, expected)
871
872
# Nothing gets printed if there's a match.
873
output = sys.stdout.getvalue()
874
self.assertEqual(output, "")
875
876
sys.stdout = StringIO()
877
878
def test_no_comments(self):
879
for targetip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
880
rule = "127.0.0.1 1.google.com foo"
881
expected = ("1.google.com", str(targetip) + " 1.google.com\n")
882
883
actual = normalize_rule(
884
rule, targetip=targetip, keep_domain_comments=False
885
)
886
self.assertEqual(actual, expected)
887
888
# Nothing gets printed if there's a match.
889
output = sys.stdout.getvalue()
890
self.assertEqual(output, "")
891
892
sys.stdout = StringIO()
893
894
def test_with_comments(self):
895
for targetip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
896
for comment in ("foo", "bar", "baz"):
897
rule = "127.0.0.1 1.google.co.uk " + comment
898
expected = (
899
"1.google.co.uk",
900
(str(targetip) + " 1.google.co.uk # " + comment + "\n"),
901
)
902
903
actual = normalize_rule(
904
rule, targetip=targetip, keep_domain_comments=True
905
)
906
self.assertEqual(actual, expected)
907
908
# Nothing gets printed if there's a match.
909
output = sys.stdout.getvalue()
910
self.assertEqual(output, "")
911
912
sys.stdout = StringIO()
913
914
def test_two_ips(self):
915
for targetip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
916
rule = "127.0.0.1 11.22.33.44 foo"
917
918
actual = normalize_rule(
919
rule, targetip=targetip, keep_domain_comments=False
920
)
921
self.assertEqual(actual, (None, None))
922
923
output = sys.stdout.getvalue()
924
925
expected = "==>" + rule + "<=="
926
self.assertIn(expected, output)
927
928
sys.stdout = StringIO()
929
930
def test_no_comment_raw(self):
931
for rule in (
932
"twitter.com",
933
"google.com",
934
"foo.bar.edu",
935
"www.example-foo.bar.edu",
936
"www.example-3045.foobar.com",
937
"www.example.xn--p1ai"
938
):
939
expected = (rule, "0.0.0.0 " + rule + "\n")
940
941
actual = normalize_rule(
942
rule, targetip="0.0.0.0", keep_domain_comments=False
943
)
944
self.assertEqual(actual, expected)
945
946
# Nothing gets printed if there's a match.
947
output = sys.stdout.getvalue()
948
self.assertEqual(output, "")
949
950
sys.stdout = StringIO()
951
952
def test_with_comments_raw(self):
953
for targetip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
954
for comment in ("foo", "bar", "baz"):
955
rule = "1.google.co.uk " + comment
956
expected = (
957
"1.google.co.uk",
958
(str(targetip) + " 1.google.co.uk # " + comment + "\n"),
959
)
960
961
actual = normalize_rule(
962
rule, targetip=targetip, keep_domain_comments=True
963
)
964
self.assertEqual(actual, expected)
965
966
# Nothing gets printed if there's a match.
967
output = sys.stdout.getvalue()
968
self.assertEqual(output, "")
969
970
sys.stdout = StringIO()
971
972
973
class TestStripRule(Base):
974
def test_strip_exactly_two(self):
975
for line in [
976
"0.0.0.0 twitter.com",
977
"127.0.0.1 facebook.com",
978
"8.8.8.8 google.com",
979
"1.2.3.4 foo.bar.edu",
980
]:
981
output = strip_rule(line)
982
self.assertEqual(output, line)
983
984
def test_strip_more_than_two(self):
985
comment = " # comments here galore"
986
987
for line in [
988
"0.0.0.0 twitter.com",
989
"127.0.0.1 facebook.com",
990
"8.8.8.8 google.com",
991
"1.2.3.4 foo.bar.edu",
992
]:
993
output = strip_rule(line + comment)
994
self.assertEqual(output, line + comment)
995
996
def test_strip_raw(self):
997
for line in [
998
"twitter.com",
999
"facebook.com",
1000
"google.com",
1001
"foo.bar.edu",
1002
]:
1003
output = strip_rule(line)
1004
self.assertEqual(output, line)
1005
1006
def test_strip_raw_with_comment(self):
1007
comment = " # comments here galore"
1008
1009
for line in [
1010
"twitter.com",
1011
"facebook.com",
1012
"google.com",
1013
"foo.bar.edu",
1014
]:
1015
output = strip_rule(line + comment)
1016
self.assertEqual(output, line + comment)
1017
1018
1019
class TestWriteOpeningHeader(BaseMockDir):
1020
def setUp(self):
1021
super(TestWriteOpeningHeader, self).setUp()
1022
self.final_file = BytesIO()
1023
1024
def test_missing_keyword(self):
1025
kwargs = dict(
1026
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=False, nounifiedhosts=False
1027
)
1028
1029
for k in kwargs.keys():
1030
bad_kwargs = kwargs.copy()
1031
bad_kwargs.pop(k)
1032
1033
self.assertRaises(
1034
KeyError, write_opening_header, self.final_file, **bad_kwargs
1035
)
1036
1037
def test_basic(self):
1038
kwargs = dict(
1039
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=True, nounifiedhosts=False
1040
)
1041
write_opening_header(self.final_file, **kwargs)
1042
1043
contents = self.final_file.getvalue()
1044
contents = contents.decode("UTF-8")
1045
1046
# Expected contents.
1047
for expected in (
1048
"# This hosts file is a merged collection",
1049
"# with a dash of crowd sourcing via GitHub",
1050
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
1051
"Fetch the latest version of this file:",
1052
"Project home page: https://github.com/StevenBlack/hosts",
1053
):
1054
self.assertIn(expected, contents)
1055
1056
# Expected non-contents.
1057
for expected in (
1058
"# Extensions added to this file:",
1059
"127.0.0.1 localhost",
1060
"127.0.0.1 local",
1061
"127.0.0.53",
1062
"127.0.1.1",
1063
):
1064
self.assertNotIn(expected, contents)
1065
1066
def test_basic_include_static_hosts(self):
1067
kwargs = dict(
1068
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=False, nounifiedhosts=False
1069
)
1070
with self.mock_property("platform.system") as obj:
1071
obj.return_value = "Windows"
1072
write_opening_header(self.final_file, **kwargs)
1073
1074
contents = self.final_file.getvalue()
1075
contents = contents.decode("UTF-8")
1076
1077
# Expected contents.
1078
for expected in (
1079
"127.0.0.1 local",
1080
"127.0.0.1 localhost",
1081
"# This hosts file is a merged collection",
1082
"# with a dash of crowd sourcing via GitHub",
1083
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
1084
"Fetch the latest version of this file:",
1085
"Project home page: https://github.com/StevenBlack/hosts",
1086
):
1087
self.assertIn(expected, contents)
1088
1089
# Expected non-contents.
1090
for expected in ("# Extensions added to this file:", "127.0.0.53", "127.0.1.1"):
1091
self.assertNotIn(expected, contents)
1092
1093
def test_basic_include_static_hosts_linux(self):
1094
kwargs = dict(
1095
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=False, nounifiedhosts=False
1096
)
1097
with self.mock_property("platform.system") as system:
1098
system.return_value = "Linux"
1099
1100
with self.mock_property("socket.gethostname") as hostname:
1101
hostname.return_value = "steven-hosts"
1102
write_opening_header(self.final_file, **kwargs)
1103
1104
contents = self.final_file.getvalue()
1105
contents = contents.decode("UTF-8")
1106
1107
# Expected contents.
1108
for expected in (
1109
"127.0.1.1",
1110
"127.0.0.53",
1111
"steven-hosts",
1112
"127.0.0.1 local",
1113
"127.0.0.1 localhost",
1114
"# This hosts file is a merged collection",
1115
"# with a dash of crowd sourcing via GitHub",
1116
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
1117
"Fetch the latest version of this file:",
1118
"Project home page: https://github.com/StevenBlack/hosts",
1119
):
1120
self.assertIn(expected, contents)
1121
1122
# Expected non-contents.
1123
expected = "# Extensions added to this file:"
1124
self.assertNotIn(expected, contents)
1125
1126
def test_extensions(self):
1127
kwargs = dict(
1128
extensions=["epsilon", "gamma", "mu", "phi"],
1129
outputsubfolder="",
1130
numberofrules=5,
1131
skipstatichosts=True,
1132
nounifiedhosts=False,
1133
)
1134
write_opening_header(self.final_file, **kwargs)
1135
1136
contents = self.final_file.getvalue()
1137
contents = contents.decode("UTF-8")
1138
1139
# Expected contents.
1140
for expected in (
1141
", ".join(kwargs["extensions"]),
1142
"# Extensions added to this file:",
1143
"# This hosts file is a merged collection",
1144
"# with a dash of crowd sourcing via GitHub",
1145
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
1146
"Fetch the latest version of this file:",
1147
"Project home page: https://github.com/StevenBlack/hosts",
1148
):
1149
self.assertIn(expected, contents)
1150
1151
# Expected non-contents.
1152
for expected in (
1153
"127.0.0.1 localhost",
1154
"127.0.0.1 local",
1155
"127.0.0.53",
1156
"127.0.1.1",
1157
):
1158
self.assertNotIn(expected, contents)
1159
1160
def test_no_unified_hosts(self):
1161
kwargs = dict(
1162
extensions=["epsilon", "gamma"],
1163
outputsubfolder="",
1164
numberofrules=5,
1165
skipstatichosts=True,
1166
nounifiedhosts=True,
1167
)
1168
write_opening_header(self.final_file, **kwargs)
1169
1170
contents = self.final_file.getvalue()
1171
contents = contents.decode("UTF-8")
1172
1173
# Expected contents.
1174
for expected in (
1175
", ".join(kwargs["extensions"]),
1176
"# The unified hosts file was not used while generating this file.",
1177
"# Extensions used to generate this file:",
1178
"# This hosts file is a merged collection",
1179
"# with a dash of crowd sourcing via GitHub",
1180
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
1181
"Fetch the latest version of this file:",
1182
"Project home page: https://github.com/StevenBlack/hosts",
1183
):
1184
self.assertIn(expected, contents)
1185
1186
# Expected non-contents.
1187
for expected in (
1188
"127.0.0.1 localhost",
1189
"127.0.0.1 local",
1190
"127.0.0.53",
1191
"127.0.1.1",
1192
):
1193
self.assertNotIn(expected, contents)
1194
1195
def _check_preamble(self, check_copy):
1196
hostsfile = os.path.join(self.test_dir, "myhosts")
1197
hostsfile += ".example" if check_copy else ""
1198
1199
with open(hostsfile, "w") as f:
1200
f.write("peter-piper-picked-a-pepper")
1201
1202
kwargs = dict(
1203
extensions="", outputsubfolder="", numberofrules=5, skipstatichosts=True, nounifiedhosts=False
1204
)
1205
1206
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
1207
updateHostsFile.BASEDIR_PATH = self.test_dir
1208
write_opening_header(self.final_file, **kwargs)
1209
1210
contents = self.final_file.getvalue()
1211
contents = contents.decode("UTF-8")
1212
1213
# Expected contents.
1214
for expected in (
1215
"peter-piper-picked-a-pepper",
1216
"# This hosts file is a merged collection",
1217
"# with a dash of crowd sourcing via GitHub",
1218
"# Number of unique domains: {count}".format(count=kwargs["numberofrules"]),
1219
"Fetch the latest version of this file:",
1220
"Project home page: https://github.com/StevenBlack/hosts",
1221
):
1222
self.assertIn(expected, contents)
1223
1224
# Expected non-contents.
1225
for expected in (
1226
"# Extensions added to this file:",
1227
"127.0.0.1 localhost",
1228
"127.0.0.1 local",
1229
"127.0.0.53",
1230
"127.0.1.1",
1231
):
1232
self.assertNotIn(expected, contents)
1233
1234
def test_preamble_exists(self):
1235
self._check_preamble(True)
1236
1237
def test_preamble_copy(self):
1238
self._check_preamble(False)
1239
1240
def tearDown(self):
1241
super(TestWriteOpeningHeader, self).tearDown()
1242
self.final_file.close()
1243
1244
1245
class TestUpdateReadmeData(BaseMockDir):
1246
def setUp(self):
1247
super(TestUpdateReadmeData, self).setUp()
1248
self.readme_file = os.path.join(self.test_dir, "readmeData.json")
1249
1250
def test_missing_keyword(self):
1251
kwargs = dict(
1252
extensions="", outputsubfolder="", numberofrules="", sourcesdata="", nounifiedhosts=False
1253
)
1254
1255
for k in kwargs.keys():
1256
bad_kwargs = kwargs.copy()
1257
bad_kwargs.pop(k)
1258
1259
self.assertRaises(
1260
KeyError, update_readme_data, self.readme_file, **bad_kwargs
1261
)
1262
1263
def test_add_fields(self):
1264
with open(self.readme_file, "w") as f:
1265
json.dump({"foo": "bar"}, f)
1266
1267
kwargs = dict(
1268
extensions=None, outputsubfolder="foo", numberofrules=5, sourcesdata="hosts", nounifiedhosts=False
1269
)
1270
update_readme_data(self.readme_file, **kwargs)
1271
1272
if platform.system().lower() == "windows":
1273
sep = "/"
1274
else:
1275
sep = self.sep
1276
1277
expected = {
1278
"base": {"location": "foo" + sep, 'nounifiedhosts': False, "sourcesdata": "hosts", "entries": 5},
1279
"foo": "bar",
1280
}
1281
1282
with open(self.readme_file, "r") as f:
1283
actual = json.load(f)
1284
self.assertEqual(actual, expected)
1285
1286
def test_modify_fields(self):
1287
with open(self.readme_file, "w") as f:
1288
json.dump({"base": "soprano"}, f)
1289
1290
kwargs = dict(
1291
extensions=None, outputsubfolder="foo", numberofrules=5, sourcesdata="hosts", nounifiedhosts=False
1292
)
1293
update_readme_data(self.readme_file, **kwargs)
1294
1295
if platform.system().lower() == "windows":
1296
sep = "/"
1297
else:
1298
sep = self.sep
1299
1300
expected = {
1301
"base": {"location": "foo" + sep, 'nounifiedhosts': False, "sourcesdata": "hosts", "entries": 5},
1302
}
1303
1304
with open(self.readme_file, "r") as f:
1305
actual = json.load(f)
1306
self.assertEqual(actual, expected)
1307
1308
def test_set_extensions(self):
1309
with open(self.readme_file, "w") as f:
1310
json.dump({}, f)
1311
1312
kwargs = dict(
1313
extensions=["com", "org"],
1314
outputsubfolder="foo",
1315
numberofrules=5,
1316
sourcesdata="hosts",
1317
nounifiedhosts=False,
1318
)
1319
update_readme_data(self.readme_file, **kwargs)
1320
1321
if platform.system().lower() == "windows":
1322
sep = "/"
1323
else:
1324
sep = self.sep
1325
1326
expected = {
1327
"com-org": {"location": "foo" + sep, 'nounifiedhosts': False, "sourcesdata": "hosts", "entries": 5}
1328
}
1329
1330
with open(self.readme_file, "r") as f:
1331
actual = json.load(f)
1332
self.assertEqual(actual, expected)
1333
1334
def test_set_no_unified_hosts(self):
1335
with open(self.readme_file, "w") as f:
1336
json.dump({}, f)
1337
1338
kwargs = dict(
1339
extensions=["com", "org"],
1340
outputsubfolder="foo",
1341
numberofrules=5,
1342
sourcesdata="hosts",
1343
nounifiedhosts=True,
1344
)
1345
update_readme_data(self.readme_file, **kwargs)
1346
1347
if platform.system().lower() == "windows":
1348
sep = "/"
1349
else:
1350
sep = self.sep
1351
1352
expected = {
1353
"com-org-only": {"location": "foo" + sep, 'nounifiedhosts': True, "sourcesdata": "hosts", "entries": 5}
1354
}
1355
1356
with open(self.readme_file, "r") as f:
1357
actual = json.load(f)
1358
self.assertEqual(actual, expected)
1359
1360
1361
class TestMoveHostsFile(BaseStdout):
1362
@mock.patch("os.path.abspath", side_effect=lambda f: f)
1363
def test_move_hosts_no_name(self, _): # TODO: Create test which tries to move actual file
1364
with self.mock_property("platform.system") as obj:
1365
obj.return_value = "foo"
1366
1367
mock_file = mock.Mock(name="foo")
1368
move_hosts_file_into_place(mock_file)
1369
1370
expected = "does not exist"
1371
output = sys.stdout.getvalue()
1372
1373
self.assertIn(expected, output)
1374
1375
@mock.patch("os.path.abspath", side_effect=lambda f: f)
1376
def test_move_hosts_windows(self, _):
1377
with self.mock_property("platform.system") as obj:
1378
obj.return_value = "Windows"
1379
1380
mock_file = mock.Mock(name="foo")
1381
move_hosts_file_into_place(mock_file)
1382
1383
expected = ""
1384
output = sys.stdout.getvalue()
1385
self.assertIn(expected, output)
1386
1387
@mock.patch("os.path.abspath", side_effect=lambda f: f)
1388
@mock.patch("subprocess.call", return_value=0)
1389
def test_move_hosts_posix(self, *_): # TODO: create test which tries to move an actual file
1390
with self.mock_property("platform.system") as obj:
1391
obj.return_value = "Linux"
1392
1393
mock_file = mock.Mock(name="foo")
1394
move_hosts_file_into_place(mock_file)
1395
1396
expected = "does not exist."
1397
output = sys.stdout.getvalue()
1398
self.assertIn(expected, output)
1399
1400
@mock.patch("os.path.abspath", side_effect=lambda f: f)
1401
@mock.patch("subprocess.call", return_value=1)
1402
def test_move_hosts_posix_fail(self, *_):
1403
with self.mock_property("platform.system") as obj:
1404
obj.return_value = "Linux"
1405
1406
mock_file = mock.Mock(name="foo")
1407
move_hosts_file_into_place(mock_file)
1408
1409
expected = "does not exist."
1410
output = sys.stdout.getvalue()
1411
self.assertIn(expected, output)
1412
1413
1414
class TestFlushDnsCache(BaseStdout):
1415
@mock.patch("subprocess.call", return_value=0)
1416
def test_flush_darwin(self, _):
1417
with self.mock_property("platform.system") as obj:
1418
obj.return_value = "Darwin"
1419
flush_dns_cache()
1420
1421
expected = (
1422
"Flushing the DNS cache to utilize new hosts "
1423
"file...\nFlushing the DNS cache requires "
1424
"administrative privileges. You might need to "
1425
"enter your password."
1426
)
1427
output = sys.stdout.getvalue()
1428
self.assertIn(expected, output)
1429
1430
@mock.patch("subprocess.call", return_value=1)
1431
def test_flush_darwin_fail(self, _):
1432
with self.mock_property("platform.system") as obj:
1433
obj.return_value = "Darwin"
1434
flush_dns_cache()
1435
1436
expected = "Flushing the DNS cache failed."
1437
output = sys.stdout.getvalue()
1438
self.assertIn(expected, output)
1439
1440
def test_flush_windows(self):
1441
with self.mock_property("platform.system") as obj:
1442
obj.return_value = "win32"
1443
1444
with self.mock_property("os.name"):
1445
os.name = "nt"
1446
flush_dns_cache()
1447
1448
expected = (
1449
"Automatically flushing the DNS cache is "
1450
"not yet supported.\nPlease copy and paste "
1451
"the command 'ipconfig /flushdns' in "
1452
"administrator command prompt after running "
1453
"this script."
1454
)
1455
output = sys.stdout.getvalue()
1456
self.assertIn(expected, output)
1457
1458
@mock.patch("os.path.isfile", return_value=False)
1459
def test_flush_no_tool(self, _):
1460
with self.mock_property("platform.system") as obj:
1461
obj.return_value = "Linux"
1462
1463
with self.mock_property("os.name"):
1464
os.name = "posix"
1465
flush_dns_cache()
1466
1467
expected = "Unable to determine DNS management tool."
1468
output = sys.stdout.getvalue()
1469
self.assertIn(expected, output)
1470
1471
@mock.patch("os.path.isfile", side_effect=[True] + [False] * 11)
1472
@mock.patch("subprocess.call", return_value=0)
1473
def test_flush_posix(self, *_):
1474
with self.mock_property("platform.system") as obj:
1475
obj.return_value = "Linux"
1476
1477
with self.mock_property("os.name"):
1478
os.name = "posix"
1479
flush_dns_cache()
1480
1481
expected = "Flushing the DNS cache by restarting nscd succeeded"
1482
output = sys.stdout.getvalue()
1483
self.assertIn(expected, output)
1484
1485
@mock.patch("os.path.isfile", side_effect=[True] + [False] * 11)
1486
@mock.patch("subprocess.call", return_value=1)
1487
def test_flush_posix_fail(self, *_):
1488
with self.mock_property("platform.system") as obj:
1489
obj.return_value = "Linux"
1490
1491
with self.mock_property("os.name"):
1492
os.name = "posix"
1493
flush_dns_cache()
1494
1495
expected = "Flushing the DNS cache by restarting nscd failed"
1496
output = sys.stdout.getvalue()
1497
self.assertIn(expected, output)
1498
1499
@mock.patch("os.path.isfile", side_effect=[True, False, False, True] + [False] * 10)
1500
@mock.patch("subprocess.call", side_effect=[1, 0, 0])
1501
def test_flush_posix_fail_then_succeed(self, *_):
1502
with self.mock_property("platform.system") as obj:
1503
obj.return_value = "Linux"
1504
1505
with self.mock_property("os.name"):
1506
os.name = "posix"
1507
flush_dns_cache()
1508
1509
output = sys.stdout.getvalue()
1510
for expected in [
1511
("Flushing the DNS cache by restarting nscd failed"),
1512
(
1513
"Flushing the DNS cache by restarting "
1514
"NetworkManager.service succeeded"
1515
),
1516
]:
1517
self.assertIn(expected, output)
1518
1519
1520
class TestRemoveOldHostsFile(BaseMockDir):
1521
def setUp(self):
1522
super(TestRemoveOldHostsFile, self).setUp()
1523
self.hostsfile = "hosts"
1524
self.full_hosts_path = os.path.join(self.test_dir, "hosts")
1525
1526
def test_remove_hosts_file(self):
1527
old_dir_count = self.dir_count
1528
1529
remove_old_hosts_file(self.test_dir, self.hostsfile, backup=False)
1530
1531
new_dir_count = old_dir_count + 1
1532
self.assertEqual(self.dir_count, new_dir_count)
1533
1534
with open(self.full_hosts_path, "r") as f:
1535
contents = f.read()
1536
self.assertEqual(contents, "")
1537
1538
def test_remove_hosts_file_exists(self):
1539
with open(self.full_hosts_path, "w") as f:
1540
f.write("foo")
1541
1542
old_dir_count = self.dir_count
1543
1544
remove_old_hosts_file(self.test_dir, self.hostsfile, backup=False)
1545
1546
new_dir_count = old_dir_count
1547
self.assertEqual(self.dir_count, new_dir_count)
1548
1549
with open(self.full_hosts_path, "r") as f:
1550
contents = f.read()
1551
self.assertEqual(contents, "")
1552
1553
@mock.patch("time.strftime", return_value="new")
1554
def test_remove_hosts_file_backup(self, _):
1555
with open(self.full_hosts_path, "w") as f:
1556
f.write("foo")
1557
1558
old_dir_count = self.dir_count
1559
1560
remove_old_hosts_file(self.test_dir, self.hostsfile, backup=True)
1561
1562
new_dir_count = old_dir_count + 1
1563
self.assertEqual(self.dir_count, new_dir_count)
1564
1565
with open(self.full_hosts_path, "r") as f:
1566
contents = f.read()
1567
self.assertEqual(contents, "")
1568
1569
new_hostsfile = self.full_hosts_path + "-new"
1570
1571
with open(new_hostsfile, "r") as f:
1572
contents = f.read()
1573
self.assertEqual(contents, "foo")
1574
1575
1576
# End File Logic
1577
1578
1579
class DomainToIDNA(Base):
1580
def __init__(self, *args, **kwargs):
1581
super(DomainToIDNA, self).__init__(*args, **kwargs)
1582
1583
self.domains = [b"\xc9\xa2oogle.com", b"www.huala\xc3\xb1e.cl"]
1584
self.expected_domains = ["xn--oogle-wmc.com", "www.xn--hualae-0wa.cl"]
1585
1586
def test_empty_line(self):
1587
data = ["", "\r", "\n"]
1588
1589
for empty in data:
1590
expected = empty
1591
1592
actual = domain_to_idna(empty)
1593
self.assertEqual(actual, expected)
1594
1595
def test_commented_line(self):
1596
data = "# Hello World"
1597
expected = data
1598
actual = domain_to_idna(data)
1599
1600
self.assertEqual(actual, expected)
1601
1602
def test_simple_line(self):
1603
# Test with a space as separator.
1604
for i in range(len(self.domains)):
1605
data = (b"0.0.0.0 " + self.domains[i]).decode("utf-8")
1606
expected = "0.0.0.0 " + self.expected_domains[i]
1607
1608
actual = domain_to_idna(data)
1609
1610
self.assertEqual(actual, expected)
1611
1612
# Test with a tabulation as separator.
1613
for i in range(len(self.domains)):
1614
data = (b"0.0.0.0\t" + self.domains[i]).decode("utf-8")
1615
expected = "0.0.0.0\t" + self.expected_domains[i]
1616
1617
actual = domain_to_idna(data)
1618
1619
self.assertEqual(actual, expected)
1620
1621
def test_multiple_space_as_separator(self):
1622
# Test with multiple space as separator.
1623
for i in range(len(self.domains)):
1624
data = (b"0.0.0.0 " + self.domains[i]).decode("utf-8")
1625
expected = "0.0.0.0 " + self.expected_domains[i]
1626
1627
actual = domain_to_idna(data)
1628
1629
self.assertEqual(actual, expected)
1630
1631
def test_multiple_tabs_as_separator(self):
1632
# Test with multiple tabls as separator.
1633
for i in range(len(self.domains)):
1634
data = (b"0.0.0.0\t\t\t\t\t\t" + self.domains[i]).decode("utf-8")
1635
expected = "0.0.0.0\t\t\t\t\t\t" + self.expected_domains[i]
1636
1637
actual = domain_to_idna(data)
1638
1639
self.assertEqual(actual, expected)
1640
1641
def test_line_with_comment_at_the_end(self):
1642
# Test with a space as separator.
1643
for i in range(len(self.domains)):
1644
data = (b"0.0.0.0 " + self.domains[i] + b" # Hello World").decode("utf-8")
1645
expected = "0.0.0.0 " + self.expected_domains[i] + " # Hello World"
1646
1647
actual = domain_to_idna(data)
1648
1649
self.assertEqual(actual, expected)
1650
1651
# Test with a tabulation as separator.
1652
for i in range(len(self.domains)):
1653
data = (b"0.0.0.0\t" + self.domains[i] + b" # Hello World").decode("utf-8")
1654
expected = "0.0.0.0\t" + self.expected_domains[i] + " # Hello World"
1655
1656
actual = domain_to_idna(data)
1657
1658
self.assertEqual(actual, expected)
1659
1660
# Test with tabulation as separator of domain and comment.
1661
for i in range(len(self.domains)):
1662
data = (b"0.0.0.0\t" + self.domains[i] + b"\t # Hello World").decode(
1663
"utf-8"
1664
)
1665
expected = "0.0.0.0\t" + self.expected_domains[i] + "\t # Hello World"
1666
1667
actual = domain_to_idna(data)
1668
1669
self.assertEqual(actual, expected)
1670
1671
# Test with space as separator of domain and tabulation as separator
1672
# of comments.
1673
for i in range(len(self.domains)):
1674
data = (b"0.0.0.0 " + self.domains[i] + b" \t # Hello World").decode(
1675
"utf-8"
1676
)
1677
expected = "0.0.0.0 " + self.expected_domains[i] + " \t # Hello World"
1678
1679
actual = domain_to_idna(data)
1680
1681
self.assertEqual(actual, expected)
1682
1683
# Test with multiple space as separator of domain and space and
1684
# tabulation as separator or comments.
1685
for i in range(len(self.domains)):
1686
data = (b"0.0.0.0 " + self.domains[i] + b" \t # Hello World").decode(
1687
"utf-8"
1688
)
1689
expected = "0.0.0.0 " + self.expected_domains[i] + " \t # Hello World"
1690
1691
actual = domain_to_idna(data)
1692
1693
self.assertEqual(actual, expected)
1694
1695
# Test with multiple tabulations as separator of domain and space and
1696
# tabulation as separator or comments.
1697
for i, domain in enumerate(self.domains):
1698
data = (b"0.0.0.0\t\t\t" + domain + b" \t # Hello World").decode(
1699
"utf-8"
1700
)
1701
expected = "0.0.0.0\t\t\t" + self.expected_domains[i] + " \t # Hello World"
1702
1703
actual = domain_to_idna(data)
1704
1705
self.assertEqual(actual, expected)
1706
1707
def test_line_without_prefix(self):
1708
for i in range(len(self.domains)):
1709
data = self.domains[i].decode("utf-8")
1710
expected = self.expected_domains[i]
1711
1712
actual = domain_to_idna(data)
1713
1714
self.assertEqual(actual, expected)
1715
1716
1717
class GetFileByUrl(BaseStdout):
1718
def test_basic(self):
1719
raw_resp_content = "hello, ".encode("ascii") + "world".encode("utf-8")
1720
resp_obj = requests.Response()
1721
resp_obj.__setstate__({"_content": raw_resp_content})
1722
1723
expected = "hello, world"
1724
1725
with mock.patch("requests.get", return_value=resp_obj):
1726
actual = get_file_by_url("www.test-url.com")
1727
1728
self.assertEqual(expected, actual)
1729
1730
def test_with_idna(self):
1731
raw_resp_content = b"www.huala\xc3\xb1e.cl"
1732
resp_obj = requests.Response()
1733
resp_obj.__setstate__({"_content": raw_resp_content})
1734
1735
expected = "www.xn--hualae-0wa.cl"
1736
1737
with mock.patch("requests.get", return_value=resp_obj):
1738
actual = get_file_by_url("www.test-url.com")
1739
1740
self.assertEqual(expected, actual)
1741
1742
def test_connect_unknown_domain(self):
1743
test_url = (
1744
"http://doesnotexist.google.com" # leads to exception: ConnectionError
1745
)
1746
with mock.patch(
1747
"requests.get", side_effect=requests.exceptions.ConnectionError
1748
):
1749
return_value = get_file_by_url(test_url)
1750
self.assertIsNone(return_value)
1751
printed_output = sys.stdout.getvalue()
1752
self.assertEqual(
1753
printed_output, "Error retrieving data from {}\n".format(test_url)
1754
)
1755
1756
def test_invalid_url(self):
1757
test_url = "http://fe80::5054:ff:fe5a:fc0" # leads to exception: InvalidURL
1758
with mock.patch(
1759
"requests.get", side_effect=requests.exceptions.ConnectionError
1760
):
1761
return_value = get_file_by_url(test_url)
1762
self.assertIsNone(return_value)
1763
printed_output = sys.stdout.getvalue()
1764
self.assertEqual(
1765
printed_output, "Error retrieving data from {}\n".format(test_url)
1766
)
1767
1768
1769
class TestWriteData(Base):
1770
def test_write_basic(self):
1771
f = BytesIO()
1772
1773
data = "foo"
1774
write_data(f, data)
1775
1776
expected = b"foo"
1777
actual = f.getvalue()
1778
1779
self.assertEqual(actual, expected)
1780
1781
def test_write_unicode(self):
1782
f = BytesIO()
1783
1784
data = u"foo"
1785
write_data(f, data)
1786
1787
expected = b"foo"
1788
actual = f.getvalue()
1789
1790
self.assertEqual(actual, expected)
1791
1792
1793
class TestQueryYesOrNo(BaseStdout):
1794
def test_invalid_default(self):
1795
for invalid_default in ["foo", "bar", "baz", 1, 2, 3]:
1796
self.assertRaises(ValueError, query_yes_no, "?", invalid_default)
1797
1798
@mock.patch("updateHostsFile.input", side_effect=["yes"] * 3)
1799
def test_valid_default(self, _):
1800
for valid_default, expected in [
1801
(None, "[y/n]"),
1802
("yes", "[Y/n]"),
1803
("no", "[y/N]"),
1804
]:
1805
self.assertTrue(query_yes_no("?", valid_default))
1806
1807
output = sys.stdout.getvalue()
1808
sys.stdout = StringIO()
1809
1810
self.assertIn(expected, output)
1811
1812
@mock.patch("updateHostsFile.input", side_effect=([""] * 2))
1813
def test_use_valid_default(self, _):
1814
for valid_default in ["yes", "no"]:
1815
expected = valid_default == "yes"
1816
actual = query_yes_no("?", valid_default)
1817
1818
self.assertEqual(actual, expected)
1819
1820
@mock.patch("updateHostsFile.input", side_effect=["no", "NO", "N", "n", "No", "nO"])
1821
def test_valid_no(self, _):
1822
self.assertFalse(query_yes_no("?", None))
1823
1824
@mock.patch(
1825
"updateHostsFile.input",
1826
side_effect=["yes", "YES", "Y", "yeS", "y", "YeS", "yES", "YEs"],
1827
)
1828
def test_valid_yes(self, _):
1829
self.assertTrue(query_yes_no("?", None))
1830
1831
@mock.patch("updateHostsFile.input", side_effect=["foo", "yes", "foo", "no"])
1832
def test_invalid_then_valid(self, _):
1833
expected = "Please respond with 'yes' or 'no'"
1834
1835
# The first time, we respond "yes"
1836
self.assertTrue(query_yes_no("?", None))
1837
1838
output = sys.stdout.getvalue()
1839
self.assertIn(expected, output)
1840
1841
sys.stdout = StringIO()
1842
1843
# The second time, we respond "no"
1844
self.assertFalse(query_yes_no("?", None))
1845
1846
output = sys.stdout.getvalue()
1847
self.assertIn(expected, output)
1848
1849
1850
class TestIsValidUserProvidedDomainFormat(BaseStdout):
1851
def test_empty_domain(self):
1852
self.assertFalse(is_valid_user_provided_domain_format(""))
1853
1854
output = sys.stdout.getvalue()
1855
expected = "You didn't enter a domain. Try again."
1856
1857
self.assertIn(expected, output)
1858
1859
def test_invalid_domain(self):
1860
expected = "Do not include www.domain.com or http(s)://domain.com. Try again."
1861
1862
for invalid_domain in [
1863
"www.subdomain.domain",
1864
"https://github.com",
1865
"http://www.google.com",
1866
]:
1867
self.assertFalse(is_valid_user_provided_domain_format(invalid_domain))
1868
1869
output = sys.stdout.getvalue()
1870
sys.stdout = StringIO()
1871
1872
self.assertIn(expected, output)
1873
1874
def test_valid_domain(self):
1875
for valid_domain in ["github.com", "travis.org", "twitter.com"]:
1876
self.assertTrue(is_valid_user_provided_domain_format(valid_domain))
1877
1878
output = sys.stdout.getvalue()
1879
sys.stdout = StringIO()
1880
1881
self.assertEqual(output, "")
1882
1883
1884
def mock_walk(stem):
1885
"""
1886
Mock method for `os.walk`.
1887
1888
Please refer to the documentation of `os.walk` for information about
1889
the provided parameters.
1890
"""
1891
1892
files = [
1893
"foo.txt",
1894
"bar.bat",
1895
"baz.py",
1896
"foo/foo.c",
1897
"foo/bar.doc",
1898
"foo/baz/foo.py",
1899
"bar/foo/baz.c",
1900
"bar/bar/foo.bat",
1901
]
1902
1903
if stem == ".":
1904
stem = ""
1905
1906
matches = []
1907
1908
for f in files:
1909
if not stem or f.startswith(stem + "/"):
1910
matches.append(("", "_", [f]))
1911
1912
return matches
1913
1914
1915
class TestRecursiveGlob(Base):
1916
@staticmethod
1917
def sorted_recursive_glob(stem, file_pattern):
1918
actual = recursive_glob(stem, file_pattern)
1919
actual.sort()
1920
1921
return actual
1922
1923
@mock.patch("os.walk", side_effect=mock_walk)
1924
def test_all_match(self, _):
1925
with self.mock_property("sys.version_info"):
1926
sys.version_info = (2, 6)
1927
1928
expected = [
1929
"bar.bat",
1930
"bar/bar/foo.bat",
1931
"bar/foo/baz.c",
1932
"baz.py",
1933
"foo.txt",
1934
"foo/bar.doc",
1935
"foo/baz/foo.py",
1936
"foo/foo.c",
1937
]
1938
actual = self.sorted_recursive_glob("*", "*")
1939
self.assertListEqual(actual, expected)
1940
1941
expected = ["bar/bar/foo.bat", "bar/foo/baz.c"]
1942
actual = self.sorted_recursive_glob("bar", "*")
1943
self.assertListEqual(actual, expected)
1944
1945
expected = ["foo/bar.doc", "foo/baz/foo.py", "foo/foo.c"]
1946
actual = self.sorted_recursive_glob("foo", "*")
1947
self.assertListEqual(actual, expected)
1948
1949
@mock.patch("os.walk", side_effect=mock_walk)
1950
def test_file_ending(self, _):
1951
with self.mock_property("sys.version_info"):
1952
sys.version_info = (2, 6)
1953
1954
expected = ["foo/baz/foo.py"]
1955
actual = self.sorted_recursive_glob("foo", "*.py")
1956
self.assertListEqual(actual, expected)
1957
1958
expected = ["bar/foo/baz.c", "foo/foo.c"]
1959
actual = self.sorted_recursive_glob("*", "*.c")
1960
self.assertListEqual(actual, expected)
1961
1962
expected = []
1963
actual = self.sorted_recursive_glob("*", ".xlsx")
1964
self.assertListEqual(actual, expected)
1965
1966
1967
def mock_path_join(*_):
1968
"""
1969
Mock method for `os.path.join`.
1970
1971
Please refer to the documentation of `os.path.join` for information about
1972
the provided parameters.
1973
"""
1974
1975
raise UnicodeDecodeError("foo", b"", 1, 5, "foo")
1976
1977
1978
class TestPathJoinRobust(Base):
1979
def test_basic(self):
1980
expected = "path1"
1981
actual = path_join_robust("path1")
1982
self.assertEqual(actual, expected)
1983
1984
actual = path_join_robust(u"path1")
1985
self.assertEqual(actual, expected)
1986
1987
def test_join(self):
1988
for i in range(1, 4):
1989
paths = ["pathNew"] * i
1990
expected = "path1" + (self.sep + "pathNew") * i
1991
actual = path_join_robust("path1", *paths)
1992
1993
self.assertEqual(actual, expected)
1994
1995
def test_join_unicode(self):
1996
for i in range(1, 4):
1997
paths = [u"pathNew"] * i
1998
expected = "path1" + (self.sep + "pathNew") * i
1999
actual = path_join_robust("path1", *paths)
2000
2001
self.assertEqual(actual, expected)
2002
2003
@mock.patch("os.path.join", side_effect=mock_path_join)
2004
def test_join_error(self, _):
2005
self.assertRaises(locale.Error, path_join_robust, "path")
2006
2007
2008
# Colors
2009
class TestSupportsColor(BaseStdout):
2010
def test_posix(self):
2011
with self.mock_property("sys.platform"):
2012
sys.platform = "Linux"
2013
2014
with self.mock_property("sys.stdout.isatty") as obj:
2015
obj.return_value = True
2016
self.assertTrue(supports_color())
2017
2018
def test_pocket_pc(self):
2019
with self.mock_property("sys.platform"):
2020
sys.platform = "Pocket PC"
2021
self.assertFalse(supports_color())
2022
2023
def test_windows_no_ansicon(self):
2024
with self.mock_property("sys.platform"):
2025
sys.platform = "win32"
2026
2027
with self.mock_property("os.environ"):
2028
os.environ = []
2029
2030
self.assertFalse(supports_color())
2031
2032
def test_windows_ansicon(self):
2033
with self.mock_property("sys.platform"):
2034
sys.platform = "win32"
2035
2036
with self.mock_property("os.environ"):
2037
os.environ = ["ANSICON"]
2038
2039
with self.mock_property("sys.stdout.isatty") as obj:
2040
obj.return_value = True
2041
self.assertTrue(supports_color())
2042
2043
def test_no_isatty_attribute(self):
2044
with self.mock_property("sys.platform"):
2045
sys.platform = "Linux"
2046
2047
with self.mock_property("sys.stdout"):
2048
sys.stdout = list()
2049
self.assertFalse(supports_color())
2050
2051
def test_no_isatty(self):
2052
with self.mock_property("sys.platform"):
2053
sys.platform = "Linux"
2054
2055
with self.mock_property("sys.stdout.isatty") as obj:
2056
obj.return_value = False
2057
self.assertFalse(supports_color())
2058
2059
2060
class TestColorize(Base):
2061
def setUp(self):
2062
self.text = "house"
2063
self.colors = ["red", "orange", "yellow", "green", "blue", "purple"]
2064
2065
@mock.patch("updateHostsFile.supports_color", return_value=False)
2066
def test_colorize_no_support(self, _):
2067
for color in self.colors:
2068
expected = self.text
2069
actual = colorize(self.text, color)
2070
2071
self.assertEqual(actual, expected)
2072
2073
@mock.patch("updateHostsFile.supports_color", return_value=True)
2074
def test_colorize_support(self, _):
2075
for color in self.colors:
2076
expected = color + self.text + Colors.ENDC
2077
actual = colorize(self.text, color)
2078
2079
self.assertEqual(actual, expected)
2080
2081
2082
class TestPrintSuccess(BaseStdout):
2083
def setUp(self):
2084
super(TestPrintSuccess, self).setUp()
2085
self.text = "house"
2086
2087
@mock.patch("updateHostsFile.supports_color", return_value=False)
2088
def test_print_success_no_support(self, _):
2089
print_success(self.text)
2090
2091
expected = self.text + "\n"
2092
actual = sys.stdout.getvalue()
2093
2094
self.assertEqual(actual, expected)
2095
2096
@mock.patch("updateHostsFile.supports_color", return_value=True)
2097
def test_print_success_support(self, _):
2098
print_success(self.text)
2099
2100
expected = Colors.SUCCESS + self.text + Colors.ENDC + "\n"
2101
actual = sys.stdout.getvalue()
2102
2103
self.assertEqual(actual, expected)
2104
2105
2106
class TestPrintFailure(BaseStdout):
2107
def setUp(self):
2108
super(TestPrintFailure, self).setUp()
2109
self.text = "house"
2110
2111
@mock.patch("updateHostsFile.supports_color", return_value=False)
2112
def test_print_failure_no_support(self, _):
2113
print_failure(self.text)
2114
2115
expected = self.text + "\n"
2116
actual = sys.stdout.getvalue()
2117
2118
self.assertEqual(actual, expected)
2119
2120
@mock.patch("updateHostsFile.supports_color", return_value=True)
2121
def test_print_failure_support(self, _):
2122
print_failure(self.text)
2123
2124
expected = Colors.FAIL + self.text + Colors.ENDC + "\n"
2125
actual = sys.stdout.getvalue()
2126
2127
self.assertEqual(actual, expected)
2128
2129
2130
# End Helper Functions
2131
2132
2133
if __name__ == "__main__":
2134
unittest.main()
2135
2136