Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/tests/test_imports.py
1191 views
1
import pytest
2
from glob import glob
3
import pkgutil
4
import nbformat
5
6
# Global variables
7
TIMEOUT = 120
8
9
# Load notebooks
10
notebooks1 = glob("notebooks/book1/*/*.ipynb")
11
notebooks2 = glob("notebooks/book2/*/*.ipynb")
12
notebooks = notebooks1 + notebooks2
13
14
#get IGNORE_LIST of notebooks
15
IGNORE_LIST = []
16
with open("internal/ignored_notebooks.txt") as fp:
17
ignored_notebooks = fp.readlines()
18
for nb in ignored_notebooks:
19
IGNORE_LIST.append(nb.strip().split("/")[-1])
20
21
def in_ignore_list(nb_path):
22
nb_name = nb_path.split("/")[-1]
23
return nb_name in IGNORE_LIST
24
25
notebooks = list(filter(lambda nb: not in_ignore_list(nb), notebooks))
26
27
# load installed modules
28
all_modules = set(map(lambda x: x[1], list(pkgutil.iter_modules())))
29
30
# Special cases
31
special_modules = set(["mpl_toolkits", "itertools", "time", "sys", "d2l", "augmax", "google"])
32
all_modules = all_modules.union(special_modules)
33
34
35
def get_simply_imported_module(line):
36
line = line.rstrip()
37
import_kw = None
38
39
if line.startswith("import "):
40
import_kw = "import "
41
elif line.startswith("from ") and "import" in line:
42
import_kw = "from "
43
44
if import_kw:
45
module = line[len(import_kw) :].split(" ")[0].split(".")[0]
46
return module
47
48
49
def get_try_except_module(line):
50
line = line.rstrip()
51
import_kw = None
52
53
if line.startswith(" ") and line.lstrip().startswith("import"):
54
import_kw = "import "
55
elif line.startswith(" ") and line.lstrip().startswith("from") and "import" in line:
56
import_kw = "from "
57
58
if import_kw:
59
module = line.lstrip()[len(import_kw) :].split(" ")[0].split(".")[0]
60
return module
61
62
63
# Parameterize notebooks
64
@pytest.mark.parametrize("notebook", notebooks)
65
def test_run_notebooks(notebook):
66
"""
67
Test notebooks
68
"""
69
nb = nbformat.read(notebook, as_version=4)
70
lines = "\n".join(map(lambda x: x["source"], nb.cells)).split("\n")
71
try_except_modules = set(map(get_try_except_module, lines))
72
modules = set(filter(None, map(get_simply_imported_module, lines)))
73
missing_modules = modules - all_modules - try_except_modules
74
assert len(missing_modules) == 0, f"Missing {missing_modules} in {notebook}"
75
76
77
if __name__ == "__main__":
78
for notebook in notebooks:
79
test_run_notebooks(notebook)
80
81