Skip to main content

quantica/ml_dsa/
ntt.rs

1//! Number Theoretic Transform for ML-DSA (FIPS 204, Algorithms 41-45).
2//!
3//! Implements the forward and inverse NTT over `Z_q[X]/(X^256 + 1)` with
4//! q = 8380417 and zeta = 1753 (a primitive 512th root of unity mod q).
5//! The NTT uses 8-bit reversal (BitRev_8) and goes all the way down to
6//! length-1 butterflies, so pointwise multiplication is simple element-wise
7//! multiplication.
8
9use super::params::{N, N_INV, Q};
10
11const Q64: i64 = Q as i64;
12
13/// Montgomery constant: q^{-1} mod 2^32.
14const QINV: i32 = 58728449;
15/// R^2 mod q where R = 2^32.
16const R_SQ_MOD_Q: i64 = 2365951;
17
18/// Reduce `a` modulo q to the range [0, q-1].
19#[inline(always)]
20pub fn mod_q(a: i32) -> i32 {
21    let mut r = a % Q;
22    r += Q & (r >> 31);
23    r
24}
25
26/// Modular multiplication: `(a * b) mod q`.
27#[inline(always)]
28pub fn mul_mod_q(a: i32, b: i32) -> i32 {
29    mod_q(((a as i64 * b as i64) % Q64) as i32)
30}
31
32/// Montgomery reduction: a * R^{-1} mod q, where R = 2^32.
33/// Input: |a| < q * R. Output: roughly [-q, q].
34#[inline(always)]
35fn montgomery_reduce(a: i64) -> i32 {
36    let t = (a as i32).wrapping_mul(QINV) as i64;
37    ((a - t * Q64) >> 32) as i32
38}
39
40/// Montgomery multiply: (a * b * R^{-1}) mod q.
41#[inline(always)]
42fn mont_mul(a: i32, b: i32) -> i32 {
43    montgomery_reduce(a as i64 * b as i64)
44}
45
46/// Convert to Montgomery domain: a * R mod q.
47#[inline(always)]
48fn to_mont(a: i32) -> i32 {
49    montgomery_reduce(a as i64 * R_SQ_MOD_Q)
50}
51
52/// Precomputed zetas in Montgomery domain (compile-time).
53const fn compute_zetas_mont() -> [i32; N] {
54    let mut table = [0i32; N];
55    let mut k = 0;
56    while k < N {
57        let rev = bitrev8(k as u8) as u64;
58        let z = pow_mod(1753, rev, Q64);
59        // to_mont(z): montgomery_reduce(z * R_SQ_MOD_Q)
60        let a = z as i64 * R_SQ_MOD_Q;
61        let t = (a as i32).wrapping_mul(QINV) as i64;
62        table[k] = ((a - t * Q64) >> 32) as i32;
63        k += 1;
64    }
65    table
66}
67
68const ZETAS_MONT: [i32; N] = compute_zetas_mont();
69
70/// Reverse 8 bits of k.
71const fn bitrev8(k: u8) -> u8 {
72    let mut r = 0u8;
73    let mut v = k;
74    let mut i = 0;
75    while i < 8 {
76        r = (r << 1) | (v & 1);
77        v >>= 1;
78        i += 1;
79    }
80    r
81}
82
83/// Modular exponentiation: base^exp mod m (const fn).
84const fn pow_mod(mut base: i64, mut exp: u64, m: i64) -> i64 {
85    let mut result = 1i64;
86    base %= m;
87    while exp > 0 {
88        if exp & 1 == 1 {
89            result = (result * base) % m;
90        }
91        exp >>= 1;
92        base = (base * base) % m;
93    }
94    result
95}
96
97/// Precomputed zetas: zetas[k] = zeta^{BitRev_8(k)} mod q for k = 0..255.
98const fn compute_zetas() -> [i32; N] {
99    let mut table = [0i32; N];
100    let mut k = 0usize;
101    while k < N {
102        let rev = bitrev8(k as u8) as u64;
103        let val = pow_mod(1753, rev, Q as i64);
104        table[k] = val as i32;
105        k += 1;
106    }
107    table
108}
109
110/// Precomputed zeta table: `ZETAS[k] = zeta^{BitRev_8(k)} mod q` for k in 0..256.
111///
112/// This table is computed at compile time via `compute_zetas` (private to this module). Entry 0 is
113/// always 1 (since `zeta^0 = 1`). The NTT and inverse NTT index into this
114/// table sequentially during their butterfly passes.
115pub const ZETAS: [i32; N] = compute_zetas();
116
117/// Forward NTT (Algorithm 41 of FIPS 204).
118///
119/// Transforms a polynomial `f` from the standard domain to the NTT domain
120/// in place. Input coefficients should be in [0, q-1]; output coefficients
121/// are also in [0, q-1].
122///
123/// After this call, `f` represents the evaluation of the original polynomial
124/// at the 256 roots of unity used by ML-DSA.
125pub fn ntt(f: &mut [i32; N]) {
126    let mut m = 0usize;
127    let mut len = 128;
128    while len >= 1 {
129        let mut start = 0;
130        while start < N {
131            m += 1;
132            let zeta = ZETAS_MONT[m];
133            let mut j = start;
134            while j < start + len {
135                // Montgomery butterfly: mont_mul(zeta_R, f) = zeta*f (R cancels)
136                let t = mont_mul(zeta, f[j + len]);
137                f[j + len] = f[j] - t;
138                f[j] = f[j] + t;
139                j += 1;
140            }
141            start += 2 * len;
142        }
143        len /= 2;
144    }
145    // Reduce to [0, q-1]
146    for c in f.iter_mut() {
147        *c = mod_q(*c);
148    }
149}
150
151/// Inverse NTT (Algorithm 42 of FIPS 204).
152///
153/// Transforms a polynomial `f` from the NTT domain back to the standard
154/// domain in place, including the final scaling by N^{-1} mod q.
155///
156/// If `ntt(f)` was called first, then `ntt_inv(f)` recovers the original
157/// polynomial exactly.
158/// iNTT scaling factor: R² · 256⁻¹ mod q = 41978.
159/// Compensates both the 256⁻¹ normalization and the /R from pointwise_mul.
160const F_SCALE_DSA: i32 = 41978;
161
162pub fn ntt_inv(f: &mut [i32; N]) {
163    let mut m = N; // 256
164    let mut len = 1;
165    while len <= 128 {
166        let mut start = 0;
167        while start < N {
168            m -= 1;
169            let neg_zeta = mod_q(-ZETAS_MONT[m]);
170            let mut j = start;
171            while j < start + len {
172                let t = f[j];
173                f[j] = t + f[j + len];
174                f[j + len] = montgomery_reduce(neg_zeta as i64 * (t - f[j + len]) as i64);
175                j += 1;
176            }
177            start += 2 * len;
178        }
179        len *= 2;
180    }
181    for coeff in f.iter_mut() {
182        *coeff = montgomery_reduce(F_SCALE_DSA as i64 * *coeff as i64);
183    }
184}
185
186/// Pointwise multiplication of two NTT-domain polynomials.
187///
188/// Implements Algorithm 45 of FIPS 204. Because the ML-DSA NTT decomposes
189/// all the way down to length-1 components, this is a simple element-wise
190/// modular multiplication (no base-case Karatsuba needed).
191/// Pointwise multiplication — full Montgomery.
192/// Output is in /R domain. Use iNTT(F_SCALE_DSA) to compensate.
193/// For accumulation in NTT domain (KeyGen), call to_mont_poly after.
194pub fn pointwise_mul(a: &[i32; N], b: &[i32; N]) -> [i32; N] {
195    let mut c = [0i32; N];
196    for i in 0..N {
197        c[i] = mont_mul(a[i], b[i]);
198    }
199    c
200}
201
202/// Convert polynomial from /R to normal domain (multiply each coeff by R).
203pub fn to_mont_poly(f: &mut [i32; N]) {
204    for c in f.iter_mut() {
205        *c = montgomery_reduce(*c as i64 * R_SQ_MOD_Q);
206    }
207}
208
209/// Add two polynomials coefficient-wise, reducing each result modulo q.
210///
211/// Returns a new polynomial `c` where `c[i] = (a[i] + b[i]) mod q`.
212pub fn poly_add(a: &[i32; N], b: &[i32; N]) -> [i32; N] {
213    let mut c = [0i32; N];
214    for i in 0..N {
215        c[i] = mod_q(a[i] + b[i]);
216    }
217    c
218}
219
220/// Subtract two polynomials coefficient-wise, reducing each result modulo q.
221///
222/// Returns a new polynomial `c` where `c[i] = (a[i] - b[i]) mod q`.
223pub fn poly_sub(a: &[i32; N], b: &[i32; N]) -> [i32; N] {
224    let mut c = [0i32; N];
225    for i in 0..N {
226        c[i] = mod_q(a[i] - b[i]);
227    }
228    c
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_zetas_0() {
237        // zetas[0] = zeta^{BitRev_8(0)} = zeta^0 = 1
238        assert_eq!(ZETAS[0], 1);
239    }
240
241    #[test]
242    fn test_ntt_roundtrip() {
243        // Full pipeline: NTT → pointwise_mul (identity) → iNTT
244        let mut f = [0i32; N];
245        for i in 0..N {
246            f[i] = (i as i32 * 17 + 3) % Q;
247        }
248        let orig = f;
249        ntt(&mut f);
250        // Multiply by NTT(1) = [1, 0, 0, ..., 0] in NTT domain
251        let mut one = [0i32; N];
252        one[0] = 1;
253        ntt(&mut one);
254        let h = pointwise_mul(&one, &f);
255        let mut result = h;
256        ntt_inv(&mut result);
257        for i in 0..N {
258            let r = mod_q(result[i]);
259            assert_eq!(r, orig[i], "mismatch at index {}: got {} expected {}", i, r, orig[i]);
260        }
261    }
262
263    #[test]
264    fn test_mod_q_negative() {
265        assert_eq!(mod_q(-1), Q - 1);
266        assert_eq!(mod_q(0), 0);
267        assert_eq!(mod_q(Q), 0);
268    }
269}