Path: blob/main/qec101/Images/decoder/bpwidget.py
1121 views
# graph_visualizer.py1import networkx as nx2import matplotlib.pyplot as plt3import ipywidgets as widgets4from IPython.display import display, clear_output56def create_visualization():7# Graph setup8G = nx.DiGraph()9top_nodes = [str(i) for i in range(1, 7)]10bottom_nodes = list('ABCDEFG')11G.add_nodes_from(top_nodes, layer=0)12G.add_nodes_from(bottom_nodes, layer=1)1314connections = [15('1', 'A'), ('1', 'B'),16('2', 'B'), ('2', 'C'),17('3', 'C'), ('3', 'D'),18('4', 'D'), ('4', 'E'),19('5', 'E'), ('5', 'F'),20('6', 'F'), ('6', 'G'),21]22G.add_edges_from(connections)2324# Position setup25pos = {}26for i, node in enumerate(top_nodes):27pos[node] = (i - len(top_nodes)/2 + 0.5, 1)28for i, node in enumerate(bottom_nodes):29pos[node] = (i - len(bottom_nodes)/2 + 0.5, -1)3031# Highlight configuration32highlight_dict = {330: (['3'], []),341: (['3'], ['C', 'D']),352: (['2', '3', '4'], ['C', 'D']),363: (['2', '3', '4'], ['B', 'C', 'D', 'E']),374: (['1', '2', '3', '4', '5'], ['B', 'C', 'D', 'E']),385: (['1', '2', '3', '4', '5'], ['A', 'B', 'C', 'D', 'E', 'F'])39}4041# Update function42def update_graph(iteration):43clear_output(wait=True)44highlight_top, highlight_bottom = highlight_dict[iteration]4546plt.figure(figsize=(12, 6))47plt.title("Propagation of Qubit 3's Initial Beliefs", fontsize=14, pad=20)4849# Draw check qubits (top row)50nx.draw_networkx_nodes(51G, pos, nodelist=top_nodes,52node_color=['red' if n in highlight_top else 'purple' for n in top_nodes],53node_size=200054)5556# Draw data qubits (bottom row)57nx.draw_networkx_nodes(58G, pos, nodelist=bottom_nodes,59node_color=['red' if n in highlight_bottom else '#76B900' for n in bottom_nodes], # NVIDIA Green60node_size=200061)6263nx.draw_networkx_edges(G, pos, edgelist=G.edges(), arrowstyle='-|>', arrowsize=20)64nx.draw_networkx_labels(G, pos, font_size=16, font_weight='bold')6566# Add row labels67plt.text(-3.5, 1, "Check Qubits", fontsize=12, weight='bold', ha='right')68plt.text(-3.5, -1, "Data Qubits", fontsize=12, weight='bold', ha='right')6970plt.axis('off')71plt.show()7273# Create and return widget74slider = widgets.IntSlider(75value=0,76min=0,77max=5,78step=1,79description='Iteration:',80continuous_update=False81)8283return widgets.interactive(update_graph, iteration=slider)84858687