Source code for experiments.harness.worker

"""Trial specification, worker functions, and parallel dispatch."""

import os
import signal
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Optional

import numpy as np
from numpy.random import default_rng

from experiments.harness.results import TrialResult


[docs] @dataclass class TrialSpec: r"""Fully serialisable specification for one trial. :class:`~concurrent.futures.ProcessPoolExecutor` requires picklable arguments. Since :class:`~mos.MoSState` contains Qiskit objects that do not always pickle cleanly, we instead pass this plain dataclass to worker processes. Each worker reconstructs the :class:`~mos.MoSState` from these fields, runs the protocol, and returns a :class:`TrialResult`. """ #: Number of input bits. n: int #: Label-bias function :math:`\varphi(x) = \Pr[y{=}1 \mid x]`, #: stored as a plain list (not ``ndarray``) for pickle safety. #: Length :math:`2^n`. phi: list[float] #: Label-flip noise rate :math:`\eta \in [0, 0.5)`. noise_rate: float #: The ground-truth heavy parity index #: :math:`s^* \in \{0,\ldots,2^n{-}1\}`. target_s: int #: Accuracy parameter :math:`\varepsilon`. epsilon: float #: Confidence parameter :math:`\delta`. delta: float #: Fourier resolution threshold :math:`\vartheta`. theta: float #: Lower bound :math:`a^2` on #: :math:`\mathbb{E}[\tilde\phi(x)^2]` (Definition 14). a_sq: float #: Upper bound :math:`b^2` on #: :math:`\mathbb{E}[\tilde\phi(x)^2]` (Definition 14). b_sq: float #: Number of MoS copies for QFS (prover). qfs_shots: int #: Number of classical samples for the prover's coefficient #: estimation. classical_samples_prover: int #: Number of classical samples for the verifier's independent #: estimation. classical_samples_verifier: int #: Per-trial random seed. seed: int #: Human-readable label for the distribution. phi_description: str #: If not ``None``, the worker runs a dishonest prover with this #: strategy instead of the honest protocol. Single-element #: strategies (used by :func:`run_soundness_experiment`): #: ``"random_list"``, ``"wrong_parity"``, ``"partial_list"``, #: ``"inflated_list"``. Multi-element strategies (used by #: :func:`run_soundness_multi_experiment`): ``"partial_real"``, #: ``"diluted_list"``, ``"shifted_coefficients"``, #: ``"subset_plus_noise"``. dishonest_strategy: Optional[str] = None #: Gate-level depolarising noise rate :math:`p`. When set, the #: worker constructs a Qiskit ``NoiseModel`` with depolarising #: channels on H, X (1-qubit) and CX (2-qubit) gates. gate_noise_rate: Optional[float] = None #: QFS simulation mode passed to ``MoSProver.run_protocol``. qfs_mode: str = "statevector" #: Fourier sparsity parameter :math:`k`. When set and ``> 1``, #: the worker calls :meth:`~ql.verifier.MoSVerifier.verify_fourier_sparse` #: instead of :meth:`~ql.verifier.MoSVerifier.verify_parity`. k: Optional[int] = None #: Number of fresh samples for misclassification rate estimation. #: When ``None``, defaults to 1000. misclassification_samples: Optional[int] = None
def _compute_misclassification_rate(state, hypothesis, seed, num_samples=1000): """Compute empirical P[h(x) != y] on fresh classical samples.""" from ql.verifier import FourierSparseHypothesis rng = default_rng(seed) xs, ys = state.sample_classical_batch(num_samples=num_samples, rng=rng) if isinstance(hypothesis, FourierSparseHypothesis): predictions = hypothesis.evaluate_batch(xs, rng=rng) else: predictions = hypothesis.evaluate_batch(xs) return float(np.mean(predictions != ys)) def _run_trial_worker(spec: TrialSpec) -> TrialResult: r"""Execute a single trial in a worker process. Reconstructs a :class:`~mos.MoSState` from the fields in *spec*, runs the honest prover (:class:`~ql.prover.MoSProver`) followed by the classical verifier (:class:`~ql.verifier.MoSVerifier`), and returns a :class:`TrialResult` capturing every observable quantity. All imports are performed inside the function body so that each forked process obtains a clean module state with no shared mutable objects. If ``spec.dishonest_strategy`` is set, the worker delegates to :func:`_run_dishonest_trial` instead. Parameters ---------- spec : TrialSpec Serialisable trial specification. Returns ------- TrialResult """ # Late imports — avoid contaminating the parent process and # ensure each worker has independent module state. from mos import MoSState from ql.prover import MoSProver from ql.verifier import MoSVerifier phi = np.array(spec.phi, dtype=np.float64) state = MoSState(n=spec.n, phi=phi, noise_rate=spec.noise_rate, seed=spec.seed) if spec.dishonest_strategy is not None: return _run_dishonest_trial(spec, state) # --- Build gate-level noise model (if requested) --- noise_model = None if spec.gate_noise_rate is not None and spec.gate_noise_rate > 0: from qiskit_aer.noise import NoiseModel, depolarizing_error noise_model = NoiseModel() noise_model.add_all_qubit_quantum_error( depolarizing_error(spec.gate_noise_rate, 1), ["h", "x"] ) noise_model.add_all_qubit_quantum_error( depolarizing_error(spec.gate_noise_rate, 2), ["cx"] ) # --- Prover --- t0 = time.time() prover = MoSProver(state, seed=spec.seed, noise_model=noise_model) msg = prover.run_protocol( epsilon=spec.epsilon, delta=spec.delta, theta=spec.theta, qfs_mode=spec.qfs_mode, qfs_shots=spec.qfs_shots, classical_samples=spec.classical_samples_prover, ) prover_time = time.time() - t0 # --- Verifier --- t1 = time.time() verifier = MoSVerifier(state, seed=spec.seed + 1_000_000) if spec.k is not None and spec.k > 1: result = verifier.verify_fourier_sparse( msg, epsilon=spec.epsilon, k=spec.k, delta=spec.delta, theta=spec.theta, a_sq=spec.a_sq, b_sq=spec.b_sq, num_samples=spec.classical_samples_verifier, ) else: result = verifier.verify_parity( msg, epsilon=spec.epsilon, delta=spec.delta, theta=spec.theta, a_sq=spec.a_sq, b_sq=spec.b_sq, num_samples=spec.classical_samples_verifier, ) verifier_time = time.time() - t1 # --- Extract hypothesis --- from ql.verifier import FourierSparseHypothesis hyp_s = None hyp_coefficients = None misclass_rate = None correct = False if result.accepted and result.hypothesis is not None: if isinstance(result.hypothesis, FourierSparseHypothesis): hyp_coefficients = dict(result.hypothesis.coefficients) hyp_s = max(hyp_coefficients, key=lambda s: abs(hyp_coefficients[s])) correct = hyp_s == spec.target_s misclass_rate = _compute_misclassification_rate( state, result.hypothesis, spec.seed + 2_000_000, num_samples=spec.misclassification_samples or 1000, ) else: hyp_s = result.hypothesis.s correct = hyp_s == spec.target_s if hyp_s is not None else False return TrialResult( n=spec.n, seed=spec.seed, prover_time_s=prover_time, qfs_shots=spec.qfs_shots, qfs_postselected=msg.qfs_result.postselected_shots, postselection_rate=msg.qfs_result.postselection_rate, list_size=msg.list_size, prover_found_target=(spec.target_s in msg.L), verifier_time_s=verifier_time, verifier_samples=spec.classical_samples_verifier, outcome=result.outcome.value, accepted=result.accepted, accumulated_weight=result.accumulated_weight, acceptance_threshold=result.acceptance_threshold, hypothesis_s=hyp_s, hypothesis_correct=correct, total_copies=msg.total_copies_used + spec.classical_samples_verifier, total_time_s=prover_time + verifier_time, epsilon=spec.epsilon, theta=spec.theta, delta=spec.delta, a_sq=spec.a_sq, b_sq=spec.b_sq, phi_description=spec.phi_description, k=spec.k, hypothesis_coefficients=hyp_coefficients, misclassification_rate=misclass_rate, ) def _extract_spectrum(phi: list[float], threshold: float = 0.01) -> list[tuple[int, float]]: """Compute Fourier spectrum of phi and return (index, coefficient) pairs above threshold.""" from experiments.harness.phi import walsh_hadamard phi_arr = np.array(phi) tilde_phi = 1.0 - 2.0 * phi_arr spectrum = walsh_hadamard(tilde_phi) return [(s, float(spectrum[s])) for s in range(len(spectrum)) if abs(spectrum[s]) > threshold] def _strategy_random_list(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage L = sorted(rng.choice(2**n, size=min(5, 2**n), replace=False).tolist()) return ProverMessage(L, {s: 0.0 for s in L}, n, epsilon, epsilon, dummy_sa, dummy_qfs, 0) def _strategy_wrong_parity(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage wrong_s = (target_s + 1) % (2**n) if wrong_s == 0: wrong_s = (target_s + 2) % (2**n) return ProverMessage([wrong_s], {wrong_s: 1.0}, n, epsilon, epsilon, dummy_sa, dummy_qfs, 0) def _strategy_partial_list(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage return ProverMessage([], {}, n, epsilon, epsilon, dummy_sa, dummy_qfs, 0) def _strategy_inflated_list(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage candidates = [s for s in range(2**n) if s != target_s] chosen = sorted(rng.choice(candidates, size=min(10, len(candidates)), replace=False).tolist()) return ProverMessage(chosen, {s: 0.5 for s in chosen}, n, epsilon, epsilon, dummy_sa, dummy_qfs, 0) def _strategy_partial_real(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage heavy = _extract_spectrum(phi) heavy_sorted = sorted(heavy, key=lambda x: abs(x[1]), reverse=True) n_real = max(1, len(heavy_sorted) // 2) real_part = [s for s, _ in heavy_sorted[n_real:]] used = {s for s, _ in heavy} fake_candidates = [s for s in range(2**n) if s not in used] n_fake = min(3, len(fake_candidates)) fakes = sorted(rng.choice(fake_candidates, size=n_fake, replace=False).tolist()) if n_fake > 0 else [] L = sorted(real_part + fakes) return ProverMessage(L, {s: 0.5 for s in L}, n, epsilon, theta, dummy_sa, dummy_qfs, 0) def _strategy_diluted_list(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage heavy = _extract_spectrum(phi) heavy_sorted = sorted(heavy, key=lambda x: abs(x[1]), reverse=True) # Audit fix m4 (audit/soundness_multi.md): the previous formula # ``n_keep = max(1, len(heavy_sorted) // 4)`` was always 1 for # k <= 4, so the strategy was effectively "keep one weakest" — its # docstring described "keep all real" which the code never did. # We now keep ``max(1, k // 2)`` of the weakest real coefficients, # which matches the "diluted signal" intuition (the prover knows # part of the spectrum but not all of it). The verifier still # rejects because accumulated weight on the kept subset is # strictly below the true Parseval mass ``pw``. n_keep = max(1, len(heavy_sorted) // 2) kept_indices = [s for s, _ in heavy_sorted[-n_keep:]] used = {s for s, _ in heavy} padding_candidates = [s for s in range(2**n) if s not in used] n_padding = min(20, len(padding_candidates)) padding = sorted(rng.choice(padding_candidates, size=n_padding, replace=False).tolist()) if n_padding > 0 else [] L = sorted(kept_indices + padding) return ProverMessage(L, {s: 0.5 for s in L}, n, epsilon, theta, dummy_sa, dummy_qfs, 0) def _strategy_shifted_coefficients(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage heavy = _extract_spectrum(phi) used = {s for s, _ in heavy} wrong_candidates = [s for s in range(2**n) if s not in used] n_wrong = min(len(heavy), len(wrong_candidates)) chosen = sorted(rng.choice(wrong_candidates, size=max(1, n_wrong), replace=False).tolist()) return ProverMessage(chosen, {s: 0.8 for s in chosen}, n, epsilon, theta, dummy_sa, dummy_qfs, 0) def _strategy_subset_plus_noise(n, rng, target_s, epsilon, theta, phi, dummy_sa, dummy_qfs): from ql.prover import ProverMessage heavy = _extract_spectrum(phi) heavy_sorted = sorted(heavy, key=lambda x: abs(x[1]), reverse=True) heaviest_s = heavy_sorted[0][0] if heavy_sorted else 0 used = {s for s, _ in heavy} fake_candidates = [s for s in range(2**n) if s not in used] n_fake = min(5, len(fake_candidates)) fakes = sorted(rng.choice(fake_candidates, size=n_fake, replace=False).tolist()) if n_fake > 0 else [] L = sorted([heaviest_s] + fakes) return ProverMessage(L, {s: 0.3 for s in L}, n, epsilon, theta, dummy_sa, dummy_qfs, 0) _DISHONEST_STRATEGIES = { "random_list": _strategy_random_list, "wrong_parity": _strategy_wrong_parity, "partial_list": _strategy_partial_list, "inflated_list": _strategy_inflated_list, "partial_real": _strategy_partial_real, "diluted_list": _strategy_diluted_list, "shifted_coefficients": _strategy_shifted_coefficients, "subset_plus_noise": _strategy_subset_plus_noise, } def _run_dishonest_trial(spec: TrialSpec, state) -> TrialResult: r"""Execute a dishonest-prover trial inside a worker process. Constructs a fake :class:`~ql.prover.ProverMessage` according to the adversarial strategy in ``spec.dishonest_strategy``, then runs the verifier against it. Strategies are registered in :data:`_DISHONEST_STRATEGIES`. When ``spec.k`` is set and ``> 1``, the verifier uses :meth:`~ql.verifier.MoSVerifier.verify_fourier_sparse` (with the k-sparse acceptance threshold :math:`a^2 - \varepsilon^2/(128k^2)`) instead of :meth:`~ql.verifier.MoSVerifier.verify_parity`. Parameters ---------- spec : TrialSpec Trial specification with ``dishonest_strategy`` set. state : MoSState Reconstructed MoS state (used only by the verifier for classical sampling via Lemma 1). Returns ------- TrialResult """ from ql.prover import SpectrumApproximation from ql.verifier import FourierSparseHypothesis, MoSVerifier from mos.sampler import QFSResult n = spec.n rng = default_rng(spec.seed) target_s = spec.target_s epsilon = spec.epsilon dummy_qfs = QFSResult({}, {}, 0, 0, n, "statevector") dummy_sa = SpectrumApproximation({}, 0.0, n, 0, 0) strategy_fn = _DISHONEST_STRATEGIES.get(spec.dishonest_strategy) if strategy_fn is None: raise ValueError(f"Unknown dishonest strategy: {spec.dishonest_strategy}") fake_msg = strategy_fn(n, rng, target_s, epsilon, spec.theta, spec.phi, dummy_sa, dummy_qfs) verifier = MoSVerifier(state, seed=spec.seed + 1_000_000) if spec.k is not None and spec.k > 1: vresult = verifier.verify_fourier_sparse( fake_msg, epsilon=epsilon, k=spec.k, delta=spec.delta, theta=spec.theta, a_sq=spec.a_sq, b_sq=spec.b_sq, num_samples=spec.classical_samples_verifier, ) else: vresult = verifier.verify_parity( fake_msg, epsilon=epsilon, delta=spec.delta, theta=spec.theta, a_sq=spec.a_sq, b_sq=spec.b_sq, num_samples=spec.classical_samples_verifier, ) # --- Extract hypothesis and compute misclassification if accepted --- hyp_s = None hyp_coefficients = None misclass_rate = None if vresult.accepted and vresult.hypothesis is not None: if isinstance(vresult.hypothesis, FourierSparseHypothesis): hyp_coefficients = dict(vresult.hypothesis.coefficients) hyp_s = max(hyp_coefficients, key=lambda s: abs(hyp_coefficients[s])) misclass_rate = _compute_misclassification_rate( state, vresult.hypothesis, spec.seed + 2_000_000, num_samples=spec.misclassification_samples or 1000, ) else: hyp_s = vresult.hypothesis.s misclass_rate = _compute_misclassification_rate( state, vresult.hypothesis, spec.seed + 2_000_000, num_samples=spec.misclassification_samples or 1000, ) return TrialResult( n=n, seed=spec.seed, prover_time_s=0.0, qfs_shots=0, qfs_postselected=0, postselection_rate=0.0, list_size=len(fake_msg.L), prover_found_target=(target_s in fake_msg.L), verifier_time_s=0.0, verifier_samples=spec.classical_samples_verifier, outcome=vresult.outcome.value, accepted=vresult.accepted, accumulated_weight=vresult.accumulated_weight, acceptance_threshold=vresult.acceptance_threshold, hypothesis_s=hyp_s, hypothesis_correct=(hyp_s == target_s) if hyp_s is not None else False, total_copies=spec.classical_samples_verifier, total_time_s=0.0, epsilon=epsilon, theta=spec.theta, delta=spec.delta, a_sq=spec.a_sq, b_sq=spec.b_sq, phi_description=f"soundness_{spec.dishonest_strategy}", k=spec.k, hypothesis_coefficients=hyp_coefficients, misclassification_rate=misclass_rate, )
[docs] def run_trials_parallel( specs: list[TrialSpec], max_workers: Optional[int] = None, label: str = "", shard_index: Optional[int] = None, num_shards: Optional[int] = None, ) -> list[TrialResult]: r"""Dispatch a batch of trials across worker processes. When ``max_workers > 1``, uses :class:`~concurrent.futures.ProcessPoolExecutor` with :func:`~concurrent.futures.as_completed` for progress reporting. Results are stored in an index-mapped list so the output preserves the original spec ordering regardless of completion order. When ``max_workers == 1``, falls back to sequential execution in the main process (useful for debugging — parallel tracebacks are less readable). When *shard_index* and *num_shards* are both set, only the contiguous slice of *specs* assigned to this shard is executed. This enables SLURM Job Array distribution where each array task regenerates the full spec list but runs only its portion. Parameters ---------- specs : list[TrialSpec] Trial specifications to execute. max_workers : int or None Number of worker processes. ``None`` defaults to :func:`os.cpu_count`. label : str Short label printed in progress output (e.g. ``"scaling"``). shard_index : int or None 0-based index of this shard (requires *num_shards*). num_shards : int or None Total number of shards (requires *shard_index*). Returns ------- list[TrialResult] Results in the same order as the (possibly sharded) *specs*. """ if max_workers is None: max_workers = os.cpu_count() or 1 # --- Apply sharding if requested --- if shard_index is not None and num_shards is not None: from experiments.harness.sharding import shard_specs full_count = len(specs) specs = shard_specs(specs, shard_index, num_shards) label = f"{label} shard {shard_index + 1}/{num_shards}" if label else f"shard {shard_index + 1}/{num_shards}" print( f" Shard {shard_index + 1}/{num_shards}: {len(specs)} of {full_count} specs", flush=True, ) total = len(specs) if total == 0: return [] if max_workers <= 1: # Sequential — easier to debug results = [] for i, spec in enumerate(specs, 1): t = _run_trial_worker(spec) _print_trial_progress(t, i, total, label) results.append(t) return results # Parallel results: list[Optional[TrialResult]] = [None] * total completed = 0 with ProcessPoolExecutor(max_workers=max_workers) as pool: def _terminate_pool(signum, frame): """Kill all worker processes on SIGTERM/SIGINT.""" pool.shutdown(wait=False, cancel_futures=True) if hasattr(pool, "_processes"): for proc in pool._processes.values(): if proc.is_alive(): proc.kill() sys.exit(128 + signum) prev_term = signal.signal(signal.SIGTERM, _terminate_pool) prev_int = signal.signal(signal.SIGINT, _terminate_pool) try: future_to_idx = { pool.submit(_run_trial_worker, spec): idx for idx, spec in enumerate(specs) } for future in as_completed(future_to_idx): idx = future_to_idx[future] t = future.result() results[idx] = t completed += 1 _print_trial_progress(t, completed, total, label) finally: signal.signal(signal.SIGTERM, prev_term) signal.signal(signal.SIGINT, prev_int) return results # type: ignore[return-value]
def _print_trial_progress(t: TrialResult, completed: int, total: int, label: str): """Print a single-line progress update for a completed trial.""" status = ( "✓" if t.hypothesis_correct else ("accept_wrong" if t.accepted else "✗reject") ) prefix = f" [{label}] " if label else " " print( f"{prefix}{completed:4d}/{total}: n={t.n:2d} " f"|L|={t.list_size:3d} {status:12s} {t.total_time_s:.2f}s " f"({t.phi_description})", flush=True, )