use std::iter;
use binius_maybe_rayon::prelude::*;
use binius_utils::checked_arithmetics::checked_int_div;
use crate::{packed::get_packed_slice_unchecked, ExtensionField, Field, PackedField};
pub fn inner_product_unchecked<F, FE>(
a: impl IntoIterator<Item = FE>,
b: impl IntoIterator<Item = F>,
) -> FE
where
F: Field,
FE: ExtensionField<F>,
{
iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum()
}
pub fn inner_product_par<FX, PX, PY>(xs: &[PX], ys: &[PY]) -> FX
where
PX: PackedField<Scalar = FX>,
PY: PackedField,
FX: ExtensionField<PY::Scalar>,
{
assert!(
PX::WIDTH * xs.len() <= PY::WIDTH * ys.len(),
"Y elements has to be at least as wide as X elements"
);
if PX::WIDTH * xs.len() < PY::WIDTH * ys.len() {
return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys));
}
let calc_product_by_ys = |xs: &[PX], ys: &[PY]| {
let mut result = FX::ZERO;
for (j, y) in ys.iter().enumerate() {
for (k, y) in y.iter().enumerate() {
result += unsafe { get_packed_slice_unchecked(xs, j * PY::WIDTH + k) } * y
}
}
result
};
const CHUNK_SIZE: usize = 64;
if ys.len() < 16 * CHUNK_SIZE {
calc_product_by_ys(xs, ys)
} else {
ys.par_chunks(CHUNK_SIZE)
.enumerate()
.map(|(i, ys)| {
let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH);
calc_product_by_ys(&xs[offset..], ys)
})
.sum()
}
}
#[inline(always)]
pub fn eq<F: Field>(x: F, y: F) -> F {
x * y + (F::ONE - x) * (F::ONE - y)
}
pub fn powers<F: Field>(val: F) -> impl Iterator<Item = F> {
iter::successors(Some(F::ONE), move |&power| Some(power * val))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::PackedBinaryField4x32b;
type P = PackedBinaryField4x32b;
type F = <P as PackedField>::Scalar;
#[test]
fn test_inner_product_par_equal_length() {
let xs1 = F::new(1);
let xs2 = F::new(2);
let xs = vec![P::set_single(xs1), P::set_single(xs2)];
let ys1 = F::new(3);
let ys2 = F::new(4);
let ys = vec![P::set_single(ys1), P::set_single(ys2)];
let result = inner_product_par::<F, P, P>(&xs, &ys);
let expected = xs1 * ys1 + xs2 * ys2;
assert_eq!(result, expected);
}
#[test]
fn test_inner_product_par_unequal_length() {
let xs1 = F::new(1);
let xs = vec![P::set_single(xs1)];
let ys1 = F::new(2);
let ys2 = F::new(3);
let ys = vec![P::set_single(ys1), P::set_single(ys2)];
let result = inner_product_par::<F, P, P>(&xs, &ys);
let expected = xs1 * ys1;
assert_eq!(result, expected);
}
#[test]
fn test_inner_product_par_large_input_single_threaded() {
let size = 256;
let xs: Vec<P> = (0..size).map(|i| P::set_single(F::new(i as u32))).collect();
let ys: Vec<P> = (0..size)
.map(|i| P::set_single(F::new((i + 1) as u32)))
.collect();
let result = inner_product_par::<F, P, P>(&xs, &ys);
let expected = (0..size)
.map(|i| F::new(i as u32) * F::new((i + 1) as u32))
.sum::<F>();
assert_eq!(result, expected);
}
#[test]
fn test_inner_product_par_large_input_par() {
let size = 2000;
let xs: Vec<P> = (0..size).map(|i| P::set_single(F::new(i as u32))).collect();
let ys: Vec<P> = (0..size)
.map(|i| P::set_single(F::new((i + 1) as u32)))
.collect();
let result = inner_product_par::<F, P, P>(&xs, &ys);
let expected = (0..size)
.map(|i| F::new(i as u32) * F::new((i + 1) as u32))
.sum::<F>();
assert_eq!(result, expected);
}
#[test]
fn test_inner_product_par_empty() {
let xs: Vec<P> = vec![];
let ys: Vec<P> = vec![];
let result = inner_product_par::<F, P, P>(&xs, &ys);
let expected = F::ZERO;
assert_eq!(result, expected);
}
}