use std::any::TypeId;
use bytemuck::Pod;
use crate::{
packed::get_packed_slice, AESTowerField128b, AESTowerField16b, AESTowerField32b,
AESTowerField64b, AESTowerField8b, BinaryField128b, BinaryField128bPolyval, BinaryField16b,
BinaryField32b, BinaryField64b, BinaryField8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b,
ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, Field,
PackedAESBinaryField16x16b, PackedAESBinaryField16x32b, PackedAESBinaryField16x8b,
PackedAESBinaryField1x128b, PackedAESBinaryField1x16b, PackedAESBinaryField1x32b,
PackedAESBinaryField1x64b, PackedAESBinaryField1x8b, PackedAESBinaryField2x128b,
PackedAESBinaryField2x16b, PackedAESBinaryField2x32b, PackedAESBinaryField2x64b,
PackedAESBinaryField2x8b, PackedAESBinaryField32x16b, PackedAESBinaryField32x8b,
PackedAESBinaryField4x128b, PackedAESBinaryField4x16b, PackedAESBinaryField4x32b,
PackedAESBinaryField4x64b, PackedAESBinaryField4x8b, PackedAESBinaryField64x8b,
PackedAESBinaryField8x16b, PackedAESBinaryField8x64b, PackedAESBinaryField8x8b,
PackedBinaryField128x1b, PackedBinaryField128x2b, PackedBinaryField128x4b,
PackedBinaryField16x16b, PackedBinaryField16x1b, PackedBinaryField16x2b,
PackedBinaryField16x32b, PackedBinaryField16x4b, PackedBinaryField16x8b,
PackedBinaryField1x128b, PackedBinaryField1x16b, PackedBinaryField1x32b,
PackedBinaryField1x64b, PackedBinaryField1x8b, PackedBinaryField256x1b,
PackedBinaryField256x2b, PackedBinaryField2x128b, PackedBinaryField2x16b,
PackedBinaryField2x32b, PackedBinaryField2x4b, PackedBinaryField2x64b, PackedBinaryField2x8b,
PackedBinaryField32x16b, PackedBinaryField32x1b, PackedBinaryField32x2b,
PackedBinaryField32x4b, PackedBinaryField32x8b, PackedBinaryField4x128b,
PackedBinaryField4x16b, PackedBinaryField4x2b, PackedBinaryField4x32b, PackedBinaryField4x4b,
PackedBinaryField4x64b, PackedBinaryField4x8b, PackedBinaryField512x1b, PackedBinaryField64x1b,
PackedBinaryField64x2b, PackedBinaryField64x4b, PackedBinaryField64x8b, PackedBinaryField8x16b,
PackedBinaryField8x1b, PackedBinaryField8x2b, PackedBinaryField8x32b, PackedBinaryField8x4b,
PackedBinaryField8x64b, PackedBinaryField8x8b, PackedBinaryPolyval1x128b,
PackedBinaryPolyval2x128b, PackedBinaryPolyval4x128b, PackedField,
};
#[allow(unused)]
unsafe trait SequentialBytes: Pod {}
unsafe impl SequentialBytes for BinaryField8b {}
unsafe impl SequentialBytes for BinaryField16b {}
unsafe impl SequentialBytes for BinaryField32b {}
unsafe impl SequentialBytes for BinaryField64b {}
unsafe impl SequentialBytes for BinaryField128b {}
unsafe impl SequentialBytes for PackedBinaryField8x1b {}
unsafe impl SequentialBytes for PackedBinaryField16x1b {}
unsafe impl SequentialBytes for PackedBinaryField32x1b {}
unsafe impl SequentialBytes for PackedBinaryField64x1b {}
unsafe impl SequentialBytes for PackedBinaryField128x1b {}
unsafe impl SequentialBytes for PackedBinaryField256x1b {}
unsafe impl SequentialBytes for PackedBinaryField512x1b {}
unsafe impl SequentialBytes for PackedBinaryField4x2b {}
unsafe impl SequentialBytes for PackedBinaryField8x2b {}
unsafe impl SequentialBytes for PackedBinaryField16x2b {}
unsafe impl SequentialBytes for PackedBinaryField32x2b {}
unsafe impl SequentialBytes for PackedBinaryField64x2b {}
unsafe impl SequentialBytes for PackedBinaryField128x2b {}
unsafe impl SequentialBytes for PackedBinaryField256x2b {}
unsafe impl SequentialBytes for PackedBinaryField2x4b {}
unsafe impl SequentialBytes for PackedBinaryField4x4b {}
unsafe impl SequentialBytes for PackedBinaryField8x4b {}
unsafe impl SequentialBytes for PackedBinaryField16x4b {}
unsafe impl SequentialBytes for PackedBinaryField32x4b {}
unsafe impl SequentialBytes for PackedBinaryField64x4b {}
unsafe impl SequentialBytes for PackedBinaryField128x4b {}
unsafe impl SequentialBytes for PackedBinaryField1x8b {}
unsafe impl SequentialBytes for PackedBinaryField2x8b {}
unsafe impl SequentialBytes for PackedBinaryField4x8b {}
unsafe impl SequentialBytes for PackedBinaryField8x8b {}
unsafe impl SequentialBytes for PackedBinaryField16x8b {}
unsafe impl SequentialBytes for PackedBinaryField32x8b {}
unsafe impl SequentialBytes for PackedBinaryField64x8b {}
unsafe impl SequentialBytes for PackedBinaryField1x16b {}
unsafe impl SequentialBytes for PackedBinaryField2x16b {}
unsafe impl SequentialBytes for PackedBinaryField4x16b {}
unsafe impl SequentialBytes for PackedBinaryField8x16b {}
unsafe impl SequentialBytes for PackedBinaryField16x16b {}
unsafe impl SequentialBytes for PackedBinaryField32x16b {}
unsafe impl SequentialBytes for PackedBinaryField1x32b {}
unsafe impl SequentialBytes for PackedBinaryField2x32b {}
unsafe impl SequentialBytes for PackedBinaryField4x32b {}
unsafe impl SequentialBytes for PackedBinaryField8x32b {}
unsafe impl SequentialBytes for PackedBinaryField16x32b {}
unsafe impl SequentialBytes for PackedBinaryField1x64b {}
unsafe impl SequentialBytes for PackedBinaryField2x64b {}
unsafe impl SequentialBytes for PackedBinaryField4x64b {}
unsafe impl SequentialBytes for PackedBinaryField8x64b {}
unsafe impl SequentialBytes for PackedBinaryField1x128b {}
unsafe impl SequentialBytes for PackedBinaryField2x128b {}
unsafe impl SequentialBytes for PackedBinaryField4x128b {}
unsafe impl SequentialBytes for AESTowerField8b {}
unsafe impl SequentialBytes for AESTowerField16b {}
unsafe impl SequentialBytes for AESTowerField32b {}
unsafe impl SequentialBytes for AESTowerField64b {}
unsafe impl SequentialBytes for AESTowerField128b {}
unsafe impl SequentialBytes for PackedAESBinaryField1x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField2x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField4x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField8x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField16x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField32x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField64x8b {}
unsafe impl SequentialBytes for PackedAESBinaryField1x16b {}
unsafe impl SequentialBytes for PackedAESBinaryField2x16b {}
unsafe impl SequentialBytes for PackedAESBinaryField4x16b {}
unsafe impl SequentialBytes for PackedAESBinaryField8x16b {}
unsafe impl SequentialBytes for PackedAESBinaryField16x16b {}
unsafe impl SequentialBytes for PackedAESBinaryField32x16b {}
unsafe impl SequentialBytes for PackedAESBinaryField1x32b {}
unsafe impl SequentialBytes for PackedAESBinaryField2x32b {}
unsafe impl SequentialBytes for PackedAESBinaryField4x32b {}
unsafe impl SequentialBytes for PackedAESBinaryField16x32b {}
unsafe impl SequentialBytes for PackedAESBinaryField1x64b {}
unsafe impl SequentialBytes for PackedAESBinaryField2x64b {}
unsafe impl SequentialBytes for PackedAESBinaryField4x64b {}
unsafe impl SequentialBytes for PackedAESBinaryField8x64b {}
unsafe impl SequentialBytes for PackedAESBinaryField1x128b {}
unsafe impl SequentialBytes for PackedAESBinaryField2x128b {}
unsafe impl SequentialBytes for PackedAESBinaryField4x128b {}
unsafe impl SequentialBytes for BinaryField128bPolyval {}
unsafe impl SequentialBytes for PackedBinaryPolyval1x128b {}
unsafe impl SequentialBytes for PackedBinaryPolyval2x128b {}
unsafe impl SequentialBytes for PackedBinaryPolyval4x128b {}
#[inline(always)]
#[allow(clippy::redundant_clone)] pub fn is_sequential_bytes<T>() -> bool {
struct X<U>(bool, std::marker::PhantomData<U>);
impl<U> Clone for X<U> {
fn clone(&self) -> Self {
Self(false, std::marker::PhantomData)
}
}
impl<U: SequentialBytes> Copy for X<U> {}
let value = [X::<T>(true, std::marker::PhantomData)];
let cloned = value.clone();
cloned[0].0
}
pub fn can_iterate_bytes<P: PackedField>() -> bool {
if is_sequential_bytes::<P>() {
return true;
}
match TypeId::of::<P>() {
x if x == TypeId::of::<ByteSlicedAES32x128b>() => true,
x if x == TypeId::of::<ByteSlicedAES32x64b>() => true,
x if x == TypeId::of::<ByteSlicedAES32x32b>() => true,
x if x == TypeId::of::<ByteSlicedAES32x16b>() => true,
x if x == TypeId::of::<ByteSlicedAES32x8b>() => true,
_ => false,
}
}
macro_rules! iterate_byte_sliced {
($packed_type:ty, $data:ident, $callback:ident) => {
assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::<P>());
let data = unsafe {
std::slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len())
};
let iter = data.iter().flat_map(|value| {
(0..<$packed_type>::BYTES).map(move |i| unsafe { value.get_byte_unchecked(i) })
});
$callback.call(iter);
};
}
pub trait ByteIteratorCallback {
fn call(&mut self, iter: impl Iterator<Item = u8>);
}
#[inline(always)]
pub fn iterate_bytes<P: PackedField>(data: &[P], callback: &mut impl ByteIteratorCallback) {
if is_sequential_bytes::<P>() {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
};
callback.call(bytes.iter().copied());
} else {
match TypeId::of::<P>() {
x if x == TypeId::of::<ByteSlicedAES32x128b>() => {
iterate_byte_sliced!(ByteSlicedAES32x128b, data, callback);
}
x if x == TypeId::of::<ByteSlicedAES32x64b>() => {
iterate_byte_sliced!(ByteSlicedAES32x64b, data, callback);
}
x if x == TypeId::of::<ByteSlicedAES32x32b>() => {
iterate_byte_sliced!(ByteSlicedAES32x32b, data, callback);
}
x if x == TypeId::of::<ByteSlicedAES32x16b>() => {
iterate_byte_sliced!(ByteSlicedAES32x16b, data, callback);
}
x if x == TypeId::of::<ByteSlicedAES32x8b>() => {
iterate_byte_sliced!(ByteSlicedAES32x8b, data, callback);
}
_ => unreachable!("packed field doesn't support byte iteration"),
}
}
}
pub trait ScalarsCollection<T> {
fn len(&self) -> usize;
fn get(&self, i: usize) -> T;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<F: Field> ScalarsCollection<F> for &[F] {
#[inline(always)]
fn len(&self) -> usize {
<[F]>::len(self)
}
#[inline(always)]
fn get(&self, i: usize) -> F {
self[i]
}
}
pub struct PackedSlice<'a, P: PackedField> {
slice: &'a [P],
len: usize,
}
impl<'a, P: PackedField> PackedSlice<'a, P> {
#[inline(always)]
pub const fn new(slice: &'a [P], len: usize) -> Self {
Self { slice, len }
}
}
impl<P: PackedField> ScalarsCollection<P::Scalar> for PackedSlice<'_, P> {
#[inline(always)]
fn len(&self) -> usize {
self.len
}
#[inline(always)]
fn get(&self, i: usize) -> P::Scalar {
get_packed_slice(self.slice, i)
}
}
pub fn create_partial_sums_lookup_tables<P: PackedField>(
values: impl ScalarsCollection<P>,
) -> Vec<P> {
let len = values.len();
assert!(len % 8 == 0);
let mut result = Vec::with_capacity(len * 256 / 8);
for chunk_i in 0..len / 8 {
let offset = chunk_i * 8;
for i in 0..256 {
let mut sum = P::zero();
for j in 0..8 {
if i & (1 << j) != 0 {
sum += values.get(offset + j);
}
}
result.push(sum);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b};
#[test]
fn test_sequential_bits() {
assert!(is_sequential_bytes::<BinaryField8b>());
assert!(is_sequential_bytes::<BinaryField16b>());
assert!(is_sequential_bytes::<BinaryField32b>());
assert!(is_sequential_bytes::<BinaryField64b>());
assert!(is_sequential_bytes::<BinaryField128b>());
assert!(is_sequential_bytes::<PackedBinaryField8x1b>());
assert!(is_sequential_bytes::<PackedBinaryField16x1b>());
assert!(is_sequential_bytes::<PackedBinaryField32x1b>());
assert!(is_sequential_bytes::<PackedBinaryField64x1b>());
assert!(is_sequential_bytes::<PackedBinaryField128x1b>());
assert!(is_sequential_bytes::<PackedBinaryField256x1b>());
assert!(is_sequential_bytes::<PackedBinaryField512x1b>());
assert!(is_sequential_bytes::<PackedBinaryField4x2b>());
assert!(is_sequential_bytes::<PackedBinaryField8x2b>());
assert!(is_sequential_bytes::<PackedBinaryField16x2b>());
assert!(is_sequential_bytes::<PackedBinaryField32x2b>());
assert!(is_sequential_bytes::<PackedBinaryField64x2b>());
assert!(is_sequential_bytes::<PackedBinaryField128x2b>());
assert!(is_sequential_bytes::<PackedBinaryField256x2b>());
assert!(is_sequential_bytes::<PackedBinaryField2x4b>());
assert!(is_sequential_bytes::<PackedBinaryField4x4b>());
assert!(is_sequential_bytes::<PackedBinaryField8x4b>());
assert!(is_sequential_bytes::<PackedBinaryField16x4b>());
assert!(is_sequential_bytes::<PackedBinaryField32x4b>());
assert!(is_sequential_bytes::<PackedBinaryField64x4b>());
assert!(is_sequential_bytes::<PackedBinaryField128x4b>());
assert!(is_sequential_bytes::<PackedBinaryField1x8b>());
assert!(is_sequential_bytes::<PackedBinaryField2x8b>());
assert!(is_sequential_bytes::<PackedBinaryField4x8b>());
assert!(is_sequential_bytes::<PackedBinaryField8x8b>());
assert!(is_sequential_bytes::<PackedBinaryField16x8b>());
assert!(is_sequential_bytes::<PackedBinaryField32x8b>());
assert!(is_sequential_bytes::<PackedBinaryField64x8b>());
assert!(is_sequential_bytes::<PackedBinaryField1x16b>());
assert!(is_sequential_bytes::<PackedBinaryField2x16b>());
assert!(is_sequential_bytes::<PackedBinaryField4x16b>());
assert!(is_sequential_bytes::<PackedBinaryField8x16b>());
assert!(is_sequential_bytes::<PackedBinaryField16x16b>());
assert!(is_sequential_bytes::<PackedBinaryField32x16b>());
assert!(is_sequential_bytes::<PackedBinaryField1x32b>());
assert!(is_sequential_bytes::<PackedBinaryField2x32b>());
assert!(is_sequential_bytes::<PackedBinaryField4x32b>());
assert!(is_sequential_bytes::<PackedBinaryField8x32b>());
assert!(is_sequential_bytes::<PackedBinaryField16x32b>());
assert!(is_sequential_bytes::<PackedBinaryField1x64b>());
assert!(is_sequential_bytes::<PackedBinaryField2x64b>());
assert!(is_sequential_bytes::<PackedBinaryField4x64b>());
assert!(is_sequential_bytes::<PackedBinaryField8x64b>());
assert!(is_sequential_bytes::<PackedBinaryField1x128b>());
assert!(is_sequential_bytes::<PackedBinaryField2x128b>());
assert!(is_sequential_bytes::<PackedBinaryField4x128b>());
assert!(is_sequential_bytes::<AESTowerField8b>());
assert!(is_sequential_bytes::<AESTowerField16b>());
assert!(is_sequential_bytes::<AESTowerField32b>());
assert!(is_sequential_bytes::<AESTowerField64b>());
assert!(is_sequential_bytes::<AESTowerField128b>());
assert!(is_sequential_bytes::<PackedAESBinaryField1x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField2x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField4x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField8x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField16x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField32x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField64x8b>());
assert!(is_sequential_bytes::<PackedAESBinaryField1x16b>());
assert!(is_sequential_bytes::<PackedAESBinaryField2x16b>());
assert!(is_sequential_bytes::<PackedAESBinaryField4x16b>());
assert!(is_sequential_bytes::<PackedAESBinaryField8x16b>());
assert!(is_sequential_bytes::<PackedAESBinaryField16x16b>());
assert!(is_sequential_bytes::<PackedAESBinaryField32x16b>());
assert!(is_sequential_bytes::<PackedAESBinaryField1x32b>());
assert!(is_sequential_bytes::<PackedAESBinaryField2x32b>());
assert!(is_sequential_bytes::<PackedAESBinaryField4x32b>());
assert!(is_sequential_bytes::<PackedAESBinaryField16x32b>());
assert!(is_sequential_bytes::<PackedAESBinaryField1x64b>());
assert!(is_sequential_bytes::<PackedAESBinaryField2x64b>());
assert!(is_sequential_bytes::<PackedAESBinaryField4x64b>());
assert!(is_sequential_bytes::<PackedAESBinaryField8x64b>());
assert!(is_sequential_bytes::<PackedAESBinaryField1x128b>());
assert!(is_sequential_bytes::<PackedAESBinaryField2x128b>());
assert!(is_sequential_bytes::<PackedAESBinaryField4x128b>());
assert!(is_sequential_bytes::<BinaryField128bPolyval>());
assert!(is_sequential_bytes::<PackedBinaryPolyval1x128b>());
assert!(is_sequential_bytes::<PackedBinaryPolyval2x128b>());
assert!(is_sequential_bytes::<PackedBinaryPolyval4x128b>());
assert!(!is_sequential_bytes::<PackedBinaryField1x1b>());
assert!(!is_sequential_bytes::<PackedBinaryField2x1b>());
assert!(!is_sequential_bytes::<PackedBinaryField4x1b>());
assert!(!is_sequential_bytes::<ByteSlicedAES32x128b>());
assert!(!is_sequential_bytes::<ByteSlicedAES32x64b>());
assert!(!is_sequential_bytes::<ByteSlicedAES32x32b>());
assert!(!is_sequential_bytes::<ByteSlicedAES32x16b>());
assert!(!is_sequential_bytes::<ByteSlicedAES32x8b>());
}
}