Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
NVIDIA
GitHub Repository: NVIDIA/cuda-q-academic
Path: blob/main/qec101/Images/decoder/bp.py
1121 views
1
import numpy as np
2
from scipy.sparse import csr_matrix
3
import cudaq_qec as qec
4
import json
5
import time
6
7
# For fetching data
8
import requests
9
import bz2
10
import os
11
12
# Note: running this script will automatically download data if necessary.
13
14
### Helper functions ###
15
16
17
def parse_csr_mat(j, dims, mat_name):
18
"""
19
Parse a CSR-style matrix from a JSON file using SciPy's sparse matrix utilities.
20
"""
21
assert len(dims) == 2, "dims must be a tuple of two integers"
22
23
# Extract indptr and indices from the JSON.
24
indptr = np.array(j[f"{mat_name}_indptr"], dtype=int)
25
indices = np.array(j[f"{mat_name}_indices"], dtype=int)
26
27
# Check that the CSR structure is consistent.
28
assert len(indptr) == dims[0] + 1, "indptr length must equal dims[0] + 1"
29
assert np.all(
30
indices < dims[1]), "All column indices must be less than dims[1]"
31
32
# Create a data array of ones.
33
data = np.ones(indptr[-1], dtype=np.uint8)
34
35
# Build the CSR matrix and return it as a dense numpy array.
36
csr = csr_matrix((data, indices, indptr), shape=dims, dtype=np.uint8)
37
return csr.toarray()
38
39
40
def parse_H_csr(j, dims):
41
"""
42
Parse a CSR-style parity check matrix from an input file in JSON format"
43
"""
44
return parse_csr_mat(j, dims, "H")
45
46
47
def parse_obs_csr(j, dims):
48
"""
49
Parse a CSR-style observable matrix from an input file in JSON format"
50
"""
51
return parse_csr_mat(j, dims, "obs_mat")
52
53
54
### Main decoder loop ###
55
56
57
def run_decoder(filename, num_shots, run_as_batched, print_output=False, osd=0):
58
"""
59
Load a JSON file and decode "num_shots" syndromes.
60
"""
61
t_load_begin = time.time()
62
with open(filename, "r") as f:
63
j = json.load(f)
64
65
dims = j["shape"]
66
assert len(dims) == 2
67
68
# Read the Parity Check Matrix
69
H = parse_H_csr(j, dims)
70
syndrome_length, block_length = dims
71
t_load_end = time.time()
72
73
#print(f"{filename} parsed in {1e3 * (t_load_end-t_load_begin)} ms")
74
75
error_rate_vec = np.array(j["error_rate_vec"])
76
assert len(error_rate_vec) == block_length
77
obs_mat_dims = j["obs_mat_shape"]
78
obs_mat = parse_obs_csr(j, obs_mat_dims)
79
assert dims[1] == obs_mat_dims[0]
80
file_num_trials = j["num_trials"]
81
num_shots = min(num_shots, file_num_trials)
82
print(
83
f'Your JSON file has {file_num_trials} shots. Running {num_shots} now.')
84
85
# osd_method: 0=Off, 1=OSD-0, 2=Exhaustive, 3=Combination Sweep
86
osd_method = osd
87
88
# When osd_method is:
89
# 2) there are 2^osd_order additional error mechanisms checked.
90
# 3) there are an additional k + osd_order*(osd_order-1)/2 error
91
# mechanisms checked.
92
# Ref: https://arxiv.org/pdf/2005.07016
93
osd_order = 0
94
95
# Maximum number of BP iterations before attempting OSD (if necessary)
96
max_iter = 50
97
98
nv_dec_args = {
99
"max_iterations": max_iter,
100
"error_rate_vec": error_rate_vec,
101
"use_sparsity": True,
102
"use_osd": osd_method > 0,
103
"osd_order": osd_order,
104
"osd_method": osd_method
105
}
106
107
if run_as_batched:
108
# Perform BP processing for up to 1000 syndromes per batch. If there
109
# are more than 1000 syndromes, the decoder will chunk them up and
110
# process each batch sequentially under the hood.
111
nv_dec_args['bp_batch_size'] = min(1000, num_shots)
112
113
try:
114
nv_dec_gpu_and_cpu = qec.get_decoder("nv-qldpc-decoder", H,
115
**nv_dec_args)
116
except Exception as e:
117
print(
118
'The nv-qldpc-decoder is not available with your current CUDA-Q ' +
119
'QEC installation.')
120
exit(0)
121
decoding_time = 0
122
bp_converged_flags = []
123
num_logical_errors = 0
124
125
# Batched API
126
if run_as_batched:
127
syndrome_list = []
128
obs_truth_list = []
129
for i in range(num_shots):
130
syndrome = j["trials"][i]["syndrome_truth"]
131
obs_truth = j["trials"][i]["obs_truth"]
132
syndrome_list.append(syndrome)
133
obs_truth_list.append(obs_truth)
134
t0 = time.time()
135
results = nv_dec_gpu_and_cpu.decode_batch(syndrome_list)
136
t1 = time.time()
137
decoding_time += t1 - t0
138
for r, obs_truth in zip(results, obs_truth_list):
139
bp_converged_flags.append(r.converged)
140
dec_result = np.array(r.result, dtype=np.uint8)
141
142
# See if this prediction flipped the observable
143
predicted_observable = obs_mat.T @ dec_result % 2
144
if print_output == True:
145
print(f"predicted_observable: {predicted_observable}")
146
147
# See if the observable was actually flipped according to the truth
148
# data
149
150
actual_observable = np.array(obs_truth, dtype=np.uint8)
151
if print_output == True:
152
print(f"actual_observable: {actual_observable}")
153
154
if np.sum(predicted_observable != actual_observable) > 0:
155
num_logical_errors += 1
156
157
# Non-batched API
158
else:
159
for i in range(num_shots):
160
syndrome = j["trials"][i]["syndrome_truth"]
161
obs_truth = j["trials"][i]["obs_truth"]
162
163
t0 = time.time()
164
bp_converged, dec_result, *_ = nv_dec_gpu_and_cpu.decode(syndrome)
165
t1 = time.time()
166
trial_diff = t1 - t0
167
decoding_time += trial_diff
168
169
dec_result = np.array(dec_result, dtype=np.uint8)
170
bp_converged_flags.append(bp_converged)
171
172
# See if this prediction flipped the observable
173
predicted_observable = obs_mat.T @ dec_result % 2
174
if print_output == True:
175
print(f"predicted_observable: {predicted_observable}")
176
177
# See if the observable was actually flipped according to the truth
178
# data
179
actual_observable = np.array(obs_truth, dtype=np.uint8)
180
if print_output == True:
181
print(f"actual_observable: {actual_observable}")
182
183
if np.sum(predicted_observable != actual_observable) > 0:
184
num_logical_errors += 1
185
186
# Count how many shots the decoder failed to correct the errors
187
print(f"{num_logical_errors} logical errors in {num_shots} shots")
188
print(
189
f"Number of shots that converged with BP processing: {np.sum(np.array(bp_converged_flags))}"
190
)
191
print(
192
f"Average decoding time for {num_shots} shots was {1e3 * decoding_time / num_shots} ms per shot"
193
)
194
195
196
197
198