quantica/ml_dsa/ntt.rs
1//! Number Theoretic Transform for ML-DSA (FIPS 204, Algorithms 41-45).
2//!
3//! Implements the forward and inverse NTT over `Z_q[X]/(X^256 + 1)` with
4//! q = 8380417 and zeta = 1753 (a primitive 512th root of unity mod q).
5//! The NTT uses 8-bit reversal (BitRev_8) and goes all the way down to
6//! length-1 butterflies, so pointwise multiplication is simple element-wise
7//! multiplication.
8
9use super::params::{N, N_INV, Q};
10
11const Q64: i64 = Q as i64;
12
13/// Montgomery constant: q^{-1} mod 2^32.
14const QINV: i32 = 58728449;
15/// R^2 mod q where R = 2^32.
16const R_SQ_MOD_Q: i64 = 2365951;
17
18/// Reduce `a` modulo q to the range [0, q-1].
19#[inline(always)]
20pub fn mod_q(a: i32) -> i32 {
21 let mut r = a % Q;
22 r += Q & (r >> 31);
23 r
24}
25
26/// Modular multiplication: `(a * b) mod q`.
27#[inline(always)]
28pub fn mul_mod_q(a: i32, b: i32) -> i32 {
29 mod_q(((a as i64 * b as i64) % Q64) as i32)
30}
31
32/// Montgomery reduction: a * R^{-1} mod q, where R = 2^32.
33/// Input: |a| < q * R. Output: roughly [-q, q].
34#[inline(always)]
35fn montgomery_reduce(a: i64) -> i32 {
36 let t = (a as i32).wrapping_mul(QINV) as i64;
37 ((a - t * Q64) >> 32) as i32
38}
39
40/// Montgomery multiply: (a * b * R^{-1}) mod q.
41#[inline(always)]
42fn mont_mul(a: i32, b: i32) -> i32 {
43 montgomery_reduce(a as i64 * b as i64)
44}
45
46/// Convert to Montgomery domain: a * R mod q.
47#[inline(always)]
48fn to_mont(a: i32) -> i32 {
49 montgomery_reduce(a as i64 * R_SQ_MOD_Q)
50}
51
52/// Precomputed zetas in Montgomery domain (compile-time).
53const fn compute_zetas_mont() -> [i32; N] {
54 let mut table = [0i32; N];
55 let mut k = 0;
56 while k < N {
57 let rev = bitrev8(k as u8) as u64;
58 let z = pow_mod(1753, rev, Q64);
59 // to_mont(z): montgomery_reduce(z * R_SQ_MOD_Q)
60 let a = z as i64 * R_SQ_MOD_Q;
61 let t = (a as i32).wrapping_mul(QINV) as i64;
62 table[k] = ((a - t * Q64) >> 32) as i32;
63 k += 1;
64 }
65 table
66}
67
68const ZETAS_MONT: [i32; N] = compute_zetas_mont();
69
70/// Reverse 8 bits of k.
71const fn bitrev8(k: u8) -> u8 {
72 let mut r = 0u8;
73 let mut v = k;
74 let mut i = 0;
75 while i < 8 {
76 r = (r << 1) | (v & 1);
77 v >>= 1;
78 i += 1;
79 }
80 r
81}
82
83/// Modular exponentiation: base^exp mod m (const fn).
84const fn pow_mod(mut base: i64, mut exp: u64, m: i64) -> i64 {
85 let mut result = 1i64;
86 base %= m;
87 while exp > 0 {
88 if exp & 1 == 1 {
89 result = (result * base) % m;
90 }
91 exp >>= 1;
92 base = (base * base) % m;
93 }
94 result
95}
96
97/// Precomputed zetas: zetas[k] = zeta^{BitRev_8(k)} mod q for k = 0..255.
98const fn compute_zetas() -> [i32; N] {
99 let mut table = [0i32; N];
100 let mut k = 0usize;
101 while k < N {
102 let rev = bitrev8(k as u8) as u64;
103 let val = pow_mod(1753, rev, Q as i64);
104 table[k] = val as i32;
105 k += 1;
106 }
107 table
108}
109
110/// Precomputed zeta table: `ZETAS[k] = zeta^{BitRev_8(k)} mod q` for k in 0..256.
111///
112/// This table is computed at compile time via `compute_zetas` (private to this module). Entry 0 is
113/// always 1 (since `zeta^0 = 1`). The NTT and inverse NTT index into this
114/// table sequentially during their butterfly passes.
115pub const ZETAS: [i32; N] = compute_zetas();
116
117/// Forward NTT (Algorithm 41 of FIPS 204).
118///
119/// Transforms a polynomial `f` from the standard domain to the NTT domain
120/// in place. Input coefficients should be in [0, q-1]; output coefficients
121/// are also in [0, q-1].
122///
123/// After this call, `f` represents the evaluation of the original polynomial
124/// at the 256 roots of unity used by ML-DSA.
125pub fn ntt(f: &mut [i32; N]) {
126 let mut m = 0usize;
127 let mut len = 128;
128 while len >= 1 {
129 let mut start = 0;
130 while start < N {
131 m += 1;
132 let zeta = ZETAS_MONT[m];
133 let mut j = start;
134 while j < start + len {
135 // Montgomery butterfly: mont_mul(zeta_R, f) = zeta*f (R cancels)
136 let t = mont_mul(zeta, f[j + len]);
137 f[j + len] = f[j] - t;
138 f[j] = f[j] + t;
139 j += 1;
140 }
141 start += 2 * len;
142 }
143 len /= 2;
144 }
145 // Reduce to [0, q-1]
146 for c in f.iter_mut() {
147 *c = mod_q(*c);
148 }
149}
150
151/// Inverse NTT (Algorithm 42 of FIPS 204).
152///
153/// Transforms a polynomial `f` from the NTT domain back to the standard
154/// domain in place, including the final scaling by N^{-1} mod q.
155///
156/// If `ntt(f)` was called first, then `ntt_inv(f)` recovers the original
157/// polynomial exactly.
158/// iNTT scaling factor: R² · 256⁻¹ mod q = 41978.
159/// Compensates both the 256⁻¹ normalization and the /R from pointwise_mul.
160const F_SCALE_DSA: i32 = 41978;
161
162pub fn ntt_inv(f: &mut [i32; N]) {
163 let mut m = N; // 256
164 let mut len = 1;
165 while len <= 128 {
166 let mut start = 0;
167 while start < N {
168 m -= 1;
169 let neg_zeta = mod_q(-ZETAS_MONT[m]);
170 let mut j = start;
171 while j < start + len {
172 let t = f[j];
173 f[j] = t + f[j + len];
174 f[j + len] = montgomery_reduce(neg_zeta as i64 * (t - f[j + len]) as i64);
175 j += 1;
176 }
177 start += 2 * len;
178 }
179 len *= 2;
180 }
181 for coeff in f.iter_mut() {
182 *coeff = montgomery_reduce(F_SCALE_DSA as i64 * *coeff as i64);
183 }
184}
185
186/// Pointwise multiplication of two NTT-domain polynomials.
187///
188/// Implements Algorithm 45 of FIPS 204. Because the ML-DSA NTT decomposes
189/// all the way down to length-1 components, this is a simple element-wise
190/// modular multiplication (no base-case Karatsuba needed).
191/// Pointwise multiplication — full Montgomery.
192/// Output is in /R domain. Use iNTT(F_SCALE_DSA) to compensate.
193/// For accumulation in NTT domain (KeyGen), call to_mont_poly after.
194pub fn pointwise_mul(a: &[i32; N], b: &[i32; N]) -> [i32; N] {
195 let mut c = [0i32; N];
196 for i in 0..N {
197 c[i] = mont_mul(a[i], b[i]);
198 }
199 c
200}
201
202/// Convert polynomial from /R to normal domain (multiply each coeff by R).
203pub fn to_mont_poly(f: &mut [i32; N]) {
204 for c in f.iter_mut() {
205 *c = montgomery_reduce(*c as i64 * R_SQ_MOD_Q);
206 }
207}
208
209/// Add two polynomials coefficient-wise, reducing each result modulo q.
210///
211/// Returns a new polynomial `c` where `c[i] = (a[i] + b[i]) mod q`.
212pub fn poly_add(a: &[i32; N], b: &[i32; N]) -> [i32; N] {
213 let mut c = [0i32; N];
214 for i in 0..N {
215 c[i] = mod_q(a[i] + b[i]);
216 }
217 c
218}
219
220/// Subtract two polynomials coefficient-wise, reducing each result modulo q.
221///
222/// Returns a new polynomial `c` where `c[i] = (a[i] - b[i]) mod q`.
223pub fn poly_sub(a: &[i32; N], b: &[i32; N]) -> [i32; N] {
224 let mut c = [0i32; N];
225 for i in 0..N {
226 c[i] = mod_q(a[i] - b[i]);
227 }
228 c
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_zetas_0() {
237 // zetas[0] = zeta^{BitRev_8(0)} = zeta^0 = 1
238 assert_eq!(ZETAS[0], 1);
239 }
240
241 #[test]
242 fn test_ntt_roundtrip() {
243 // Full pipeline: NTT → pointwise_mul (identity) → iNTT
244 let mut f = [0i32; N];
245 for i in 0..N {
246 f[i] = (i as i32 * 17 + 3) % Q;
247 }
248 let orig = f;
249 ntt(&mut f);
250 // Multiply by NTT(1) = [1, 0, 0, ..., 0] in NTT domain
251 let mut one = [0i32; N];
252 one[0] = 1;
253 ntt(&mut one);
254 let h = pointwise_mul(&one, &f);
255 let mut result = h;
256 ntt_inv(&mut result);
257 for i in 0..N {
258 let r = mod_q(result[i]);
259 assert_eq!(r, orig[i], "mismatch at index {}: got {} expected {}", i, r, orig[i]);
260 }
261 }
262
263 #[test]
264 fn test_mod_q_negative() {
265 assert_eq!(mod_q(-1), Q - 1);
266 assert_eq!(mod_q(0), 0);
267 assert_eq!(mod_q(Q), 0);
268 }
269}