Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
duyuefeng0708
GitHub Repository: duyuefeng0708/Cryptography-From-First-Principle
Path: blob/main/shared/cryptolab/plot.py
483 views
unlisted
1
"""
2
Matplotlib visualization wrappers replacing SageMath plotting primitives.
3
4
Provides: Cayley graphs, cycle diagrams, subgroup lattices,
5
multiplication heatmaps, coset coloring, and graphics arrays.
6
7
All functions use matplotlib (Pyodide-compatible).
8
"""
9
10
import math
11
import matplotlib.pyplot as plt
12
import matplotlib.patches as mpatches
13
from matplotlib.patches import FancyArrowPatch
14
import numpy as np
15
16
17
# ---------------------------------------------------------------------------
18
# Cayley graph (replaces DiGraph().plot(layout='circular'))
19
# ---------------------------------------------------------------------------
20
21
def cayley_graph(n, generator, op='add', figsize=5, vertex_color='lightblue',
22
edge_color='steelblue', vertex_size=500, title=None, ax=None):
23
"""
24
Draw a Cayley graph for Z/nZ with the given generator.
25
26
op='add': additive group, arrow from a to (a + generator) mod n
27
op='mul': multiplicative group, arrow from a to (a * generator) mod n
28
Only draws edges for the given elements.
29
30
Returns the matplotlib Figure (or None if ax was provided).
31
"""
32
show = ax is None
33
if ax is None:
34
fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
35
else:
36
fig = ax.figure
37
38
if op == 'add':
39
elements = list(range(n))
40
else:
41
# For multiplicative, only use units
42
from .number_theory import gcd
43
elements = [a for a in range(1, n) if gcd(a, n) == 1]
44
45
num = len(elements)
46
# Place on circle
47
angles = {elem: 2 * math.pi * i / num - math.pi / 2 for i, elem in enumerate(elements)}
48
positions = {elem: (math.cos(angles[elem]), math.sin(angles[elem])) for elem in elements}
49
50
# Draw edges (arrows)
51
for a in elements:
52
if op == 'add':
53
target = (a + generator) % n
54
else:
55
target = (a * generator) % n
56
if target not in positions:
57
continue
58
px, py = positions[a]
59
cx, cy = positions[target]
60
dx, dy = cx - px, cy - py
61
dist = math.sqrt(dx * dx + dy * dy)
62
if dist > 0.01:
63
shrink = 0.12
64
ax.annotate(
65
'', xy=(cx - shrink * dx / dist, cy - shrink * dy / dist),
66
xytext=(px + shrink * dx / dist, py + shrink * dy / dist),
67
arrowprops=dict(arrowstyle='->', color=edge_color, lw=1.5),
68
zorder=2
69
)
70
71
# Draw nodes
72
radius = 0.08
73
for elem in elements:
74
ex, ey = positions[elem]
75
circ = plt.Circle((ex, ey), radius, facecolor=vertex_color,
76
edgecolor='gray', linewidth=1, zorder=3)
77
ax.add_patch(circ)
78
ax.text(ex, ey, str(elem), ha='center', va='center',
79
fontsize=10, zorder=4)
80
81
ax.set_xlim(-1.5, 1.5)
82
ax.set_ylim(-1.5, 1.5)
83
ax.set_aspect('equal')
84
ax.axis('off')
85
if title:
86
ax.set_title(title, fontsize=11)
87
88
if show:
89
plt.tight_layout()
90
plt.show()
91
return fig
92
return None
93
94
95
# ---------------------------------------------------------------------------
96
# Cycle diagram (replaces Graphics() + circle + arrow + text in 01d)
97
# ---------------------------------------------------------------------------
98
99
def cycle_diagram(n, elements, generator, op='mul', figsize=5,
100
highlight_color='steelblue', title=None, ax=None):
101
"""
102
Draw a cycle diagram showing the power/addition cycle of a generator.
103
104
elements: the full list of group elements (placed on the circle)
105
generator: the element whose cycle to highlight
106
op: 'mul' for multiplicative, 'add' for additive
107
108
Highlighted nodes are in the cycle, gray nodes are not reached.
109
Returns the Figure (or None if ax was provided).
110
"""
111
show = ax is None
112
if ax is None:
113
fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
114
else:
115
fig = ax.figure
116
117
num = len(elements)
118
angles = {int(elem): 2 * math.pi * i / num - math.pi / 2
119
for i, elem in enumerate(elements)}
120
positions = {int(elem): (math.cos(angles[int(elem)]), math.sin(angles[int(elem)]))
121
for elem in elements}
122
123
# Compute the cycle
124
modulus = n
125
if op == 'mul':
126
cycle = []
127
val = 1
128
for _ in range(num):
129
val = (val * generator) % modulus
130
cycle.append(val)
131
if val == 1:
132
break
133
else:
134
cycle = []
135
val = 0
136
for _ in range(n):
137
val = (val + generator) % modulus
138
cycle.append(val)
139
if val == 0:
140
break
141
142
cycle_set = set(cycle)
143
144
# Draw all nodes (gray for non-cycle, highlighted for cycle)
145
radius = 0.1
146
for elem in elements:
147
e = int(elem)
148
ex, ey = positions[e]
149
if e in cycle_set:
150
fc = highlight_color
151
tc = 'white'
152
fw = 'bold'
153
else:
154
fc = 'lightgray'
155
tc = 'black'
156
fw = 'normal'
157
circ = plt.Circle((ex, ey), radius, facecolor=fc,
158
edgecolor='white' if e in cycle_set else 'gray',
159
linewidth=1.5, zorder=5)
160
ax.add_patch(circ)
161
ax.text(ex, ey, str(e), ha='center', va='center',
162
fontsize=10, fontweight=fw, color=tc, zorder=6)
163
164
# Draw arrows along the cycle
165
identity = 1 if op == 'mul' else 0
166
prev_elem = identity
167
for curr_elem in cycle:
168
if prev_elem in positions and curr_elem in positions:
169
px, py = positions[prev_elem]
170
cx, cy = positions[curr_elem]
171
dx, dy = cx - px, cy - py
172
dist = math.sqrt(dx * dx + dy * dy)
173
if dist > 0.01:
174
shrink = 0.13
175
ax.annotate(
176
'', xy=(cx - shrink * dx / dist, cy - shrink * dy / dist),
177
xytext=(px + shrink * dx / dist, py + shrink * dy / dist),
178
arrowprops=dict(arrowstyle='->', color=highlight_color,
179
lw=1.5),
180
zorder=2
181
)
182
prev_elem = curr_elem
183
184
ax.set_xlim(-1.5, 1.5)
185
ax.set_ylim(-1.5, 1.5)
186
ax.set_aspect('equal')
187
ax.axis('off')
188
if title:
189
ax.set_title(title, fontsize=10)
190
191
if show:
192
plt.tight_layout()
193
plt.show()
194
return fig
195
return None
196
197
198
# ---------------------------------------------------------------------------
199
# Subgroup lattice (replaces Poset().plot())
200
# ---------------------------------------------------------------------------
201
202
def subgroup_lattice(n, figsize=6):
203
"""
204
Draw the subgroup lattice of Z/nZ.
205
Each node is labeled with the generator and the subgroup size.
206
Lines connect subgroups where one contains the other (direct inclusion).
207
Returns the Figure.
208
"""
209
from .number_theory import divisors as get_divisors
210
211
divs = get_divisors(n)
212
# Vertical position by subgroup size (log scale for better spacing)
213
y_pos = {d: math.log2(d) if d > 0 else 0 for d in divs}
214
max_y = max(y_pos.values()) if y_pos else 1
215
216
# Group divisors by their y level for horizontal spacing
217
levels = {}
218
for d in divs:
219
y = y_pos[d]
220
if y not in levels:
221
levels[y] = []
222
levels[y].append(d)
223
224
positions = {}
225
for y, ds in levels.items():
226
count = len(ds)
227
for i, d in enumerate(sorted(ds)):
228
x = (i - (count - 1) / 2) * 1.5
229
positions[d] = (x, y / max_y * 4 if max_y > 0 else 0)
230
231
fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
232
233
# Draw edges (direct containment)
234
for d1 in divs:
235
for d2 in divs:
236
if d2 <= d1 or d2 % d1 != 0:
237
continue
238
# Check for direct edge: no d3 strictly between d1 and d2
239
if any(d1 < d3 < d2 and d2 % d3 == 0 and d3 % d1 == 0 for d3 in divs):
240
continue
241
x1, y1 = positions[d1]
242
x2, y2 = positions[d2]
243
ax.plot([x1, x2], [y1, y2], color='gray', linewidth=1, zorder=1)
244
245
# Draw nodes
246
for d in divs:
247
x, y = positions[d]
248
gen = n // d
249
circ = plt.Circle((x, y), 0.3, facecolor='lightyellow',
250
edgecolor='black', linewidth=1.5, zorder=3)
251
ax.add_patch(circ)
252
ax.text(x, y + 0.07, f'<{gen}>', ha='center', va='center',
253
fontsize=9, fontweight='bold', zorder=4)
254
ax.text(x, y - 0.1, f'|{d}|', ha='center', va='center',
255
fontsize=8, color='gray', zorder=4)
256
257
margin = 1.0
258
xs = [p[0] for p in positions.values()]
259
ys = [p[1] for p in positions.values()]
260
ax.set_xlim(min(xs) - margin, max(xs) + margin)
261
ax.set_ylim(min(ys) - margin, max(ys) + margin)
262
ax.set_aspect('equal')
263
ax.axis('off')
264
ax.set_title(f'Subgroup lattice of (Z/{n}Z, +)', fontsize=12)
265
266
plt.tight_layout()
267
plt.show()
268
return fig
269
270
271
# ---------------------------------------------------------------------------
272
# Multiplication heatmap (replaces matrix_plot())
273
# ---------------------------------------------------------------------------
274
275
def multiplication_heatmap(table, labels=None, cmap='viridis', figsize=5,
276
title=None):
277
"""
278
Draw a heatmap of a multiplication (or addition) table.
279
280
table: 2D list or numpy array of values
281
labels: row/column labels (defaults to indices)
282
cmap: matplotlib colormap name
283
Returns the Figure.
284
"""
285
arr = np.array(table)
286
n = arr.shape[0]
287
if labels is None:
288
labels = list(range(n))
289
290
fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
291
im = ax.imshow(arr, cmap=cmap, aspect='equal', origin='upper')
292
293
ax.set_xticks(range(n))
294
ax.set_xticklabels([str(l) for l in labels])
295
ax.set_yticks(range(n))
296
ax.set_yticklabels([str(l) for l in labels])
297
298
ax.xaxis.set_ticks_position('top')
299
ax.xaxis.set_label_position('top')
300
301
if title:
302
ax.set_title(title, fontsize=12, pad=15)
303
304
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
305
plt.tight_layout()
306
plt.show()
307
return fig
308
309
310
# ---------------------------------------------------------------------------
311
# Coset coloring (replaces Graphics() + circle + text + parametric_plot)
312
# ---------------------------------------------------------------------------
313
314
def coset_coloring(n, subgroup_elements, figsize=5, title=None,
315
colors=None):
316
"""
317
Draw elements of Z/nZ on a circle, colored by their coset membership.
318
319
subgroup_elements: list of ints forming the subgroup H
320
Colors are assigned per coset. Returns the Figure.
321
"""
322
if colors is None:
323
colors = ['royalblue', 'orangered', 'forestgreen', 'mediumorchid',
324
'goldenrod', 'crimson', 'teal', 'slateblue']
325
326
H = [int(h) % n for h in subgroup_elements]
327
H_set = set(H)
328
329
# Assign cosets
330
covered = set()
331
cosets = []
332
for a in range(n):
333
if a in covered:
334
continue
335
coset = sorted(set((a + h) % n for h in H))
336
cosets.append((a, coset))
337
covered.update(coset)
338
339
# Color mapping
340
element_color = {}
341
for idx, (rep, coset) in enumerate(cosets):
342
c = colors[idx % len(colors)]
343
for elem in coset:
344
element_color[elem] = c
345
346
fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
347
348
# Draw faint outline circle
349
theta = np.linspace(0, 2 * math.pi, 100)
350
ax.plot(np.cos(theta), np.sin(theta), color='lightgray', linewidth=0.5, zorder=1)
351
352
# Draw elements
353
radius = 0.1
354
for i in range(n):
355
angle = 2 * math.pi * i / n - math.pi / 2
356
cx, cy = math.cos(angle), math.sin(angle)
357
circ = plt.Circle((cx, cy), radius, facecolor=element_color[i],
358
edgecolor='white', linewidth=2, zorder=3)
359
ax.add_patch(circ)
360
ax.text(cx, cy, str(i), ha='center', va='center',
361
fontsize=11, fontweight='bold', color='white', zorder=4)
362
363
ax.set_xlim(-1.5, 1.5)
364
ax.set_ylim(-1.5, 1.5)
365
ax.set_aspect('equal')
366
ax.axis('off')
367
if title:
368
ax.set_title(title, fontsize=12)
369
370
plt.tight_layout()
371
plt.show()
372
return fig
373
374
375
# ---------------------------------------------------------------------------
376
# Graphics array (replaces SageMath's graphics_array())
377
# ---------------------------------------------------------------------------
378
379
def graphics_array(plot_funcs, rows, cols, figsize=None):
380
"""
381
Arrange multiple plots in a grid.
382
383
plot_funcs: list of callables, each taking an ax parameter.
384
Each function should call one of the plot functions above
385
with the ax= parameter, e.g.:
386
lambda ax: cayley_graph(6, 1, ax=ax)
387
388
Returns the Figure.
389
"""
390
if figsize is None:
391
figsize = (4 * cols, 4 * rows)
392
393
fig, axes = plt.subplots(rows, cols, figsize=figsize)
394
if rows == 1 and cols == 1:
395
axes = np.array([axes])
396
axes = np.atleast_2d(axes)
397
398
for idx, func in enumerate(plot_funcs):
399
r, c = divmod(idx, cols)
400
if r < rows and c < cols:
401
func(axes[r, c])
402
403
# Hide unused axes
404
for idx in range(len(plot_funcs), rows * cols):
405
r, c = divmod(idx, cols)
406
axes[r, c].axis('off')
407
408
plt.tight_layout()
409
plt.show()
410
return fig
411
412