use super::{Error, VerificationError};
use crate::{
oracle::{CompositePolyOracle, OracleId},
polynomial::{CompositionPoly, MultilinearComposite, MultilinearPoly},
protocols::evalcheck::EvalcheckClaim,
};
use auto_impl::auto_impl;
use binius_field::{Field, PackedField};
use binius_math::EvaluationDomain;
use binius_utils::bail;
use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct AbstractSumcheckRound<F> {
pub coeffs: Vec<F>,
}
#[derive(Debug, Clone)]
pub struct AbstractSumcheckProof<F> {
pub rounds: Vec<AbstractSumcheckRound<F>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AbstractSumcheckRoundClaim<F: Field> {
pub partial_point: Vec<F>,
pub current_round_sum: F,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ReducedClaim<F: Field> {
pub eval_point: Vec<F>,
pub eval: F,
}
impl<F: Field> From<AbstractSumcheckRoundClaim<F>> for ReducedClaim<F> {
fn from(claim: AbstractSumcheckRoundClaim<F>) -> Self {
Self {
eval_point: claim.partial_point,
eval: claim.current_round_sum,
}
}
}
pub trait AbstractSumcheckReductor<F: Field> {
type Error: std::error::Error + From<Error>;
fn validate_round_proof_shape(
&self,
round: usize,
proof: &AbstractSumcheckRound<F>,
) -> Result<(), Self::Error>;
fn reduce_round_claim(
&self,
round: usize,
claim: AbstractSumcheckRoundClaim<F>,
challenge: F,
round_proof: AbstractSumcheckRound<F>,
) -> Result<AbstractSumcheckRoundClaim<F>, Self::Error>;
}
pub trait AbstractSumcheckClaim<F: Field> {
fn n_vars(&self) -> usize;
fn max_individual_degree(&self) -> usize;
fn sum(&self) -> F;
}
#[auto_impl(&)]
pub trait AbstractSumcheckWitness<PW: PackedField> {
type MultilinearId: Clone + Hash + Eq + Sync;
type Composition: CompositionPoly<PW>;
type Multilinear: MultilinearPoly<PW> + Send + Sync;
fn composition(&self) -> &Self::Composition;
fn multilinears(
&self,
seq_id: usize,
claim_multilinear_ids: &[Self::MultilinearId],
) -> Result<impl IntoIterator<Item = (Self::MultilinearId, Self::Multilinear)>, Error>;
}
pub trait AbstractSumcheckProversState<F: Field> {
type Error: std::error::Error + From<Error>;
type PackedWitnessField: PackedField<Scalar: From<F> + Into<F>>;
type Claim: AbstractSumcheckClaim<F>;
type Witness: AbstractSumcheckWitness<Self::PackedWitnessField>;
type Prover;
fn pre_execute_rounds(&mut self, prev_rd_challenge: Option<F>) -> Result<(), Self::Error>;
fn new_prover(
&mut self,
claim: Self::Claim,
witness: Self::Witness,
seq_id: usize,
) -> Result<Self::Prover, Self::Error>;
fn prover_execute_round(
&self,
prover: &mut Self::Prover,
prev_rd_challenge: Option<F>,
) -> Result<AbstractSumcheckRound<F>, Self::Error>;
fn prover_finalize(
prover: Self::Prover,
prev_rd_challenge: Option<F>,
) -> Result<ReducedClaim<F>, Self::Error>;
}
impl<P, C, M> AbstractSumcheckWitness<P> for MultilinearComposite<P, C, M>
where
P: PackedField,
C: CompositionPoly<P>,
M: MultilinearPoly<P> + Clone + Send + Sync,
{
type MultilinearId = OracleId;
type Composition = C;
type Multilinear = M;
fn composition(&self) -> &C {
&self.composition
}
fn multilinears(
&self,
_seq_id: usize,
claim_multilinear_ids: &[OracleId],
) -> Result<impl IntoIterator<Item = (OracleId, M)>, Error> {
if claim_multilinear_ids.len() != self.multilinears.len() {
bail!(Error::ProverClaimWitnessMismatch);
}
Ok(claim_multilinear_ids
.iter()
.copied()
.zip(self.multilinears.iter().cloned()))
}
}
pub fn check_evaluation_domain<F: Field>(
max_individual_degree: usize,
domain: &EvaluationDomain<F>,
) -> Result<(), Error> {
if max_individual_degree == 0
|| domain.size() != max_individual_degree + 1
|| domain.points()[0] != F::ZERO
|| domain.points()[1] != F::ONE
{
bail!(Error::EvaluationDomainMismatch);
}
Ok(())
}
pub fn validate_rd_challenge<F: Field>(
prev_rd_challenge: Option<F>,
round: usize,
) -> Result<(), Error> {
if round == 0 && prev_rd_challenge.is_some() {
bail!(Error::PreviousRoundChallengePresent);
} else if round > 0 && prev_rd_challenge.is_none() {
bail!(Error::PreviousRoundChallengeAbsent);
}
Ok(())
}
pub fn finalize_evalcheck_claim<F: Field>(
poly_oracle: &CompositePolyOracle<F>,
reduced_claim: ReducedClaim<F>,
) -> Result<EvalcheckClaim<F>, Error> {
let ReducedClaim { eval_point, eval } = reduced_claim;
if eval_point.len() != poly_oracle.n_vars() {
return Err(VerificationError::NumberOfRounds.into());
}
let evalcheck_claim = EvalcheckClaim {
poly: poly_oracle.clone(),
eval_point,
eval,
is_random_point: true,
};
Ok(evalcheck_claim)
}
pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
move |extension_degree: usize| {
let switchover_round = extension_degree.ilog2() as isize + k;
switchover_round.max(1) as usize
}
}