use core::slice;
use std::{any::TypeId, cmp::min, mem::MaybeUninit};
use binius_field::{
arch::{ArchOptimal, OptimalUnderlier},
byte_iteration::{
can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes,
ByteIteratorCallback, PackedSlice,
},
packed::{get_packed_slice, get_packed_slice_unchecked, set_packed_slice_unchecked},
underlier::{UnderlierWithBitOps, WithUnderlier},
AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField,
Field, PackedField,
};
use binius_utils::bail;
use bytemuck::fill_zeroes;
use lazy_static::lazy_static;
use stackalloc::helpers::slice_assume_init_mut;
use crate::Error;
pub fn fold_right<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [PE],
) -> Result<(), Error>
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
&& fold_right_1bit_evals(evals, log_evals_size, query, log_query_size, out)
{
return Ok(());
}
let is_lerp = log_query_size == 1
&& get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE;
if is_lerp {
let lerp_query = get_packed_slice(query, 1);
fold_right_lerp(evals, log_evals_size, lerp_query, out);
} else {
fold_right_fallback(evals, log_evals_size, query, log_query_size, out);
}
Ok(())
}
pub fn fold_left<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [MaybeUninit<PE>],
) -> Result<(), Error>
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
check_fold_arguments(evals, log_evals_size, query, log_query_size, out)?;
if TypeId::of::<P::Scalar>() == TypeId::of::<BinaryField1b>()
&& fold_left_1b_128b(evals, log_evals_size, query, log_query_size, out)
{
return Ok(());
}
fold_left_fallback(evals, log_evals_size, query, log_query_size, out);
Ok(())
}
#[inline]
fn check_fold_arguments<P, PE, POut>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &[POut],
) -> Result<(), Error>
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
if log_evals_size < log_query_size {
bail!(Error::IncorrectQuerySize {
expected: log_evals_size
});
}
if P::LOG_WIDTH + evals.len() < log_evals_size {
bail!(Error::IncorrectArgumentLength {
arg: "evals".into(),
expected: log_evals_size
});
}
if PE::LOG_WIDTH + query.len() < log_query_size {
bail!(Error::IncorrectArgumentLength {
arg: "query".into(),
expected: log_query_size
});
}
if PE::LOG_WIDTH + out.len() < log_evals_size - log_query_size {
bail!(Error::IncorrectOutputPolynomialSize {
expected: log_evals_size - log_query_size
});
}
Ok(())
}
fn fold_right_1bit_evals_small_query<P, PE, const LOG_QUERY_SIZE: usize>(
evals: &[P],
query: &[PE],
out: &mut [PE],
) -> bool
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
if LOG_QUERY_SIZE >= 3 {
return false;
}
if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH {
return false;
}
let cached_table = (0..1 << (1 << LOG_QUERY_SIZE))
.map(|i| {
let mut result = PE::Scalar::ZERO;
for j in 0..1 << LOG_QUERY_SIZE {
if i >> j & 1 == 1 {
result += get_packed_slice(query, j);
}
}
result
})
.collect::<Vec<_>>();
struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
out: &'a mut [PE],
cached_table: &'a [PE::Scalar],
}
impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
for Callback<'_, PE, LOG_QUERY_SIZE>
{
#[inline(always)]
fn call(&mut self, iterator: impl Iterator<Item = u8>) {
let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1;
let values_in_byte = 1 << (3 - LOG_QUERY_SIZE);
let mut current_index = 0;
for byte in iterator {
for k in 0..values_in_byte {
let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask;
unsafe {
set_packed_slice_unchecked(
self.out,
current_index + k,
self.cached_table[index as usize],
);
}
}
current_index += values_in_byte;
}
}
}
let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> {
out,
cached_table: &cached_table,
};
iterate_bytes(evals, &mut callback);
true
}
fn fold_right_1bit_evals_medium_query<P, PE, const LOG_QUERY_SIZE: usize>(
evals: &[P],
query: &[PE],
out: &mut [PE],
) -> bool
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
if LOG_QUERY_SIZE < 3 {
return false;
}
if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH {
return false;
}
let cached_tables =
create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE));
struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> {
out: &'a mut [PE],
cached_tables: &'a [PE::Scalar],
}
impl<PE: PackedField, const LOG_QUERY_SIZE: usize> ByteIteratorCallback
for Callback<'_, PE, LOG_QUERY_SIZE>
{
#[inline(always)]
fn call(&mut self, iterator: impl Iterator<Item = u8>) {
let log_tables_count = LOG_QUERY_SIZE - 3;
let tables_count = 1 << log_tables_count;
let mut current_index = 0;
let mut current_table = 0;
let mut current_value = PE::Scalar::ZERO;
for byte in iterator {
current_value += self.cached_tables[(current_table << 8) + byte as usize];
current_table += 1;
if current_table == tables_count {
unsafe {
set_packed_slice_unchecked(self.out, current_index, current_value);
}
current_index += 1;
current_table = 0;
current_value = PE::Scalar::ZERO;
}
}
}
}
let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> {
out,
cached_tables: &cached_tables,
};
iterate_bytes(evals, &mut callback);
true
}
fn fold_right_1bit_evals<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [PE],
) -> bool
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
if log_evals_size < P::LOG_WIDTH {
return false;
}
if !can_iterate_bytes::<P>() {
return false;
}
match log_query_size {
0 => fold_right_1bit_evals_small_query::<P, PE, 0>(evals, query, out),
1 => fold_right_1bit_evals_small_query::<P, PE, 1>(evals, query, out),
2 => fold_right_1bit_evals_small_query::<P, PE, 2>(evals, query, out),
3 => fold_right_1bit_evals_medium_query::<P, PE, 3>(evals, query, out),
4 => fold_right_1bit_evals_medium_query::<P, PE, 4>(evals, query, out),
5 => fold_right_1bit_evals_medium_query::<P, PE, 5>(evals, query, out),
6 => fold_right_1bit_evals_medium_query::<P, PE, 6>(evals, query, out),
7 => fold_right_1bit_evals_medium_query::<P, PE, 7>(evals, query, out),
_ => false,
}
}
fn fold_right_lerp<P, PE>(
evals: &[P],
log_evals_size: usize,
lerp_query: PE::Scalar,
out: &mut [PE],
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
assert_eq!(1 << log_evals_size.saturating_sub(PE::LOG_WIDTH + 1), out.len());
out.iter_mut()
.enumerate()
.for_each(|(i, packed_result_eval)| {
for j in 0..min(PE::WIDTH, 1 << (log_evals_size - 1)) {
let index = (i << PE::LOG_WIDTH) | j;
let (eval0, eval1) = unsafe {
(
get_packed_slice_unchecked(evals, index << 1),
get_packed_slice_unchecked(evals, (index << 1) | 1),
)
};
let result_eval =
PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0);
unsafe {
packed_result_eval.set_unchecked(j, result_eval);
}
}
})
}
fn fold_right_fallback<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [PE],
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
for (k, packed_result_eval) in out.iter_mut().enumerate() {
for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) {
let index = (k << PE::LOG_WIDTH) | j;
let offset = index << log_query_size;
let mut result_eval = PE::Scalar::ZERO;
for (t, query_expansion) in PackedField::iter_slice(query)
.take(1 << log_query_size)
.enumerate()
{
result_eval += query_expansion * get_packed_slice(evals, t + offset);
}
unsafe {
packed_result_eval.set_unchecked(j, result_eval);
}
}
}
}
type ArchOptimaType<F> = <F as ArchOptimal>::OptimalThroughputPacked;
#[inline(always)]
fn get_arch_optimal_packed_type_id<F: ArchOptimal>() -> TypeId {
TypeId::of::<ArchOptimaType<F>>()
}
fn fold_left_1b_128b<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [MaybeUninit<PE>],
) -> bool
where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
if log_evals_size < P::LOG_WIDTH || !is_sequential_bytes::<P>() {
return false;
}
let log_row_size = log_evals_size - log_query_size;
if log_row_size < 3 {
return false;
}
if PE::LOG_WIDTH > 3 {
return false;
}
let evals_u8: &[u8] = unsafe {
std::slice::from_raw_parts(evals.as_ptr() as *const u8, std::mem::size_of_val(evals))
};
#[inline]
fn try_run_specialization<PE, F>(
lookup_table: &[OptimalUnderlier],
evals_u8: &[u8],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [MaybeUninit<PE>],
) -> bool
where
PE: PackedField,
F: ArchOptimal,
{
if TypeId::of::<PE>() == get_arch_optimal_packed_type_id::<F>() {
let query = cast_same_type_slice::<_, ArchOptimaType<F>>(query);
let out = cast_same_type_slice_mut::<_, MaybeUninit<ArchOptimaType<F>>>(out);
fold_left_1b_128b_impl(
lookup_table,
evals_u8,
log_evals_size,
query,
log_query_size,
out,
);
true
} else {
false
}
}
let lookup_table = &*LOOKUP_TABLE;
try_run_specialization::<_, BinaryField128b>(
lookup_table,
evals_u8,
log_evals_size,
query,
log_query_size,
out,
) || try_run_specialization::<_, AESTowerField128b>(
lookup_table,
evals_u8,
log_evals_size,
query,
log_query_size,
out,
) || try_run_specialization::<_, BinaryField128bPolyval>(
lookup_table,
evals_u8,
log_evals_size,
query,
log_query_size,
out,
)
}
#[inline(always)]
fn cast_same_type_slice_mut<T: Sized + 'static, U: Sized + 'static>(slice: &mut [T]) -> &mut [U] {
assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut U, slice.len()) }
}
#[inline(always)]
fn cast_same_type_slice<T: Sized + 'static, U: Sized + 'static>(slice: &[T]) -> &[U] {
assert_eq!(TypeId::of::<T>(), TypeId::of::<U>());
unsafe { slice::from_raw_parts(slice.as_ptr() as *const U, slice.len()) }
}
fn init_lookup_table_width<U>() -> Vec<U>
where
U: UnderlierWithBitOps + From<u128>,
{
let items_b128 = U::BITS / u128::BITS as usize;
assert!(items_b128 <= 8);
let items_in_byte = 8 / items_b128;
let mut result = Vec::with_capacity(256 * items_in_byte);
for i in 0..256 {
for j in 0..items_in_byte {
let bits = (i >> (j * items_b128)) & ((1 << items_b128) - 1);
let mut value = U::ZERO;
for k in 0..items_b128 {
if (bits >> k) & 1 == 1 {
unsafe {
value.set_subvalue(k, u128::ONES);
}
}
}
result.push(value);
}
}
result
}
lazy_static! {
static ref LOOKUP_TABLE: Vec<OptimalUnderlier> = init_lookup_table_width::<OptimalUnderlier>();
}
#[inline]
fn fold_left_1b_128b_impl<PE, U>(
lookup_table: &[U],
evals: &[u8],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [MaybeUninit<PE>],
) where
PE: PackedField + WithUnderlier<Underlier = U>,
U: UnderlierWithBitOps,
{
let out = unsafe { slice_assume_init_mut(out) };
fill_zeroes(out);
let items_in_byte = 8 / PE::WIDTH;
let row_size_bytes = 1 << (log_evals_size - log_query_size - 3);
for (query_val, row_bytes) in PE::iter_slice(query).zip(evals.chunks(row_size_bytes)) {
let query_val = PE::broadcast(query_val).to_underlier();
for (byte_index, byte) in row_bytes.iter().enumerate() {
let mask_offset = *byte as usize * items_in_byte;
let out_offset = byte_index * items_in_byte;
for i in 0..items_in_byte {
let mask = unsafe { lookup_table.get_unchecked(mask_offset + i) };
let multiplied = query_val & *mask;
let out = unsafe { out.get_unchecked_mut(out_offset + i) };
*out += PE::from_underlier(multiplied);
}
}
}
}
fn fold_left_fallback<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [MaybeUninit<PE>],
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
let new_n_vars = log_evals_size - log_query_size;
out.iter_mut()
.enumerate()
.for_each(|(outer_index, out_val)| {
let mut res = PE::default();
for inner_index in 0..min(PE::WIDTH, 1 << new_n_vars) {
res.set(
inner_index,
PackedField::iter_slice(query)
.take(1 << log_query_size)
.enumerate()
.map(|(query_index, basis_eval)| {
let eval_index = (query_index << new_n_vars)
| (outer_index << PE::LOG_WIDTH)
| inner_index;
let subpoly_eval_i = get_packed_slice(evals, eval_index);
basis_eval * subpoly_eval_i
})
.sum(),
);
}
out_val.write(res);
});
}
#[cfg(test)]
mod tests {
use std::iter::repeat_with;
use binius_field::{
packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b,
PackedBinaryField16x8b, PackedBinaryField512x1b, PackedBinaryField64x8b,
};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
fn fold_right_reference<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [PE],
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
for i in 0..1 << (log_evals_size - log_query_size) {
let mut result = PE::Scalar::ZERO;
for j in 0..1 << log_query_size {
result +=
get_packed_slice(query, j) * get_packed_slice(evals, (i << log_query_size) | j);
}
set_packed_slice(out, i, result);
}
}
fn check_fold_right<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
let mut reference_out =
vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
let mut out = reference_out.clone();
fold_right(evals, log_evals_size, query, log_query_size, &mut out).unwrap();
fold_right_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
for i in 0..1 << (log_evals_size - log_query_size) {
assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
}
}
#[test]
fn test_1b_small_poly_query_log_size_0() {
let mut rng = StdRng::seed_from_u64(0);
let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
let query = vec![PackedBinaryField128x1b::random(&mut rng)];
check_fold_right(&evals, 0, &query, 0);
}
#[test]
fn test_1b_small_poly_query_log_size_1() {
let mut rng = StdRng::seed_from_u64(0);
let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
let query = vec![PackedBinaryField128x1b::random(&mut rng)];
check_fold_right(&evals, 2, &query, 1);
}
#[test]
fn test_1b_small_poly_query_log_size_7() {
let mut rng = StdRng::seed_from_u64(0);
let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
let query = vec![PackedBinaryField128x1b::random(&mut rng)];
check_fold_right(&evals, 7, &query, 7);
}
#[test]
fn test_1b_many_evals() {
const LOG_EVALS_SIZE: usize = 14;
let mut rng = StdRng::seed_from_u64(1);
let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
.take(1 << LOG_EVALS_SIZE)
.collect::<Vec<_>>();
let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng))
.take(8)
.collect::<Vec<_>>();
for log_query_size in 0..10 {
check_fold_right(
&evals,
LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
&query,
log_query_size,
);
}
}
#[test]
fn test_8b_small_poly() {
const LOG_EVALS_SIZE: usize = 5;
let mut rng = StdRng::seed_from_u64(0);
let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
.take(1 << LOG_EVALS_SIZE)
.collect::<Vec<_>>();
let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
.take(1 << 8)
.collect::<Vec<_>>();
for log_query_size in 0..8 {
check_fold_right(
&evals,
LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
&query,
log_query_size,
);
}
}
#[test]
fn test_8b_many_evals() {
const LOG_EVALS_SIZE: usize = 13;
let mut rng = StdRng::seed_from_u64(0);
let evals = repeat_with(|| PackedBinaryField16x8b::random(&mut rng))
.take(1 << LOG_EVALS_SIZE)
.collect::<Vec<_>>();
let query = repeat_with(|| PackedBinaryField16x32b::random(&mut rng))
.take(1 << 8)
.collect::<Vec<_>>();
for log_query_size in 0..8 {
check_fold_right(
&evals,
LOG_EVALS_SIZE + PackedBinaryField16x8b::LOG_WIDTH,
&query,
log_query_size,
);
}
}
fn fold_left_reference<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
out: &mut [PE],
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
for i in 0..1 << (log_evals_size - log_query_size) {
let mut result = PE::Scalar::ZERO;
for j in 0..1 << log_query_size {
result += get_packed_slice(query, j)
* get_packed_slice(evals, i | (j << (log_evals_size - log_query_size)));
}
set_packed_slice(out, i, result);
}
}
fn check_fold_left<P, PE>(
evals: &[P],
log_evals_size: usize,
query: &[PE],
log_query_size: usize,
) where
P: PackedField,
PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
{
let mut reference_out =
vec![PE::zero(); (1usize << (log_evals_size - log_query_size)).div_ceil(PE::WIDTH)];
let mut out = reference_out.clone();
out.clear();
fold_left(evals, log_evals_size, query, log_query_size, out.spare_capacity_mut()).unwrap();
unsafe {
out.set_len(out.capacity());
}
fold_left_reference(evals, log_evals_size, query, log_query_size, &mut reference_out);
for i in 0..1 << (log_evals_size - log_query_size) {
assert_eq!(get_packed_slice(&out, i), get_packed_slice(&reference_out, i));
}
}
#[test]
fn test_fold_left_1b_small_poly_query_log_size_0() {
let mut rng = StdRng::seed_from_u64(0);
let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
let query = vec![PackedBinaryField128x1b::random(&mut rng)];
check_fold_left(&evals, 0, &query, 0);
}
#[test]
fn test_fold_left_1b_small_poly_query_log_size_1() {
let mut rng = StdRng::seed_from_u64(0);
let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
let query = vec![PackedBinaryField128x1b::random(&mut rng)];
check_fold_left(&evals, 2, &query, 1);
}
#[test]
fn test_fold_left_1b_small_poly_query_log_size_7() {
let mut rng = StdRng::seed_from_u64(0);
let evals = vec![PackedBinaryField128x1b::random(&mut rng)];
let query = vec![PackedBinaryField128x1b::random(&mut rng)];
check_fold_left(&evals, 7, &query, 7);
}
#[test]
fn test_fold_left_1b_many_evals() {
const LOG_EVALS_SIZE: usize = 14;
let mut rng = StdRng::seed_from_u64(1);
let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
.take(1 << LOG_EVALS_SIZE)
.collect::<Vec<_>>();
let query = vec![PackedBinaryField512x1b::random(&mut rng)];
for log_query_size in 0..10 {
check_fold_left(
&evals,
LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
&query,
log_query_size,
);
}
}
type B128bOptimal = ArchOptimaType<BinaryField128b>;
#[test]
fn test_fold_left_1b_128b_optimal() {
const LOG_EVALS_SIZE: usize = 14;
let mut rng = StdRng::seed_from_u64(0);
let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng))
.take(1 << LOG_EVALS_SIZE)
.collect::<Vec<_>>();
let query = repeat_with(|| B128bOptimal::random(&mut rng))
.take(1 << (10 - B128bOptimal::LOG_WIDTH))
.collect::<Vec<_>>();
for log_query_size in 0..10 {
check_fold_left(
&evals,
LOG_EVALS_SIZE + PackedBinaryField128x1b::LOG_WIDTH,
&query,
log_query_size,
);
}
}
#[test]
fn test_fold_left_128b_128b() {
const LOG_EVALS_SIZE: usize = 14;
let mut rng = StdRng::seed_from_u64(0);
let evals = repeat_with(|| B128bOptimal::random(&mut rng))
.take(1 << LOG_EVALS_SIZE)
.collect::<Vec<_>>();
let query = repeat_with(|| B128bOptimal::random(&mut rng))
.take(1 << 10)
.collect::<Vec<_>>();
for log_query_size in 0..10 {
check_fold_left(&evals, LOG_EVALS_SIZE, &query, log_query_size);
}
}
}