1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
// Copyright 2024 Ulvetanna Inc.

use crate::polynomial::{Error, MultilinearExtension, MultivariatePoly};
use binius_field::{Field, PackedField, TowerField};
use binius_utils::bail;
use std::marker::PhantomData;

/// Represents the $\mathcal{T}_{\iota}$-basis of $\mathcal{T}_{\iota+k}$
///
/// Recall that $\mathcal{T}_{\iota}$ is defined as
/// * Let \mathbb{F} := \mathbb{F}_2[X_0, \ldots, X_{\iota-1}]
/// * Let \mathcal{J} := (X_0^2 + X_0 + 1, \ldots, X_{\iota-1}^2 + X_{\iota-1}X_{\iota-2} + 1)
/// * $\mathcal{T}_{\iota} := \mathbb{F} / J $
///
/// and $\mathcal{T}_{\iota}$ has the following $\mathbb{F}_2$-basis:
/// * $1, X_0, X_1, X_0X_1, X_2, \ldots, X_0 X_1 \ldots X_{\iota-1}$
///
/// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$:
/// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1} \ldots X_{\iota+k-1}$
#[derive(Debug, Copy, Clone)]
pub struct TowerBasis<F: Field> {
	k: usize,
	iota: usize,
	_marker: PhantomData<F>,
}

impl<F: TowerField> TowerBasis<F> {
	pub fn new(k: usize, iota: usize) -> Result<Self, Error> {
		if iota + k > F::TOWER_LEVEL {
			bail!(Error::ArgumentRangeError {
				arg: "iota + k".into(),
				range: 0..F::TOWER_LEVEL + 1,
			});
		}
		Ok(Self {
			k,
			iota,
			_marker: Default::default(),
		})
	}

	pub fn multilinear_extension<P: PackedField<Scalar = F>>(
		&self,
	) -> Result<MultilinearExtension<P>, Error> {
		let n_values = (1 << self.k) / P::WIDTH;
		let values = (0..n_values)
			.map(|i| {
				let mut packed_value = P::default();
				for j in 0..P::WIDTH {
					let basis_idx = i * P::WIDTH + j;
					let value = TowerField::basis(self.iota, basis_idx)?;
					packed_value.set(j, value);
				}
				Ok(packed_value)
			})
			.collect::<Result<Vec<_>, Error>>()?;

		MultilinearExtension::from_values(values)
	}
}

impl<F> MultivariatePoly<F> for TowerBasis<F>
where
	F: TowerField,
{
	fn n_vars(&self) -> usize {
		self.k
	}

	fn degree(&self) -> usize {
		self.k
	}

	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
		if query.len() != self.k {
			bail!(Error::IncorrectQuerySize { expected: self.k });
		}

		let mut result = F::ONE;
		for (i, query_i) in query.iter().enumerate() {
			let r_comp = F::ONE - query_i;
			let basis_elt = <F as TowerField>::basis(self.iota + i, 1)?;
			result *= r_comp + *query_i * basis_elt;
		}
		Ok(result)
	}

	fn binary_tower_level(&self) -> usize {
		self.iota + self.k
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use crate::polynomial::multilinear_query::MultilinearQuery;
	use binius_field::{BinaryField128b, BinaryField32b, PackedBinaryField4x32b};
	use binius_hal::make_portable_backend;
	use rand::{rngs::StdRng, SeedableRng};
	use std::iter::repeat_with;

	fn test_consistency(iota: usize, k: usize) {
		type F = BinaryField128b;
		let mut rng = StdRng::seed_from_u64(0);
		let backend = make_portable_backend();

		let basis = TowerBasis::<F>::new(k, iota).unwrap();
		let challenge = repeat_with(|| <F as Field>::random(&mut rng))
			.take(k)
			.collect::<Vec<_>>();

		let eval1 = basis.evaluate(&challenge).unwrap();
		let multilin_query =
			MultilinearQuery::<F, _>::with_full_query(&challenge, backend.clone()).unwrap();
		let mle = basis.multilinear_extension::<F>().unwrap();
		let eval2 = mle.evaluate(&multilin_query).unwrap();

		assert_eq!(eval1, eval2);
	}

	#[test]
	fn test_consistency_packing() {
		let iota = 2;
		let kappa = 3;
		type F = BinaryField32b;
		type P = PackedBinaryField4x32b;
		let mut rng = StdRng::seed_from_u64(0);
		let backend = make_portable_backend();

		let basis = TowerBasis::<F>::new(kappa, iota).unwrap();
		let challenge = repeat_with(|| <F as Field>::random(&mut rng))
			.take(kappa)
			.collect::<Vec<_>>();
		let eval1 = basis.evaluate(&challenge).unwrap();
		let multilin_query =
			MultilinearQuery::<F, _>::with_full_query(&challenge, backend).unwrap();
		let mle = basis.multilinear_extension::<P>().unwrap();
		let eval2 = mle.evaluate(&multilin_query).unwrap();
		assert_eq!(eval1, eval2);
	}

	#[test]
	fn test_consistency_all() {
		for iota in 0..=7 {
			for k in 0..=(7 - iota) {
				test_consistency(iota, k);
			}
		}
	}
}