binius_core/reed_solomon/
reed_solomon.rsuse std::marker::PhantomData;
use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension};
use binius_maybe_rayon::prelude::*;
use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings};
use binius_utils::{bail, checked_arithmetics::checked_log_2};
use getset::CopyGetters;
use tracing::instrument;
#[derive(Debug, CopyGetters)]
pub struct ReedSolomonCode<P>
where
P: PackedField,
P::Scalar: BinaryField,
{
ntt: DynamicDispatchNTT<P::Scalar>,
log_dimension: usize,
#[getset(get_copy = "pub")]
log_inv_rate: usize,
multithreaded: bool,
_p_marker: PhantomData<P>,
}
impl<P> ReedSolomonCode<P>
where
P: PackedField<Scalar: BinaryField>,
{
pub fn new(
log_dimension: usize,
log_inv_rate: usize,
ntt_options: &NTTOptions,
) -> Result<Self, Error> {
let ntt_log_threads = ntt_options
.thread_settings
.log_threads_count()
.saturating_sub(log_inv_rate);
let ntt = DynamicDispatchNTT::new(
log_dimension + log_inv_rate,
&NTTOptions {
thread_settings: ThreadingSettings::ExplicitThreadsCount {
log_threads: ntt_log_threads,
},
precompute_twiddles: ntt_options.precompute_twiddles,
},
)?;
let multithreaded =
!matches!(ntt_options.thread_settings, ThreadingSettings::SingleThreaded);
Ok(Self {
ntt,
log_dimension,
log_inv_rate,
multithreaded,
_p_marker: PhantomData,
})
}
pub const fn get_ntt(&self) -> &impl AdditiveNTT<P> {
&self.ntt
}
pub const fn dim(&self) -> usize {
1 << self.dim_bits()
}
pub const fn log_dim(&self) -> usize {
self.log_dimension
}
pub const fn log_len(&self) -> usize {
self.log_dimension + self.log_inv_rate
}
#[allow(clippy::len_without_is_empty)]
pub const fn len(&self) -> usize {
1 << (self.log_dimension + self.log_inv_rate)
}
const fn dim_bits(&self) -> usize {
self.log_dimension
}
pub const fn inv_rate(&self) -> usize {
1 << self.log_inv_rate
}
fn encode_batch_inplace(&self, code: &mut [P], log_batch_size: usize) -> Result<(), Error> {
let _scope = tracing::trace_span!(
"Reed–Solomon encode",
log_len = self.log_len(),
log_batch_size = log_batch_size,
symbol_bits = P::Scalar::N_BITS,
)
.entered();
if (code.len() << log_batch_size) < self.len() {
bail!(Error::BufferTooSmall {
log_code_len: self.len(),
});
}
if self.dim() % P::WIDTH != 0 {
bail!(Error::PackingWidthMustDivideDimension);
}
let msgs_len = (self.dim() / P::WIDTH) << log_batch_size;
for i in 1..(1 << self.log_inv_rate) {
code.copy_within(0..msgs_len, i * msgs_len);
}
if self.multithreaded {
(0..(1 << self.log_inv_rate))
.into_par_iter()
.zip(code.par_chunks_exact_mut(msgs_len))
.try_for_each(|(i, data)| self.ntt.forward_transform(data, i, log_batch_size))
} else {
(0..(1 << self.log_inv_rate))
.zip(code.chunks_exact_mut(msgs_len))
.try_for_each(|(i, data)| self.ntt.forward_transform(data, i, log_batch_size))
}
}
#[instrument(skip_all, level = "debug")]
pub fn encode_ext_batch_inplace<PE>(
&self,
code: &mut [PE],
log_batch_size: usize,
) -> Result<(), Error>
where
PE: RepackedExtension<P>,
PE::Scalar: ExtensionField<<P as PackedField>::Scalar>,
{
let log_degree = checked_log_2(PE::Scalar::DEGREE);
self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + log_degree)
}
}