Skip to main content

quantica/ml_dsa/
dsa.rs

1//! Core ML-DSA algorithms (FIPS 204, Algorithms 1-8).
2//!
3//! Contains the key generation, signing, and verification routines at both
4//! the public API level (Algorithms 1-3) and the internal/deterministic level
5//! (Algorithms 6-8).
6//!
7//! All internal polynomial vector operations use fixed-size stack arrays
8//! (`[[i32; N]; MAX_K]` / `[[i32; N]; MAX_L]`) to avoid heap allocations.
9//!
10//! # Side-channel countermeasures (`sca-protected` feature)
11//!
12//! When the `sca-protected` Cargo feature is enabled (on by default),
13//! `sign_internal` runs an additional layer of defences on the secret-key
14//! material:
15//!
16//! | Countermeasure        | Module                       | Threat addressed                              |
17//! |-----------------------|------------------------------|-----------------------------------------------|
18//! | Constant-time arith   | always-on                    | Cache- / branch-based timing attacks          |
19//! | Zeroization           | always-on                    | Cold-boot dumps, use-after-free               |
20//! | Hedged signing        | always-on                    | Fault-induced nonce reuse (`rnd ≠ 0`)         |
21//! | Shuffled NTT          | `super::shuffle` (sca)     | SPA, trace-alignment for DPA                  |
22//! | First-order masking   | `super::masked` (sca)      | First-order DPA, template attacks             |
23//! | Mask refresh / hop    | `super::masked` (sca)      | Inter-iteration share correlation             |
24//!
25//! The masking + shuffling layer is deliberately confined to `sign_internal`,
26//! because that is where the secret key `(s1, s2, t0)` is consumed in
27//! polynomial multiplications with values an attacker can influence:
28//!
29//! ```text
30//!   ŝ1, ŝ2, t̂0  ←  NTT(s1), NTT(s2), NTT(t0)               // SPA + DPA target
31//!   loop:
32//!       ĉ ← NTT(SampleInBall(c̃))
33//!       cs1[i]  ← ĉ · ŝ1[i]                                 // ×L  — DPA target
34//!       cs2[i]  ← ĉ · ŝ2[i]                                 // ×K  — DPA target
35//!       ct0[i]  ← ĉ · t̂0[i]                                 // ×K  — DPA target
36//! ```
37//!
38//! The challenge polynomial `ĉ` is **public** (the verifier recomputes it),
39//! so every secret×public multiplication only needs first-order masking:
40//! `(s₀ + s₁) · ĉ = s₀·ĉ + s₁·ĉ`. There is no secret×secret operation in
41//! Sign that would require second-order shares.
42//!
43//! Mask randomness is drawn from a SHAKE256-based deterministic
44//! `ScaRng` seeded with `(K ‖ rnd ‖ tr ‖ M')`, so that:
45//!
46//! * `sign_internal` keeps a deterministic signature (no `&mut dyn CryptoRng`
47//!   parameter), and the NIST ACVP fixed-`rnd = 0` test vectors still match
48//!   bit-for-bit;
49//! * different `rnd` values produce independent share streams (hedged
50//!   signing entropy is preserved through the SCA layer).
51//!
52//! The standard build (without `sca-protected`) still benefits from the
53//! always-on countermeasures listed above; only the masking + shuffling
54//! defences are conditionally compiled out.
55
56use super::MlDsaError;
57use super::decompose;
58use super::encode;
59use super::ntt::{self, mod_q};
60use super::params::{D, MAX_K, MAX_L, N, Params, Q};
61use super::rng::CryptoRng;
62use super::sample;
63use super::sha3;
64use alloc::vec::Vec;
65
66#[cfg(any(feature = "compressed-poly", feature = "compressed-challenge"))]
67use super::compressed;
68#[cfg(feature = "sca-protected")]
69use super::masked::{self, MaskedPoly};
70
71#[cfg(all(feature = "sca-protected", feature = "compressed-challenge"))]
72compile_error!(
73    "features `sca-protected` and `compressed-challenge` are mutually exclusive: masking requires NTT-domain multiplication, schoolbook operates in time domain"
74);
75
76#[cfg(all(feature = "sca-protected", feature = "small-secret"))]
77compile_error!("features `sca-protected` and `small-secret` are mutually exclusive: masking operates in i32 domain");
78
79#[cfg(all(feature = "sca-protected", feature = "union-buffer"))]
80compile_error!("features `sca-protected` and `union-buffer` are mutually exclusive");
81
82#[cfg(feature = "sca-protected")]
83use super::sha3::KeccakState;
84#[cfg(feature = "sca-protected")]
85use super::shuffle;
86#[cfg(feature = "small-secret")]
87use super::smallpoly::{self, SmallPoly};
88
89/// Deterministic SHAKE256-based randomness source for the SCA layer.
90///
91/// `sign_internal` does not take a `&mut dyn CryptoRng` parameter
92/// (it must stay fully deterministic so that the NIST ACVP fixed-`rnd`
93/// vectors still match bit-for-bit), so the masking and shuffling
94/// modules cannot reach for [`super::OsRng`] either. Instead they
95/// share a per-call `ScaRng` whose seed is derived from
96/// `(K ‖ rnd ‖ tr ‖ M')` via SHAKE256:
97///
98/// * `K` is the secret-key field used by FIPS 204 hedged signing.
99/// * `rnd` is the 32-byte hedged-signing randomness — all-zero in
100///   deterministic / ACVP test mode, fresh entropy otherwise.
101/// * `tr` and `M'` make the seed bind to the public key + message,
102///   so two signatures over different messages produce uncorrelated
103///   share streams even when `rnd = 0`.
104///
105/// A short domain-separation tag (`b"quantica-mldsa-sca-v1"`) is
106/// absorbed first to keep the SHAKE squeeze stream disjoint from
107/// any other SHAKE use elsewhere in the algorithm.
108///
109/// All ML-DSA share / shuffle randomness for one signature flows
110/// from one `ScaRng` instance: the initial mask of `(s1, s2, t0)`,
111/// the shuffled-NTT permutations, and the per-rejection-iteration
112/// `MaskedPoly::refresh()` calls. This guarantees a single coherent
113/// stream that the test vectors can reproduce.
114///
115/// The PRG itself is **not cryptographic in the standard sense** —
116/// it is not seeded from system entropy. Its job is purely to make
117/// internal mask shares unpredictable to a passive side-channel
118/// observer. The actual signature security still derives from FIPS
119/// 204's own randomness (`rnd`), which is mixed into the seed.
120#[cfg(feature = "sca-protected")]
121struct ScaRng {
122    state: KeccakState,
123}
124
125#[cfg(feature = "sca-protected")]
126impl ScaRng {
127    /// Initialize a fresh SCA RNG state from a caller-supplied seed.
128    /// The domain-separation tag is absorbed first so the squeeze
129    /// stream is disjoint from any other SHAKE256 usage.
130    fn from_seed(seed: &[u8]) -> Self {
131        let mut s = sha3::shake256();
132        s.absorb(b"quantica-mldsa-sca-v1");
133        s.absorb(seed);
134        Self { state: s }
135    }
136}
137
138#[cfg(feature = "sca-protected")]
139impl CryptoRng for ScaRng {
140    fn fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), MlDsaError> {
141        self.state.squeeze(dest);
142        Ok(())
143    }
144}
145
146/// Compute the matrix-vector product A_hat * s in the NTT domain.
147///
148/// `a_hat` is a k-by-l matrix of NTT-domain polynomials and `s` is a vector
149/// of l NTT-domain polynomials. The result is a vector of k polynomials.
150///
151/// Used when `low-mem` is **disabled** (default). The full matrix is
152/// pre-expanded in RAM for maximum throughput.
153#[cfg(not(feature = "low-mem"))]
154fn mat_vec_mul(a_hat: &[[[i32; N]; MAX_L]; MAX_K], s: &[[i32; N]], k: usize, l: usize, result: &mut [[i32; N]]) {
155    for i in 0..k {
156        result[i] = [0i32; N];
157        for j in 0..l {
158            let prod = ntt::pointwise_mul(&a_hat[i][j], &s[j]);
159            result[i] = ntt::poly_add(&result[i], &prod);
160        }
161    }
162}
163
164/// Low-memory matrix-vector product: recomputes each `a_hat[i][j]`
165/// polynomial on-the-fly from `rho` via SHAKE128, instead of holding
166/// the full k×l matrix (57 KB for ML-DSA-87) on the stack.
167///
168/// Trade-off: saves **57 KB of stack** at the cost of re-running
169/// SHAKE128 for each polynomial element each time this function is
170/// called. In `sign_internal`, the rejection loop calls this once per
171/// iteration, so the SHAKE overhead is multiplied by the average
172/// rejection count (~4 for ML-DSA-65).
173///
174/// Used when the `low-mem` feature is **enabled**.
175#[cfg(feature = "low-mem")]
176fn mat_vec_mul_lazy(rho: &[u8; 32], s: &[[i32; N]], k: usize, l: usize, result: &mut [[i32; N]]) {
177    for i in 0..k {
178        result[i] = [0i32; N];
179        for j in 0..l {
180            // Recompute a_hat[i][j] from rho (same as expand_a).
181            let a_ij = sample::rej_ntt_poly(rho, j as u8, i as u8);
182            let prod = ntt::pointwise_mul(&a_ij, &s[j]);
183            result[i] = ntt::poly_add(&result[i], &prod);
184        }
185    }
186}
187
188/// NTT of a polynomial vector (in-place, first `len` elements).
189fn ntt_vec(v: &mut [[i32; N]], len: usize) {
190    for poly in v[..len].iter_mut() {
191        ntt::ntt(poly);
192    }
193}
194
195/// Inverse NTT of a polynomial vector (in-place, first `len` elements).
196fn ntt_inv_vec(v: &mut [[i32; N]], len: usize) {
197    for poly in v[..len].iter_mut() {
198        ntt::ntt_inv(poly);
199    }
200}
201
202/// Add two polynomial vectors into `out`. Only processes `len` elements.
203fn vec_add(a: &[[i32; N]], b: &[[i32; N]], out: &mut [[i32; N]], len: usize) {
204    for i in 0..len {
205        out[i] = ntt::poly_add(&a[i], &b[i]);
206    }
207}
208
209/// Subtract two polynomial vectors into `out`. Only processes `len` elements.
210fn vec_sub(a: &[[i32; N]], b: &[[i32; N]], out: &mut [[i32; N]], len: usize) {
211    for i in 0..len {
212        out[i] = ntt::poly_sub(&a[i], &b[i]);
213    }
214}
215
216// =====================================================================
217// low-stack helpers: heap-allocated polynomial vectors
218// =====================================================================
219//
220// When `low-stack` is enabled, the rejection-loop temporaries in
221// sign_internal are allocated on the heap (Vec) with scoped lifetimes
222// and explicit drop() calls to keep the high-water mark low (~23 KB).
223// The downstream functions (vec_add, check_norm_vec, decompose::*)
224// take &[[i32; N]] slices and work unchanged with either stack or heap.
225
226/// Allocate a zero-initialized polynomial vector of `len` polynomials.
227#[cfg(feature = "low-stack")]
228fn poly_vec(len: usize) -> Vec<[i32; N]> {
229    vec![[0i32; N]; len]
230}
231
232/// Check infinity norm of polynomial: all coefficients strictly below `bound`.
233///
234/// Returns `true` iff every coefficient `c` satisfies `|c| < bound`
235/// (i.e., `c ∈ (−bound, bound)`). This implements the **strict**
236/// inequality `||v||∞ < bound` required by FIPS 204 Algorithm 7
237/// step 25 / Algorithm 8 step 15: "if `||z||∞ ≥ γ₁ − β` then
238/// return ⊥".
239fn check_norm(v: &[i32; N], bound: i32) -> bool {
240    for &c in v.iter() {
241        // Bring to centered representation
242        let mut val = mod_q(c);
243        if val > Q / 2 {
244            val -= Q;
245        }
246        // Strict: reject if |val| >= bound
247        if val >= bound || val <= -bound {
248            return false;
249        }
250    }
251    true
252}
253
254/// Check infinity norm of polynomial vector. Only checks first `len` elements.
255fn check_norm_vec(v: &[[i32; N]], bound: i32, len: usize) -> bool {
256    for poly in v[..len].iter() {
257        if !check_norm(poly, bound) {
258            return false;
259        }
260    }
261    true
262}
263
264/// Deterministic key generation from a 32-byte seed.
265///
266/// Implements Algorithm 6 of FIPS 204 (ML-DSA.KeyGen_internal).
267///
268/// Given a 32-byte seed `xi`, derives the public matrix A (via ExpandA),
269/// secret vectors s1 and s2 (via ExpandS), and computes the public key
270/// `pk = (rho, t1)` and secret key `sk = (rho, K, tr, s1, s2, t0)`.
271///
272/// - `xi`: 32-byte random seed.
273///
274/// Returns `(pk, sk)` as byte vectors.
275pub fn keygen_internal<P: Params>(xi: &[u8; 32]) -> (Vec<u8>, Vec<u8>) {
276    let k = P::K;
277    let l = P::L;
278
279    // (rho, rho', K) = H(xi || k || l)
280    let mut h_input = [0u8; 34];
281    h_input[..32].copy_from_slice(xi);
282    h_input[32] = k as u8;
283    h_input[33] = l as u8;
284
285    let mut hash_out = [0u8; 128]; // need 32 + 64 + 32 = 128 bytes
286    let mut state = sha3::shake256();
287    state.absorb(&h_input);
288    state.squeeze(&mut hash_out);
289
290    let mut rho = [0u8; 32];
291    rho.copy_from_slice(&hash_out[..32]);
292    let mut rho_prime = [0u8; 64];
293    rho_prime.copy_from_slice(&hash_out[32..96]);
294    let mut k_seed = [0u8; 32];
295    k_seed.copy_from_slice(&hash_out[96..128]);
296
297    // Generate s1, s2 from rho'
298    let (mut s1, mut s2) = sample::expand_s::<P>(&rho_prime);
299
300    // s1_hat = NTT(s1)
301    let mut s1_hat = s1;
302    ntt_vec(&mut s1_hat, l);
303
304    // t = NTT^{-1}(A-hat * s1_hat) + s2
305    let mut t = [[0i32; N]; MAX_K];
306    #[cfg(not(feature = "low-mem"))]
307    {
308        let a_hat = sample::expand_a::<P>(&rho);
309        mat_vec_mul(&a_hat, &s1_hat, k, l, &mut t);
310    }
311    #[cfg(feature = "low-mem")]
312    mat_vec_mul_lazy(&rho, &s1_hat, k, l, &mut t);
313    ntt_inv_vec(&mut t, k);
314    // t = t + s2 (in-place into t)
315    {
316        let mut tmp = [[0i32; N]; MAX_K];
317        vec_add(&t, &s2, &mut tmp, k);
318        t = tmp;
319    }
320
321    // (t1, t0) = Power2Round(t)
322    let (t1, t0) = encode::power2round_vec(&t, k);
323
324    // Encode public key
325    let pk = encode::pk_encode::<P>(&rho, &t1);
326
327    // tr = H(pk)  (SHAKE256, 64 bytes)
328    let mut tr = [0u8; 64];
329    sha3::shake256_digest(&pk, &mut tr);
330
331    // Encode secret key
332    let sk = encode::sk_encode::<P>(&rho, &k_seed, &tr, &s1, &s2, &t0);
333
334    // Zeroize sensitive data
335    for poly in s1[..l].iter_mut() {
336        for c in poly.iter_mut() {
337            *c = 0;
338        }
339    }
340    for poly in s2[..k].iter_mut() {
341        for c in poly.iter_mut() {
342            *c = 0;
343        }
344    }
345    for byte in rho_prime.iter_mut() {
346        *byte = 0;
347    }
348    for byte in k_seed.iter_mut() {
349        *byte = 0;
350    }
351
352    (pk, sk)
353}
354
355/// Sign a pre-formatted message (deterministic or hedged).
356///
357/// Implements Algorithm 7 of FIPS 204 (ML-DSA.Sign_internal).
358///
359/// This function contains the core rejection sampling loop: candidate
360/// signatures `(z, h)` are generated from a masking vector `y` and the
361/// challenge polynomial `c`, then tested against the norm bounds
362/// `||z||_inf < gamma1 - beta` and `||r0||_inf < gamma2 - beta`. If any
363/// check fails, the counter `kappa` is incremented and a new attempt begins.
364///
365/// - `sk`: encoded secret key bytes.
366/// - `m_prime`: pre-formatted message (e.g., `0x00 || len(ctx) || ctx || msg`).
367/// - `rnd`: 32-byte randomness. Use random bytes for hedged signing or
368///   all-zeros for fully deterministic signing.
369///
370/// # Side-channel countermeasures
371///
372/// With the `sca-protected` Cargo feature enabled (default), this
373/// function activates the additional defences described in the
374/// crate-level documentation:
375///
376/// 1. **Shuffled NTT** on `s1`, `s2`, `t0` — runs once at entry,
377///    via [`super::shuffle::ntt_shuffled`]. Defends against SPA on
378///    the secret-key NTT and disrupts trace alignment for any
379///    later DPA campaign that tries to average aligned traces.
380/// 2. **First-order additive masking** of the NTT-domain secrets,
381///    via [`super::masked::MaskedPoly`]. Each polynomial is split
382///    into two shares mod `q = 8 380 417`; no single intermediate
383///    value reveals the secret to a first-order observer.
384/// 3. **Per-iteration `c·sₓ` multiplications** go through
385///    [`super::masked::masked_pointwise_mul_public`], which
386///    multiplies each share independently by the public challenge
387///    `ĉ`. Because `ĉ` is public, first-order shares are sufficient
388///    — no secret×secret operation is performed.
389/// 4. **Mask refresh after every use**: the share pair is
390///    re-randomized via `MaskedPoly::refresh()` between rejection
391///    iterations, so the same secret never multiplies the same
392///    share twice — defeating higher-order correlation attacks
393///    that would otherwise become available across many rejection
394///    retries on the same key.
395///
396/// All randomness for the SCA layer comes from a deterministic
397/// SHAKE256-based `ScaRng` seeded with `(K ‖ rnd ‖ tr ‖ M')`, so
398/// the function remains deterministic for fixed `rnd`. The masked
399/// path produces signatures **bit-identical** to the unmasked path
400/// — proven by the NIST ACVP siggen vectors, which the SCA build
401/// passes unchanged.
402///
403/// # Errors
404///
405/// Returns [`MlDsaError::InvalidSecretKey`] if `sk` has incorrect length
406/// (checked by the caller in the public API).
407pub fn sign_internal<P: Params>(sk: &[u8], m_prime: &[u8], rnd: &[u8; 32]) -> Result<Vec<u8>, MlDsaError> {
408    let k = P::K;
409    let l = P::L;
410    let gamma1 = P::GAMMA1;
411    let gamma2 = P::GAMMA2;
412    let beta = P::BETA;
413    let omega = P::OMEGA;
414    let c_tilde_len = P::LAMBDA / 4;
415
416    // Decode secret key seeds (128 bytes on stack).
417    //
418    // indexed-sk: decode only rho/K/tr here; the polynomial vectors
419    // s1/s2/t0 are decoded one-at-a-time below, directly into the
420    // NTT-domain destination arrays, avoiding the 23 KB intermediate
421    // tuple that sk_decode() would put on the stack.
422    //
423    // Default: sk_decode() returns the full tuple at once (simpler,
424    // but 23 KB of stack for the return value alone).
425    #[cfg(feature = "indexed-sk")]
426    let (rho, k_seed, tr) = encode::sk_decode_seeds::<P>(sk);
427    #[cfg(not(feature = "indexed-sk"))]
428    let (rho, k_seed, tr, s1, s2, t0) = encode::sk_decode::<P>(sk);
429
430    // ----- ŝ1, ŝ2, t̂0 = NTT(s1, s2, t0) ------------------------------
431    //
432    // This is the most leakage-prone step in Sign.
433    //
434    // indexed-sk: decode each polynomial from the packed sk directly
435    // into the destination slot, then NTT in-place. Only one decoded
436    // polynomial (1 KB) is live at a time instead of the full 23 KB
437    // tuple from sk_decode.
438    //
439    // SCA-protected build: after NTT, each polynomial is additionally
440    // split into masked shares (see below).
441    //
442    // Standard build: straight in-place Montgomery NTT.
443    #[cfg(not(feature = "low-stack"))]
444    let mut s1_hat = {
445        #[cfg(feature = "indexed-sk")]
446        {
447            let mut v = [[0i32; N]; MAX_L];
448            for i in 0..l {
449                encode::sk_decode_s1::<P>(sk, i, &mut v[i]);
450            }
451            v
452        }
453        #[cfg(not(feature = "indexed-sk"))]
454        {
455            s1
456        }
457    };
458    #[cfg(feature = "low-stack")]
459    let mut s1_hat = {
460        let mut v = poly_vec(l);
461        #[cfg(feature = "indexed-sk")]
462        for i in 0..l {
463            encode::sk_decode_s1::<P>(sk, i, &mut v[i]);
464        }
465        #[cfg(not(feature = "indexed-sk"))]
466        for i in 0..l {
467            v[i] = s1[i];
468        }
469        v
470    };
471
472    #[cfg(not(feature = "low-stack"))]
473    let mut s2_hat = {
474        #[cfg(feature = "indexed-sk")]
475        {
476            let mut v = [[0i32; N]; MAX_K];
477            for i in 0..k {
478                encode::sk_decode_s2::<P>(sk, i, &mut v[i]);
479            }
480            v
481        }
482        #[cfg(not(feature = "indexed-sk"))]
483        {
484            s2
485        }
486    };
487    #[cfg(feature = "low-stack")]
488    let mut s2_hat = {
489        let mut v = poly_vec(k);
490        #[cfg(feature = "indexed-sk")]
491        for i in 0..k {
492            encode::sk_decode_s2::<P>(sk, i, &mut v[i]);
493        }
494        #[cfg(not(feature = "indexed-sk"))]
495        for i in 0..k {
496            v[i] = s2[i];
497        }
498        v
499    };
500
501    #[cfg(not(feature = "low-stack"))]
502    let mut t0_hat = {
503        #[cfg(feature = "indexed-sk")]
504        {
505            let mut v = [[0i32; N]; MAX_K];
506            for i in 0..k {
507                encode::sk_decode_t0::<P>(sk, i, &mut v[i]);
508            }
509            v
510        }
511        #[cfg(not(feature = "indexed-sk"))]
512        {
513            t0
514        }
515    };
516    #[cfg(feature = "low-stack")]
517    let mut t0_hat = {
518        let mut v = poly_vec(k);
519        #[cfg(feature = "indexed-sk")]
520        for i in 0..k {
521            encode::sk_decode_t0::<P>(sk, i, &mut v[i]);
522        }
523        #[cfg(not(feature = "indexed-sk"))]
524        for i in 0..k {
525            v[i] = t0[i];
526        }
527        v
528    };
529    #[cfg(feature = "sca-protected")]
530    let (mut s1_hat_m, mut s2_hat_m, mut t0_hat_m, mut sca_rng) = {
531        // Seed the SCA RNG from K ‖ rnd ‖ tr ‖ M'. K and rnd give us
532        // the FIPS 204 hedged-signing entropy; tr and M' bind the
533        // share stream to this particular (key, message) pair so two
534        // signatures over different inputs use uncorrelated shares
535        // even when rnd = 0 (deterministic / ACVP test mode).
536        let mut sca_seed = [0u8; 64];
537        {
538            let mut h = sha3::shake256();
539            h.absorb(b"quantica-mldsa-sca-seed-v1");
540            h.absorb(&k_seed);
541            h.absorb(rnd);
542            h.absorb(&tr);
543            h.absorb(m_prime);
544            h.squeeze(&mut sca_seed);
545        }
546        let mut rng = ScaRng::from_seed(&sca_seed);
547
548        // Step 1 — SPA defence: NTT each secret polynomial through
549        // the Fisher-Yates shuffled NTT, drawing fresh per-level and
550        // per-group permutations from the SCA RNG.
551        for i in 0..l {
552            shuffle::ntt_shuffled(&mut s1_hat[i], &mut rng)?;
553        }
554        for i in 0..k {
555            shuffle::ntt_shuffled(&mut s2_hat[i], &mut rng)?;
556        }
557        for i in 0..k {
558            shuffle::ntt_shuffled(&mut t0_hat[i], &mut rng)?;
559        }
560
561        // Step 2 — DPA defence: split each NTT-domain secret into two
562        // additive shares mod q. The MaskedPoly::zero() initializer
563        // is a stack-resident no-allocation array fill; the real
564        // shares are written immediately below by MaskedPoly::mask.
565        let mut s1m: [MaskedPoly; MAX_L] = core::array::from_fn(|_| MaskedPoly::zero());
566        let mut s2m: [MaskedPoly; MAX_K] = core::array::from_fn(|_| MaskedPoly::zero());
567        let mut t0m: [MaskedPoly; MAX_K] = core::array::from_fn(|_| MaskedPoly::zero());
568        for i in 0..l {
569            s1m[i] = MaskedPoly::mask(&s1_hat[i], &mut rng)?;
570        }
571        for i in 0..k {
572            s2m[i] = MaskedPoly::mask(&s2_hat[i], &mut rng)?;
573        }
574        for i in 0..k {
575            t0m[i] = MaskedPoly::mask(&t0_hat[i], &mut rng)?;
576        }
577
578        // Step 3 — wipe the unmasked NTT-domain buffers. From this
579        // point on the secret only exists as `(share0, share1)` pairs;
580        // any side-channel observation of `s1_hat[i]` etc. yields zero
581        // information about the underlying coefficients.
582        for i in 0..l {
583            s1_hat[i] = [0i32; N];
584        }
585        for i in 0..k {
586            s2_hat[i] = [0i32; N];
587        }
588        for i in 0..k {
589            t0_hat[i] = [0i32; N];
590        }
591        (s1m, s2m, t0m, rng)
592    };
593    #[cfg(not(feature = "sca-protected"))]
594    {
595        // compressed-challenge: secrets stay in time domain for
596        // schoolbook multiplication. No NTT needed.
597        // small-secret: s1/s2 are converted to SmallPoly and NTT'd
598        // via the i16 Kyber NTT instead. t0 still uses i32 NTT
599        // (coefficients too large for i16).
600        #[cfg(not(any(feature = "compressed-challenge", feature = "small-secret")))]
601        {
602            ntt_vec(&mut s1_hat, l);
603            ntt_vec(&mut s2_hat, k);
604            ntt_vec(&mut t0_hat, k);
605        }
606        #[cfg(all(feature = "small-secret", not(feature = "compressed-challenge")))]
607        {
608            // t0 still needs i32 NTT (coefficients up to 4096).
609            ntt_vec(&mut t0_hat, k);
610            // s1/s2 are converted to SmallPoly below; we don't NTT the i32 versions.
611        }
612    }
613
614    // small-secret: convert s1/s2 to i16 SmallPoly and NTT via Kyber NTT.
615    // The i32 s1_hat/s2_hat arrays are kept for any non-small-secret
616    // code paths but are effectively unused when small-secret is on.
617    #[cfg(feature = "small-secret")]
618    let (s1_small, s2_small) = {
619        let mut s1s: [SmallPoly; MAX_L] = core::array::from_fn(|_| SmallPoly::zero());
620        let mut s2s: [SmallPoly; MAX_K] = core::array::from_fn(|_| SmallPoly::zero());
621        for i in 0..l {
622            s1s[i] = SmallPoly::from_i32(&s1_hat[i]);
623            smallpoly::small_ntt(&mut s1s[i]);
624        }
625        for i in 0..k {
626            s2s[i] = SmallPoly::from_i32(&s2_hat[i]);
627            smallpoly::small_ntt(&mut s2s[i]);
628        }
629        (s1s, s2s)
630    };
631
632    // A-hat = ExpandA(rho)
633    // Default: full matrix on stack (57 KB). Low-mem: recomputed on-the-fly.
634    #[cfg(not(feature = "low-mem"))]
635    let a_hat = sample::expand_a::<P>(&rho);
636
637    // mu = H(tr || M')
638    let mut mu = [0u8; 64];
639    {
640        let mut state = sha3::shake256();
641        state.absorb(&tr);
642        state.absorb(m_prime);
643        state.squeeze(&mut mu);
644    }
645
646    // rho'' = H(K || rnd || mu)
647    let mut rho_double_prime = [0u8; 64];
648    {
649        let mut state = sha3::shake256();
650        state.absorb(&k_seed);
651        state.absorb(rnd);
652        state.absorb(&mu);
653        state.squeeze(&mut rho_double_prime);
654    }
655
656    let mut kappa: u16 = 0;
657
658    loop {
659        // T1-A — refresh the persistent masked-secret-poly shares at
660        // the **start** of every rejection iteration, before any
661        // operation on them (Hermelink-Ning-Petri 2025/276 §4).
662        // `s1_hat_m`, `s2_hat_m`, `t0_hat_m` survive across all
663        // iterations (declared at line 530); without per-iteration
664        // refresh, higher-order DPA aggregating traces over multiple
665        // iterations sees correlated share pairs. Cost is one
666        // `MaskedPoly::refresh` per polynomial per iteration —
667        // identical to the previous end-of-cs/ct refresh placement;
668        // KAT output bytes are unchanged because the mask cancels in
669        // every `unmask()`.
670        #[cfg(feature = "sca-protected")]
671        {
672            for i in 0..l {
673                s1_hat_m[i].refresh(&mut sca_rng)?;
674            }
675            for i in 0..k {
676                s2_hat_m[i].refresh(&mut sca_rng)?;
677            }
678            for i in 0..k {
679                t0_hat_m[i].refresh(&mut sca_rng)?;
680            }
681        }
682
683        // y = ExpandMask(rho'', kappa)
684        //
685        // sca-masked-y: sample y directly as arithmetic shares from
686        // SHAKE256 (MaskedPoly::sample_expand_mask), keep it masked
687        // through NTT + mat_vec_mul + iNTT. Only unmask y and w at
688        // the end of the linear ops: w is about to be published via
689        // w1 in c_tilde, and y is recoverable from z = y + cs1 in
690        // the final signature anyway.
691        //
692        // Default: sample y in clear via expand_mask.
693        #[cfg(not(feature = "sca-masked-y"))]
694        let y = sample::expand_mask::<P>(&rho_double_prime, kappa);
695
696        #[cfg(feature = "sca-masked-y")]
697        let (y, w_precomputed) = {
698            // Full masking pipeline: y stays split into two arithmetic
699            // shares from sampling through NTT, A·y, and iNTT. Only
700            // once w reaches its "about-to-be-published" form do we
701            // unmask y and w together.
702            //
703            // Countermeasure references:
704            //   ePrint 2025/276 — Hermelink–Ning–Petri, DPA on y
705            //   ePrint 2025/582 — Rejected-signature timing leak
706
707            // 1. Sample y as arithmetic shares from SHAKE256. The
708            //    unmasked coefficient value only transits through CPU
709            //    registers — never written to RAM.
710            let mut y_m: [masked::MaskedPoly; MAX_L] = core::array::from_fn(|_| masked::MaskedPoly::zero());
711            for r in 0..l {
712                y_m[r] = masked::MaskedPoly::sample_expand_mask(
713                    &rho_double_prime,
714                    kappa + r as u16,
715                    gamma1,
716                    P::BITLEN_GAMMA1_MINUS1,
717                );
718            }
719
720            // 2. Masked NTT into y_hat_m — y_m is preserved for the
721            //    later time-domain unmask (needed by z = y + cs1).
722            let mut y_hat_m: [masked::MaskedPoly; MAX_L] = core::array::from_fn(|_| masked::MaskedPoly::zero());
723            for r in 0..l {
724                y_hat_m[r].share0 = y_m[r].share0;
725                y_hat_m[r].share1 = y_m[r].share1;
726                masked::masked_ntt(&mut y_hat_m[r]);
727            }
728
729            // 3. Masked A · y_hat → w_m (NTT domain, masked).
730            //    A is public; the matrix multiplication touches each
731            //    share independently.
732            let mut w_m: [masked::MaskedPoly; MAX_K] = core::array::from_fn(|_| masked::MaskedPoly::zero());
733            #[cfg(not(feature = "low-mem"))]
734            masked::masked_mat_vec_mul(&a_hat, &y_hat_m, k, l, &mut w_m);
735            #[cfg(feature = "low-mem")]
736            masked::masked_mat_vec_mul_lazy(&rho, &y_hat_m, k, l, &mut w_m);
737
738            for r in 0..l {
739                y_hat_m[r].zeroize();
740            }
741
742            // 4. Masked iNTT on each share.
743            for i in 0..k {
744                masked::masked_ntt_inv(&mut w_m[i]);
745            }
746
747            // 5. Unmask w — w1 = HighBits(w) is public (it ends up in
748            //    c_tilde). Output in [0, q-1]; `decompose` handles that.
749            let mut w_tmp = [[0i32; N]; MAX_K];
750            for i in 0..k {
751                w_tmp[i] = w_m[i].unmask();
752                w_m[i].zeroize();
753            }
754
755            // 6. Unmask y to centered (-gamma1, gamma1] time domain,
756            //    matching the default `expand_mask` output range.
757            let mut y_out = [[0i32; N]; MAX_L];
758            for r in 0..l {
759                let um = y_m[r].unmask();
760                for n in 0..N {
761                    let mut v = um[n];
762                    if v > Q / 2 {
763                        v -= Q;
764                    }
765                    y_out[r][n] = v;
766                }
767                y_m[r].zeroize();
768            }
769            (y_out, w_tmp)
770        };
771
772        // ============================================================
773        // Rejection-loop body.
774        //
775        // low-stack build: temporary polynomial vectors (w, w1, cs1,
776        // cs2, w_minus_cs2, r0, ct0, neg_ct0, w_cs2_ct0) are
777        // heap-allocated via Vec with scoped lifetimes and explicit
778        // drop() calls. Only ~23 KB of heap is live at peak instead
779        // of ~96 KB of stack.
780        //
781        // Default build: everything on the stack as fixed arrays.
782        // ============================================================
783
784        // w = NTT^{-1}(A-hat * NTT(y))
785        //
786        // Default path: compute y_hat = NTT(y), then w_tmp = iNTT(A·y_hat).
787        // sca-masked-y: w was already computed in the masked block
788        // above (w_precomputed) — the unmasked y was never in RAM.
789
790        // --- Compute w and w1, then derive c_tilde ---------------
791        //
792        // compressed-poly: after iNTT(w), pack w into 3-byte/coeff
793        // compressed form (−25% RAM), then derive w1 and w_minus_cs2
794        // from the compressed representation.
795        #[cfg(not(feature = "compressed-poly"))]
796        {
797            #[cfg(not(feature = "low-stack"))]
798            let mut _w_full = [[0i32; N]; MAX_K];
799            #[cfg(feature = "low-stack")]
800            let mut _w_full = poly_vec(k);
801            // (assigned below, used via w_ref)
802        }
803
804        // Compute w into a temporary full-poly buffer, then either
805        // keep it (default) or compress it (compressed-poly).
806        #[cfg(not(feature = "sca-masked-y"))]
807        let mut w_tmp = {
808            let mut y_hat = y;
809            ntt_vec(&mut y_hat, l);
810            let mut wt = [[0i32; N]; MAX_K];
811            #[cfg(not(feature = "low-mem"))]
812            mat_vec_mul(&a_hat, &y_hat, k, l, &mut wt);
813            #[cfg(feature = "low-mem")]
814            mat_vec_mul_lazy(&rho, &y_hat, k, l, &mut wt);
815            ntt_inv_vec(&mut wt, k);
816            wt
817        };
818        #[cfg(feature = "sca-masked-y")]
819        let mut w_tmp = w_precomputed;
820
821        // compressed-poly: pack w into 3-byte/coeff storage.
822        #[cfg(feature = "compressed-poly")]
823        let w_comp = {
824            let mut wc = compressed::CompressedVecK::new(k);
825            for i in 0..k {
826                // Reduce to [0, q-1] before packing.
827                for c in w_tmp[i].iter_mut() {
828                    *c = mod_q(*c);
829                }
830                wc.pack(i, &w_tmp[i]);
831            }
832            wc
833        };
834
835        // w1 = HighBits(w) — works on the full-poly tmp (before we drop it).
836        #[cfg(not(feature = "low-stack"))]
837        let mut w1 = [[0i32; N]; MAX_K];
838        #[cfg(feature = "low-stack")]
839        let mut w1 = poly_vec(k);
840        decompose::high_bits_vec(&w_tmp, gamma2, &mut w1, k);
841
842        // In non-compressed mode, keep w_tmp as "w" for later vec_sub.
843        // In compressed mode, drop w_tmp (we'll read from w_comp).
844        #[cfg(not(feature = "compressed-poly"))]
845        let w = w_tmp;
846        #[cfg(feature = "compressed-poly")]
847        drop(w_tmp);
848
849        let w1_encoded = encode::w1_encode::<P>(&w1);
850        #[cfg(feature = "low-stack")]
851        drop(w1);
852
853        let mut c_tilde_buf = [0u8; 64];
854        {
855            let mut state = sha3::shake256();
856            state.absorb(&mu);
857            state.absorb(&w1_encoded);
858            state.squeeze(&mut c_tilde_buf[..c_tilde_len]);
859        }
860        let c_tilde = &c_tilde_buf[..c_tilde_len];
861
862        // --- challenge computation --------------------------------
863        //
864        // Default: c_hat = NTT(SampleInBall(c_tilde)), 2 KB stack.
865        // compressed-challenge: compress c into 68 bytes and use
866        // schoolbook multiplication in time domain, saving ~2 KB.
867        let c = sample::sample_in_ball::<P>(c_tilde);
868        #[cfg(not(feature = "compressed-challenge"))]
869        let c_hat = {
870            let mut ch = c;
871            for coeff in ch.iter_mut() {
872                *coeff = mod_q(*coeff);
873            }
874            ntt::ntt(&mut ch);
875            ch
876        };
877        #[cfg(feature = "compressed-challenge")]
878        let c_comp = {
879            let mut cc = [0u8; compressed::COMPRESSED_CHALLENGE_BYTES];
880            compressed::challenge_compress(&mut cc, &c, P::TAU);
881            cc
882        };
883        // small-secret: also convert c to SmallPoly NTT for basemul.
884        #[cfg(feature = "small-secret")]
885        let c_small_ntt = {
886            let mut cs = SmallPoly::from_i32(&c);
887            smallpoly::small_ntt(&mut cs);
888            cs
889        };
890
891        // ============================================================
892        // union-buffer path: single 1 KB workspace reused per poly.
893        // Processes L + K iterations sequentially, only z and h persist.
894        // ============================================================
895        #[cfg(feature = "union-buffer")]
896        {
897            let mut z = [[0i32; N]; MAX_L];
898            let mut tmp = [0i32; N];
899            let mut rejected = false;
900
901            // Phase 1: z[i] = y[i] + c*s1[i]
902            for l_idx in 0..l {
903                #[cfg(all(not(feature = "compressed-challenge"), not(feature = "small-secret")))]
904                {
905                    tmp = ntt::pointwise_mul(&c_hat, &s1_hat[l_idx]);
906                    ntt::ntt_inv(&mut tmp);
907                }
908                #[cfg(feature = "compressed-challenge")]
909                {
910                    tmp = [0i32; N];
911                    compressed::schoolbook_mul_add(&mut tmp, &c_comp, &s1_hat[l_idx], P::TAU);
912                }
913                #[cfg(all(feature = "small-secret", not(feature = "compressed-challenge")))]
914                {
915                    tmp = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s1_small[l_idx]);
916                }
917                z[l_idx] = ntt::poly_add(&y[l_idx], &tmp);
918            }
919            if !check_norm_vec(&z, gamma1 - beta, l) {
920                kappa += l as u16;
921                continue;
922            }
923
924            // Phase 2: per k_idx — cs2, r0, ct0, hint
925            let mut h = [[0i32; N]; MAX_K];
926            let mut total_hints = 0usize;
927            let mut wbuf = [0i32; N];
928
929            for k_idx in 0..k {
930                if rejected {
931                    break;
932                }
933                // cs2 → tmp
934                #[cfg(all(not(feature = "compressed-challenge"), not(feature = "small-secret")))]
935                {
936                    tmp = ntt::pointwise_mul(&c_hat, &s2_hat[k_idx]);
937                    ntt::ntt_inv(&mut tmp);
938                }
939                #[cfg(feature = "compressed-challenge")]
940                {
941                    tmp = [0i32; N];
942                    compressed::schoolbook_mul_add(&mut tmp, &c_comp, &s2_hat[k_idx], P::TAU);
943                }
944                #[cfg(all(feature = "small-secret", not(feature = "compressed-challenge")))]
945                {
946                    tmp = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s2_small[k_idx]);
947                }
948
949                // wbuf = w[k_idx] - cs2
950                #[cfg(not(feature = "compressed-poly"))]
951                for j in 0..N {
952                    wbuf[j] = w[k_idx][j] - tmp[j];
953                }
954                #[cfg(feature = "compressed-poly")]
955                w_comp.sub_into(k_idx, &tmp, &mut wbuf);
956
957                // r0 check in tmp
958                for j in 0..N {
959                    tmp[j] = decompose::low_bits(wbuf[j], gamma2);
960                }
961                if !check_norm(&tmp, gamma2 - beta) {
962                    rejected = true;
963                    continue;
964                }
965
966                // ct0 → tmp
967                #[cfg(not(feature = "compressed-challenge"))]
968                {
969                    tmp = ntt::pointwise_mul(&c_hat, &t0_hat[k_idx]);
970                    ntt::ntt_inv(&mut tmp);
971                }
972                #[cfg(feature = "compressed-challenge")]
973                {
974                    tmp = [0i32; N];
975                    compressed::schoolbook_mul_add(&mut tmp, &c_comp, &t0_hat[k_idx], P::TAU);
976                }
977
978                if !check_norm(&tmp, gamma2) {
979                    rejected = true;
980                    continue;
981                }
982
983                // hint for this k_idx
984                for j in 0..N {
985                    h[k_idx][j] = decompose::make_hint(mod_q(-tmp[j]), wbuf[j] + tmp[j], gamma2);
986                    if h[k_idx][j] == 1 {
987                        total_hints += 1;
988                    }
989                }
990            }
991
992            if rejected || total_hints > omega {
993                kappa += l as u16;
994                continue;
995            }
996
997            // Center z and encode
998            for poly in z[..l].iter_mut() {
999                for c in poly.iter_mut() {
1000                    *c = mod_q(*c);
1001                    if *c > Q / 2 {
1002                        *c -= Q;
1003                    }
1004                }
1005            }
1006            let sig = encode::sig_encode::<P>(c_tilde, &z, &h);
1007            return Ok(sig);
1008        }
1009
1010        // ============================================================
1011        // Standard path (non-union-buffer)
1012        // ============================================================
1013        #[cfg(not(feature = "union-buffer"))]
1014        {
1015            // --- cs1 = ĉ · ŝ1, then z = y + cs1 --------------------
1016            #[cfg(not(feature = "low-stack"))]
1017            let mut cs1 = [[0i32; N]; MAX_L];
1018            #[cfg(feature = "low-stack")]
1019            let mut cs1 = poly_vec(l);
1020
1021            #[cfg(feature = "sca-protected")]
1022            for i in 0..l {
1023                let prod = masked::masked_pointwise_mul_public(&s1_hat_m[i], &c_hat);
1024                cs1[i] = prod.unmask();
1025                ntt::ntt_inv(&mut cs1[i]);
1026                // refresh of s1_hat_m happens at the head of the
1027                // next rejection iteration (T1-A).
1028            }
1029            #[cfg(all(
1030                not(feature = "sca-protected"),
1031                not(feature = "compressed-challenge"),
1032                not(feature = "small-secret")
1033            ))]
1034            for i in 0..l {
1035                cs1[i] = ntt::pointwise_mul(&c_hat, &s1_hat[i]);
1036                ntt::ntt_inv(&mut cs1[i]);
1037            }
1038            #[cfg(all(not(feature = "sca-protected"), feature = "compressed-challenge"))]
1039            for i in 0..l {
1040                cs1[i] = [0i32; N];
1041                compressed::schoolbook_mul_add(&mut cs1[i], &c_comp, &s1_hat[i], P::TAU);
1042            }
1043            #[cfg(all(
1044                not(feature = "sca-protected"),
1045                feature = "small-secret",
1046                not(feature = "compressed-challenge")
1047            ))]
1048            for i in 0..l {
1049                cs1[i] = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s1_small[i]);
1050            }
1051
1052            #[cfg(not(feature = "low-stack"))]
1053            let mut z = [[0i32; N]; MAX_L];
1054            #[cfg(feature = "low-stack")]
1055            let mut z = poly_vec(l);
1056            vec_add(&y, &cs1, &mut z, l);
1057            // cs1 no longer needed.
1058            #[cfg(feature = "low-stack")]
1059            drop(cs1);
1060
1061            // --- cs2, w_minus_cs2, r0 --------------------------------
1062            #[cfg(not(feature = "low-stack"))]
1063            let mut cs2 = [[0i32; N]; MAX_K];
1064            #[cfg(feature = "low-stack")]
1065            let mut cs2 = poly_vec(k);
1066
1067            #[cfg(feature = "sca-protected")]
1068            for i in 0..k {
1069                let prod = masked::masked_pointwise_mul_public(&s2_hat_m[i], &c_hat);
1070                cs2[i] = prod.unmask();
1071                ntt::ntt_inv(&mut cs2[i]);
1072                // refresh of s2_hat_m happens at the head of the
1073                // next rejection iteration (T1-A).
1074            }
1075            #[cfg(all(
1076                not(feature = "sca-protected"),
1077                not(feature = "compressed-challenge"),
1078                not(feature = "small-secret")
1079            ))]
1080            for i in 0..k {
1081                cs2[i] = ntt::pointwise_mul(&c_hat, &s2_hat[i]);
1082                ntt::ntt_inv(&mut cs2[i]);
1083            }
1084            #[cfg(all(not(feature = "sca-protected"), feature = "compressed-challenge"))]
1085            for i in 0..k {
1086                cs2[i] = [0i32; N];
1087                compressed::schoolbook_mul_add(&mut cs2[i], &c_comp, &s2_hat[i], P::TAU);
1088            }
1089            #[cfg(all(
1090                not(feature = "sca-protected"),
1091                feature = "small-secret",
1092                not(feature = "compressed-challenge")
1093            ))]
1094            for i in 0..k {
1095                cs2[i] = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s2_small[i]);
1096            }
1097
1098            #[cfg(not(feature = "low-stack"))]
1099            let mut w_minus_cs2 = [[0i32; N]; MAX_K];
1100            #[cfg(feature = "low-stack")]
1101            let mut w_minus_cs2 = poly_vec(k);
1102            #[cfg(not(feature = "compressed-poly"))]
1103            vec_sub(&w, &cs2, &mut w_minus_cs2, k);
1104            #[cfg(feature = "compressed-poly")]
1105            for i in 0..k {
1106                w_comp.sub_into(i, &cs2[i], &mut w_minus_cs2[i]);
1107            }
1108            // cs2 and w/w_comp no longer needed for the norm checks.
1109            #[cfg(feature = "low-stack")]
1110            drop(cs2);
1111            #[cfg(all(feature = "low-stack", not(feature = "compressed-poly")))]
1112            drop(w);
1113            #[cfg(feature = "compressed-poly")]
1114            drop(w_comp);
1115
1116            #[cfg(not(feature = "low-stack"))]
1117            let mut r0 = [[0i32; N]; MAX_K];
1118            #[cfg(feature = "low-stack")]
1119            let mut r0 = poly_vec(k);
1120            decompose::low_bits_vec(&w_minus_cs2, gamma2, &mut r0, k);
1121
1122            // Norm checks. Standard build: early-abort for performance.
1123            // sca-ct-rejection build: collect all flags and decide at end.
1124            #[cfg(not(feature = "sca-ct-rejection"))]
1125            {
1126                if !check_norm_vec(&z, gamma1 - beta, l) {
1127                    kappa += l as u16;
1128                    continue;
1129                }
1130                if !check_norm_vec(&r0, gamma2 - beta, k) {
1131                    kappa += l as u16;
1132                    continue;
1133                }
1134            }
1135            #[cfg(feature = "sca-ct-rejection")]
1136            let mut _reject_flag = {
1137                let z_ok = check_norm_vec(&z, gamma1 - beta, l);
1138                let r0_ok = check_norm_vec(&r0, gamma2 - beta, k);
1139                !(z_ok & r0_ok)
1140            };
1141            // r0 no longer needed.
1142            #[cfg(feature = "low-stack")]
1143            drop(r0);
1144
1145            // --- ct0, hint computation --------------------------------
1146            #[cfg(not(feature = "low-stack"))]
1147            let mut ct0 = [[0i32; N]; MAX_K];
1148            #[cfg(feature = "low-stack")]
1149            let mut ct0 = poly_vec(k);
1150
1151            #[cfg(feature = "sca-protected")]
1152            for i in 0..k {
1153                let prod = masked::masked_pointwise_mul_public(&t0_hat_m[i], &c_hat);
1154                ct0[i] = prod.unmask();
1155                ntt::ntt_inv(&mut ct0[i]);
1156                // refresh of t0_hat_m happens at the head of the
1157                // next rejection iteration (T1-A).
1158            }
1159            #[cfg(all(not(feature = "sca-protected"), not(feature = "compressed-challenge")))]
1160            for i in 0..k {
1161                ct0[i] = ntt::pointwise_mul(&c_hat, &t0_hat[i]);
1162                ntt::ntt_inv(&mut ct0[i]);
1163            }
1164            #[cfg(all(not(feature = "sca-protected"), feature = "compressed-challenge"))]
1165            for i in 0..k {
1166                ct0[i] = [0i32; N];
1167                compressed::schoolbook_mul_add(&mut ct0[i], &c_comp, &t0_hat[i], P::TAU);
1168            }
1169
1170            // Check ||ct0||_inf < gamma2
1171            #[cfg(not(feature = "sca-ct-rejection"))]
1172            {
1173                if !check_norm_vec(&ct0, gamma2, k) {
1174                    kappa += l as u16;
1175                    continue;
1176                }
1177            }
1178            #[cfg(feature = "sca-ct-rejection")]
1179            {
1180                _reject_flag |= !check_norm_vec(&ct0, gamma2, k);
1181            }
1182
1183            // h = MakeHint(-ct0, w_minus_cs2 + ct0)
1184            #[cfg(not(feature = "low-stack"))]
1185            let mut w_cs2_ct0 = [[0i32; N]; MAX_K];
1186            #[cfg(feature = "low-stack")]
1187            let mut w_cs2_ct0 = poly_vec(k);
1188            vec_add(&w_minus_cs2, &ct0, &mut w_cs2_ct0, k);
1189
1190            #[cfg(not(feature = "low-stack"))]
1191            let mut neg_ct0 = [[0i32; N]; MAX_K];
1192            #[cfg(feature = "low-stack")]
1193            let mut neg_ct0 = poly_vec(k);
1194            for i in 0..k {
1195                for j in 0..N {
1196                    neg_ct0[i][j] = mod_q(-ct0[i][j]);
1197                }
1198            }
1199            // ct0 and w_minus_cs2 no longer needed.
1200            #[cfg(feature = "low-stack")]
1201            {
1202                drop(ct0);
1203                drop(w_minus_cs2);
1204            }
1205
1206            let (h, num_ones) = decompose::make_hint_vec(&neg_ct0, &w_cs2_ct0, gamma2, k);
1207            // neg_ct0 and w_cs2_ct0 no longer needed.
1208            #[cfg(feature = "low-stack")]
1209            {
1210                drop(neg_ct0);
1211                drop(w_cs2_ct0);
1212            }
1213
1214            #[cfg(not(feature = "sca-ct-rejection"))]
1215            {
1216                if num_ones > omega {
1217                    kappa += l as u16;
1218                    continue;
1219                }
1220            }
1221            #[cfg(feature = "sca-ct-rejection")]
1222            {
1223                _reject_flag |= num_ones > omega;
1224                if _reject_flag {
1225                    kappa += l as u16;
1226                    continue;
1227                }
1228            }
1229
1230            // Center z coefficients to [-gamma1+1, gamma1] before encoding
1231            for poly in z[..l].iter_mut() {
1232                for c in poly.iter_mut() {
1233                    *c = mod_q(*c);
1234                    if *c > Q / 2 {
1235                        *c -= Q;
1236                    }
1237                }
1238            }
1239
1240            // Encode signature
1241            let sig = encode::sig_encode::<P>(c_tilde, &z, &h);
1242            return Ok(sig);
1243        } // end #[cfg(not(feature = "union-buffer"))]
1244    }
1245}
1246
1247/// Verify a signature against a pre-formatted message.
1248///
1249/// Implements Algorithm 8 of FIPS 204 (ML-DSA.Verify_internal).
1250///
1251/// Recomputes the commitment w1' from the public key, signature components
1252/// (c_tilde, z, h), and the message hash mu. Verification succeeds when the
1253/// recomputed commitment hash matches the c_tilde embedded in the signature.
1254///
1255/// - `pk`: encoded public key (must be `P::PK_LEN` bytes).
1256/// - `m_prime`: pre-formatted message.
1257/// - `sig`: encoded signature (must be `P::SIG_LEN` bytes).
1258///
1259/// Returns `Ok(true)` if the signature is valid, `Ok(false)` otherwise.
1260///
1261/// # Errors
1262///
1263/// - [`MlDsaError::InvalidPublicKey`] if `pk` has the wrong length.
1264/// - [`MlDsaError::InvalidSignature`] if `sig` has the wrong length.
1265pub fn verify_internal<P: Params>(pk: &[u8], m_prime: &[u8], sig: &[u8]) -> Result<bool, MlDsaError> {
1266    let k = P::K;
1267    let l = P::L;
1268    let gamma1 = P::GAMMA1;
1269    let gamma2 = P::GAMMA2;
1270    let beta = P::BETA;
1271    let omega = P::OMEGA;
1272    let c_tilde_len = P::LAMBDA / 4;
1273
1274    if pk.len() != P::PK_LEN {
1275        return Err(MlDsaError::InvalidPublicKey);
1276    }
1277    if sig.len() != P::SIG_LEN {
1278        return Err(MlDsaError::InvalidSignature);
1279    }
1280
1281    // Decode public key
1282    let (rho, t1) = encode::pk_decode::<P>(pk);
1283
1284    // tr = H(pk) (64 bytes)
1285    let mut tr = [0u8; 64];
1286    sha3::shake256_digest(pk, &mut tr);
1287
1288    // Decode signature
1289    let (c_tilde, z, h) = match encode::sig_decode::<P>(sig) {
1290        Some(x) => x,
1291        None => return Ok(false),
1292    };
1293
1294    // Check ||z||_inf < gamma1 - beta
1295    if !check_norm_vec(&z, gamma1 - beta, l) {
1296        return Ok(false);
1297    }
1298
1299    // A-hat = ExpandA(rho)
1300    #[cfg(not(feature = "low-mem"))]
1301    let a_hat = sample::expand_a::<P>(&rho);
1302
1303    // mu = H(tr || M')
1304    let mut mu = [0u8; 64];
1305    {
1306        let mut state = sha3::shake256();
1307        state.absorb(&tr);
1308        state.absorb(m_prime);
1309        state.squeeze(&mut mu);
1310    }
1311
1312    // c = SampleInBall(c_tilde)
1313    let mut c = sample::sample_in_ball::<P>(&c_tilde);
1314    for coeff in c.iter_mut() {
1315        *coeff = mod_q(*coeff);
1316    }
1317    let mut c_hat = c;
1318    ntt::ntt(&mut c_hat);
1319
1320    // z_hat = NTT(z)
1321    let mut z_hat = z;
1322    ntt_vec(&mut z_hat, l);
1323
1324    // w'_approx = NTT^{-1}(A-hat * z_hat - c_hat * NTT(t1 * 2^d))
1325    // First compute NTT(t1 * 2^d)
1326    let mut t1_2d_hat = [[0i32; N]; MAX_K];
1327    for i in 0..k {
1328        for j in 0..N {
1329            t1_2d_hat[i][j] = mod_q(t1[i][j] * (1 << D));
1330        }
1331        ntt::ntt(&mut t1_2d_hat[i]);
1332    }
1333
1334    // A-hat * z_hat
1335    let mut az = [[0i32; N]; MAX_K];
1336    #[cfg(not(feature = "low-mem"))]
1337    mat_vec_mul(&a_hat, &z_hat, k, l, &mut az);
1338    #[cfg(feature = "low-mem")]
1339    mat_vec_mul_lazy(&rho, &z_hat, k, l, &mut az);
1340
1341    // c_hat * t1_2d_hat (component-wise)
1342    let mut ct1 = [[0i32; N]; MAX_K];
1343    for i in 0..k {
1344        ct1[i] = ntt::pointwise_mul(&c_hat, &t1_2d_hat[i]);
1345    }
1346
1347    // w'_approx = NTT^{-1}(az - ct1)
1348    let mut w_approx = [[0i32; N]; MAX_K];
1349    vec_sub(&az, &ct1, &mut w_approx, k);
1350    ntt_inv_vec(&mut w_approx, k);
1351
1352    // w1' = UseHint(h, w'_approx)
1353    let w1_prime = decompose::use_hint_vec(&h, &w_approx, gamma2, k);
1354
1355    // Recompute c_tilde' = H(mu || w1Encode(w1'))
1356    let w1_encoded = encode::w1_encode::<P>(&w1_prime);
1357    let mut c_tilde_prime = vec![0u8; c_tilde_len];
1358    {
1359        let mut state = sha3::shake256();
1360        state.absorb(&mu);
1361        state.absorb(&w1_encoded);
1362        state.squeeze(&mut c_tilde_prime);
1363    }
1364
1365    // Check c_tilde == c_tilde'
1366    // Also verify hint weight
1367    let mut hint_count = 0usize;
1368    for i in 0..k {
1369        for &c in h[i].iter() {
1370            hint_count += c as usize;
1371        }
1372    }
1373    if hint_count > omega {
1374        return Ok(false);
1375    }
1376
1377    Ok(c_tilde == c_tilde_prime)
1378}
1379
1380/// Generate an ML-DSA key pair.
1381///
1382/// Implements Algorithm 1 of FIPS 204 (ML-DSA.KeyGen). Draws 32 random
1383/// bytes from `rng` and delegates to [`keygen_internal`].
1384///
1385/// Returns `(pk, sk)` as byte vectors.
1386///
1387/// # Errors
1388///
1389/// Returns [`MlDsaError::RngFailure`] if the RNG cannot provide bytes.
1390pub fn keygen<P: Params>(rng: &mut dyn CryptoRng) -> Result<(Vec<u8>, Vec<u8>), MlDsaError> {
1391    let mut xi = [0u8; 32];
1392    rng.fill_bytes(&mut xi)?;
1393    let result = keygen_internal::<P>(&xi);
1394    Ok(result)
1395}
1396
1397/// Sign a message with an optional context string (hedged mode).
1398///
1399/// Implements Algorithm 2 of FIPS 204 (ML-DSA.Sign). Constructs the
1400/// pre-formatted message `M' = 0x00 || len(ctx) || ctx || msg`, draws 32
1401/// random bytes for hedged signing, and calls `sign_internal`.
1402///
1403/// - `sk`: secret key (must be `P::SK_LEN` bytes).
1404/// - `msg`: message to sign.
1405/// - `ctx`: optional context string (at most 255 bytes).
1406/// - `rng`: source of randomness for the hedged nonce.
1407///
1408/// # Errors
1409///
1410/// - [`MlDsaError::ContextTooLong`] if `ctx` exceeds 255 bytes.
1411/// - [`MlDsaError::InvalidSecretKey`] if `sk` has the wrong length.
1412/// - [`MlDsaError::RngFailure`] if the RNG cannot provide bytes.
1413pub fn sign<P: Params>(sk: &[u8], msg: &[u8], ctx: &[u8], rng: &mut dyn CryptoRng) -> Result<Vec<u8>, MlDsaError> {
1414    if ctx.len() > 255 {
1415        return Err(MlDsaError::ContextTooLong);
1416    }
1417    if sk.len() != P::SK_LEN {
1418        return Err(MlDsaError::InvalidSecretKey);
1419    }
1420
1421    // M' = 0x00 || len(ctx) || ctx || M
1422    let mut m_prime = Vec::with_capacity(1 + 1 + ctx.len() + msg.len());
1423    m_prime.push(0x00);
1424    m_prime.push(ctx.len() as u8);
1425    m_prime.extend_from_slice(ctx);
1426    m_prime.extend_from_slice(msg);
1427
1428    // Random bytes for hedged signing
1429    let mut rnd = [0u8; 32];
1430    rng.fill_bytes(&mut rnd)?;
1431
1432    sign_internal::<P>(sk, &m_prime, &rnd)
1433}
1434
1435/// Verify a signature on a message with an optional context string.
1436///
1437/// Implements Algorithm 3 of FIPS 204 (ML-DSA.Verify). Constructs the
1438/// pre-formatted message `M' = 0x00 || len(ctx) || ctx || msg` and
1439/// delegates to [`verify_internal`].
1440///
1441/// - `pk`: public key (must be `P::PK_LEN` bytes).
1442/// - `msg`: the signed message.
1443/// - `ctx`: the context string used at signing time (at most 255 bytes).
1444/// - `sig`: the signature (must be `P::SIG_LEN` bytes).
1445///
1446/// Returns `Ok(true)` if the signature is valid, `Ok(false)` otherwise.
1447///
1448/// # Errors
1449///
1450/// - [`MlDsaError::ContextTooLong`] if `ctx` exceeds 255 bytes.
1451/// - [`MlDsaError::InvalidPublicKey`] if `pk` has the wrong length.
1452/// - [`MlDsaError::InvalidSignature`] if `sig` has the wrong length.
1453pub fn verify<P: Params>(pk: &[u8], msg: &[u8], ctx: &[u8], sig: &[u8]) -> Result<bool, MlDsaError> {
1454    if ctx.len() > 255 {
1455        return Err(MlDsaError::ContextTooLong);
1456    }
1457    if pk.len() != P::PK_LEN {
1458        return Err(MlDsaError::InvalidPublicKey);
1459    }
1460    if sig.len() != P::SIG_LEN {
1461        return Err(MlDsaError::InvalidSignature);
1462    }
1463
1464    // M' = 0x00 || len(ctx) || ctx || M
1465    let mut m_prime = Vec::with_capacity(1 + 1 + ctx.len() + msg.len());
1466    m_prime.push(0x00);
1467    m_prime.push(ctx.len() as u8);
1468    m_prime.extend_from_slice(ctx);
1469    m_prime.extend_from_slice(msg);
1470
1471    verify_internal::<P>(pk, &m_prime, sig)
1472}