Path: blob/main/quantum-applications-to-finance/helper.py
1147 views
import numpy as np1from typing import List, Tuple2import matplotlib.pyplot as plt3import numpy as np4from collections import defaultdict567def top_k_bitstrings(shots, Q: np.ndarray, k: int = 58) -> List[Tuple[str, float, float]]:9"""10List the k most-probable bit-strings in `shots` and their QUBO values.1112Returns13-------14[(bitstring, probability, qubo_cost), ...]15"""16# ------------------------------------------------------------------17# 1) extract a counts-dict {str -> int}18# ------------------------------------------------------------------19if isinstance(shots, dict): # plain counts dict20counts_dict = shots21elif hasattr(shots, "items"): # cudaq.SampleResult22counts_dict = dict(shots.items())23elif hasattr(shots, "counts"): # older naming24counts_dict = shots.counts25else:26raise TypeError("Unrecognised shot container type.")2728total_shots = sum(counts_dict.values())29if total_shots == 0:30raise ValueError("No shots in sample result.")3132# ------------------------------------------------------------------33# 2) sort by frequency and keep top-k34# ------------------------------------------------------------------35top = sorted(counts_dict.items(),36key=lambda kv: kv[1],37reverse=True)[:k]3839# ------------------------------------------------------------------40# 3) compute QUBO value for every string41# ------------------------------------------------------------------42n = Q.shape[0]43results = []44for bitstr, cnt in top:45# bitstr is already a string like '1010'46x = np.fromiter(bitstr, dtype=int, count=n)47cost = float(x @ (Q @ x))48prob = cnt / total_shots49results.append((bitstr, prob, cost))50print(f"{bitstr} prob = {prob:.3f} QUBO = {cost:.6f}")5152return results5354def plot_samples_histogram(sample1, sample2, solutions_data, title="Portfolio Comparison"):55"""56Plot a histogram comparing two CUDAQ sample objects.5758Args:59sample1: First CUDAQ sample object60sample2: Second CUDAQ sample object61solutions_data: List of tuples ((bit0, bit1, ...), objective_value)62title: Plot title63"""64# Sort solutions by objective value65sorted_solutions = sorted(solutions_data, key=lambda x: x[1])6667# Create a mapping from bitstring tuples to their string representation68bitstring_map = {tuple(bits): ''.join(str(b) for b in bits) for bits, _ in sorted_solutions}6970# Convert samples to dictionaries of counts71counts1 = defaultdict(int)72counts2 = defaultdict(int)7374# Extract counts from sample175for bitstring, count in sample1.items():76bitstring_tuple = tuple(int(b) for b in bitstring)77counts1[bitstring_tuple] = count7879# Extract counts from sample280for bitstring, count in sample2.items():81bitstring_tuple = tuple(int(b) for b in bitstring)82counts2[bitstring_tuple] = count8384# Get all unique bitstrings in order of objective value85all_bitstrings = [bits for bits, _ in sorted_solutions]8687# Create x-axis labels with bitstring and objective value88x_labels = [f"{bitstring_map[bits]}\n(obj: {val:.2f})" for bits, val in sorted_solutions]8990# Set up plot91fig, ax = plt.subplots(figsize=(14, 8))9293# Set positions for bars94x = np.arange(len(all_bitstrings))95width = 0.359697# Create bars with NVIDIA colors98nvidia_green = '#76B900' # NVIDIA green color99counts1_values = [counts1.get(bits, 0) for bits in all_bitstrings]100counts2_values = [counts2.get(bits, 0) for bits in all_bitstrings]101102# Plot bars with black and NVIDIA green103bar1 = ax.bar(x - width/2, counts1_values, width, label='Initial State', color='black')104bar2 = ax.bar(x + width/2, counts2_values, width, label='Final State', color=nvidia_green)105106# Find the transition point between good and bad portfolios107good_portfolios = []108bad_portfolios = []109for i, (_, val) in enumerate(sorted_solutions):110if val < 0:111good_portfolios.append(i)112else:113bad_portfolios.append(i)114115# Add annotations for good and bad portfolios116max_count = max(max(counts1_values or [0]), max(counts2_values or [0]))117if max_count > 0:118if good_portfolios:119mid_good = good_portfolios[len(good_portfolios)//2]120ax.text(mid_good, max_count * 0.95, "Good Portfolios",121ha='center', va='center', fontsize=12, fontweight='bold')122123if bad_portfolios:124mid_bad = bad_portfolios[len(bad_portfolios)//2]125ax.text(mid_bad, max_count * 0.95, "Bad Portfolios",126ha='center', va='center', fontsize=12, fontweight='bold')127128# Customize plot129ax.set_xlabel('Bitstring Configuration (Objective Value)', fontsize=12)130ax.set_ylabel('Sample Count', fontsize=12)131ax.set_title(title, fontsize=14, fontweight='bold')132ax.set_xticks(x)133ax.set_xticklabels(x_labels, rotation=45, ha='right')134135# Improve legend136ax.legend(loc='upper right', frameon=True, framealpha=0.9, fontsize=10)137138# Add grid and adjust layout139ax.grid(axis='y', linestyle='--', alpha=0.3)140plt.tight_layout()141142return fig, ax143144145