Skip to main content

quantica/ml_kem/
ntt.rs

1/// Number-Theoretic Transform for ML-KEM (FIPS 203 Section 4.3).
2///
3/// Full Montgomery arithmetic: all NTT butterflies, basemul, and iNTT use
4/// Montgomery multiplication (shifts instead of divisions).
5///
6/// q = 3329, R = 2^16. Coefficients stored as i16 with lazy reduction.
7/// Bounds: after NTT |c| <= 4q, after basemul |c| <= q, after iNTT |c| <= 8q.
8/// All bounds fit in i16 range [-32768, 32767] since 8q = 26632 < 32768.
9use super::params::N;
10
11// ---- Constants ----
12
13pub const Q: i16 = 3329;
14const Q32: i32 = 3329;
15const QINV: i32 = -3327; // q^{-1} mod 2^16 (signed)
16const R_SQ_MOD_Q: i32 = 1353; // R² mod q
17
18/// Zetas: ζ^{BitRev_7(i)} mod q (FIPS 203 Appendix A)
19const ZETAS: [i16; 128] = [
20    1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476,
21    3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
22    650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
23    2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
24    1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220,
25    375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
26    1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
27];
28
29/// Gammas: γᵢ = ζ^{2·BitRev_7(i)+1} mod q (FIPS 203 Appendix A)
30const GAMMAS: [i16; 128] = [
31    17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,
32    3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540,
33    1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
34    268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684,
35    1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237,
36    403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029,
37    2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
38];
39
40// ---- Compile-time precomputation ----
41
42const fn mont_reduce_const(a: i32) -> i16 {
43    let t = (a as i16).wrapping_mul(QINV as i16);
44    ((a - t as i32 * Q32) >> 16) as i16
45}
46
47const fn to_mont_const(a: i16) -> i16 {
48    mont_reduce_const(a as i32 * R_SQ_MOD_Q)
49}
50
51const fn compute_table_mont(table: &[i16; 128]) -> [i16; 128] {
52    let mut r = [0i16; 128];
53    let mut i = 0;
54    while i < 128 {
55        r[i] = to_mont_const(table[i]);
56        i += 1;
57    }
58    r
59}
60
61const ZETAS_MONT: [i16; 128] = compute_table_mont(&ZETAS);
62const GAMMAS_MONT: [i16; 128] = compute_table_mont(&GAMMAS);
63
64/// iNTT scaling factor: R² · 128⁻¹ mod q = 1441.
65/// Compensates: 128⁻¹ (NTT normalization) + R⁻¹ (from basemul via mont_mul).
66const F_SCALE: i16 = 1441;
67
68// ---- Montgomery arithmetic ----
69
70/// Montgomery reduction: a·R⁻¹ mod q. Result bounded by |q|.
71#[inline(always)]
72fn montgomery_reduce(a: i32) -> i16 {
73    let t = (a as i16).wrapping_mul(QINV as i16);
74    ((a - t as i32 * Q32) >> 16) as i16
75}
76
77/// Montgomery multiply: a·b·R⁻¹ mod q.
78#[inline(always)]
79fn mont_mul(a: i16, b: i16) -> i16 {
80    montgomery_reduce(a as i32 * b as i32)
81}
82
83/// Barrett reduction to [0, q-1]. Constant-time.
84#[inline(always)]
85pub fn barrett_reduce(a: i16) -> i16 {
86    let t = ((20159i32 * a as i32 + (1 << 25)) >> 26) as i16;
87    let mut r = a - t.wrapping_mul(Q);
88    r += (r >> 15) & Q;
89    r
90}
91
92// ---- NTT ----
93
94/// Forward NTT (Algorithm 9) — full Montgomery.
95///
96/// Input: coefficients in [0, q-1].
97/// Output: NTT-domain coefficients, |c| <= 4q (lazy reduced).
98pub fn ntt(f: &mut [i16; N]) {
99    let mut k = 1usize;
100    let mut len = 128;
101    while len >= 2 {
102        let mut start = 0;
103        while start < N {
104            let zeta = ZETAS_MONT[k];
105            k += 1;
106            for j in start..start + len {
107                let t = mont_mul(zeta, f[j + len]);
108                f[j + len] = f[j] - t;
109                f[j] = f[j] + t;
110            }
111            start += 2 * len;
112        }
113        len >>= 1;
114    }
115    // Reduce output to [0,q-1] for basemul
116    for c in f.iter_mut() {
117        *c = barrett_reduce(*c);
118    }
119}
120
121/// Inverse NTT (Algorithm 10) — full Montgomery.
122///
123/// Input: NTT-domain (from basemul, in /R domain), |c| <= q.
124/// Output: standard-domain, |c| <= q after final scaling.
125///
126/// Uses **negative** Montgomery zetas: `-ZETAS_MONT[k]`.
127/// GS butterfly: `f[j] = t + f[j+l]`, `f[j+l] = fqmul(-z, t - f[j+l])`.
128/// Final scaling by F_SCALE = 1441 = R²·128⁻¹ mod q.
129pub fn ntt_inv(f: &mut [i16; N]) {
130    // Use i32 working array to avoid i16 overflow in butterfly additions.
131    // The accumulation through 7 GS levels can reach ~8q ≈ 26632,
132    // which fits i16 but leaves no margin. Using i32 eliminates this concern.
133    let mut w = [0i32; N];
134    for i in 0..N {
135        w[i] = f[i] as i32;
136    }
137
138    let mut k = 127usize;
139    let mut len = 2;
140    while len <= 128 {
141        let mut start = 0;
142        while start < N {
143            let neg_zeta = ZETAS_MONT[k].wrapping_neg();
144            k = k.wrapping_sub(1);
145            for j in start..start + len {
146                let t = w[j];
147                let u = w[j + len];
148                w[j] = t + u;
149                // mont_reduce(neg_zeta * (t-u)) with i32 diff
150                w[j + len] = montgomery_reduce(neg_zeta as i32 * ((t - u) as i32)) as i32;
151            }
152            start += 2 * len;
153        }
154        len <<= 1;
155    }
156    for i in 0..N {
157        // w[i] can exceed i16 range; use montgomery_reduce with i32 multiply
158        f[i] = montgomery_reduce(F_SCALE as i32 * w[i]) as i16;
159    }
160}
161
162/// Pointwise basemul in NTT domain (Algorithms 11+12) — full Montgomery.
163///
164/// Each pair `(h[2i], h[2i+1])` is the product of input pairs modulo `(X²-γᵢ)`.
165/// Output is in **/R domain** (one Montgomery division per coefficient).
166///
167/// For the encrypt/decrypt pipeline (basemul → iNTT), the /R is compensated
168/// by the F_SCALE factor in iNTT.
169///
170/// For KeyGen (basemul accumulate in NTT domain), call [`to_mont_poly`] to
171/// convert from /R to normal domain before adding NTT(e).
172pub fn multiply_ntts(f: &[i16; N], g: &[i16; N], h: &mut [i16; N]) {
173    for i in 0..128 {
174        let gamma = GAMMAS_MONT[i];
175        let a0 = f[2 * i];
176        let a1 = f[2 * i + 1];
177        let b0 = g[2 * i];
178        let b1 = g[2 * i + 1];
179        // c0 = a0·b0·R⁻¹ + a1·b1·R⁻¹·γ·R·R⁻¹ = (a0·b0 + a1·b1·γ)·R⁻¹
180        let t = mont_mul(mont_mul(a1, b1), gamma);
181        h[2 * i] = mont_mul(a0, b0) + t;
182        // c1 = (a0·b1 + a1·b0)·R⁻¹
183        h[2 * i + 1] = mont_mul(a0, b1) + mont_mul(a1, b0);
184    }
185}
186
187/// Convert polynomial from /R domain to normal domain (multiply by R).
188///
189/// Used in KeyGen after accumulating basemul results, before adding NTT(e).
190/// `to_mont(c) = c·R mod q` via Montgomery: `mont_reduce(c · R²_mod_q)`.
191pub fn to_mont_poly(f: &mut [i16; N]) {
192    for c in f.iter_mut() {
193        *c = montgomery_reduce(*c as i32 * R_SQ_MOD_Q);
194    }
195}
196
197// ---- Polynomial helpers ----
198
199pub fn poly_add(a: &[i16; N], b: &[i16; N], c: &mut [i16; N]) {
200    for i in 0..N {
201        c[i] = a[i] + b[i];
202    }
203}
204
205pub fn poly_sub(a: &[i16; N], b: &[i16; N], c: &mut [i16; N]) {
206    for i in 0..N {
207        c[i] = a[i] - b[i];
208    }
209}
210
211/// Reduce all coefficients to [0, q-1].
212pub fn reduce(f: &mut [i16; N]) {
213    for c in f.iter_mut() {
214        *c = barrett_reduce(*c);
215    }
216}
217
218// ---- Zeroize ----
219
220#[inline(never)]
221pub fn zeroize_poly(f: &mut [i16; N]) {
222    for c in f.iter_mut() {
223        unsafe { core::ptr::write_volatile(c, 0) };
224    }
225    core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
226}
227
228#[inline(never)]
229pub fn zeroize_bytes(b: &mut [u8]) {
230    for byte in b.iter_mut() {
231        unsafe { core::ptr::write_volatile(byte, 0) };
232    }
233    core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
234}
235
236// ---- Tests ----
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_ntt_basemul_intt_simple() {
244        // f(X) = 42, g(X) = 1 + X → product = 42 + 42X
245        let mut f = [0i16; N];
246        f[0] = 42;
247        let mut g = [0i16; N];
248        g[0] = 1;
249        g[1] = 1;
250
251        ntt(&mut f);
252        ntt(&mut g);
253
254        let mut h = [0i16; N];
255        multiply_ntts(&f, &g, &mut h);
256        ntt_inv(&mut h);
257        reduce(&mut h);
258
259        assert_eq!(h[0], 42, "h[0]={} expected 42", h[0]);
260        assert_eq!(h[1], 42, "h[1]={} expected 42", h[1]);
261        for i in 2..N {
262            assert_eq!(h[i], 0, "h[{}]={} expected 0", i, h[i]);
263        }
264    }
265
266    #[test]
267    fn test_ntt_basemul_intt_identity() {
268        // NTT(1) × NTT(b) → iNTT should recover b
269        let mut one = [0i16; N];
270        one[0] = 1;
271        let mut b = [0i16; N];
272        for i in 0..N {
273            b[i] = (i as i16 * 7 + 13) % Q;
274        }
275        let orig = b;
276
277        ntt(&mut one);
278        ntt(&mut b);
279        let mut c = [0i16; N];
280        multiply_ntts(&one, &b, &mut c);
281        ntt_inv(&mut c);
282        reduce(&mut c);
283
284        for i in 0..N {
285            assert_eq!(c[i], orig[i], "mismatch at {}: got {} expected {}", i, c[i], orig[i]);
286        }
287    }
288
289    #[test]
290    fn test_to_mont_poly_keygen_pattern() {
291        // KeyGen: basemul accumulate → to_mont → add e
292        let mut a = [0i16; N];
293        a[0] = 100;
294        let mut s = [0i16; N];
295        s[0] = 1;
296        let mut e = [0i16; N];
297        e[0] = 5;
298
299        ntt(&mut a);
300        ntt(&mut s);
301        ntt(&mut e);
302
303        let mut t = [0i16; N];
304        multiply_ntts(&a, &s, &mut t);
305        // t is in /R domain
306        to_mont_poly(&mut t);
307        // t is now in normal domain, safe to add e
308        for i in 0..N {
309            t[i] = t[i] + e[i];
310        }
311
312        // Verify: decode via exact iNTT of t should give [105, 0, ...]
313        // But we can't directly iNTT (it expects /R input).
314        // Instead verify the encoded t_hat matches what exact arithmetic gives.
315        reduce(&mut t);
316        // t[0] should be the NTT-domain value corresponding to polynomial 105
317        // The exact test is done by the ACVP KeyGen vectors.
318    }
319
320    #[test]
321    fn test_montgomery_reduce_basic() {
322        let a: i16 = 1729;
323        let in_mont = to_mont_const(a);
324        let back = montgomery_reduce(in_mont as i32);
325        let back_reduced = barrett_reduce(back);
326        assert_eq!(back_reduced, a);
327    }
328}