r"""
Honest Quantum Prover for Classical Verification of Quantum Learning.
Implements the prover side of the interactive verification protocol from
Caro et al. (ITCS 2024), Theorems 8, 10, 12, and 15.
**Protocol overview** (prover's role):
1. **Quantum Fourier Sampling** (Theorem 5): Given copies of the MoS
state :math:`\rho_D`, apply :math:`H^{\otimes(n+1)}`, measure,
post-select on the label qubit being 1. Each accepted shot yields
a sample :math:`s \in \{0,1\}^n` from the distribution
.. math::
\Pr[s \mid b{=}1]
= \frac{1 - \mathbb{E}[\tilde\phi_{\text{eff}}(x)^2]}{2^n}
+ \hat{\tilde\phi}_{\text{eff}}(s)^2
2. **Empirical spectrum approximation** (Corollary 5 via Lemma 3 / DKW):
From :math:`m` post-selected QFS samples, build the empirical
distribution :math:`\tilde{q}_m` over :math:`\{0,1\}^n`. By
DKW, :math:`m = O(\log(1/\delta)/\varepsilon^4)` samples suffice
for :math:`\|\tilde{q}_m - q\|_\infty \leq \varepsilon^2/8` with
probability :math:`\geq 1 - \delta/2`.
3. **Heavy coefficient extraction**: Identify the list
.. math::
L = \{s \in \{0,1\}^n : \tilde{q}_m(s,1) \geq \varepsilon^2/4\}
By the analysis in Corollary 5:
- If :math:`|\hat{\tilde\phi}(s)| \geq \varepsilon`, then :math:`s \in L`.
- If :math:`s \in L`, then :math:`|\hat{\tilde\phi}(s)| \geq \varepsilon/4`.
4. **Fourier coefficient estimation** (optional): For each :math:`s \in L`,
estimate :math:`\hat{\tilde\phi}(s)` from classical samples of
:math:`D` (obtained by computational-basis measurement of
:math:`\rho_D`, per Lemma 1).
5. **Send** :math:`L` (and optionally the estimates) to the verifier.
The prover is *honest*: it follows the protocol faithfully. Soundness
holds against *any* prover — the verifier's checks ensure correctness
regardless.
**Copy complexity**: The prover uses :math:`O(\log(1/\delta\varepsilon^2)/\varepsilon^4)`
copies of :math:`\rho_D` for QFS (Corollary 5), plus
:math:`O(\log(1/\delta\varepsilon^2)/\varepsilon^4)` copies for classical
estimation (via computational-basis measurement).
References
----------
- Caro et al., "Classical Verification of Quantum Learning", ITCS 2024.
§5.1 (Corollary 5), §6 (Theorems 7–15).
- Lemma 3 (DKW-based empirical approximation).
"""
from dataclasses import dataclass
from typing import Optional
import numpy as np
from numpy.random import Generator, default_rng
from mos import MoSState
from mos.sampler import QuantumFourierSampler, QFSResult
# ===================================================================
# Result containers
# ===================================================================
[docs]
@dataclass(frozen=True)
class SpectrumApproximation:
r"""
Succinct approximation to the Fourier spectrum (Corollary 5).
Attributes
----------
entries : dict[int, float]
Sparse representation: maps frequency index :math:`s` to
the estimated squared-coefficient-related quantity
:math:`\tilde{q}_m(s)`. Only entries above the extraction
threshold are stored.
threshold : float
The extraction threshold used (typically :math:`\varepsilon^2/4`).
n : int
Number of input bits.
num_qfs_samples : int
Number of post-selected QFS samples used to build the
empirical distribution.
total_qfs_shots : int
Total QFS shots consumed (before post-selection).
"""
entries: dict[int, float]
threshold: float
n: int
num_qfs_samples: int
total_qfs_shots: int
[docs]
@dataclass(frozen=True)
class ProverMessage:
r"""
The message sent from the honest prover to the classical verifier.
This implements the communication in Step 2 of the verification
protocols (Theorems 7–15): a list :math:`L` of candidate heavy
Fourier coefficient indices, optionally with estimated coefficient
values.
Attributes
----------
L : list[int]
List of frequency indices :math:`s` identified as having
non-negligible Fourier weight. Sorted by estimated weight
(descending).
estimates : dict[int, float]
For each :math:`s \in L`, an estimate of
:math:`\hat{\tilde\phi}(s)` obtained from classical samples.
Empty if ``estimate_coefficients=False`` was used.
n : int
Number of input bits.
epsilon : float
Accuracy parameter used by the prover.
theta : float
Fourier coefficient resolution threshold :math:`\vartheta`.
spectrum_approx : SpectrumApproximation
The intermediate Fourier spectrum approximation (for diagnostics).
qfs_result : QFSResult
Raw QFS result (for diagnostics / post-hoc analysis).
num_classical_samples : int
Number of classical samples used for coefficient estimation.
"""
L: list[int]
estimates: dict[int, float]
n: int
epsilon: float
theta: float
spectrum_approx: SpectrumApproximation
qfs_result: QFSResult
num_classical_samples: int
@property
def list_size(self) -> int:
"""Number of candidate heavy coefficients."""
return len(self.L)
@property
def total_copies_used(self) -> int:
"""Total MoS copies consumed (QFS + classical estimation)."""
return self.spectrum_approx.total_qfs_shots + self.num_classical_samples
[docs]
def summary(self) -> str:
"""Human-readable summary of the prover's message."""
lines = [
"Prover Message (§6 protocol)",
f" n = {self.n}",
f" epsilon = {self.epsilon:.4f}, theta = {self.theta:.4f}",
f" |L| = {self.list_size} (Parseval bound: {int(np.ceil(16 / self.theta**2))})",
f" QFS copies used: {self.spectrum_approx.total_qfs_shots}",
f" Post-selected QFS samples: {self.spectrum_approx.num_qfs_samples}",
f" Classical samples for estimation: {self.num_classical_samples}",
f" Total copies: {self.total_copies_used}",
]
if self.estimates:
lines.append(" Estimated coefficients:")
for s in self.L:
bits = format(s, f"0{self.n}b")
est = self.estimates.get(s, float("nan"))
lines.append(f" s={s} ({bits}): est={est:+.6f}")
else:
lines.append(" Coefficient estimates: not computed")
lines.append(f" L = {self.L}")
return "\n".join(lines)
# ===================================================================
# Prover
# ===================================================================
[docs]
class MoSProver:
r"""
Honest quantum prover for the verification protocol.
Follows the prover side of Theorems 8/10/12/15: uses MoS copies
for Quantum Fourier Sampling, builds a succinct Fourier spectrum
approximation, extracts heavy coefficients, and optionally
estimates their values from classical samples.
Parameters
----------
mos_state : MoSState
The MoS state to work with. The prover has quantum access
to copies of :math:`\rho_D`.
seed : int, optional
Random seed for reproducibility.
Notes
-----
The prover's computational complexity is dominated by QFS
(Section 5.1): :math:`O(n \cdot m)` single-qubit gates and
:math:`\tilde{O}(n \cdot m)` classical processing, where
:math:`m = O(\log(1/\delta)/\varepsilon^4)` is the number of
QFS copies.
"""
def __init__(
self,
mos_state: MoSState,
seed: Optional[int] = None,
noise_model: Optional[object] = None,
):
self.state = mos_state
self.n = mos_state.n
self._seed = seed
self._rng: Generator = default_rng(seed)
self._noise_model = noise_model
# ------------------------------------------------------------------
# Main protocol entry point
# ------------------------------------------------------------------
[docs]
def run_protocol(
self,
epsilon: float,
delta: float = 0.1,
theta: Optional[float] = None,
estimate_coefficients: bool = True,
qfs_mode: str = "statevector",
qfs_shots: Optional[int] = None,
classical_samples: Optional[int] = None,
) -> ProverMessage:
r"""
Execute the prover's side of the verification protocol.
**Step 1**: Perform QFS to obtain post-selected samples.
**Step 2**: Build the empirical spectrum approximation
(Corollary 5 / Lemma 3).
**Step 3**: Extract the heavy coefficient list :math:`L`.
**Step 4** (optional): Estimate the Fourier coefficients
for each :math:`s \in L` using classical samples.
**Step 5**: Package and return the message.
Parameters
----------
epsilon : float
Accuracy parameter :math:`\varepsilon \in (0, 1)`.
The prover resolves the Fourier spectrum to accuracy
:math:`\varepsilon` in :math:`\ell_\infty`-norm.
delta : float
Confidence parameter :math:`\delta \in (0, 1)`.
The protocol succeeds with probability
:math:`\geq 1 - \delta`.
theta : float, optional
Fourier coefficient resolution threshold
:math:`\vartheta`. If not provided, defaults to
``epsilon`` (appropriate for the functional agnostic
case per Theorem 8). For the distributional case
(Theorem 12), should be set according to the
promise on the distribution class.
estimate_coefficients : bool
If True (default), estimate :math:`\hat{\tilde\phi}(s)`
for each :math:`s \in L` using classical samples
(computational-basis measurement of :math:`\rho_D`).
This is needed for the verifier's Fourier weight check.
qfs_mode : str
QFS simulation mode (``"statevector"`` or ``"circuit"``).
qfs_shots : int, optional
Override the number of QFS shots. If not provided,
computed from the DKW bound (Lemma 3):
:math:`m = \lceil 2\log(4/\delta) / (\varepsilon^2/8)^2 \rceil`.
Note: since post-selection succeeds with probability 1/2,
we double this to get the expected number of accepted samples.
classical_samples : int, optional
Override the number of classical samples for coefficient
estimation. If not provided, computed from Hoeffding:
:math:`m_2 = O(|L| \cdot \log(|L|/\delta) / \varepsilon^2)`.
Returns
-------
ProverMessage
The prover's message to the verifier.
Raises
------
ValueError
If parameters are out of range.
"""
# --- Parameter validation ---
if not 0 < epsilon < 1:
raise ValueError(f"epsilon must be in (0, 1), got {epsilon}")
if not 0 < delta < 1:
raise ValueError(f"delta must be in (0, 1), got {delta}")
if theta is None:
theta = epsilon
if not 0 < theta < 1:
raise ValueError(f"theta must be in (0, 1), got {theta}")
# --- Step 1: Compute required QFS shots ---
# From Corollary 5: need m = O(log(1/delta) / tau^2) post-selected
# samples where tau = epsilon^2 / 8.
# DKW (Lemma 3): m = ceil(2 * log(2/delta_1) / tau^2) with delta_1 = delta/2
tau = theta**2 / 8.0
if qfs_shots is None:
m_postselected = int(np.ceil(2.0 * np.log(4.0 / delta) / tau**2))
# Post-selection succeeds ~1/2 the time, so need ~2m total shots.
# Add a safety margin for finite-sample fluctuation.
qfs_shots = int(np.ceil(2.5 * m_postselected))
# --- Step 2: Run QFS ---
sampler = QuantumFourierSampler(
self.state,
seed=int(self._rng.integers(0, 2**31)),
noise_model=self._noise_model,
)
qfs_result = sampler.sample(shots=qfs_shots, mode=qfs_mode)
# --- Step 3: Build empirical spectrum approximation ---
spectrum_approx = self._build_spectrum_approximation(
qfs_result=qfs_result,
theta=theta,
)
# --- Step 4: Extract heavy coefficient list ---
L = self._extract_heavy_list(
spectrum_approx=spectrum_approx,
theta=theta,
)
# --- Step 5: Estimate Fourier coefficients from classical samples ---
estimates: dict[int, float] = {}
num_classical = 0
if estimate_coefficients and len(L) > 0:
estimates, num_classical = self._estimate_coefficients(
L=L,
epsilon=epsilon,
delta=delta,
num_samples_override=classical_samples,
)
return ProverMessage(
L=L,
estimates=estimates,
n=self.n,
epsilon=epsilon,
theta=theta,
spectrum_approx=spectrum_approx,
qfs_result=qfs_result,
num_classical_samples=num_classical,
)
# ------------------------------------------------------------------
# Step 3: Empirical spectrum approximation (Corollary 5 / Lemma 3)
# ------------------------------------------------------------------
[docs]
def _build_spectrum_approximation(
self,
qfs_result: QFSResult,
theta: float,
) -> SpectrumApproximation:
r"""
Build a succinct empirical approximation to the QFS distribution.
From the post-selected QFS samples, compute the empirical
distribution :math:`\tilde{q}_m(s, 1)` for each observed
frequency :math:`s`.
By Lemma 3 (DKW), with :math:`m` post-selected samples:
.. math::
\|\tilde{q}_m - q\|_\infty \leq \tau
with probability :math:`\geq 1 - 2\exp(-m\tau^2/2)`.
The empirical distribution :math:`\tilde{q}_m(s, 1)` relates
to the Fourier coefficients via:
.. math::
\tilde{q}_m(s, 1) \approx q(s, 1)
= \frac{1}{2}\Bigl[
\frac{1 - \mathbb{E}[\tilde\phi^2]}{2^n}
+ \hat{\tilde\phi}(s)^2
\Bigr]
We store the (sparse) empirical distribution and use it to
identify heavy coefficients.
Parameters
----------
qfs_result : QFSResult
Output from the QFS procedure.
theta : float
Resolution threshold.
Returns
-------
SpectrumApproximation
"""
n = self.n
# Build empirical distribution from post-selected counts
# qfs_result.postselected_counts maps n-bit strings to counts
ps_total = qfs_result.postselected_shots
# Compute extraction threshold: epsilon^2 / 4
# In terms of q(s, 1) = (1/2) * Pr[s | b=1], the threshold
# for the full (n+1)-bit distribution is epsilon^2 / 8.
# But since we're working with the conditional distribution
# Pr[s | b=1] directly (post-selected), the threshold on the
# conditional distribution is epsilon^2 / 4.
# (See Corollary 5 proof: if q~_m(s,1) >= eps^2/4 then s in L)
extraction_threshold = theta**2 / 4.0
entries: dict[int, float] = {}
if ps_total > 0:
for bitstring, count in qfs_result.postselected_counts.items():
s = int(bitstring, 2)
empirical_prob = count / ps_total
if empirical_prob >= extraction_threshold:
entries[s] = empirical_prob
return SpectrumApproximation(
entries=entries,
threshold=extraction_threshold,
n=n,
num_qfs_samples=ps_total,
total_qfs_shots=qfs_result.total_shots,
)
# ------------------------------------------------------------------
# Step 4: Extract heavy coefficient list
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Step 5: Classical coefficient estimation
# ------------------------------------------------------------------
[docs]
def _estimate_coefficients(
self,
L: list[int],
epsilon: float,
delta: float,
num_samples_override: Optional[int] = None,
) -> tuple[dict[int, float], int]:
r"""
Estimate Fourier coefficients for each :math:`s \in L`
from classical random examples.
By Lemma 1, computational-basis measurement of :math:`\rho_D`
yields classical samples from :math:`D`. For each :math:`s`,
.. math::
\hat{\tilde\phi}(s) = \mathbb{E}_{(x,y) \sim D}
[(1 - 2y)(-1)^{s \cdot x}]
so we estimate this expectation via sample mean. By Hoeffding,
:math:`m_2 = O(\log(|L|/\delta) / \varepsilon^2)` samples
suffice for simultaneous :math:`\varepsilon`-accuracy across
all :math:`s \in L`.
Parameters
----------
L : list[int]
Frequency indices to estimate.
epsilon : float
Desired accuracy per coefficient.
delta : float
Overall confidence parameter.
num_samples_override : int, optional
Override the computed sample count.
Returns
-------
estimates : dict[int, float]
``estimates[s]`` is the empirical estimate of
:math:`\hat{\tilde\phi}(s)` for each :math:`s \in L`.
num_samples : int
Number of classical samples used.
"""
L_size = len(L)
if L_size == 0:
return {}, 0
# Hoeffding bound: for each s, need m2 samples for eps-accuracy
# with failure probability delta / (2 * |L|).
# Since |(1-2y)(-1)^{s.x}| <= 1, the range is 2.
# Hoeffding: P[|mean - E| > eps] <= 2*exp(-2*m2*eps^2 / 4)
# = 2*exp(-m2*eps^2/2)
# Want this <= delta / (2*|L|), so:
# m2 >= (2 / eps^2) * log(4*|L| / delta)
if num_samples_override is not None:
num_samples = num_samples_override
else:
num_samples = int(np.ceil(2.0 / epsilon**2 * np.log(4.0 * L_size / delta)))
# Minimum sensible sample count
num_samples = max(num_samples, 100)
# Draw classical samples via computational-basis measurement
xs, ys = self.state.sample_classical_batch(
num_samples=num_samples,
rng=self._rng,
)
# Compute (1 - 2y) for all samples
signed_labels = 1.0 - 2.0 * ys.astype(np.float64) # shape (m2,)
# For each s in L, compute the empirical mean of (1-2y)*chi_s(x)
estimates: dict[int, float] = {}
for s in L:
# chi_s(x) = (-1)^{popcount(s & x)}
parities = np.array(
[bin(s & int(x)).count("1") % 2 for x in xs],
dtype=np.float64,
)
chi_s = 1.0 - 2.0 * parities
est = float(np.mean(signed_labels * chi_s))
# Project to [-1, 1] for safety
est = np.clip(est, -1.0, 1.0)
estimates[s] = est
return estimates, num_samples
# ------------------------------------------------------------------
# Convenience: direct access to exact quantities (for validation)
# ------------------------------------------------------------------
[docs]
def exact_heavy_coefficients(
self,
theta: float,
*,
effective: bool = True,
) -> list[tuple[int, float]]:
r"""
Return the exact list of heavy Fourier coefficients
(for validation / comparison with the empirical list).
Parameters
----------
theta : float
Threshold: returns all :math:`s` with
:math:`|\hat{\tilde\phi}(s)| \geq \vartheta`.
effective : bool
If True, use noise-adjusted coefficients.
Returns
-------
heavy : list[tuple[int, float]]
Pairs :math:`(s, \hat{\tilde\phi}(s))`, sorted by
absolute value descending.
"""
spectrum = self.state.fourier_spectrum(effective=effective)
heavy = [
(s, float(spectrum[s]))
for s in range(self.state.dim_x)
if abs(spectrum[s]) >= theta
]
heavy.sort(key=lambda t: abs(t[1]), reverse=True)
return heavy