binius_core/ring_switch/
common.rsuse std::sync::Arc;
use binius_field::{Field, TowerField};
use binius_utils::sparse_index::SparseIndex;
use super::error::Error;
use crate::{
oracle::{MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant},
piop::CommitMeta,
protocols::evalcheck::EvalcheckMultilinearClaim,
};
#[derive(Debug)]
pub struct EvalClaimPrefixDesc<F: Field> {
pub prefix: Vec<F>,
}
impl<F: Field> EvalClaimPrefixDesc<F> {
pub fn kappa(&self) -> usize {
self.prefix.len()
}
}
#[derive(Debug)]
pub struct EvalClaimSuffixDesc<F: Field> {
pub suffix: Arc<[F]>,
pub kappa: usize,
}
#[derive(Debug)]
pub struct PIOPSumcheckClaimDesc<'a, F: Field> {
pub committed_idx: usize,
pub suffix_desc_idx: usize,
pub eval_claim: &'a EvalcheckMultilinearClaim<F>,
}
#[derive(Debug)]
pub struct EvalClaimSystem<'a, F: Field> {
pub commit_meta: &'a CommitMeta,
pub prefix_descs: Vec<EvalClaimPrefixDesc<F>>,
pub suffix_descs: Vec<EvalClaimSuffixDesc<F>>,
pub sumcheck_claim_descs: Vec<PIOPSumcheckClaimDesc<'a, F>>,
pub eval_claim_to_prefix_desc_index: Vec<usize>,
}
impl<'a, F: TowerField> EvalClaimSystem<'a, F> {
pub fn new(
oracles: &MultilinearOracleSet<F>,
commit_meta: &'a CommitMeta,
oracle_to_commit_index: &SparseIndex<usize>,
eval_claims: &'a [EvalcheckMultilinearClaim<F>],
) -> Result<Self, Error> {
let mut eval_claims = eval_claims.iter().collect::<Vec<_>>();
eval_claims.sort_by_key(|claim| match oracles.oracle(claim.id) {
MultilinearPolyOracle {
n_vars,
tower_level,
variant: MultilinearPolyVariant::Committed,
..
} => n_vars + tower_level,
_ => 0,
});
let (
prefix_descs,
eval_claim_to_prefix_desc_index,
suffix_descs,
eval_claim_to_suffix_desc_index,
) = group_claims_by_eval_point(oracles, &eval_claims)?;
let sumcheck_claim_descs = eval_claims
.into_iter()
.enumerate()
.map(|(i, eval_claim)| {
let oracle = oracles.oracle(eval_claim.id);
if !matches!(oracle.variant, MultilinearPolyVariant::Committed) {
return Err(Error::EvalcheckClaimForDerivedPoly { id: eval_claim.id });
}
let committed_idx = oracle_to_commit_index
.get(oracle.id())
.copied()
.ok_or_else(|| Error::OracleToCommitIndexMissingEntry { id: eval_claim.id })?;
let suffix_desc_idx = eval_claim_to_suffix_desc_index[i];
Ok(PIOPSumcheckClaimDesc {
committed_idx,
suffix_desc_idx,
eval_claim,
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
commit_meta,
prefix_descs,
suffix_descs,
sumcheck_claim_descs,
eval_claim_to_prefix_desc_index,
})
}
pub fn max_claim_kappa(&self) -> usize {
self.prefix_descs
.iter()
.map(|desc| desc.kappa())
.max()
.unwrap_or(0)
}
}
#[allow(clippy::type_complexity)]
fn group_claims_by_eval_point<F: TowerField>(
oracles: &MultilinearOracleSet<F>,
claims: &[&EvalcheckMultilinearClaim<F>],
) -> Result<(Vec<EvalClaimPrefixDesc<F>>, Vec<usize>, Vec<EvalClaimSuffixDesc<F>>, Vec<usize>), Error>
{
let mut prefix_descs = Vec::<EvalClaimPrefixDesc<F>>::new();
let mut suffix_descs = Vec::<EvalClaimSuffixDesc<F>>::new();
let mut claim_to_prefix_index = Vec::with_capacity(claims.len());
let mut claim_to_suffix_index = Vec::with_capacity(claims.len());
for claim in claims {
let MultilinearPolyOracle {
id,
tower_level,
variant: MultilinearPolyVariant::Committed,
..
} = oracles.oracle(claim.id)
else {
return Err(Error::EvalcheckClaimForDerivedPoly { id: claim.id });
};
let kappa = F::TOWER_LEVEL.checked_sub(tower_level).ok_or_else(|| {
Error::OracleTowerLevelTooHigh {
id,
max: F::TOWER_LEVEL,
}
})?;
let (prefix, suffix) = claim.eval_point.split_at(kappa);
let prefix_id = prefix_descs
.iter()
.position(|desc| desc.prefix == prefix)
.unwrap_or_else(|| {
let index = prefix_descs.len();
prefix_descs.push(EvalClaimPrefixDesc {
prefix: prefix.to_vec(),
});
index
});
claim_to_prefix_index.push(prefix_id);
let suffix_id = suffix_descs
.iter()
.position(|desc| &*desc.suffix == suffix && desc.kappa == kappa)
.unwrap_or_else(|| {
let index = suffix_descs.len();
suffix_descs.push(EvalClaimSuffixDesc {
suffix: suffix.to_vec().into(),
kappa,
});
index
});
claim_to_suffix_index.push(suffix_id);
}
Ok((prefix_descs, claim_to_prefix_index, suffix_descs, claim_to_suffix_index))
}