Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
NVIDIA
GitHub Repository: NVIDIA/cuda-q-academic
Path: blob/main/qec101/Images/decoder/bpwidget.py
1121 views
1
# graph_visualizer.py
2
import networkx as nx
3
import matplotlib.pyplot as plt
4
import ipywidgets as widgets
5
from IPython.display import display, clear_output
6
7
def create_visualization():
8
# Graph setup
9
G = nx.DiGraph()
10
top_nodes = [str(i) for i in range(1, 7)]
11
bottom_nodes = list('ABCDEFG')
12
G.add_nodes_from(top_nodes, layer=0)
13
G.add_nodes_from(bottom_nodes, layer=1)
14
15
connections = [
16
('1', 'A'), ('1', 'B'),
17
('2', 'B'), ('2', 'C'),
18
('3', 'C'), ('3', 'D'),
19
('4', 'D'), ('4', 'E'),
20
('5', 'E'), ('5', 'F'),
21
('6', 'F'), ('6', 'G'),
22
]
23
G.add_edges_from(connections)
24
25
# Position setup
26
pos = {}
27
for i, node in enumerate(top_nodes):
28
pos[node] = (i - len(top_nodes)/2 + 0.5, 1)
29
for i, node in enumerate(bottom_nodes):
30
pos[node] = (i - len(bottom_nodes)/2 + 0.5, -1)
31
32
# Highlight configuration
33
highlight_dict = {
34
0: (['3'], []),
35
1: (['3'], ['C', 'D']),
36
2: (['2', '3', '4'], ['C', 'D']),
37
3: (['2', '3', '4'], ['B', 'C', 'D', 'E']),
38
4: (['1', '2', '3', '4', '5'], ['B', 'C', 'D', 'E']),
39
5: (['1', '2', '3', '4', '5'], ['A', 'B', 'C', 'D', 'E', 'F'])
40
}
41
42
# Update function
43
def update_graph(iteration):
44
clear_output(wait=True)
45
highlight_top, highlight_bottom = highlight_dict[iteration]
46
47
plt.figure(figsize=(12, 6))
48
plt.title("Propagation of Qubit 3's Initial Beliefs", fontsize=14, pad=20)
49
50
# Draw check qubits (top row)
51
nx.draw_networkx_nodes(
52
G, pos, nodelist=top_nodes,
53
node_color=['red' if n in highlight_top else 'purple' for n in top_nodes],
54
node_size=2000
55
)
56
57
# Draw data qubits (bottom row)
58
nx.draw_networkx_nodes(
59
G, pos, nodelist=bottom_nodes,
60
node_color=['red' if n in highlight_bottom else '#76B900' for n in bottom_nodes], # NVIDIA Green
61
node_size=2000
62
)
63
64
nx.draw_networkx_edges(G, pos, edgelist=G.edges(), arrowstyle='-|>', arrowsize=20)
65
nx.draw_networkx_labels(G, pos, font_size=16, font_weight='bold')
66
67
# Add row labels
68
plt.text(-3.5, 1, "Check Qubits", fontsize=12, weight='bold', ha='right')
69
plt.text(-3.5, -1, "Data Qubits", fontsize=12, weight='bold', ha='right')
70
71
plt.axis('off')
72
plt.show()
73
74
# Create and return widget
75
slider = widgets.IntSlider(
76
value=0,
77
min=0,
78
max=5,
79
step=1,
80
description='Iteration:',
81
continuous_update=False
82
)
83
84
return widgets.interactive(update_graph, iteration=slider)
85
86
87