Source code for mos.sampler

r"""
Quantum Fourier Sampling from MoS states — Theorem 5 of Caro et al.

Implements the QFS procedure: given a copy of the MoS state
:math:`\rho_D`, apply :math:`H^{\otimes(n+1)}`, measure all qubits in
the computational basis, and post-select on the label qubit being 1.

**Theorem 5** (Distributional agnostic approximate quantum Fourier sampling).
Conditioned on observing outcome 1 for the last qubit (which occurs with
probability 1/2), the first :math:`n` qubits output :math:`s \in \{0,1\}^n`
with probability

.. math::

    \Pr[s \mid b{=}1]
    = \frac{1}{2^n}
      \bigl(1 - \mathbb{E}_{x \sim U_n}[(\tilde\phi_{\text{eff}}(x))^2]\bigr)
    + \bigl(\hat{\tilde\phi}_{\text{eff}}(s)\bigr)^2

This module provides two simulation strategies:

- **statevector** (default): For each sampled :math:`f \sim F_D`, constructs
  :math:`|\psi_f\rangle` as a Statevector, applies :math:`H^{\otimes(n+1)}`,
  and samples from the resulting probability distribution.  Exact (no
  shot noise beyond finite sampling).  Cost: :math:`O(2^n)` per copy.
  Practical for :math:`n \leq 20`.

- **circuit**: Builds a Qiskit circuit (Hadamard layer + oracle) for each
  :math:`f` and executes it via ``StatevectorSampler``.  Produces identical
  distributions to statevector mode but uses the Qiskit primitives pipeline,
  validating circuit construction.  Practical for :math:`n \leq 12` due
  to multi-controlled gate overhead.

Both modes return raw measurement counts (before post-selection) so that
the caller can verify the 1/2 label-qubit marginal and inspect rejection
rates.

References:

- Caro et al., "Classical Verification of Quantum Learning", ITCS 2024.
  Theorem 5 (Section 5.1), Lemma 2 (Section 4.1), Corollary 5 (Section 5.1).
- Bernstein & Vazirani (1997) for the original QFS idea.
"""

import warnings
from dataclasses import dataclass
from typing import Optional

import numpy as np
from numpy.random import Generator, default_rng

from qiskit import QuantumCircuit
from qiskit.primitives import StatevectorSampler

# ---------------------------------------------------------------------------
# Import MoSState — assumes mos_state.py is importable
# ---------------------------------------------------------------------------
from mos import MoSState


# ===================================================================
# Result container
# ===================================================================


[docs] @dataclass(frozen=True) class QFSResult: r""" Container for raw and post-selected QFS measurement outcomes. Attributes ---------- raw_counts : dict[str, int] Full :math:`(n+1)`-bit measurement counts (bitstrings in Qiskit little-endian convention: rightmost character = qubit 0). postselected_counts : dict[str, int] :math:`n`-bit frequency counts for the input register, conditioned on the label qubit being 1. total_shots : int Number of MoS copies consumed (= circuits executed). postselected_shots : int Number of shots surviving post-selection (label qubit = 1). n : int Number of input qubits. mode : str Simulation mode used (``"statevector"`` or ``"circuit"``). """ raw_counts: dict[str, int] postselected_counts: dict[str, int] total_shots: int postselected_shots: int n: int mode: str # ---- derived quantities ---- @property def postselection_rate(self) -> float: r""" Fraction of shots surviving post-selection. By Theorem 5(i) this should concentrate around 1/2. """ if self.total_shots == 0: return 0.0 return self.postselected_shots / self.total_shots
[docs] def empirical_distribution(self) -> np.ndarray: r""" Normalised empirical distribution over :math:`\{0,1\}^n` from the post-selected counts. Returns ------- dist : np.ndarray of shape :math:`(2^n,)` ``dist[s]`` is the empirical probability of frequency s. Zero everywhere if no shots survived post-selection. """ dim = 2**self.n dist = np.zeros(dim, dtype=np.float64) if self.postselected_shots == 0: return dist for bitstring, count in self.postselected_counts.items(): s = int(bitstring, 2) dist[s] += count dist /= self.postselected_shots return dist
# =================================================================== # Main class # ===================================================================
[docs] class QuantumFourierSampler: r""" Approximate quantum Fourier sampling from MoS states (Theorem 5). Consumes copies of :math:`\rho_D`, applies :math:`H^{\otimes(n+1)}`, measures all :math:`n+1` qubits in the computational basis, and post-selects on the label qubit (qubit :math:`n`) being 1. **Protocol** (one copy): 1. Sample :math:`f \sim F_D` using :math:`\phi_{\text{eff}}`. 2. Prepare :math:`|\psi_{U_n,f}\rangle = 2^{-n/2}\sum_x |x,f(x)\rangle`. 3. Apply :math:`H^{\otimes(n+1)}`. 4. Measure all qubits → outcome :math:`(s, b) \in \{0,1\}^n \times \{0,1\}`. 5. If :math:`b = 1`, record :math:`s`. By Theorem 5, the conditional distribution is .. math:: \Pr[s \mid b{=}1] = \frac{1 - \mathbb{E}_x[\tilde\phi_{\text{eff}}(x)^2]}{2^n} + \hat{\tilde\phi}_{\text{eff}}(s)^2 Parameters ---------- mos_state : MoSState The MoS state to sample from. Defines :math:`n`, :math:`\phi`, and the noise model. seed : int, optional Random seed for reproducibility. """ # valid mode names _MODES = {"statevector", "circuit"} def __init__( self, mos_state: MoSState, seed: Optional[int] = None, noise_model: Optional[object] = None, ): self.state = mos_state self.n = mos_state.n self._seed = seed self._rng: Generator = default_rng(seed) self._noise_model = noise_model # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def sample( self, shots: int, mode: str = "statevector", ) -> QFSResult: r""" Execute the QFS protocol and return raw + post-selected counts. Each shot consumes one independent copy of :math:`\rho_D`. The label qubit marginal should be close to 1/2 (Theorem 5(i)); any significant deviation indicates a bug. Parameters ---------- shots : int Number of MoS copies to consume (:math:`\geq 1`). mode : str Simulation strategy: - ``"statevector"`` — direct Statevector computation per copy. - ``"circuit"`` — Qiskit circuit + ``StatevectorSampler`` per copy. Returns ------- QFSResult Raw and post-selected measurement counts. Raises ------ ValueError If *shots* < 1 or *mode* is unrecognised. """ if shots < 1: raise ValueError(f"shots must be >= 1, got {shots}") if mode not in self._MODES: raise ValueError( f"Unknown mode {mode!r}; expected one of {sorted(self._MODES)}" ) dispatch = { "statevector": self._sample_statevector, "circuit": self._sample_circuit, } raw_counts = dispatch[mode](shots) ps_counts, ps_shots = self._postselect(raw_counts) return QFSResult( raw_counts=raw_counts, postselected_counts=ps_counts, total_shots=shots, postselected_shots=ps_shots, n=self.n, mode=mode, )
# ------------------------------------------------------------------ # Theoretical reference quantities (delegated to MoSState) # ------------------------------------------------------------------
[docs] def theoretical_distribution(self) -> np.ndarray: r""" Exact :math:`\Pr[s \mid b{=}1]` from Theorem 5. .. math:: \Pr[s \mid b{=}1] = \frac{1 - \mathbb{E}_x[\tilde\phi_{\text{eff}}(x)^2]}{2^n} + \hat{\tilde\phi}_{\text{eff}}(s)^2 Always uses the effective (noise-adjusted) spectrum, since this is what the physical QFS circuit produces. Returns ------- dist : np.ndarray of shape :math:`(2^n,)` """ return self.state.qfs_distribution()
[docs] def fourier_coefficient( self, s: int, effective: bool = True, ) -> float: r""" Exact Fourier coefficient for validation. Parameters ---------- s : int Frequency index in :math:`\{0, \ldots, 2^n - 1\}`. effective : bool If True (default), return :math:`\hat{\tilde\phi}_{\text{eff}}(s) = (1-2\eta)\hat{\tilde\phi}(s)`. If False, return the noiseless :math:`\hat{\tilde\phi}(s)`. Returns ------- float """ return self.state.fourier_coefficient(s, effective=effective)
# ------------------------------------------------------------------ # Private: simulation backends # ------------------------------------------------------------------
[docs] def _sample_statevector(self, shots: int) -> dict[str, int]: r""" Statevector mode: for each copy, build :math:`H^{\otimes(n+1)}|\psi_f\rangle` and sample. For every sampled :math:`f`, the Hadamard-transformed statevector has closed-form amplitudes (proof of Lemma 2): .. math:: H^{\otimes(n+1)}|\psi_f\rangle = \frac{1}{\sqrt{2}}|0\rangle^{\otimes(n+1)} + \frac{1}{\sqrt{2}}\sum_s \hat{g}_f(s)\,|s,1\rangle where :math:`g_f = (-1)^f`. Rather than recomputing this symbolically, we apply the Hadamard via :meth:`Statevector.evolve` and use :meth:`Statevector.sample_counts`. """ n = self.n dim_total = self.state.dim_total counts: dict[str, int] = {} # Build the (n+1)-qubit Hadamard circuit once h_circuit = QuantumCircuit(n + 1, name="H_all") for q in range(n + 1): h_circuit.h(q) for _ in range(shots): f = self.state.sample_f(rng=self._rng) psi_f = self.state.statevector_f(f) psi_h = psi_f.evolve(h_circuit) # Draw one measurement outcome using the internal RNG probs = psi_h.probabilities() idx = self._rng.choice(dim_total, p=probs) bitstring = format(idx, f"0{n + 1}b") counts[bitstring] = counts.get(bitstring, 0) + 1 return counts
[docs] def _sample_circuit(self, shots: int) -> dict[str, int]: r""" Circuit mode: build a full Qiskit circuit per copy and execute via ``StatevectorSampler``. Each circuit is: .. math:: |0\rangle^{\otimes(n+1)} \;\xrightarrow{H^{\otimes n}\otimes I}\; |{+}\rangle^{\otimes n}|0\rangle \;\xrightarrow{U_f}\; |\psi_f\rangle \;\xrightarrow{H^{\otimes(n+1)}}\; \text{measure} This validates the circuit-construction pipeline (oracle gates, Hadamard layer, measurement) against the statevector mode. Practical for :math:`n \leq 12` due to multi-controlled gate overhead in :meth:`MoSState._circuit_oracle_f`. """ if self.n > 12: warnings.warn( f"Circuit mode with n={self.n} will be very slow " f"(up to 2^n multi-controlled gates per copy).", stacklevel=3, ) n = self.n counts: dict[str, int] = {} # Build all circuits up front circuits = [] for _ in range(shots): f = self.state.sample_f(rng=self._rng) qc = self.state.circuit_prepare_f(f) for q in range(n + 1): qc.h(q) qc.measure_all() circuits.append(qc) if self._noise_model is not None: # Gate-level noise: use AerSimulator with the noise model. # The circuit is transpiled so that MCX gates decompose # into basis gates (CX, H, X, etc.) to which the noise # model's depolarising channels apply. from qiskit_aer import AerSimulator from qiskit import transpile backend = AerSimulator(noise_model=self._noise_model) for qc in circuits: child_seed = int(self._rng.integers(0, 2**31)) qc_t = transpile(qc, backend) result = backend.run( qc_t, shots=1, seed_simulator=child_seed ).result() for bitstring, cnt in result.get_counts().items(): # AerSimulator may include spaces in bitstrings; # strip them for consistency. bs = bitstring.replace(" ", "") counts[bs] = counts.get(bs, 0) + cnt else: # Batch all circuits in a single sampler.run() call. # Pass a Generator *object* (not an int) so the sampler # holds a mutable reference whose state advances between # PUBs — otherwise identical circuits get identical draws # (Qiskit bug #13047). child_rng = default_rng(int(self._rng.integers(0, 2**31))) sampler = StatevectorSampler(seed=child_rng) job = sampler.run(circuits, shots=1) for pub_result in job.result(): for bitstring, cnt in pub_result.data.meas.get_counts().items(): counts[bitstring] = counts.get(bitstring, 0) + cnt return counts
# ------------------------------------------------------------------ # Private: post-selection # ------------------------------------------------------------------
[docs] def _postselect( self, raw_counts: dict[str, int], ) -> tuple[dict[str, int], int]: r""" Extract the :math:`n`-bit frequency distribution conditioned on the label qubit :math:`b = 1`. In the raw bitstrings (Qiskit convention: highest qubit index on the left), the label qubit (qubit :math:`n`) is the **leftmost** character. We keep only bitstrings where this character is ``'1'``, then strip it to obtain the :math:`n`-bit frequency string. Parameters ---------- raw_counts : dict[str, int] Full :math:`(n+1)`-bit counts from measurement. Returns ------- ps_counts : dict[str, int] :math:`n`-bit post-selected counts. ps_total : int Total shots surviving post-selection. """ ps_counts: dict[str, int] = {} ps_total = 0 for bitstring, count in raw_counts.items(): # bitstring[0] = highest qubit = qubit n = label qubit if bitstring[0] == "1": s_bits = bitstring[1:] # remaining n bits ps_counts[s_bits] = ps_counts.get(s_bits, 0) + count ps_total += count return ps_counts, ps_total