Path: blob/main/chemistry-simulations/aux_files/vqe_and_gqe/run_gqe_h2.py
1166 views
import argparse1import cudaq2import os3import torch4import json5import matplotlib.pyplot as plt6import numpy as np7import cudaq_solvers as solvers8from cudaq import spin9from lightning.fabric.loggers import CSVLogger10from cudaq_solvers.gqe_algorithm.gqe import get_default_config11from typing import List1213# ---------------------------------------------------------------------------14# Argument Parsing15# ---------------------------------------------------------------------------16parser = argparse.ArgumentParser(description="Run GQE for LiH/H2.")17parser.add_argument('--mpi', action='store_true', help='Enable MPI distribution.')18parser.add_argument('--max_iters', type=int, default=75, help='Maximum number of GQE iterations.')19parser.add_argument('--ngates', type=int, default=40, help='Number of gates (ansatz depth).')20parser.add_argument('--num_samples', type=int, default=10, help='Number of samples (population size).')21parser.add_argument('--temperature', type=float, default=5.0, help='Temperature for sampling.')22parser.add_argument('--lr', type=float, default=1e-7, help='Learning Rate.')23parser.add_argument('--output_file', type=str, default='gqe_convergence.png', help='Filename to save the convergence plot.')2425args = parser.parse_args()2627if args.mpi:28try:29cudaq.set_target('nvidia', option='mqpu')30cudaq.mpi.initialize()31except RuntimeError:32print('Warning: NVIDIA GPUs or MPI not available. Skipping...')33exit(0)34else:35try:36cudaq.set_target('nvidia', option='fp64')37except RuntimeError:38cudaq.set_target('qpp-cpu')3940# Deterministic setup41os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'42torch.manual_seed(3047)43torch.use_deterministic_algorithms(True)44torch.backends.cudnn.deterministic = True45torch.backends.cudnn.benchmark = False4647# ---------------------------------------------------------------------------48# 1. Update Geometry and Get FCI49# ---------------------------------------------------------------------------50geometry = [('H', (0., 0., 0.)), ('H', (0., 0., 0.74))]51molecule = solvers.create_molecule(geometry, '6-31g', 0, 0, nele_cas=2, norb_cas=3, casci=True)52spin_ham = molecule.hamiltonian53n_qubits = molecule.n_orbitals * 254n_electrons = molecule.n_electrons5556# Extract FCI Energy (R-CASCI)57energies = molecule.energies58fci_energy = energies['R-CASCI'] if 'R-CASCI' in energies else -0.98863771909034415960# Only print preamble on rank 061if not args.mpi or cudaq.mpi.rank() == 0:62print(f"Full CI Energy (R-CASCI): {fci_energy}")63print(f"Configuration: max_iters={args.max_iters}, ngates={args.ngates}, "64f"num_samples={args.num_samples}, temperature={args.temperature}")6566# ---------------------------------------------------------------------------67# 2. Build Operator Pool68# ---------------------------------------------------------------------------6970def get_identity(n_qubits: int) -> cudaq.SpinOperator:71In = cudaq.spin.i(0)72for q in range(1, n_qubits):73In = In * cudaq.spin.i(q)74return 1.0 * cudaq.SpinOperator(In)7576def get_gqe_pauli_pool(num_qubits: int, num_electrons: int, params: List[float]) -> List[cudaq.SpinOperator]:77uccsd_operators = solvers.get_operator_pool("uccsd", num_qubits=num_qubits, num_electrons=num_electrons)78pool = [get_identity(num_qubits)]7980individual_terms = []81for op in uccsd_operators:82for term in op:83pauli_word = term.get_pauli_word(num_qubits)84pauli_op = None85for qubit_idx, pauli_char in enumerate(pauli_word):86if pauli_char == 'I': gate = spin.i(qubit_idx)87elif pauli_char == 'X': gate = spin.x(qubit_idx)88elif pauli_char == 'Y': gate = spin.y(qubit_idx)89elif pauli_char == 'Z': gate = spin.z(qubit_idx)90else: continue91pauli_op = gate if pauli_op is None else pauli_op * gate9293if pauli_op is not None:94individual_terms.append(cudaq.SpinOperator(pauli_op))9596for term_op in individual_terms:97for param in params:98pool.append(param * term_op)99return pool100101params = [0.003125, -0.003125, 0.00625, -0.00625, 0.0125, -0.0125,1020.025, -0.025, 0.05, -0.05, 0.1, -0.1]103op_pool = get_gqe_pauli_pool(n_qubits, n_electrons, params)104105def term_coefficients(op): return [term.evaluate_coefficient() for term in op]106def term_words(op): return [term.get_pauli_word(n_qubits) for term in op]107108@cudaq.kernel109def kernel(n_qubits: int, n_electrons: int, coeffs: list[float], words: list[cudaq.pauli_word]):110q = cudaq.qvector(n_qubits)111for i in range(n_electrons): x(q[i])112for i in range(len(coeffs)): exp_pauli(coeffs[i], q, words[i])113114def cost(sampled_ops, **kwargs):115full_coeffs = []116full_words = []117for op in sampled_ops:118full_coeffs += [c.real for c in term_coefficients(op)]119full_words += term_words(op)120121if args.mpi:122handle = cudaq.observe_async(kernel, spin_ham, n_qubits, n_electrons, full_coeffs, full_words, qpu_id=kwargs['qpu_id'])123return handle, lambda res: res.get().expectation()124else:125return cudaq.observe(kernel, spin_ham, n_qubits, n_electrons, full_coeffs, full_words).expectation()126127# ---------------------------------------------------------------------------128# Configure GQE129# ---------------------------------------------------------------------------130cfg = get_default_config()131cfg.use_fabric_logging = False132logger = CSVLogger("gqe_lih_logs/gqe.csv")133cfg.fabric_logger = logger134cfg.save_trajectory = True135cfg.verbose = True136cfg.del_temperature = 0.05137cfg.max_iters = args.max_iters138cfg.ngates = args.ngates139cfg.num_samples = args.num_samples140cfg.temperature = args.temperature141cfg.lr = args.lr142143# Run GQE144minE, best_ops = solvers.gqe(cost, op_pool, config=cfg)145146# ---------------------------------------------------------------------------147# 3. Process Results & Plot (Rank 0 only)148# ---------------------------------------------------------------------------149if not args.mpi or cudaq.mpi.rank() == 0:150print(f'Ground Energy = {minE}')151print('Ansatz Ops')152for idx in best_ops: print(op_pool[idx])153154print("\nProcessing trajectory and generating plot...")155trajectory_file = "gqe_logs/gqe_trajectory.json"156157if os.path.exists(trajectory_file):158try:159# Switch to Dark Background160plt.style.use('dark_background')161162# Data storage163iterations = []164all_energies = [] # List of lists (all samples per epoch)165losses = []166167# For finding global min168global_min_val = float('inf')169global_min_iter = 0170171with open(trajectory_file, 'r') as f:172for line in f:173line = line.strip()174if line:175try:176entry = json.loads(line)177current_iter = entry.get('iter', len(iterations))178179# Parse Energies (Population)180if 'energies' in entry and isinstance(entry['energies'], list) and entry['energies']:181epoch_energies = entry['energies']182all_energies.append(epoch_energies)183184# Check for global min in this batch185batch_min = min(epoch_energies)186if batch_min < global_min_val:187global_min_val = batch_min188global_min_iter = current_iter189190elif 'energy' in entry:191val = entry['energy']192all_energies.append([val])193if val < global_min_val:194global_min_val = val195global_min_iter = current_iter196else:197continue198199# Parse Loss200losses.append(entry.get('loss', None))201iterations.append(current_iter)202203except json.JSONDecodeError: continue204205if all_energies:206# Setup 2 Subplots sharing X axis207fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True, gridspec_kw={'height_ratios': [2, 1]})208209nvidia_green = '#76B900'210bright_red = '#FF4444'211212# --- Plot 1: Energy Scatter ---213# Flatten data for scatter plotting214scatter_x = []215scatter_y = []216for idx, epoch_data in enumerate(all_energies):217iter_num = iterations[idx]218scatter_x.extend([iter_num] * len(epoch_data))219scatter_y.extend(epoch_data)220221ax1.scatter(scatter_x, scatter_y, color=nvidia_green, alpha=0.7, s=15, label='Sampled Energies')222223# Plot FCI Line224ax1.axhline(y=fci_energy, color=bright_red, linestyle='--', linewidth=1.5,225label=f'FCI Energy = {fci_energy:.5f} Ha')226227# Circle the Global Minimum228ax1.plot(global_min_iter, global_min_val, marker='o', markersize=20,229markeredgecolor='white', markerfacecolor='none', markeredgewidth=2,230label='Global Minimum')231232# Annotation: Delta E233energy_diff = abs(global_min_val - fci_energy)234ax1.text(0.5, 0.95, f'ΔE (GQE - FCI) = {energy_diff:.2e} Ha',235transform=ax1.transAxes, ha='center', va='top',236fontsize=14, fontweight='bold', color='white',237bbox=dict(facecolor='#333333', alpha=0.9, edgecolor='gray', boxstyle='round'))238239# Formatting Plot 1240ax1.set_ylabel('Energy (Hartree)', fontweight='bold', color='white')241ax1.set_ylim(bottom=fci_energy - 0.02, top=-0.70) # Fixed range242ax1.grid(True, which='both', linestyle='--', alpha=0.3, color='gray')243ax1.legend(loc='upper right', facecolor='#333333', edgecolor='white')244ax1.set_title(f'GQE Convergence (Gates={args.ngates}, Samples={args.num_samples})', color='white')245246# --- Plot 2: Loss ---247valid_loss = [(i, l) for i, l in zip(iterations, losses) if l is not None]248if valid_loss:249lx, ly = zip(*valid_loss)250# Using White for Loss line to pop against dark background251ax2.plot(lx, ly, color='white', marker='x', linestyle='-', linewidth=1, label='Loss')252253ax2.set_ylabel('Loss', fontweight='bold', color='white')254ax2.set_xlabel('Iteration', fontweight='bold', color='white')255ax2.set_ylim(0, 10) # Fixed range 0 to 10256ax2.grid(True, which='both', linestyle='--', alpha=0.3, color='gray')257ax2.legend(loc='upper right', facecolor='#333333', edgecolor='white')258259# Save260plt.tight_layout()261output_img = args.output_file262plt.savefig(output_img, dpi=300, facecolor='black')263print(f"Plot saved successfully to: {output_img}")264print(f"Global Minimum: {global_min_val:.6f} at iter {global_min_iter}")265266else:267print("No energy data found.")268269except Exception as e:270print(f"Error plotting: {e}")271import traceback272traceback.print_exc()273else:274print(f"Error: {trajectory_file} not found.")275276if args.mpi:277cudaq.mpi.finalize()278279