1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright 2024 Ulvetanna Inc.

use crate::{
	polynomial::{CompositionPoly, MultilinearPoly},
	protocols::abstract_sumcheck::{
		self, AbstractSumcheckBatchProof, AbstractSumcheckBatchProveOutput, AbstractSumcheckClaim,
		ReducedClaim,
	},
};
use binius_field::{ExtensionField, Field, PackedExtension};
use binius_hal::ComputationBackend;
use binius_math::EvaluationDomainFactory;
use binius_utils::bail;
use p3_challenger::{CanObserve, CanSample};
use tracing::instrument;

use super::{
	gkr_sumcheck::{GkrSumcheckClaim, GkrSumcheckReductor, GkrSumcheckWitness},
	prove::GkrSumcheckProversState,
	Error,
};

pub type GkrSumcheckBatchProof<F> = AbstractSumcheckBatchProof<F>;
pub type GkrSumcheckBatchProveOutput<F> = AbstractSumcheckBatchProveOutput<F>;

/// Prove a batched GkrSumcheck instance.
///
/// See module documentation for details.
#[instrument(skip_all, name = "gkr_sumcheck::batch_prove", level = "debug")]
pub fn batch_prove<F, PW, DomainField, CW, M, CH, Backend>(
	gkr_sumchecks: impl IntoIterator<Item = (GkrSumcheckClaim<F>, GkrSumcheckWitness<PW, CW, M>)>,
	evaluation_domain_factory: impl EvaluationDomainFactory<DomainField>,
	switchover_fn: impl Fn(usize) -> usize + 'static,
	challenger: CH,
	backend: Backend,
) -> Result<GkrSumcheckBatchProveOutput<F>, Error>
where
	F: Field,
	DomainField: Field,
	PW: PackedExtension<DomainField, Scalar: From<F> + Into<F> + ExtensionField<DomainField>>,
	CW: CompositionPoly<PW>,
	M: MultilinearPoly<PW> + Clone + Send + Sync,
	CH: CanObserve<F> + CanSample<F>,
	Backend: ComputationBackend,
{
	let gkr_sumchecks = gkr_sumchecks.into_iter().collect::<Vec<_>>();
	let n_vars = gkr_sumchecks
		.iter()
		.map(|(claim, _)| claim.n_vars())
		.max()
		.unwrap_or(0);

	let gkr_round_challenge = gkr_sumchecks
		.first()
		.map(|(claim, _)| claim.r.clone())
		.ok_or(Error::EmptyClaimsArray)?;

	let mut provers_state = GkrSumcheckProversState::<F, PW, DomainField, _, _, _, _>::new(
		n_vars,
		evaluation_domain_factory,
		gkr_round_challenge.as_slice(),
		switchover_fn,
		backend,
	)?;

	abstract_sumcheck::batch_prove(gkr_sumchecks, &mut provers_state, challenger)
}

/// Verify a batched GkrSumcheck instance.
///
/// See module documentation for details.
#[instrument(skip_all, name = "gkr_sumcheck::batch_verify", level = "debug")]
pub fn batch_verify<F, CH>(
	claims: impl IntoIterator<Item = GkrSumcheckClaim<F>>,
	proof: GkrSumcheckBatchProof<F>,
	challenger: CH,
) -> Result<Vec<ReducedClaim<F>>, Error>
where
	F: Field,
	CH: CanSample<F> + CanObserve<F>,
{
	let claims_vec = claims.into_iter().collect::<Vec<_>>();
	if claims_vec.is_empty() {
		bail!(Error::EmptyClaimsArray);
	}

	let gkr_challenge_point = claims_vec[0].r.clone();

	// Ensure all claims have the same gkr_challenge
	if !claims_vec
		.iter()
		.all(|claim| claim.r == gkr_challenge_point)
	{
		bail!(Error::MismatchedGkrChallengeInClaimsBatch);
	}

	let max_individual_degree = claims_vec
		.iter()
		.map(|claim| claim.max_individual_degree())
		.max()
		.unwrap_or(0);

	let reductor = GkrSumcheckReductor {
		max_individual_degree,
		gkr_challenge_point: &gkr_challenge_point,
	};

	abstract_sumcheck::batch_verify(claims_vec, proof, reductor, challenger)
}