Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
NVIDIA
GitHub Repository: NVIDIA/cuda-q-academic
Path: blob/main/chemistry-simulations/aux_files/vqe_and_gqe/run_gqe_h2.py
1166 views
1
import argparse
2
import cudaq
3
import os
4
import torch
5
import json
6
import matplotlib.pyplot as plt
7
import numpy as np
8
import cudaq_solvers as solvers
9
from cudaq import spin
10
from lightning.fabric.loggers import CSVLogger
11
from cudaq_solvers.gqe_algorithm.gqe import get_default_config
12
from typing import List
13
14
# ---------------------------------------------------------------------------
15
# Argument Parsing
16
# ---------------------------------------------------------------------------
17
parser = argparse.ArgumentParser(description="Run GQE for LiH/H2.")
18
parser.add_argument('--mpi', action='store_true', help='Enable MPI distribution.')
19
parser.add_argument('--max_iters', type=int, default=75, help='Maximum number of GQE iterations.')
20
parser.add_argument('--ngates', type=int, default=40, help='Number of gates (ansatz depth).')
21
parser.add_argument('--num_samples', type=int, default=10, help='Number of samples (population size).')
22
parser.add_argument('--temperature', type=float, default=5.0, help='Temperature for sampling.')
23
parser.add_argument('--lr', type=float, default=1e-7, help='Learning Rate.')
24
parser.add_argument('--output_file', type=str, default='gqe_convergence.png', help='Filename to save the convergence plot.')
25
26
args = parser.parse_args()
27
28
if args.mpi:
29
try:
30
cudaq.set_target('nvidia', option='mqpu')
31
cudaq.mpi.initialize()
32
except RuntimeError:
33
print('Warning: NVIDIA GPUs or MPI not available. Skipping...')
34
exit(0)
35
else:
36
try:
37
cudaq.set_target('nvidia', option='fp64')
38
except RuntimeError:
39
cudaq.set_target('qpp-cpu')
40
41
# Deterministic setup
42
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
43
torch.manual_seed(3047)
44
torch.use_deterministic_algorithms(True)
45
torch.backends.cudnn.deterministic = True
46
torch.backends.cudnn.benchmark = False
47
48
# ---------------------------------------------------------------------------
49
# 1. Update Geometry and Get FCI
50
# ---------------------------------------------------------------------------
51
geometry = [('H', (0., 0., 0.)), ('H', (0., 0., 0.74))]
52
molecule = solvers.create_molecule(geometry, '6-31g', 0, 0, nele_cas=2, norb_cas=3, casci=True)
53
spin_ham = molecule.hamiltonian
54
n_qubits = molecule.n_orbitals * 2
55
n_electrons = molecule.n_electrons
56
57
# Extract FCI Energy (R-CASCI)
58
energies = molecule.energies
59
fci_energy = energies['R-CASCI'] if 'R-CASCI' in energies else -0.9886377190903441
60
61
# Only print preamble on rank 0
62
if not args.mpi or cudaq.mpi.rank() == 0:
63
print(f"Full CI Energy (R-CASCI): {fci_energy}")
64
print(f"Configuration: max_iters={args.max_iters}, ngates={args.ngates}, "
65
f"num_samples={args.num_samples}, temperature={args.temperature}")
66
67
# ---------------------------------------------------------------------------
68
# 2. Build Operator Pool
69
# ---------------------------------------------------------------------------
70
71
def get_identity(n_qubits: int) -> cudaq.SpinOperator:
72
In = cudaq.spin.i(0)
73
for q in range(1, n_qubits):
74
In = In * cudaq.spin.i(q)
75
return 1.0 * cudaq.SpinOperator(In)
76
77
def get_gqe_pauli_pool(num_qubits: int, num_electrons: int, params: List[float]) -> List[cudaq.SpinOperator]:
78
uccsd_operators = solvers.get_operator_pool("uccsd", num_qubits=num_qubits, num_electrons=num_electrons)
79
pool = [get_identity(num_qubits)]
80
81
individual_terms = []
82
for op in uccsd_operators:
83
for term in op:
84
pauli_word = term.get_pauli_word(num_qubits)
85
pauli_op = None
86
for qubit_idx, pauli_char in enumerate(pauli_word):
87
if pauli_char == 'I': gate = spin.i(qubit_idx)
88
elif pauli_char == 'X': gate = spin.x(qubit_idx)
89
elif pauli_char == 'Y': gate = spin.y(qubit_idx)
90
elif pauli_char == 'Z': gate = spin.z(qubit_idx)
91
else: continue
92
pauli_op = gate if pauli_op is None else pauli_op * gate
93
94
if pauli_op is not None:
95
individual_terms.append(cudaq.SpinOperator(pauli_op))
96
97
for term_op in individual_terms:
98
for param in params:
99
pool.append(param * term_op)
100
return pool
101
102
params = [0.003125, -0.003125, 0.00625, -0.00625, 0.0125, -0.0125,
103
0.025, -0.025, 0.05, -0.05, 0.1, -0.1]
104
op_pool = get_gqe_pauli_pool(n_qubits, n_electrons, params)
105
106
def term_coefficients(op): return [term.evaluate_coefficient() for term in op]
107
def term_words(op): return [term.get_pauli_word(n_qubits) for term in op]
108
109
@cudaq.kernel
110
def kernel(n_qubits: int, n_electrons: int, coeffs: list[float], words: list[cudaq.pauli_word]):
111
q = cudaq.qvector(n_qubits)
112
for i in range(n_electrons): x(q[i])
113
for i in range(len(coeffs)): exp_pauli(coeffs[i], q, words[i])
114
115
def cost(sampled_ops, **kwargs):
116
full_coeffs = []
117
full_words = []
118
for op in sampled_ops:
119
full_coeffs += [c.real for c in term_coefficients(op)]
120
full_words += term_words(op)
121
122
if args.mpi:
123
handle = cudaq.observe_async(kernel, spin_ham, n_qubits, n_electrons, full_coeffs, full_words, qpu_id=kwargs['qpu_id'])
124
return handle, lambda res: res.get().expectation()
125
else:
126
return cudaq.observe(kernel, spin_ham, n_qubits, n_electrons, full_coeffs, full_words).expectation()
127
128
# ---------------------------------------------------------------------------
129
# Configure GQE
130
# ---------------------------------------------------------------------------
131
cfg = get_default_config()
132
cfg.use_fabric_logging = False
133
logger = CSVLogger("gqe_lih_logs/gqe.csv")
134
cfg.fabric_logger = logger
135
cfg.save_trajectory = True
136
cfg.verbose = True
137
cfg.del_temperature = 0.05
138
cfg.max_iters = args.max_iters
139
cfg.ngates = args.ngates
140
cfg.num_samples = args.num_samples
141
cfg.temperature = args.temperature
142
cfg.lr = args.lr
143
144
# Run GQE
145
minE, best_ops = solvers.gqe(cost, op_pool, config=cfg)
146
147
# ---------------------------------------------------------------------------
148
# 3. Process Results & Plot (Rank 0 only)
149
# ---------------------------------------------------------------------------
150
if not args.mpi or cudaq.mpi.rank() == 0:
151
print(f'Ground Energy = {minE}')
152
print('Ansatz Ops')
153
for idx in best_ops: print(op_pool[idx])
154
155
print("\nProcessing trajectory and generating plot...")
156
trajectory_file = "gqe_logs/gqe_trajectory.json"
157
158
if os.path.exists(trajectory_file):
159
try:
160
# Switch to Dark Background
161
plt.style.use('dark_background')
162
163
# Data storage
164
iterations = []
165
all_energies = [] # List of lists (all samples per epoch)
166
losses = []
167
168
# For finding global min
169
global_min_val = float('inf')
170
global_min_iter = 0
171
172
with open(trajectory_file, 'r') as f:
173
for line in f:
174
line = line.strip()
175
if line:
176
try:
177
entry = json.loads(line)
178
current_iter = entry.get('iter', len(iterations))
179
180
# Parse Energies (Population)
181
if 'energies' in entry and isinstance(entry['energies'], list) and entry['energies']:
182
epoch_energies = entry['energies']
183
all_energies.append(epoch_energies)
184
185
# Check for global min in this batch
186
batch_min = min(epoch_energies)
187
if batch_min < global_min_val:
188
global_min_val = batch_min
189
global_min_iter = current_iter
190
191
elif 'energy' in entry:
192
val = entry['energy']
193
all_energies.append([val])
194
if val < global_min_val:
195
global_min_val = val
196
global_min_iter = current_iter
197
else:
198
continue
199
200
# Parse Loss
201
losses.append(entry.get('loss', None))
202
iterations.append(current_iter)
203
204
except json.JSONDecodeError: continue
205
206
if all_energies:
207
# Setup 2 Subplots sharing X axis
208
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True, gridspec_kw={'height_ratios': [2, 1]})
209
210
nvidia_green = '#76B900'
211
bright_red = '#FF4444'
212
213
# --- Plot 1: Energy Scatter ---
214
# Flatten data for scatter plotting
215
scatter_x = []
216
scatter_y = []
217
for idx, epoch_data in enumerate(all_energies):
218
iter_num = iterations[idx]
219
scatter_x.extend([iter_num] * len(epoch_data))
220
scatter_y.extend(epoch_data)
221
222
ax1.scatter(scatter_x, scatter_y, color=nvidia_green, alpha=0.7, s=15, label='Sampled Energies')
223
224
# Plot FCI Line
225
ax1.axhline(y=fci_energy, color=bright_red, linestyle='--', linewidth=1.5,
226
label=f'FCI Energy = {fci_energy:.5f} Ha')
227
228
# Circle the Global Minimum
229
ax1.plot(global_min_iter, global_min_val, marker='o', markersize=20,
230
markeredgecolor='white', markerfacecolor='none', markeredgewidth=2,
231
label='Global Minimum')
232
233
# Annotation: Delta E
234
energy_diff = abs(global_min_val - fci_energy)
235
ax1.text(0.5, 0.95, f'ΔE (GQE - FCI) = {energy_diff:.2e} Ha',
236
transform=ax1.transAxes, ha='center', va='top',
237
fontsize=14, fontweight='bold', color='white',
238
bbox=dict(facecolor='#333333', alpha=0.9, edgecolor='gray', boxstyle='round'))
239
240
# Formatting Plot 1
241
ax1.set_ylabel('Energy (Hartree)', fontweight='bold', color='white')
242
ax1.set_ylim(bottom=fci_energy - 0.02, top=-0.70) # Fixed range
243
ax1.grid(True, which='both', linestyle='--', alpha=0.3, color='gray')
244
ax1.legend(loc='upper right', facecolor='#333333', edgecolor='white')
245
ax1.set_title(f'GQE Convergence (Gates={args.ngates}, Samples={args.num_samples})', color='white')
246
247
# --- Plot 2: Loss ---
248
valid_loss = [(i, l) for i, l in zip(iterations, losses) if l is not None]
249
if valid_loss:
250
lx, ly = zip(*valid_loss)
251
# Using White for Loss line to pop against dark background
252
ax2.plot(lx, ly, color='white', marker='x', linestyle='-', linewidth=1, label='Loss')
253
254
ax2.set_ylabel('Loss', fontweight='bold', color='white')
255
ax2.set_xlabel('Iteration', fontweight='bold', color='white')
256
ax2.set_ylim(0, 10) # Fixed range 0 to 10
257
ax2.grid(True, which='both', linestyle='--', alpha=0.3, color='gray')
258
ax2.legend(loc='upper right', facecolor='#333333', edgecolor='white')
259
260
# Save
261
plt.tight_layout()
262
output_img = args.output_file
263
plt.savefig(output_img, dpi=300, facecolor='black')
264
print(f"Plot saved successfully to: {output_img}")
265
print(f"Global Minimum: {global_min_val:.6f} at iter {global_min_iter}")
266
267
else:
268
print("No energy data found.")
269
270
except Exception as e:
271
print(f"Error plotting: {e}")
272
import traceback
273
traceback.print_exc()
274
else:
275
print(f"Error: {trajectory_file} not found.")
276
277
if args.mpi:
278
cudaq.mpi.finalize()
279