1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright 2023 Ulvetanna Inc.

use super::{Error, VerificationError};
use crate::{
	oracle::{CompositePolyOracle, OracleId},
	polynomial::MultilinearComposite,
	protocols::{
		abstract_sumcheck::{
			AbstractSumcheckClaim, AbstractSumcheckProof, AbstractSumcheckReductor,
			AbstractSumcheckRound, AbstractSumcheckRoundClaim, AbstractSumcheckWitness,
		},
		evalcheck::EvalcheckClaim,
	},
};
use binius_field::{Field, PackedField};
use binius_math::evaluate_univariate;
use binius_utils::bail;

pub type SumcheckRound<F> = AbstractSumcheckRound<F>;
pub type SumcheckProof<F> = AbstractSumcheckProof<F>;

#[derive(Debug)]
pub struct SumcheckProveOutput<F: Field> {
	pub evalcheck_claim: EvalcheckClaim<F>,
	pub sumcheck_proof: SumcheckProof<F>,
}

#[derive(Debug, Clone)]
pub struct SumcheckClaim<F: Field> {
	pub poly: CompositePolyOracle<F>,
	pub sum: F,
}

impl<F: Field> AbstractSumcheckClaim<F> for SumcheckClaim<F> {
	fn n_vars(&self) -> usize {
		self.poly.n_vars()
	}

	fn max_individual_degree(&self) -> usize {
		self.poly.max_individual_degree()
	}

	fn sum(&self) -> F {
		self.sum
	}
}

// Default sumcheck witness type is just multilinear composite
pub type SumcheckWitness<P, C, M> = MultilinearComposite<P, C, M>;

pub type SumcheckRoundClaim<F> = AbstractSumcheckRoundClaim<F>;

pub struct SumcheckReductor {
	pub max_individual_degree: usize,
}

impl<F: Field> AbstractSumcheckReductor<F> for SumcheckReductor {
	type Error = Error;

	fn validate_round_proof_shape(
		&self,
		_round: usize,
		proof: &AbstractSumcheckRound<F>,
	) -> Result<(), Self::Error> {
		if proof.coeffs.len() != self.max_individual_degree {
			return Err(VerificationError::NumberOfCoefficients {
				expected: self.max_individual_degree,
			}
			.into());
		}
		Ok(())
	}

	fn reduce_round_claim(
		&self,
		_round: usize,
		claim: AbstractSumcheckRoundClaim<F>,
		challenge: F,
		round_proof: AbstractSumcheckRound<F>,
	) -> Result<AbstractSumcheckRoundClaim<F>, Self::Error> {
		reduce_intermediate_round_claim_helper(claim, challenge, round_proof)
	}
}

fn reduce_intermediate_round_claim_helper<F: Field>(
	claim: SumcheckRoundClaim<F>,
	challenge: F,
	proof: SumcheckRound<F>,
) -> Result<SumcheckRoundClaim<F>, Error> {
	let SumcheckRoundClaim {
		mut partial_point,
		current_round_sum,
	} = claim;

	let SumcheckRound { mut coeffs } = proof;

	// The prover has sent coefficients for the purported ith round polynomial
	// * $r_i(X) = \sum_{j=0}^d a_j * X^j$
	// However, the prover has not sent the highest degree coefficient $a_d$.
	// The verifier will need to recover this missing coefficient.
	//
	// Let $s$ denote the current round's claimed sum.
	// The verifier expects the round polynomial $r_i$ to satisfy the identity
	// * $s = r_i(0) + r_i(1)$
	// Using
	//     $r_i(0) = a_0$
	//     $r_i(1) = \sum_{j=0}^d a_j$
	// There is a unique $a_d$ that allows $r_i$ to satisfy the above identity.
	// Specifically
	//     $a_d = s - a_0 - \sum_{j=0}^{d-1} a_j$
	//
	// Not sending the whole round polynomial is an optimization.
	// In the unoptimized version of the protocol, the verifier will halt and reject
	// if given a round polynomial that does not satisfy the above identity.
	let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
	let last_coeff = current_round_sum - first_coeff - coeffs.iter().sum::<F>();
	coeffs.push(last_coeff);
	let new_round_sum = evaluate_univariate(&coeffs, challenge);

	partial_point.push(challenge);

	Ok(SumcheckRoundClaim {
		partial_point,
		current_round_sum: new_round_sum,
	})
}

pub fn validate_witness<F, PW, W>(claim: &SumcheckClaim<F>, witness: W) -> Result<(), Error>
where
	F: Field,
	PW: PackedField<Scalar: From<F> + Into<F>>,
	W: AbstractSumcheckWitness<PW, MultilinearId = OracleId>,
{
	let log_size = claim.n_vars();
	let oracle_ids = claim.poly.inner_polys_oracle_ids().collect::<Vec<_>>();
	let multilinears = witness
		.multilinears(0, oracle_ids.as_slice())?
		.into_iter()
		.map(|(_, multilinear)| multilinear)
		.collect::<Vec<_>>();

	let witness = MultilinearComposite::new(log_size, witness.composition(), multilinears)?;

	let sum = (0..(1 << log_size))
		.try_fold(PW::Scalar::ZERO, |acc, i| witness.evaluate_on_hypercube(i).map(|res| res + acc));

	if sum? == claim.sum().into() {
		Ok(())
	} else {
		bail!(Error::NaiveValidation)
	}
}