binius_math/
tensor_prod_eq_ind.rsuse std::cmp::max;
use binius_field::{Field, PackedField};
use binius_maybe_rayon::prelude::*;
use binius_utils::bail;
use bytemuck::zeroed_vec;
use crate::Error;
pub fn tensor_prod_eq_ind<P: PackedField>(
log_n_values: usize,
packed_values: &mut [P],
extra_query_coordinates: &[P::Scalar],
) -> Result<(), Error> {
let new_n_vars = log_n_values + extra_query_coordinates.len();
if packed_values.len() != max(1, (1 << new_n_vars) / P::WIDTH) {
bail!(Error::InvalidPackedValuesLength);
}
for (i, r_i) in extra_query_coordinates.iter().enumerate() {
let prev_length = 1 << (log_n_values + i);
if prev_length < P::WIDTH {
let q = &mut packed_values[0];
for h in 0..prev_length {
let x = q.get(h);
let prod = x * r_i;
q.set(h, x - prod);
q.set(prev_length | h, prod);
}
} else {
let prev_packed_length = prev_length / P::WIDTH;
let packed_r_i = P::broadcast(*r_i);
let (xs, ys) = packed_values.split_at_mut(prev_packed_length);
assert!(xs.len() <= ys.len());
xs.par_iter_mut()
.zip(ys.par_iter_mut())
.with_min_len(64)
.for_each(|(x, y)| {
let prod = (*x) * packed_r_i;
*x -= prod;
*y = prod;
});
}
}
Ok(())
}
pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> Vec<P> {
let n = point.len();
let len = 1 << n.saturating_sub(P::LOG_WIDTH);
let mut buffer = zeroed_vec::<P>(len);
buffer[0].set(0, P::Scalar::ONE);
tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length");
buffer
}
#[cfg(test)]
mod tests {
use binius_field::{packed::set_packed_slice, Field, PackedBinaryField4x32b};
use itertools::Itertools;
use super::*;
type P = PackedBinaryField4x32b;
type F = <P as PackedField>::Scalar;
#[test]
fn test_tensor_prod_eq_ind() {
let v0 = F::new(1);
let v1 = F::new(2);
let query = vec![v0, v1];
let mut result = vec![P::default(); 1 << (query.len() - P::LOG_WIDTH)];
set_packed_slice(&mut result, 0, F::ONE);
tensor_prod_eq_ind(0, &mut result, &query).unwrap();
let result = PackedField::iter_slice(&result).collect_vec();
assert_eq!(
result,
vec![
(F::ONE - v0) * (F::ONE - v1),
v0 * (F::ONE - v1),
(F::ONE - v0) * v1,
v0 * v1
]
);
}
#[test]
fn test_eq_ind_partial_eval_empty() {
let result = eq_ind_partial_eval::<P>(&[]);
let expected = vec![P::set_single(F::ONE)];
assert_eq!(result, expected);
}
#[test]
fn test_eq_ind_partial_eval_single_var() {
let r0 = F::new(2);
let result = eq_ind_partial_eval::<P>(&[r0]);
let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO];
let result = PackedField::iter_slice(&result).collect_vec();
assert_eq!(result, expected);
}
#[test]
fn test_eq_ind_partial_eval_two_vars() {
let r0 = F::new(2);
let r1 = F::new(3);
let result = eq_ind_partial_eval::<P>(&[r0, r1]);
let result = PackedField::iter_slice(&result).collect_vec();
let expected = vec![
(F::ONE - r0) * (F::ONE - r1),
r0 * (F::ONE - r1),
(F::ONE - r0) * r1,
r0 * r1,
];
assert_eq!(result, expected);
}
#[test]
fn test_eq_ind_partial_eval_three_vars() {
let r0 = F::new(2);
let r1 = F::new(3);
let r2 = F::new(5);
let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
let result = PackedField::iter_slice(&result).collect_vec();
let expected = vec![
(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
r0 * (F::ONE - r1) * (F::ONE - r2),
(F::ONE - r0) * r1 * (F::ONE - r2),
r0 * r1 * (F::ONE - r2),
(F::ONE - r0) * (F::ONE - r1) * r2,
r0 * (F::ONE - r1) * r2,
(F::ONE - r0) * r1 * r2,
r0 * r1 * r2,
];
assert_eq!(result, expected);
}
}