Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/internal/wrap_try_accept.py
1191 views
1
from glob import glob
2
import nbformat
3
4
# this packages need to passed by user
5
INSTALLED_MODULES = {
6
"hashlib",
7
"argparse",
8
"dataclasses",
9
"pytest_forked",
10
"bleach",
11
"_testimportmultiple",
12
"imaplib",
13
"IPython",
14
"_pickle",
15
"widgetsnbextension",
16
"__future__",
17
"uuid",
18
"lzma",
19
"webbrowser",
20
"decimal",
21
"backcall",
22
"sysconfig",
23
"nntplib",
24
"sre_compile",
25
"site",
26
"asyncore",
27
"blackd",
28
"requests",
29
"string",
30
"fcntl",
31
"weakref",
32
"copyreg",
33
"resource",
34
"PIL",
35
"xdrlib",
36
"wheel",
37
"typing_extensions",
38
"_strptime",
39
"platform",
40
"six",
41
"threadpoolctl",
42
"_testmultiphase",
43
"codecs",
44
"ensurepip",
45
"_ssl",
46
"html",
47
"_json",
48
"sndhdr",
49
"_multibytecodec",
50
"nbclient",
51
"kiwisolver",
52
"graphviz",
53
"charset_normalizer",
54
"imp",
55
"_compat_pickle",
56
"doctest",
57
"colorsys",
58
"curses",
59
"_multiprocessing",
60
"psutil",
61
"fileinput",
62
"termios",
63
"argon2",
64
"contextvars",
65
"_bz2",
66
"logging",
67
"xmlrpc",
68
"mypy_extensions",
69
"pandocfilters",
70
"distutils",
71
"numbers",
72
"sre_parse",
73
"webencodings",
74
"_bootlocale",
75
"macpath",
76
"black",
77
"smtpd",
78
"traitlets",
79
"zipapp",
80
"ipython_genutils",
81
"gzip",
82
"keyword",
83
"ipywidgets",
84
"symbol",
85
"tomli",
86
"fastjsonschema",
87
"mmap",
88
"_dummy_thread",
89
"pyparsing",
90
"stringprep",
91
"modulefinder",
92
"binascii",
93
"_osx_support",
94
"gettext",
95
"pydoc",
96
"re",
97
"pipes",
98
"dis",
99
"operator",
100
"_markupbase",
101
"execnet",
102
"ftplib",
103
"wcwidth",
104
"_collections_abc",
105
"netrc",
106
"crypt",
107
"_sysconfigdata_x86_64_conda_cos7_linux_gnu",
108
"entrypoints",
109
"nbconvert",
110
"asynchat",
111
"test",
112
"warnings",
113
"_codecs_hk",
114
"send2trash",
115
"enum",
116
"threading",
117
"plistlib",
118
"concurrent",
119
"_sysconfigdata_powerpc64le_conda_cos7_linux_gnu",
120
"pydoc_data",
121
"xxlimited",
122
"_pytest",
123
"profile",
124
"blib2to3",
125
"typed_ast",
126
"datetime",
127
"ipaddress",
128
"posixpath",
129
"_testbuffer",
130
"opcode",
131
"pytest_timeout",
132
"_sysconfigdata_s390x_conda_cos7_linux_gnu",
133
"pathspec",
134
"cmd",
135
"tracemalloc",
136
"jupyterlab_pygments",
137
"parso",
138
"numpy",
139
"matplotlib",
140
"locale",
141
"pvectorc",
142
"aifc",
143
"jax",
144
"pylab",
145
"pytest",
146
"urllib",
147
"_pyio",
148
"os",
149
"telnetlib",
150
"tty",
151
"compileall",
152
"pyrsistent",
153
"_sysconfigdata_i686_conda_cos6_linux_gnu",
154
"_sysconfigdata_x86_64_conda_cos6_linux_gnu",
155
"_sha512",
156
"packaging",
157
"socket",
158
"wave",
159
"jsonschema",
160
"pickle",
161
"binhex",
162
"opt_einsum",
163
"_datetime",
164
"_sysconfigdata_m_linux_x86_64-linux-gnu",
165
"importlib_metadata",
166
"click",
167
"readline",
168
"statistics",
169
"_weakrefset",
170
"pstats",
171
"antigravity",
172
"_ctypes",
173
"flatbuffers",
174
"select",
175
"uu",
176
"syslog",
177
"cffi",
178
"pkgutil",
179
"xml",
180
"soupsieve",
181
"pygments",
182
"_lsprof",
183
"calendar",
184
"mistune",
185
"fontTools",
186
"ipykernel_launcher",
187
"signal",
188
"dummy_threading",
189
"attr",
190
"struct",
191
"wsgiref",
192
"selectors",
193
"_sysconfigdata_aarch64_conda_linux_gnu",
194
"pkg_resources",
195
"_ctypes_test",
196
"venv",
197
"cgi",
198
"cgitb",
199
"defusedxml",
200
"csv",
201
"email",
202
"tokenize",
203
"abc",
204
"2dd510b5c3364608e57a__mypyc",
205
"qtconsole",
206
"testbook",
207
"pickletools",
208
"_sitebuiltins",
209
"pprint",
210
"_sqlite3",
211
"inspect",
212
"socketserver",
213
"jedi",
214
"tornado",
215
"pycparser",
216
"timeit",
217
"pathlib",
218
"ssl",
219
"mailcap",
220
"http",
221
"random",
222
"cProfile",
223
"_sha1",
224
"pickleshare",
225
"difflib",
226
"idna",
227
"_sysconfigdata_x86_64_apple_darwin13_4_0",
228
"optparse",
229
"token",
230
"tempfile",
231
"_sysconfigdata_aarch64_conda_cos7_linux_gnu",
232
"shelve",
233
"functools",
234
"iniconfig",
235
"_py_abc",
236
"runpy",
237
"jaxlib",
238
"contextlib",
239
"ptyprocess",
240
"traceback",
241
"configparser",
242
"tabnanny",
243
"debugpy",
244
"bisect",
245
"urllib3",
246
"py_compile",
247
"_codecs_tw",
248
"prometheus_client",
249
"collections",
250
"pip",
251
"typing",
252
"symtable",
253
"_threading_local",
254
"fractions",
255
"glob",
256
"pyclbr",
257
"platformdirs",
258
"secrets",
259
"zipfile",
260
"_sha3",
261
"textwrap",
262
"reprlib",
263
"certifi",
264
"decorator",
265
"scipy",
266
"py",
267
"importlib_resources",
268
"array",
269
"ctypes",
270
"matplotlib_inline",
271
"ntpath",
272
"poplib",
273
"trace",
274
"_crypt",
275
"tarfile",
276
"types",
277
"xdist",
278
"_xxtestfuzz",
279
"dateutil",
280
"jinja2",
281
"genericpath",
282
"_hashlib",
283
"grp",
284
"formatter",
285
"spwd",
286
"notebook",
287
"nis",
288
"_codecs_cn",
289
"cycler",
290
"jupyter_client",
291
"cmath",
292
"_testcapi",
293
"getpass",
294
"_csv",
295
"base64",
296
"mimetypes",
297
"_black_version",
298
"lib2to3",
299
"qtpy",
300
"sunau",
301
"rlcompleter",
302
"_struct",
303
"setuptools",
304
"smtplib",
305
"queue",
306
"sqlite3",
307
"_sysconfigdata_s390x_conda_linux_gnu",
308
"filecmp",
309
"_socket",
310
"heapq",
311
"_cffi_backend",
312
"pexpect",
313
"tinycss2",
314
"parser",
315
"_opcode",
316
"tokenize_rt",
317
"_codecs_iso2022",
318
"jupyter_core",
319
"dbm",
320
"fnmatch",
321
"pyexpat",
322
"bs4",
323
"_distutils_hack",
324
"_lzma",
325
"pytz",
326
"getopt",
327
"_elementtree",
328
"_sysconfigdata_x86_64_conda_linux_gnu",
329
"_queue",
330
"_codecs_kr",
331
"sched",
332
"_blake2",
333
"jupyter_console",
334
"unicodedata",
335
"quopri",
336
"_asyncio",
337
"shlex",
338
"h5py",
339
"json",
340
"tkinter",
341
"stat",
342
"_curses_panel",
343
"nturl2path",
344
"asyncio",
345
"subprocess",
346
"get_installed_packegs",
347
"unittest",
348
"bdb",
349
"turtledemo",
350
"_codecs_jp",
351
"chunk",
352
"sre_constants",
353
"importlib",
354
"_contextvars",
355
"imghdr",
356
"zipp",
357
"_curses",
358
"_bisect",
359
"attrs",
360
"_heapq",
361
"ossaudiodev",
362
"ipykernel",
363
"pluggy",
364
"encodings",
365
"pdb",
366
"audioop",
367
"seaborn",
368
"_pyrsistent_version",
369
"codeop",
370
"jupyterlab_widgets",
371
"mailbox",
372
"math",
373
"_tkinter",
374
"bz2",
375
"prompt_toolkit",
376
"terminado",
377
"_compression",
378
"jupyter",
379
"pty",
380
"idlelib",
381
"joblib",
382
"hmac",
383
"_sha256",
384
"markupsafe",
385
"copy",
386
"turtle",
387
"_md5",
388
"io",
389
"cached_property",
390
"this",
391
"multiprocessing",
392
"pandas",
393
"zmq",
394
"ast",
395
"zlib",
396
"code",
397
"shutil",
398
"absl",
399
"_argon2_cffi_bindings",
400
"nbformat",
401
"_random",
402
"_pydecimal",
403
"nest_asyncio",
404
"sklearn",
405
"_posixsubprocess",
406
"linecache",
407
"_decimal",
408
}
409
410
411
def get_installed_modules(installed_packages=INSTALLED_MODULES):
412
# Special cases
413
special_modules = set(["mpl_toolkits", "itertools", "time", "sys", "d2l", "augmax"])
414
return special_modules.union(installed_packages)
415
416
417
def get_try_except_module(line):
418
line = line.rstrip()
419
import_kw = None
420
421
if line.startswith(" ") and line.lstrip().startswith("import"):
422
import_kw = "import "
423
elif line.startswith(" ") and line.lstrip().startswith("from") and "import" in line:
424
import_kw = "from "
425
426
if import_kw:
427
module = line.lstrip()[len(import_kw) :].split(" ")[0].split(".")[0]
428
return module
429
430
431
def get_simply_imported_module(line):
432
line = line.rstrip()
433
import_kw = None
434
435
if line.startswith("import "):
436
import_kw = "import "
437
elif line.startswith("from ") and "import" in line:
438
import_kw = "from "
439
440
if import_kw:
441
module = line[len(import_kw) :].split(" ")[0].split(".")[0]
442
return module
443
444
445
def wrap_line_with_try_accept(line, module):
446
447
transformed_modules = {
448
"PIL": "pillow",
449
"tensorflow_probability": "tensorflow-probability",
450
"sklearn": "scikit-learn",
451
"pl_bolts": "lightning-bolts",
452
"skimage": "scikit-image",
453
"cv2": "opencv-python",
454
"tensorflow_datasets": "tensorflow tensorflow_datasets",
455
"umap":"umap-learn"
456
}
457
f"""
458
check if import {module} is in given line: {line}
459
if present, then return {line} wrapped with try...except
460
"""
461
line = line.rstrip()
462
if module in transformed_modules:
463
module = transformed_modules[module]
464
try_except_line = f"try:\n {line}\nexcept ModuleNotFoundError:\n %pip install -qq {module}\n {line}"
465
return try_except_line
466
467
468
def wrap_try_accept_in_code(code):
469
lines = code.split("\n")
470
try_except_modules = set(map(get_try_except_module, lines))
471
present_modules = try_except_modules.union(get_installed_modules(installed_packages=INSTALLED_MODULES))
472
473
for line_no, line in enumerate(lines):
474
module = get_simply_imported_module(line)
475
if module and module not in present_modules:
476
updated_line = wrap_line_with_try_accept(line, module)
477
print(updated_line)
478
lines[line_no] = updated_line
479
present_modules.add(module)
480
code = "\n".join(lines)
481
return code
482
483
484
def remove_superimport(code):
485
lines = code.split("\n")
486
updated_code = "\n".join(list(map(lambda line: line.replace("import superimport", ""), lines)))
487
return updated_code
488
489
def remove_pyprobml(code):
490
code = code.replace("from pyprobml_utils import save_fig", "from probml_utils import savefig")
491
code = code.replace("%pip install pyprobml_utils", "%pip install git+https://github.com/probml/probml-utils.git")
492
code = code.replace("import pyprobml_utils as pml", "import probml_utils as pml")
493
return code
494
495
def apply_fun_to_notebook(notebook, fun):
496
"""
497
fun should take one argument: code
498
"""
499
nb = nbformat.read(notebook, as_version=4)
500
for cell in nb.cells:
501
code = cell["source"]
502
updated_code = fun(code)
503
if updated_code != code:
504
cell["source"] = updated_code
505
nbformat.write(nb, notebook)
506
507
508
if __name__ == "__main__":
509
# Load notebooks
510
notebooks1 = glob("notebooks/book1/*/*.ipynb")
511
notebooks2 = glob("notebooks/book2/*/*.ipynb")
512
notebooks = notebooks1 + notebooks2
513
514
#get IGNORE_LIST of notebooks
515
IGNORE_LIST = []
516
with open("internal/ignored_notebooks.txt") as fp:
517
ignored_notebooks = fp.readlines()
518
for nb in ignored_notebooks:
519
IGNORE_LIST.append(nb.strip().split("/")[-1])
520
521
def in_ignore_list(nb_path):
522
nb_name = nb_path.split("/")[-1]
523
return nb_name in IGNORE_LIST
524
525
print(f"{len(IGNORE_LIST)} notebooks ignored")
526
notebooks = list(filter(lambda nb: not in_ignore_list(nb), notebooks))
527
528
for notebook in notebooks:
529
print(f"******* {notebook} *******")
530
apply_fun_to_notebook(notebook, remove_superimport)
531
apply_fun_to_notebook(notebook, wrap_try_accept_in_code)
532
533