Path: blob/main/shared/cryptolab/plot.py
483 views
unlisted
"""1Matplotlib visualization wrappers replacing SageMath plotting primitives.23Provides: Cayley graphs, cycle diagrams, subgroup lattices,4multiplication heatmaps, coset coloring, and graphics arrays.56All functions use matplotlib (Pyodide-compatible).7"""89import math10import matplotlib.pyplot as plt11import matplotlib.patches as mpatches12from matplotlib.patches import FancyArrowPatch13import numpy as np141516# ---------------------------------------------------------------------------17# Cayley graph (replaces DiGraph().plot(layout='circular'))18# ---------------------------------------------------------------------------1920def cayley_graph(n, generator, op='add', figsize=5, vertex_color='lightblue',21edge_color='steelblue', vertex_size=500, title=None, ax=None):22"""23Draw a Cayley graph for Z/nZ with the given generator.2425op='add': additive group, arrow from a to (a + generator) mod n26op='mul': multiplicative group, arrow from a to (a * generator) mod n27Only draws edges for the given elements.2829Returns the matplotlib Figure (or None if ax was provided).30"""31show = ax is None32if ax is None:33fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))34else:35fig = ax.figure3637if op == 'add':38elements = list(range(n))39else:40# For multiplicative, only use units41from .number_theory import gcd42elements = [a for a in range(1, n) if gcd(a, n) == 1]4344num = len(elements)45# Place on circle46angles = {elem: 2 * math.pi * i / num - math.pi / 2 for i, elem in enumerate(elements)}47positions = {elem: (math.cos(angles[elem]), math.sin(angles[elem])) for elem in elements}4849# Draw edges (arrows)50for a in elements:51if op == 'add':52target = (a + generator) % n53else:54target = (a * generator) % n55if target not in positions:56continue57px, py = positions[a]58cx, cy = positions[target]59dx, dy = cx - px, cy - py60dist = math.sqrt(dx * dx + dy * dy)61if dist > 0.01:62shrink = 0.1263ax.annotate(64'', xy=(cx - shrink * dx / dist, cy - shrink * dy / dist),65xytext=(px + shrink * dx / dist, py + shrink * dy / dist),66arrowprops=dict(arrowstyle='->', color=edge_color, lw=1.5),67zorder=268)6970# Draw nodes71radius = 0.0872for elem in elements:73ex, ey = positions[elem]74circ = plt.Circle((ex, ey), radius, facecolor=vertex_color,75edgecolor='gray', linewidth=1, zorder=3)76ax.add_patch(circ)77ax.text(ex, ey, str(elem), ha='center', va='center',78fontsize=10, zorder=4)7980ax.set_xlim(-1.5, 1.5)81ax.set_ylim(-1.5, 1.5)82ax.set_aspect('equal')83ax.axis('off')84if title:85ax.set_title(title, fontsize=11)8687if show:88plt.tight_layout()89plt.show()90return fig91return None929394# ---------------------------------------------------------------------------95# Cycle diagram (replaces Graphics() + circle + arrow + text in 01d)96# ---------------------------------------------------------------------------9798def cycle_diagram(n, elements, generator, op='mul', figsize=5,99highlight_color='steelblue', title=None, ax=None):100"""101Draw a cycle diagram showing the power/addition cycle of a generator.102103elements: the full list of group elements (placed on the circle)104generator: the element whose cycle to highlight105op: 'mul' for multiplicative, 'add' for additive106107Highlighted nodes are in the cycle, gray nodes are not reached.108Returns the Figure (or None if ax was provided).109"""110show = ax is None111if ax is None:112fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))113else:114fig = ax.figure115116num = len(elements)117angles = {int(elem): 2 * math.pi * i / num - math.pi / 2118for i, elem in enumerate(elements)}119positions = {int(elem): (math.cos(angles[int(elem)]), math.sin(angles[int(elem)]))120for elem in elements}121122# Compute the cycle123modulus = n124if op == 'mul':125cycle = []126val = 1127for _ in range(num):128val = (val * generator) % modulus129cycle.append(val)130if val == 1:131break132else:133cycle = []134val = 0135for _ in range(n):136val = (val + generator) % modulus137cycle.append(val)138if val == 0:139break140141cycle_set = set(cycle)142143# Draw all nodes (gray for non-cycle, highlighted for cycle)144radius = 0.1145for elem in elements:146e = int(elem)147ex, ey = positions[e]148if e in cycle_set:149fc = highlight_color150tc = 'white'151fw = 'bold'152else:153fc = 'lightgray'154tc = 'black'155fw = 'normal'156circ = plt.Circle((ex, ey), radius, facecolor=fc,157edgecolor='white' if e in cycle_set else 'gray',158linewidth=1.5, zorder=5)159ax.add_patch(circ)160ax.text(ex, ey, str(e), ha='center', va='center',161fontsize=10, fontweight=fw, color=tc, zorder=6)162163# Draw arrows along the cycle164identity = 1 if op == 'mul' else 0165prev_elem = identity166for curr_elem in cycle:167if prev_elem in positions and curr_elem in positions:168px, py = positions[prev_elem]169cx, cy = positions[curr_elem]170dx, dy = cx - px, cy - py171dist = math.sqrt(dx * dx + dy * dy)172if dist > 0.01:173shrink = 0.13174ax.annotate(175'', xy=(cx - shrink * dx / dist, cy - shrink * dy / dist),176xytext=(px + shrink * dx / dist, py + shrink * dy / dist),177arrowprops=dict(arrowstyle='->', color=highlight_color,178lw=1.5),179zorder=2180)181prev_elem = curr_elem182183ax.set_xlim(-1.5, 1.5)184ax.set_ylim(-1.5, 1.5)185ax.set_aspect('equal')186ax.axis('off')187if title:188ax.set_title(title, fontsize=10)189190if show:191plt.tight_layout()192plt.show()193return fig194return None195196197# ---------------------------------------------------------------------------198# Subgroup lattice (replaces Poset().plot())199# ---------------------------------------------------------------------------200201def subgroup_lattice(n, figsize=6):202"""203Draw the subgroup lattice of Z/nZ.204Each node is labeled with the generator and the subgroup size.205Lines connect subgroups where one contains the other (direct inclusion).206Returns the Figure.207"""208from .number_theory import divisors as get_divisors209210divs = get_divisors(n)211# Vertical position by subgroup size (log scale for better spacing)212y_pos = {d: math.log2(d) if d > 0 else 0 for d in divs}213max_y = max(y_pos.values()) if y_pos else 1214215# Group divisors by their y level for horizontal spacing216levels = {}217for d in divs:218y = y_pos[d]219if y not in levels:220levels[y] = []221levels[y].append(d)222223positions = {}224for y, ds in levels.items():225count = len(ds)226for i, d in enumerate(sorted(ds)):227x = (i - (count - 1) / 2) * 1.5228positions[d] = (x, y / max_y * 4 if max_y > 0 else 0)229230fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))231232# Draw edges (direct containment)233for d1 in divs:234for d2 in divs:235if d2 <= d1 or d2 % d1 != 0:236continue237# Check for direct edge: no d3 strictly between d1 and d2238if any(d1 < d3 < d2 and d2 % d3 == 0 and d3 % d1 == 0 for d3 in divs):239continue240x1, y1 = positions[d1]241x2, y2 = positions[d2]242ax.plot([x1, x2], [y1, y2], color='gray', linewidth=1, zorder=1)243244# Draw nodes245for d in divs:246x, y = positions[d]247gen = n // d248circ = plt.Circle((x, y), 0.3, facecolor='lightyellow',249edgecolor='black', linewidth=1.5, zorder=3)250ax.add_patch(circ)251ax.text(x, y + 0.07, f'<{gen}>', ha='center', va='center',252fontsize=9, fontweight='bold', zorder=4)253ax.text(x, y - 0.1, f'|{d}|', ha='center', va='center',254fontsize=8, color='gray', zorder=4)255256margin = 1.0257xs = [p[0] for p in positions.values()]258ys = [p[1] for p in positions.values()]259ax.set_xlim(min(xs) - margin, max(xs) + margin)260ax.set_ylim(min(ys) - margin, max(ys) + margin)261ax.set_aspect('equal')262ax.axis('off')263ax.set_title(f'Subgroup lattice of (Z/{n}Z, +)', fontsize=12)264265plt.tight_layout()266plt.show()267return fig268269270# ---------------------------------------------------------------------------271# Multiplication heatmap (replaces matrix_plot())272# ---------------------------------------------------------------------------273274def multiplication_heatmap(table, labels=None, cmap='viridis', figsize=5,275title=None):276"""277Draw a heatmap of a multiplication (or addition) table.278279table: 2D list or numpy array of values280labels: row/column labels (defaults to indices)281cmap: matplotlib colormap name282Returns the Figure.283"""284arr = np.array(table)285n = arr.shape[0]286if labels is None:287labels = list(range(n))288289fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))290im = ax.imshow(arr, cmap=cmap, aspect='equal', origin='upper')291292ax.set_xticks(range(n))293ax.set_xticklabels([str(l) for l in labels])294ax.set_yticks(range(n))295ax.set_yticklabels([str(l) for l in labels])296297ax.xaxis.set_ticks_position('top')298ax.xaxis.set_label_position('top')299300if title:301ax.set_title(title, fontsize=12, pad=15)302303plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)304plt.tight_layout()305plt.show()306return fig307308309# ---------------------------------------------------------------------------310# Coset coloring (replaces Graphics() + circle + text + parametric_plot)311# ---------------------------------------------------------------------------312313def coset_coloring(n, subgroup_elements, figsize=5, title=None,314colors=None):315"""316Draw elements of Z/nZ on a circle, colored by their coset membership.317318subgroup_elements: list of ints forming the subgroup H319Colors are assigned per coset. Returns the Figure.320"""321if colors is None:322colors = ['royalblue', 'orangered', 'forestgreen', 'mediumorchid',323'goldenrod', 'crimson', 'teal', 'slateblue']324325H = [int(h) % n for h in subgroup_elements]326H_set = set(H)327328# Assign cosets329covered = set()330cosets = []331for a in range(n):332if a in covered:333continue334coset = sorted(set((a + h) % n for h in H))335cosets.append((a, coset))336covered.update(coset)337338# Color mapping339element_color = {}340for idx, (rep, coset) in enumerate(cosets):341c = colors[idx % len(colors)]342for elem in coset:343element_color[elem] = c344345fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))346347# Draw faint outline circle348theta = np.linspace(0, 2 * math.pi, 100)349ax.plot(np.cos(theta), np.sin(theta), color='lightgray', linewidth=0.5, zorder=1)350351# Draw elements352radius = 0.1353for i in range(n):354angle = 2 * math.pi * i / n - math.pi / 2355cx, cy = math.cos(angle), math.sin(angle)356circ = plt.Circle((cx, cy), radius, facecolor=element_color[i],357edgecolor='white', linewidth=2, zorder=3)358ax.add_patch(circ)359ax.text(cx, cy, str(i), ha='center', va='center',360fontsize=11, fontweight='bold', color='white', zorder=4)361362ax.set_xlim(-1.5, 1.5)363ax.set_ylim(-1.5, 1.5)364ax.set_aspect('equal')365ax.axis('off')366if title:367ax.set_title(title, fontsize=12)368369plt.tight_layout()370plt.show()371return fig372373374# ---------------------------------------------------------------------------375# Graphics array (replaces SageMath's graphics_array())376# ---------------------------------------------------------------------------377378def graphics_array(plot_funcs, rows, cols, figsize=None):379"""380Arrange multiple plots in a grid.381382plot_funcs: list of callables, each taking an ax parameter.383Each function should call one of the plot functions above384with the ax= parameter, e.g.:385lambda ax: cayley_graph(6, 1, ax=ax)386387Returns the Figure.388"""389if figsize is None:390figsize = (4 * cols, 4 * rows)391392fig, axes = plt.subplots(rows, cols, figsize=figsize)393if rows == 1 and cols == 1:394axes = np.array([axes])395axes = np.atleast_2d(axes)396397for idx, func in enumerate(plot_funcs):398r, c = divmod(idx, cols)399if r < rows and c < cols:400func(axes[r, c])401402# Hide unused axes403for idx in range(len(plot_funcs), rows * cols):404r, c = divmod(idx, cols)405axes[r, c].axis('off')406407plt.tight_layout()408plt.show()409return fig410411412