binius_core/protocols/sumcheck/prove/
common.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Copyright 2024-2025 Irreducible Inc.

use binius_field::{packed::get_packed_slice, Field, PackedFieldIndexable};
use binius_hal::ComputationBackend;
use binius_maybe_rayon::prelude::*;
use tracing::instrument;

use crate::protocols::utils::packed_from_fn_with_offset;

#[instrument(skip_all, level = "debug")]
pub fn fold_partial_eq_ind<P, Backend>(n_vars: usize, partial_eq_ind_evals: &mut Backend::Vec<P>)
where
	P: PackedFieldIndexable,
	Backend: ComputationBackend,
{
	debug_assert_eq!(1 << n_vars.saturating_sub(P::LOG_WIDTH), partial_eq_ind_evals.len());

	if n_vars == 0 {
		return;
	}

	if partial_eq_ind_evals.len() == 1 {
		let unpacked = P::unpack_scalars_mut(partial_eq_ind_evals);
		let last = 1 << (n_vars - 1);
		for i in 0..last {
			unpacked[i] = unpacked[2 * i] + unpacked[2 * i + 1];
		}
		unpacked[last..].fill(P::Scalar::ZERO);
	} else {
		let current_evals = &*partial_eq_ind_evals;
		let updated_evals = (0..current_evals.len() / 2)
			.into_par_iter()
			.map(|i| {
				packed_from_fn_with_offset(i, |index| {
					let eval0 = get_packed_slice(current_evals, index << 1);
					let eval1 = get_packed_slice(current_evals, (index << 1) + 1);
					eval0 + eval1
				})
			})
			.collect();

		*partial_eq_ind_evals = Backend::to_hal_slice(updated_evals);
	}
}