Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
duyuefeng0708
GitHub Repository: duyuefeng0708/Cryptography-From-First-Principle
Path: blob/main/scripts/python_to_sage.py
483 views
unlisted
1
#!/usr/bin/env python3
2
"""Replace Python stdlib calls with SageMath equivalents in Jupyter notebooks.
3
4
Targets:
5
1. import time + time.time() -> walltime()
6
2. import random + random.randint(a,b) -> randint(a,b)
7
3. random.shuffle(L) -> shuffle(L) [available in SageMath]
8
4. random.sample(L, k) -> sample(L, k) [available in SageMath]
9
5. random.choice(L) -> L[randint(0, len(L)-1)]
10
6. import math + math.X() -> X() for known builtins
11
7. from itertools import combinations -> Combinations (SageMath)
12
8. from itertools import product -> (keep, no clean SageMath equiv)
13
9. import numpy as np -> remove (plots already converted)
14
"""
15
16
import json
17
import re
18
import sys
19
from pathlib import Path
20
21
22
# Track what we need from sage.misc.prandom per cell
23
PRANDOM_FUNCS = {'shuffle', 'sample', 'choice'}
24
25
# math.X -> SageMath equivalent
26
MATH_REPLACEMENTS = {
27
'math.sqrt': 'sqrt',
28
'math.log2': 'log2',
29
'math.log': 'log',
30
'math.ceil': 'ceil',
31
'math.floor': 'floor',
32
'math.pi': 'pi',
33
'math.e': 'e',
34
'math.gcd': 'gcd',
35
'math.factorial': 'factorial',
36
'math.isqrt': 'isqrt',
37
'math.inf': 'Infinity',
38
}
39
40
41
def clean_cell_source(source_lines: list[str]) -> tuple[list[str], int]:
42
"""Clean Python imports in a single cell's source lines.
43
44
Returns (cleaned_lines, fix_count).
45
"""
46
fixes = 0
47
cleaned = []
48
49
# First pass: figure out what random functions are used (beyond randint)
50
full_source = ''.join(source_lines)
51
needs_prandom = set()
52
if 'random.shuffle(' in full_source:
53
needs_prandom.add('shuffle')
54
if 'random.sample(' in full_source:
55
needs_prandom.add('sample')
56
57
for line in source_lines:
58
original = line
59
60
# --- import time ---
61
if re.match(r'^import time\s*\\?\n?$', line.strip() if not line.endswith('\n') else line.rstrip('\n').strip()):
62
stripped = line.strip().rstrip('\\').rstrip()
63
if stripped == 'import time':
64
fixes += 1
65
continue # remove the line
66
67
# --- import random ---
68
stripped_clean = line.strip().rstrip('\n')
69
if stripped_clean == 'import random':
70
if needs_prandom:
71
# Replace with sage.misc.prandom import
72
indent = line[:len(line) - len(line.lstrip())]
73
funcs = ', '.join(sorted(needs_prandom))
74
line = f'{indent}from sage.misc.prandom import {funcs}\n'
75
fixes += 1
76
cleaned.append(line)
77
continue
78
else:
79
fixes += 1
80
continue # just remove it
81
82
# --- import math ---
83
if stripped_clean == 'import math':
84
fixes += 1
85
continue # remove
86
87
# --- import numpy as np ---
88
if stripped_clean == 'import numpy as np' or stripped_clean == 'import numpy':
89
fixes += 1
90
continue # remove
91
92
# --- from itertools import combinations ---
93
m = re.match(r'^(\s*)from itertools import combinations\s*$', stripped_clean)
94
if m:
95
fixes += 1
96
continue # remove, Combinations is a SageMath builtin
97
98
# --- from collections import Counter ---
99
# Keep this, no good SageMath equivalent
100
101
# --- from collections import defaultdict ---
102
# Keep this too
103
104
# --- time.time() replacements ---
105
if 'time.time()' in line:
106
# Pattern: start = time.time() -> start = walltime()
107
line = line.replace('time.time()', 'walltime()')
108
if line != original:
109
fixes += 1
110
111
# --- random.randint(a, b) -> randint(a, b) ---
112
if 'random.randint(' in line:
113
line = line.replace('random.randint(', 'randint(')
114
if line != original:
115
fixes += 1
116
original = line
117
118
# --- random.shuffle(L) -> shuffle(L) ---
119
if 'random.shuffle(' in line:
120
line = line.replace('random.shuffle(', 'shuffle(')
121
if line != original:
122
fixes += 1
123
original = line
124
125
# --- random.sample(L, k) -> sample(L, k) ---
126
if 'random.sample(' in line:
127
line = line.replace('random.sample(', 'sample(')
128
if line != original:
129
fixes += 1
130
original = line
131
132
# --- random.choice(L) -> not replaced, rare ---
133
# (would need complex parsing to replace cleanly)
134
135
# --- math.X() -> X() ---
136
for py_func, sage_func in MATH_REPLACEMENTS.items():
137
if py_func in line:
138
# Be careful with math.log2 vs math.log
139
# Replace longer patterns first (already ordered in dict with log2 before log)
140
line = line.replace(py_func, sage_func)
141
142
if line != original:
143
fixes += 1
144
original = line
145
146
# --- combinations( -> Combinations( ---
147
# Only replace if we removed the itertools import in this cell
148
if 'from itertools import combinations' in full_source:
149
if 'combinations(' in line and 'Combinations(' not in line:
150
# Don't replace inside strings or comments
151
line = re.sub(r'\bcombinations\(', 'Combinations(', line)
152
if line != original:
153
fixes += 1
154
original = line
155
156
cleaned.append(line)
157
158
return cleaned, fixes
159
160
161
def process_notebook(path: Path, dry_run: bool = False) -> int:
162
"""Process a single notebook. Returns number of fixes."""
163
with open(path) as f:
164
nb = json.load(f)
165
166
total_fixes = 0
167
modified = False
168
169
for cell in nb.get('cells', []):
170
if cell.get('cell_type') != 'code':
171
continue
172
173
source = cell.get('source', [])
174
if not source:
175
continue
176
177
cleaned, fixes = clean_cell_source(source)
178
if fixes > 0:
179
total_fixes += fixes
180
if not dry_run:
181
cell['source'] = cleaned
182
modified = True
183
184
if modified and not dry_run:
185
with open(path, 'w') as f:
186
json.dump(nb, f, indent=1, ensure_ascii=False)
187
f.write('\n')
188
189
return total_fixes
190
191
192
def main():
193
dry_run = '--dry-run' in sys.argv
194
repo = Path(__file__).parent.parent
195
196
notebooks = sorted(repo.glob('**/*.ipynb'))
197
notebooks = [nb for nb in notebooks if '.ipynb_checkpoints' not in str(nb)]
198
199
total = 0
200
changed_files = 0
201
202
for nb_path in notebooks:
203
fixes = process_notebook(nb_path, dry_run=dry_run)
204
if fixes > 0:
205
rel = nb_path.relative_to(repo)
206
print(f' {rel}: {fixes} fixes')
207
total += fixes
208
changed_files += 1
209
210
mode = 'DRY RUN' if dry_run else 'APPLIED'
211
print(f'\n{mode}: {total} fixes across {changed_files} files')
212
213
214
if __name__ == '__main__':
215
main()
216
217