Path: blob/main/scripts/python_to_sage.py
483 views
unlisted
#!/usr/bin/env python31"""Replace Python stdlib calls with SageMath equivalents in Jupyter notebooks.23Targets:41. import time + time.time() -> walltime()52. import random + random.randint(a,b) -> randint(a,b)63. random.shuffle(L) -> shuffle(L) [available in SageMath]74. random.sample(L, k) -> sample(L, k) [available in SageMath]85. random.choice(L) -> L[randint(0, len(L)-1)]96. import math + math.X() -> X() for known builtins107. from itertools import combinations -> Combinations (SageMath)118. from itertools import product -> (keep, no clean SageMath equiv)129. import numpy as np -> remove (plots already converted)13"""1415import json16import re17import sys18from pathlib import Path192021# Track what we need from sage.misc.prandom per cell22PRANDOM_FUNCS = {'shuffle', 'sample', 'choice'}2324# math.X -> SageMath equivalent25MATH_REPLACEMENTS = {26'math.sqrt': 'sqrt',27'math.log2': 'log2',28'math.log': 'log',29'math.ceil': 'ceil',30'math.floor': 'floor',31'math.pi': 'pi',32'math.e': 'e',33'math.gcd': 'gcd',34'math.factorial': 'factorial',35'math.isqrt': 'isqrt',36'math.inf': 'Infinity',37}383940def clean_cell_source(source_lines: list[str]) -> tuple[list[str], int]:41"""Clean Python imports in a single cell's source lines.4243Returns (cleaned_lines, fix_count).44"""45fixes = 046cleaned = []4748# First pass: figure out what random functions are used (beyond randint)49full_source = ''.join(source_lines)50needs_prandom = set()51if 'random.shuffle(' in full_source:52needs_prandom.add('shuffle')53if 'random.sample(' in full_source:54needs_prandom.add('sample')5556for line in source_lines:57original = line5859# --- import time ---60if re.match(r'^import time\s*\\?\n?$', line.strip() if not line.endswith('\n') else line.rstrip('\n').strip()):61stripped = line.strip().rstrip('\\').rstrip()62if stripped == 'import time':63fixes += 164continue # remove the line6566# --- import random ---67stripped_clean = line.strip().rstrip('\n')68if stripped_clean == 'import random':69if needs_prandom:70# Replace with sage.misc.prandom import71indent = line[:len(line) - len(line.lstrip())]72funcs = ', '.join(sorted(needs_prandom))73line = f'{indent}from sage.misc.prandom import {funcs}\n'74fixes += 175cleaned.append(line)76continue77else:78fixes += 179continue # just remove it8081# --- import math ---82if stripped_clean == 'import math':83fixes += 184continue # remove8586# --- import numpy as np ---87if stripped_clean == 'import numpy as np' or stripped_clean == 'import numpy':88fixes += 189continue # remove9091# --- from itertools import combinations ---92m = re.match(r'^(\s*)from itertools import combinations\s*$', stripped_clean)93if m:94fixes += 195continue # remove, Combinations is a SageMath builtin9697# --- from collections import Counter ---98# Keep this, no good SageMath equivalent99100# --- from collections import defaultdict ---101# Keep this too102103# --- time.time() replacements ---104if 'time.time()' in line:105# Pattern: start = time.time() -> start = walltime()106line = line.replace('time.time()', 'walltime()')107if line != original:108fixes += 1109110# --- random.randint(a, b) -> randint(a, b) ---111if 'random.randint(' in line:112line = line.replace('random.randint(', 'randint(')113if line != original:114fixes += 1115original = line116117# --- random.shuffle(L) -> shuffle(L) ---118if 'random.shuffle(' in line:119line = line.replace('random.shuffle(', 'shuffle(')120if line != original:121fixes += 1122original = line123124# --- random.sample(L, k) -> sample(L, k) ---125if 'random.sample(' in line:126line = line.replace('random.sample(', 'sample(')127if line != original:128fixes += 1129original = line130131# --- random.choice(L) -> not replaced, rare ---132# (would need complex parsing to replace cleanly)133134# --- math.X() -> X() ---135for py_func, sage_func in MATH_REPLACEMENTS.items():136if py_func in line:137# Be careful with math.log2 vs math.log138# Replace longer patterns first (already ordered in dict with log2 before log)139line = line.replace(py_func, sage_func)140141if line != original:142fixes += 1143original = line144145# --- combinations( -> Combinations( ---146# Only replace if we removed the itertools import in this cell147if 'from itertools import combinations' in full_source:148if 'combinations(' in line and 'Combinations(' not in line:149# Don't replace inside strings or comments150line = re.sub(r'\bcombinations\(', 'Combinations(', line)151if line != original:152fixes += 1153original = line154155cleaned.append(line)156157return cleaned, fixes158159160def process_notebook(path: Path, dry_run: bool = False) -> int:161"""Process a single notebook. Returns number of fixes."""162with open(path) as f:163nb = json.load(f)164165total_fixes = 0166modified = False167168for cell in nb.get('cells', []):169if cell.get('cell_type') != 'code':170continue171172source = cell.get('source', [])173if not source:174continue175176cleaned, fixes = clean_cell_source(source)177if fixes > 0:178total_fixes += fixes179if not dry_run:180cell['source'] = cleaned181modified = True182183if modified and not dry_run:184with open(path, 'w') as f:185json.dump(nb, f, indent=1, ensure_ascii=False)186f.write('\n')187188return total_fixes189190191def main():192dry_run = '--dry-run' in sys.argv193repo = Path(__file__).parent.parent194195notebooks = sorted(repo.glob('**/*.ipynb'))196notebooks = [nb for nb in notebooks if '.ipynb_checkpoints' not in str(nb)]197198total = 0199changed_files = 0200201for nb_path in notebooks:202fixes = process_notebook(nb_path, dry_run=dry_run)203if fixes > 0:204rel = nb_path.relative_to(repo)205print(f' {rel}: {fixes} fixes')206total += fixes207changed_files += 1208209mode = 'DRY RUN' if dry_run else 'APPLIED'210print(f'\n{mode}: {total} fixes across {changed_files} files')211212213if __name__ == '__main__':214main()215216217