1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright 2024 Ulvetanna Inc.

//! Batch proving and verification of the sumcheck protocol.
//!
//! The sumcheck protocol over can be batched over multiple instances by taking random linear
//! combinations over the claimed sums and polynomials. When the sumcheck instances are not all
//! over polynomials with the same number of variables, we can still batch them together, sharing
//! later round challenges. Importantly, the verifier samples mixing challenges "just-in-time".
//! That is, the verifier samples mixing challenges for new sumcheck claims over n variables only
//! after the last sumcheck round message has been sent by the prover.

use super::{error::Error, prove::SumcheckProversState, sumcheck::SumcheckReductor, SumcheckClaim};
use crate::{
	challenger::{CanObserve, CanSample},
	oracle::OracleId,
	protocols::{
		abstract_sumcheck::{
			self, finalize_evalcheck_claim, AbstractSumcheckBatchProof,
			AbstractSumcheckBatchProveOutput, AbstractSumcheckClaim, AbstractSumcheckWitness,
		},
		evalcheck::EvalcheckClaim,
	},
};
use binius_field::{ExtensionField, Field, PackedExtension};
use binius_hal::ComputationBackend;
use binius_math::EvaluationDomainFactory;

pub type SumcheckBatchProof<F> = AbstractSumcheckBatchProof<F>;

#[derive(Debug)]
pub struct SumcheckBatchProveOutput<F: Field> {
	pub evalcheck_claims: Vec<EvalcheckClaim<F>>,
	pub proof: SumcheckBatchProof<F>,
}

/// Prove a batched sumcheck instance.
///
/// See module documentation for details.
pub fn batch_prove<F, PW, DomainField, CH, Backend>(
	sumchecks: impl IntoIterator<
		Item = (SumcheckClaim<F>, impl AbstractSumcheckWitness<PW, MultilinearId = OracleId>),
	>,
	evaluation_domain_factory: impl EvaluationDomainFactory<DomainField>,
	switchover_fn: impl Fn(usize) -> usize + 'static,
	challenger: CH,
	backend: Backend,
) -> Result<SumcheckBatchProveOutput<F>, Error>
where
	F: Field,
	DomainField: Field,
	PW: PackedExtension<DomainField, Scalar: From<F> + Into<F> + ExtensionField<DomainField>>,
	CH: CanSample<F> + CanObserve<F>,
	Backend: ComputationBackend,
{
	let sumchecks = sumchecks.into_iter().collect::<Vec<_>>();
	let n_vars = sumchecks
		.iter()
		.map(|(claim, _)| claim.n_vars())
		.max()
		.unwrap_or(0);

	let mut provers_state = SumcheckProversState::<F, PW, DomainField, _, _, _>::new(
		n_vars,
		evaluation_domain_factory,
		switchover_fn,
		backend,
	);

	let oracles = sumchecks
		.iter()
		.map(|(claim, _)| claim.poly.clone())
		.collect::<Vec<_>>();

	let AbstractSumcheckBatchProveOutput {
		proof,
		reduced_claims,
	} = abstract_sumcheck::batch_prove(sumchecks, &mut provers_state, challenger)?;

	let evalcheck_claims = reduced_claims
		.into_iter()
		.zip(oracles)
		.map(|(rc, o)| finalize_evalcheck_claim(&o, rc))
		.collect::<Result<_, _>>()?;

	Ok(SumcheckBatchProveOutput {
		evalcheck_claims,
		proof,
	})
}

/// Verify a batched sumcheck instance.
///
/// See module documentation for details.
pub fn batch_verify<F, CH>(
	claims: impl IntoIterator<Item = SumcheckClaim<F>>,
	proof: SumcheckBatchProof<F>,
	challenger: CH,
) -> Result<Vec<EvalcheckClaim<F>>, Error>
where
	F: Field,
	CH: CanSample<F> + CanObserve<F>,
{
	let claims_vec = claims.into_iter().collect::<Vec<_>>();
	let max_individual_degree = claims_vec
		.iter()
		.map(|claim| claim.max_individual_degree())
		.max()
		.unwrap_or(0);

	let sumcheck_reductor = SumcheckReductor {
		max_individual_degree,
	};

	let reduced_claims =
		abstract_sumcheck::batch_verify(claims_vec.clone(), proof, sumcheck_reductor, challenger)?;

	let evalcheck_claims = reduced_claims
		.into_iter()
		.zip(claims_vec)
		.map(|(rc, c)| finalize_evalcheck_claim(&c.poly, rc))
		.collect::<Result<_, _>>()?;

	Ok(evalcheck_claims)
}