1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright 2024 Ulvetanna Inc.

use super::{
	common::{GreedyEvalcheckProof, GreedyEvalcheckProveOutput},
	error::Error,
};
use crate::{
	challenger::{CanObserve, CanSample},
	oracle::MultilinearOracleSet,
	protocols::{
		evalcheck::{EvalcheckClaim, EvalcheckProver},
		test_utils::{
			make_non_same_query_pcs_sumchecks, prove_bivariate_sumchecks_with_switchover,
		},
	},
	witness::MultilinearExtensionIndex,
};
use binius_field::{
	as_packed_field::PackScalar, underlier::WithUnderlier, ExtensionField, PackedExtension,
	PackedFieldIndexable, TowerField,
};
use binius_hal::ComputationBackend;
use binius_math::EvaluationDomainFactory;

pub fn prove<F, PW, DomainField, Challenger, Backend>(
	oracles: &mut MultilinearOracleSet<F>,
	witness_index: &mut MultilinearExtensionIndex<PW::Underlier, PW::Scalar>,
	claims: impl IntoIterator<Item = EvalcheckClaim<F>>,
	switchover_fn: impl Fn(usize) -> usize + Clone + 'static,
	challenger: &mut Challenger,
	domain_factory: impl EvaluationDomainFactory<DomainField>,
	backend: Backend,
) -> Result<GreedyEvalcheckProveOutput<F>, Error>
where
	F: TowerField + From<PW::Scalar>,
	PW: PackedFieldIndexable + PackedExtension<DomainField> + WithUnderlier,
	PW::Scalar: TowerField + From<F> + ExtensionField<DomainField>,
	PW::Underlier: PackScalar<PW::Scalar, Packed = PW>,
	DomainField: TowerField,
	Challenger: CanObserve<F> + CanSample<F>,
	Backend: ComputationBackend,
{
	let committed_batches = oracles.committed_batches();
	let mut proof = GreedyEvalcheckProof::default();
	let mut evalcheck_prover =
		EvalcheckProver::<F, PW, _>::new(oracles, witness_index, backend.clone());

	// Prove the initial evalcheck claims
	proof.initial_evalcheck_proofs = claims
		.into_iter()
		.map(|claim| evalcheck_prover.prove(claim))
		.collect::<Result<Vec<_>, _>>()?;

	loop {
		let new_sumchecks = evalcheck_prover.take_new_sumchecks();
		if new_sumchecks.is_empty() {
			break;
		}

		// Reduce the new sumcheck claims for virtual polynomial openings to new evalcheck claims.
		let (batch_sumcheck_proof, new_evalcheck_claims) =
			prove_bivariate_sumchecks_with_switchover::<_, _, DomainField, _, _>(
				new_sumchecks,
				challenger,
				switchover_fn.clone(),
				domain_factory.clone(),
				backend.clone(),
			)?;

		let new_evalcheck_proofs = new_evalcheck_claims
			.into_iter()
			.map(|claim| evalcheck_prover.prove(claim))
			.collect::<Result<Vec<_>, _>>()?;

		proof
			.virtual_opening_proofs
			.push((batch_sumcheck_proof, new_evalcheck_proofs));
	}

	// Now all remaining evalcheck claims are for committed polynomials.
	// Batch together all committed polynomial evaluation claims to one point per batch.
	let same_query_claims = committed_batches
		.into_iter()
		.map(|batch| {
			let maybe_same_query_claim = evalcheck_prover
				.batch_committed_eval_claims()
				.try_extract_same_query_pcs_claim(batch.id)?;
			let same_query_claim = if let Some(same_query_claim) = maybe_same_query_claim {
				proof.batch_opening_proof.push(None);
				same_query_claim
			} else {
				let non_sqpcs_claims = evalcheck_prover
					.batch_committed_eval_claims_mut()
					.take_claims(batch.id)?;

				let non_sqpcs_sumchecks = make_non_same_query_pcs_sumchecks(
					&mut evalcheck_prover,
					&non_sqpcs_claims,
					backend.clone(),
				)?;

				let (sumcheck_proof, new_evalcheck_claims) =
					prove_bivariate_sumchecks_with_switchover::<_, _, DomainField, _, _>(
						non_sqpcs_sumchecks,
						challenger,
						switchover_fn.clone(),
						domain_factory.clone(),
						backend.clone(),
					)?;

				let new_evalcheck_proofs = new_evalcheck_claims
					.into_iter()
					.map(|claim| evalcheck_prover.prove(claim))
					.collect::<Result<Vec<_>, _>>()?;

				proof
					.batch_opening_proof
					.push(Some((sumcheck_proof, new_evalcheck_proofs)));

				evalcheck_prover
					.batch_committed_eval_claims_mut()
					.try_extract_same_query_pcs_claim(batch.id)?
					.expect(
						"by construction, we must be left with a same query eval claim for the \
						batch",
					)
			};

			Ok((batch.id, same_query_claim))
		})
		.collect::<Result<_, Error>>()?;

	// The batch committed reduction must not result in any new sumcheck claims.
	assert!(evalcheck_prover.take_new_sumchecks().is_empty());

	Ok(GreedyEvalcheckProveOutput {
		proof,
		same_query_claims,
	})
}