1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
// Copyright 2024 Ulvetanna Inc.

use binius_field::{square_transpose, ExtensionField, Field, PackedExtension};
use binius_utils::checked_arithmetics::checked_log_2;
use std::{
	iter::Sum,
	marker::PhantomData,
	mem,
	ops::{Add, AddAssign, Sub, SubAssign},
};

/// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields.
///
/// A tensor algebra element is a length $D$ vector of `FE` field elements, where $D$ is the degree
/// of `FE` as an extension of `F`. The algebra has a "vertical subring" and a "horizontal subring",
/// which are both isomorphic to `FE` as a field.
///
/// See [DP24] Section 2 for further details.
///
/// [DP24]: <https://eprint.iacr.org/2024/504>
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	pub elems: Vec<FE>,
	_marker: PhantomData<F>,
}

impl<F, FE> Default for TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	fn default() -> Self {
		Self {
			elems: vec![FE::default(); FE::DEGREE],
			_marker: PhantomData,
		}
	}
}

impl<F, FE> TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	/// Constructs an element from a vector of vertical subring elements.
	///
	/// ## Preconditions
	///
	/// * `elems` must have length `FE::DEGREE`, otherwise this will pad or truncate.
	pub fn new(mut elems: Vec<FE>) -> Self {
		elems.resize(FE::DEGREE, FE::ZERO);
		Self {
			elems,
			_marker: PhantomData,
		}
	}

	/// Returns $\kappa$, the base-2 logarithm of the extension degree.
	pub const fn kappa() -> usize {
		checked_log_2(FE::DEGREE)
	}

	/// Returns the byte size of an element.
	pub fn byte_size() -> usize {
		mem::size_of::<FE>() << Self::kappa()
	}

	/// Returns the multiplicative identity element, one.
	pub fn one() -> Self {
		let mut one = Self::default();
		one.elems[0] = FE::ONE;
		one
	}

	/// Returns a slice of the vertical subfield elements composing the tensor algebra element.
	pub fn vertical_elems(&self) -> &[FE] {
		&self.elems
	}

	/// Tensor product of a vertical subring element and a horizontal subring element.
	pub fn tensor(vertical: FE, horizontal: FE) -> Self {
		let elems = horizontal
			.iter_bases()
			.map(|base| vertical * base)
			.collect();
		Self {
			elems,
			_marker: PhantomData,
		}
	}

	/// Constructs a [`TensorAlgebra`] in the vertical subring.
	pub fn from_vertical(x: FE) -> Self {
		let mut elems = vec![FE::ZERO; FE::DEGREE];
		elems[0] = x;
		Self {
			elems,
			_marker: PhantomData,
		}
	}

	/// If the algebra element lives in the vertical subring, this returns it as a field element.
	pub fn try_extract_vertical(&self) -> Option<FE> {
		self.elems
			.iter()
			.skip(1)
			.all(|&elem| elem == FE::ZERO)
			.then_some(self.elems[0])
	}

	/// Multiply by an element from the vertical subring.
	pub fn scale_vertical(mut self, scalar: FE) -> Self {
		for elem_i in self.elems.iter_mut() {
			*elem_i *= scalar;
		}
		self
	}
}

impl<F, FE> TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F> + PackedExtension<F>,
	FE::Scalar: ExtensionField<F>,
{
	/// Multiply by an element from the vertical subring.
	///
	/// Internally, this performs a transpose, vertical scaling, then transpose sequence. If
	/// multiple horizontal scaling operations are required and performance is a concern, it may be
	/// better for the caller to do the transposes directly and amortize their cost.
	pub fn scale_horizontal(self, scalar: FE) -> Self {
		self.transpose().scale_vertical(scalar).transpose()
	}

	/// Transposes the algebra element.
	///
	/// A transpose flips the vertical and horizontal subring elements.
	pub fn transpose(mut self) -> Self {
		square_transpose(Self::kappa(), FE::cast_bases_mut(&mut self.elems))
			.expect("transpose dimensions are square by struct invariant");
		self
	}
}

impl<F, FE> Add<&Self> for TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	type Output = Self;

	fn add(mut self, rhs: &Self) -> Self {
		self.add_assign(rhs);
		self
	}
}

impl<F, FE> Sub<&Self> for TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	type Output = Self;

	fn sub(mut self, rhs: &Self) -> Self {
		self.sub_assign(rhs);
		self
	}
}

impl<F, FE> AddAssign<&Self> for TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	fn add_assign(&mut self, rhs: &Self) {
		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
			*self_i += *rhs_i;
		}
	}
}

impl<F, FE> SubAssign<&Self> for TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	fn sub_assign(&mut self, rhs: &Self) {
		for (self_i, rhs_i) in self.elems.iter_mut().zip(rhs.elems.iter()) {
			*self_i -= *rhs_i;
		}
	}
}

impl<'a, F, FE> Sum<&'a Self> for TensorAlgebra<F, FE>
where
	F: Field,
	FE: ExtensionField<F>,
{
	fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
		iter.fold(Self::default(), |sum, item| sum + item)
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use binius_field::{BinaryField128b, BinaryField8b};
	use rand::{rngs::StdRng, SeedableRng};

	#[test]
	fn test_tensor_product() {
		type F = BinaryField8b;
		type FE = BinaryField128b;

		let mut rng = StdRng::seed_from_u64(0);

		let vert = FE::random(&mut rng);
		let hztl = FE::random(&mut rng);

		let expected = TensorAlgebra::<F, _>::from_vertical(vert).scale_horizontal(hztl);
		assert_eq!(TensorAlgebra::tensor(vert, hztl), expected);
	}

	#[test]
	fn test_try_extract_vertical() {
		type F = BinaryField8b;
		type FE = BinaryField128b;

		let mut rng = StdRng::seed_from_u64(0);

		let vert = FE::random(&mut rng);
		let elem = TensorAlgebra::<F, _>::from_vertical(vert);
		assert_eq!(elem.try_extract_vertical(), Some(vert));

		// Scale horizontally by an extension element, and we should no longer be vertical.
		let hztl = FE::new(1111);
		let elem = elem.scale_horizontal(hztl);
		assert_eq!(elem.try_extract_vertical(), None);

		// Scale back by the inverse to get back to the vertical subring.
		let hztl_inv = hztl.invert().unwrap();
		let elem = elem.scale_horizontal(hztl_inv);
		assert_eq!(elem.try_extract_vertical(), Some(vert));

		// If we scale horizontally by an F element, we should remain in the vertical subring.
		let hztl_subfield = FE::from(F::new(7));
		let elem = elem.scale_horizontal(hztl_subfield);
		assert_eq!(elem.try_extract_vertical(), Some(vert * hztl_subfield));
	}
}