quantica/ml_dsa/dsa.rs
1//! Core ML-DSA algorithms (FIPS 204, Algorithms 1-8).
2//!
3//! Contains the key generation, signing, and verification routines at both
4//! the public API level (Algorithms 1-3) and the internal/deterministic level
5//! (Algorithms 6-8).
6//!
7//! All internal polynomial vector operations use fixed-size stack arrays
8//! (`[[i32; N]; MAX_K]` / `[[i32; N]; MAX_L]`) to avoid heap allocations.
9//!
10//! # Side-channel countermeasures (`sca-protected` feature)
11//!
12//! When the `sca-protected` Cargo feature is enabled (on by default),
13//! `sign_internal` runs an additional layer of defences on the secret-key
14//! material:
15//!
16//! | Countermeasure | Module | Threat addressed |
17//! |-----------------------|------------------------------|-----------------------------------------------|
18//! | Constant-time arith | always-on | Cache- / branch-based timing attacks |
19//! | Zeroization | always-on | Cold-boot dumps, use-after-free |
20//! | Hedged signing | always-on | Fault-induced nonce reuse (`rnd ≠ 0`) |
21//! | Shuffled NTT | `super::shuffle` (sca) | SPA, trace-alignment for DPA |
22//! | First-order masking | `super::masked` (sca) | First-order DPA, template attacks |
23//! | Mask refresh / hop | `super::masked` (sca) | Inter-iteration share correlation |
24//!
25//! The masking + shuffling layer is deliberately confined to `sign_internal`,
26//! because that is where the secret key `(s1, s2, t0)` is consumed in
27//! polynomial multiplications with values an attacker can influence:
28//!
29//! ```text
30//! ŝ1, ŝ2, t̂0 ← NTT(s1), NTT(s2), NTT(t0) // SPA + DPA target
31//! loop:
32//! ĉ ← NTT(SampleInBall(c̃))
33//! cs1[i] ← ĉ · ŝ1[i] // ×L — DPA target
34//! cs2[i] ← ĉ · ŝ2[i] // ×K — DPA target
35//! ct0[i] ← ĉ · t̂0[i] // ×K — DPA target
36//! ```
37//!
38//! The challenge polynomial `ĉ` is **public** (the verifier recomputes it),
39//! so every secret×public multiplication only needs first-order masking:
40//! `(s₀ + s₁) · ĉ = s₀·ĉ + s₁·ĉ`. There is no secret×secret operation in
41//! Sign that would require second-order shares.
42//!
43//! Mask randomness is drawn from a SHAKE256-based deterministic
44//! `ScaRng` seeded with `(K ‖ rnd ‖ tr ‖ M')`, so that:
45//!
46//! * `sign_internal` keeps a deterministic signature (no `&mut dyn CryptoRng`
47//! parameter), and the NIST ACVP fixed-`rnd = 0` test vectors still match
48//! bit-for-bit;
49//! * different `rnd` values produce independent share streams (hedged
50//! signing entropy is preserved through the SCA layer).
51//!
52//! The standard build (without `sca-protected`) still benefits from the
53//! always-on countermeasures listed above; only the masking + shuffling
54//! defences are conditionally compiled out.
55
56use super::MlDsaError;
57use super::decompose;
58use super::encode;
59use super::ntt::{self, mod_q};
60use super::params::{D, MAX_K, MAX_L, N, Params, Q};
61use super::rng::CryptoRng;
62use super::sample;
63use super::sha3;
64use alloc::vec::Vec;
65
66#[cfg(any(feature = "compressed-poly", feature = "compressed-challenge"))]
67use super::compressed;
68#[cfg(feature = "sca-protected")]
69use super::masked::{self, MaskedPoly};
70
71#[cfg(all(feature = "sca-protected", feature = "compressed-challenge"))]
72compile_error!(
73 "features `sca-protected` and `compressed-challenge` are mutually exclusive: masking requires NTT-domain multiplication, schoolbook operates in time domain"
74);
75
76#[cfg(all(feature = "sca-protected", feature = "small-secret"))]
77compile_error!("features `sca-protected` and `small-secret` are mutually exclusive: masking operates in i32 domain");
78
79#[cfg(all(feature = "sca-protected", feature = "union-buffer"))]
80compile_error!("features `sca-protected` and `union-buffer` are mutually exclusive");
81
82#[cfg(feature = "sca-protected")]
83use super::sha3::KeccakState;
84#[cfg(feature = "sca-protected")]
85use super::shuffle;
86#[cfg(feature = "small-secret")]
87use super::smallpoly::{self, SmallPoly};
88
89/// Deterministic SHAKE256-based randomness source for the SCA layer.
90///
91/// `sign_internal` does not take a `&mut dyn CryptoRng` parameter
92/// (it must stay fully deterministic so that the NIST ACVP fixed-`rnd`
93/// vectors still match bit-for-bit), so the masking and shuffling
94/// modules cannot reach for [`super::OsRng`] either. Instead they
95/// share a per-call `ScaRng` whose seed is derived from
96/// `(K ‖ rnd ‖ tr ‖ M')` via SHAKE256:
97///
98/// * `K` is the secret-key field used by FIPS 204 hedged signing.
99/// * `rnd` is the 32-byte hedged-signing randomness — all-zero in
100/// deterministic / ACVP test mode, fresh entropy otherwise.
101/// * `tr` and `M'` make the seed bind to the public key + message,
102/// so two signatures over different messages produce uncorrelated
103/// share streams even when `rnd = 0`.
104///
105/// A short domain-separation tag (`b"quantica-mldsa-sca-v1"`) is
106/// absorbed first to keep the SHAKE squeeze stream disjoint from
107/// any other SHAKE use elsewhere in the algorithm.
108///
109/// All ML-DSA share / shuffle randomness for one signature flows
110/// from one `ScaRng` instance: the initial mask of `(s1, s2, t0)`,
111/// the shuffled-NTT permutations, and the per-rejection-iteration
112/// `MaskedPoly::refresh()` calls. This guarantees a single coherent
113/// stream that the test vectors can reproduce.
114///
115/// The PRG itself is **not cryptographic in the standard sense** —
116/// it is not seeded from system entropy. Its job is purely to make
117/// internal mask shares unpredictable to a passive side-channel
118/// observer. The actual signature security still derives from FIPS
119/// 204's own randomness (`rnd`), which is mixed into the seed.
120#[cfg(feature = "sca-protected")]
121struct ScaRng {
122 state: KeccakState,
123}
124
125#[cfg(feature = "sca-protected")]
126impl ScaRng {
127 /// Initialize a fresh SCA RNG state from a caller-supplied seed.
128 /// The domain-separation tag is absorbed first so the squeeze
129 /// stream is disjoint from any other SHAKE256 usage.
130 fn from_seed(seed: &[u8]) -> Self {
131 let mut s = sha3::shake256();
132 s.absorb(b"quantica-mldsa-sca-v1");
133 s.absorb(seed);
134 Self { state: s }
135 }
136}
137
138#[cfg(feature = "sca-protected")]
139impl CryptoRng for ScaRng {
140 fn fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), MlDsaError> {
141 self.state.squeeze(dest);
142 Ok(())
143 }
144}
145
146/// Compute the matrix-vector product A_hat * s in the NTT domain.
147///
148/// `a_hat` is a k-by-l matrix of NTT-domain polynomials and `s` is a vector
149/// of l NTT-domain polynomials. The result is a vector of k polynomials.
150///
151/// Used when `low-mem` is **disabled** (default). The full matrix is
152/// pre-expanded in RAM for maximum throughput.
153#[cfg(not(feature = "low-mem"))]
154fn mat_vec_mul(a_hat: &[[[i32; N]; MAX_L]; MAX_K], s: &[[i32; N]], k: usize, l: usize, result: &mut [[i32; N]]) {
155 for i in 0..k {
156 result[i] = [0i32; N];
157 for j in 0..l {
158 let prod = ntt::pointwise_mul(&a_hat[i][j], &s[j]);
159 result[i] = ntt::poly_add(&result[i], &prod);
160 }
161 }
162}
163
164/// Low-memory matrix-vector product: recomputes each `a_hat[i][j]`
165/// polynomial on-the-fly from `rho` via SHAKE128, instead of holding
166/// the full k×l matrix (57 KB for ML-DSA-87) on the stack.
167///
168/// Trade-off: saves **57 KB of stack** at the cost of re-running
169/// SHAKE128 for each polynomial element each time this function is
170/// called. In `sign_internal`, the rejection loop calls this once per
171/// iteration, so the SHAKE overhead is multiplied by the average
172/// rejection count (~4 for ML-DSA-65).
173///
174/// Used when the `low-mem` feature is **enabled**.
175#[cfg(feature = "low-mem")]
176fn mat_vec_mul_lazy(rho: &[u8; 32], s: &[[i32; N]], k: usize, l: usize, result: &mut [[i32; N]]) {
177 for i in 0..k {
178 result[i] = [0i32; N];
179 for j in 0..l {
180 // Recompute a_hat[i][j] from rho (same as expand_a).
181 let a_ij = sample::rej_ntt_poly(rho, j as u8, i as u8);
182 let prod = ntt::pointwise_mul(&a_ij, &s[j]);
183 result[i] = ntt::poly_add(&result[i], &prod);
184 }
185 }
186}
187
188/// NTT of a polynomial vector (in-place, first `len` elements).
189fn ntt_vec(v: &mut [[i32; N]], len: usize) {
190 for poly in v[..len].iter_mut() {
191 ntt::ntt(poly);
192 }
193}
194
195/// Inverse NTT of a polynomial vector (in-place, first `len` elements).
196fn ntt_inv_vec(v: &mut [[i32; N]], len: usize) {
197 for poly in v[..len].iter_mut() {
198 ntt::ntt_inv(poly);
199 }
200}
201
202/// Add two polynomial vectors into `out`. Only processes `len` elements.
203fn vec_add(a: &[[i32; N]], b: &[[i32; N]], out: &mut [[i32; N]], len: usize) {
204 for i in 0..len {
205 out[i] = ntt::poly_add(&a[i], &b[i]);
206 }
207}
208
209/// Subtract two polynomial vectors into `out`. Only processes `len` elements.
210fn vec_sub(a: &[[i32; N]], b: &[[i32; N]], out: &mut [[i32; N]], len: usize) {
211 for i in 0..len {
212 out[i] = ntt::poly_sub(&a[i], &b[i]);
213 }
214}
215
216// =====================================================================
217// low-stack helpers: heap-allocated polynomial vectors
218// =====================================================================
219//
220// When `low-stack` is enabled, the rejection-loop temporaries in
221// sign_internal are allocated on the heap (Vec) with scoped lifetimes
222// and explicit drop() calls to keep the high-water mark low (~23 KB).
223// The downstream functions (vec_add, check_norm_vec, decompose::*)
224// take &[[i32; N]] slices and work unchanged with either stack or heap.
225
226/// Allocate a zero-initialized polynomial vector of `len` polynomials.
227#[cfg(feature = "low-stack")]
228fn poly_vec(len: usize) -> Vec<[i32; N]> {
229 vec![[0i32; N]; len]
230}
231
232/// Check infinity norm of polynomial: all coefficients strictly below `bound`.
233///
234/// Returns `true` iff every coefficient `c` satisfies `|c| < bound`
235/// (i.e., `c ∈ (−bound, bound)`). This implements the **strict**
236/// inequality `||v||∞ < bound` required by FIPS 204 Algorithm 7
237/// step 25 / Algorithm 8 step 15: "if `||z||∞ ≥ γ₁ − β` then
238/// return ⊥".
239fn check_norm(v: &[i32; N], bound: i32) -> bool {
240 for &c in v.iter() {
241 // Bring to centered representation
242 let mut val = mod_q(c);
243 if val > Q / 2 {
244 val -= Q;
245 }
246 // Strict: reject if |val| >= bound
247 if val >= bound || val <= -bound {
248 return false;
249 }
250 }
251 true
252}
253
254/// Check infinity norm of polynomial vector. Only checks first `len` elements.
255fn check_norm_vec(v: &[[i32; N]], bound: i32, len: usize) -> bool {
256 for poly in v[..len].iter() {
257 if !check_norm(poly, bound) {
258 return false;
259 }
260 }
261 true
262}
263
264/// Deterministic key generation from a 32-byte seed.
265///
266/// Implements Algorithm 6 of FIPS 204 (ML-DSA.KeyGen_internal).
267///
268/// Given a 32-byte seed `xi`, derives the public matrix A (via ExpandA),
269/// secret vectors s1 and s2 (via ExpandS), and computes the public key
270/// `pk = (rho, t1)` and secret key `sk = (rho, K, tr, s1, s2, t0)`.
271///
272/// - `xi`: 32-byte random seed.
273///
274/// Returns `(pk, sk)` as byte vectors.
275pub fn keygen_internal<P: Params>(xi: &[u8; 32]) -> (Vec<u8>, Vec<u8>) {
276 let k = P::K;
277 let l = P::L;
278
279 // (rho, rho', K) = H(xi || k || l)
280 let mut h_input = [0u8; 34];
281 h_input[..32].copy_from_slice(xi);
282 h_input[32] = k as u8;
283 h_input[33] = l as u8;
284
285 let mut hash_out = [0u8; 128]; // need 32 + 64 + 32 = 128 bytes
286 let mut state = sha3::shake256();
287 state.absorb(&h_input);
288 state.squeeze(&mut hash_out);
289
290 let mut rho = [0u8; 32];
291 rho.copy_from_slice(&hash_out[..32]);
292 let mut rho_prime = [0u8; 64];
293 rho_prime.copy_from_slice(&hash_out[32..96]);
294 let mut k_seed = [0u8; 32];
295 k_seed.copy_from_slice(&hash_out[96..128]);
296
297 // Generate s1, s2 from rho'
298 let (mut s1, mut s2) = sample::expand_s::<P>(&rho_prime);
299
300 // s1_hat = NTT(s1)
301 let mut s1_hat = s1;
302 ntt_vec(&mut s1_hat, l);
303
304 // t = NTT^{-1}(A-hat * s1_hat) + s2
305 let mut t = [[0i32; N]; MAX_K];
306 #[cfg(not(feature = "low-mem"))]
307 {
308 let a_hat = sample::expand_a::<P>(&rho);
309 mat_vec_mul(&a_hat, &s1_hat, k, l, &mut t);
310 }
311 #[cfg(feature = "low-mem")]
312 mat_vec_mul_lazy(&rho, &s1_hat, k, l, &mut t);
313 ntt_inv_vec(&mut t, k);
314 // t = t + s2 (in-place into t)
315 {
316 let mut tmp = [[0i32; N]; MAX_K];
317 vec_add(&t, &s2, &mut tmp, k);
318 t = tmp;
319 }
320
321 // (t1, t0) = Power2Round(t)
322 let (t1, t0) = encode::power2round_vec(&t, k);
323
324 // Encode public key
325 let pk = encode::pk_encode::<P>(&rho, &t1);
326
327 // tr = H(pk) (SHAKE256, 64 bytes)
328 let mut tr = [0u8; 64];
329 sha3::shake256_digest(&pk, &mut tr);
330
331 // Encode secret key
332 let sk = encode::sk_encode::<P>(&rho, &k_seed, &tr, &s1, &s2, &t0);
333
334 // Zeroize sensitive data
335 for poly in s1[..l].iter_mut() {
336 for c in poly.iter_mut() {
337 *c = 0;
338 }
339 }
340 for poly in s2[..k].iter_mut() {
341 for c in poly.iter_mut() {
342 *c = 0;
343 }
344 }
345 for byte in rho_prime.iter_mut() {
346 *byte = 0;
347 }
348 for byte in k_seed.iter_mut() {
349 *byte = 0;
350 }
351
352 (pk, sk)
353}
354
355/// Sign a pre-formatted message (deterministic or hedged).
356///
357/// Implements Algorithm 7 of FIPS 204 (ML-DSA.Sign_internal).
358///
359/// This function contains the core rejection sampling loop: candidate
360/// signatures `(z, h)` are generated from a masking vector `y` and the
361/// challenge polynomial `c`, then tested against the norm bounds
362/// `||z||_inf < gamma1 - beta` and `||r0||_inf < gamma2 - beta`. If any
363/// check fails, the counter `kappa` is incremented and a new attempt begins.
364///
365/// - `sk`: encoded secret key bytes.
366/// - `m_prime`: pre-formatted message (e.g., `0x00 || len(ctx) || ctx || msg`).
367/// - `rnd`: 32-byte randomness. Use random bytes for hedged signing or
368/// all-zeros for fully deterministic signing.
369///
370/// # Side-channel countermeasures
371///
372/// With the `sca-protected` Cargo feature enabled (default), this
373/// function activates the additional defences described in the
374/// crate-level documentation:
375///
376/// 1. **Shuffled NTT** on `s1`, `s2`, `t0` — runs once at entry,
377/// via [`super::shuffle::ntt_shuffled`]. Defends against SPA on
378/// the secret-key NTT and disrupts trace alignment for any
379/// later DPA campaign that tries to average aligned traces.
380/// 2. **First-order additive masking** of the NTT-domain secrets,
381/// via [`super::masked::MaskedPoly`]. Each polynomial is split
382/// into two shares mod `q = 8 380 417`; no single intermediate
383/// value reveals the secret to a first-order observer.
384/// 3. **Per-iteration `c·sₓ` multiplications** go through
385/// [`super::masked::masked_pointwise_mul_public`], which
386/// multiplies each share independently by the public challenge
387/// `ĉ`. Because `ĉ` is public, first-order shares are sufficient
388/// — no secret×secret operation is performed.
389/// 4. **Mask refresh after every use**: the share pair is
390/// re-randomized via `MaskedPoly::refresh()` between rejection
391/// iterations, so the same secret never multiplies the same
392/// share twice — defeating higher-order correlation attacks
393/// that would otherwise become available across many rejection
394/// retries on the same key.
395///
396/// All randomness for the SCA layer comes from a deterministic
397/// SHAKE256-based `ScaRng` seeded with `(K ‖ rnd ‖ tr ‖ M')`, so
398/// the function remains deterministic for fixed `rnd`. The masked
399/// path produces signatures **bit-identical** to the unmasked path
400/// — proven by the NIST ACVP siggen vectors, which the SCA build
401/// passes unchanged.
402///
403/// # Errors
404///
405/// Returns [`MlDsaError::InvalidSecretKey`] if `sk` has incorrect length
406/// (checked by the caller in the public API).
407pub fn sign_internal<P: Params>(sk: &[u8], m_prime: &[u8], rnd: &[u8; 32]) -> Result<Vec<u8>, MlDsaError> {
408 let k = P::K;
409 let l = P::L;
410 let gamma1 = P::GAMMA1;
411 let gamma2 = P::GAMMA2;
412 let beta = P::BETA;
413 let omega = P::OMEGA;
414 let c_tilde_len = P::LAMBDA / 4;
415
416 // Decode secret key seeds (128 bytes on stack).
417 //
418 // indexed-sk: decode only rho/K/tr here; the polynomial vectors
419 // s1/s2/t0 are decoded one-at-a-time below, directly into the
420 // NTT-domain destination arrays, avoiding the 23 KB intermediate
421 // tuple that sk_decode() would put on the stack.
422 //
423 // Default: sk_decode() returns the full tuple at once (simpler,
424 // but 23 KB of stack for the return value alone).
425 #[cfg(feature = "indexed-sk")]
426 let (rho, k_seed, tr) = encode::sk_decode_seeds::<P>(sk);
427 #[cfg(not(feature = "indexed-sk"))]
428 let (rho, k_seed, tr, s1, s2, t0) = encode::sk_decode::<P>(sk);
429
430 // ----- ŝ1, ŝ2, t̂0 = NTT(s1, s2, t0) ------------------------------
431 //
432 // This is the most leakage-prone step in Sign.
433 //
434 // indexed-sk: decode each polynomial from the packed sk directly
435 // into the destination slot, then NTT in-place. Only one decoded
436 // polynomial (1 KB) is live at a time instead of the full 23 KB
437 // tuple from sk_decode.
438 //
439 // SCA-protected build: after NTT, each polynomial is additionally
440 // split into masked shares (see below).
441 //
442 // Standard build: straight in-place Montgomery NTT.
443 #[cfg(not(feature = "low-stack"))]
444 let mut s1_hat = {
445 #[cfg(feature = "indexed-sk")]
446 {
447 let mut v = [[0i32; N]; MAX_L];
448 for i in 0..l {
449 encode::sk_decode_s1::<P>(sk, i, &mut v[i]);
450 }
451 v
452 }
453 #[cfg(not(feature = "indexed-sk"))]
454 {
455 s1
456 }
457 };
458 #[cfg(feature = "low-stack")]
459 let mut s1_hat = {
460 let mut v = poly_vec(l);
461 #[cfg(feature = "indexed-sk")]
462 for i in 0..l {
463 encode::sk_decode_s1::<P>(sk, i, &mut v[i]);
464 }
465 #[cfg(not(feature = "indexed-sk"))]
466 for i in 0..l {
467 v[i] = s1[i];
468 }
469 v
470 };
471
472 #[cfg(not(feature = "low-stack"))]
473 let mut s2_hat = {
474 #[cfg(feature = "indexed-sk")]
475 {
476 let mut v = [[0i32; N]; MAX_K];
477 for i in 0..k {
478 encode::sk_decode_s2::<P>(sk, i, &mut v[i]);
479 }
480 v
481 }
482 #[cfg(not(feature = "indexed-sk"))]
483 {
484 s2
485 }
486 };
487 #[cfg(feature = "low-stack")]
488 let mut s2_hat = {
489 let mut v = poly_vec(k);
490 #[cfg(feature = "indexed-sk")]
491 for i in 0..k {
492 encode::sk_decode_s2::<P>(sk, i, &mut v[i]);
493 }
494 #[cfg(not(feature = "indexed-sk"))]
495 for i in 0..k {
496 v[i] = s2[i];
497 }
498 v
499 };
500
501 #[cfg(not(feature = "low-stack"))]
502 let mut t0_hat = {
503 #[cfg(feature = "indexed-sk")]
504 {
505 let mut v = [[0i32; N]; MAX_K];
506 for i in 0..k {
507 encode::sk_decode_t0::<P>(sk, i, &mut v[i]);
508 }
509 v
510 }
511 #[cfg(not(feature = "indexed-sk"))]
512 {
513 t0
514 }
515 };
516 #[cfg(feature = "low-stack")]
517 let mut t0_hat = {
518 let mut v = poly_vec(k);
519 #[cfg(feature = "indexed-sk")]
520 for i in 0..k {
521 encode::sk_decode_t0::<P>(sk, i, &mut v[i]);
522 }
523 #[cfg(not(feature = "indexed-sk"))]
524 for i in 0..k {
525 v[i] = t0[i];
526 }
527 v
528 };
529 #[cfg(feature = "sca-protected")]
530 let (mut s1_hat_m, mut s2_hat_m, mut t0_hat_m, mut sca_rng) = {
531 // Seed the SCA RNG from K ‖ rnd ‖ tr ‖ M'. K and rnd give us
532 // the FIPS 204 hedged-signing entropy; tr and M' bind the
533 // share stream to this particular (key, message) pair so two
534 // signatures over different inputs use uncorrelated shares
535 // even when rnd = 0 (deterministic / ACVP test mode).
536 let mut sca_seed = [0u8; 64];
537 {
538 let mut h = sha3::shake256();
539 h.absorb(b"quantica-mldsa-sca-seed-v1");
540 h.absorb(&k_seed);
541 h.absorb(rnd);
542 h.absorb(&tr);
543 h.absorb(m_prime);
544 h.squeeze(&mut sca_seed);
545 }
546 let mut rng = ScaRng::from_seed(&sca_seed);
547
548 // Step 1 — SPA defence: NTT each secret polynomial through
549 // the Fisher-Yates shuffled NTT, drawing fresh per-level and
550 // per-group permutations from the SCA RNG.
551 for i in 0..l {
552 shuffle::ntt_shuffled(&mut s1_hat[i], &mut rng)?;
553 }
554 for i in 0..k {
555 shuffle::ntt_shuffled(&mut s2_hat[i], &mut rng)?;
556 }
557 for i in 0..k {
558 shuffle::ntt_shuffled(&mut t0_hat[i], &mut rng)?;
559 }
560
561 // Step 2 — DPA defence: split each NTT-domain secret into two
562 // additive shares mod q. The MaskedPoly::zero() initializer
563 // is a stack-resident no-allocation array fill; the real
564 // shares are written immediately below by MaskedPoly::mask.
565 let mut s1m: [MaskedPoly; MAX_L] = core::array::from_fn(|_| MaskedPoly::zero());
566 let mut s2m: [MaskedPoly; MAX_K] = core::array::from_fn(|_| MaskedPoly::zero());
567 let mut t0m: [MaskedPoly; MAX_K] = core::array::from_fn(|_| MaskedPoly::zero());
568 for i in 0..l {
569 s1m[i] = MaskedPoly::mask(&s1_hat[i], &mut rng)?;
570 }
571 for i in 0..k {
572 s2m[i] = MaskedPoly::mask(&s2_hat[i], &mut rng)?;
573 }
574 for i in 0..k {
575 t0m[i] = MaskedPoly::mask(&t0_hat[i], &mut rng)?;
576 }
577
578 // Step 3 — wipe the unmasked NTT-domain buffers. From this
579 // point on the secret only exists as `(share0, share1)` pairs;
580 // any side-channel observation of `s1_hat[i]` etc. yields zero
581 // information about the underlying coefficients.
582 for i in 0..l {
583 s1_hat[i] = [0i32; N];
584 }
585 for i in 0..k {
586 s2_hat[i] = [0i32; N];
587 }
588 for i in 0..k {
589 t0_hat[i] = [0i32; N];
590 }
591 (s1m, s2m, t0m, rng)
592 };
593 #[cfg(not(feature = "sca-protected"))]
594 {
595 // compressed-challenge: secrets stay in time domain for
596 // schoolbook multiplication. No NTT needed.
597 // small-secret: s1/s2 are converted to SmallPoly and NTT'd
598 // via the i16 Kyber NTT instead. t0 still uses i32 NTT
599 // (coefficients too large for i16).
600 #[cfg(not(any(feature = "compressed-challenge", feature = "small-secret")))]
601 {
602 ntt_vec(&mut s1_hat, l);
603 ntt_vec(&mut s2_hat, k);
604 ntt_vec(&mut t0_hat, k);
605 }
606 #[cfg(all(feature = "small-secret", not(feature = "compressed-challenge")))]
607 {
608 // t0 still needs i32 NTT (coefficients up to 4096).
609 ntt_vec(&mut t0_hat, k);
610 // s1/s2 are converted to SmallPoly below; we don't NTT the i32 versions.
611 }
612 }
613
614 // small-secret: convert s1/s2 to i16 SmallPoly and NTT via Kyber NTT.
615 // The i32 s1_hat/s2_hat arrays are kept for any non-small-secret
616 // code paths but are effectively unused when small-secret is on.
617 #[cfg(feature = "small-secret")]
618 let (s1_small, s2_small) = {
619 let mut s1s: [SmallPoly; MAX_L] = core::array::from_fn(|_| SmallPoly::zero());
620 let mut s2s: [SmallPoly; MAX_K] = core::array::from_fn(|_| SmallPoly::zero());
621 for i in 0..l {
622 s1s[i] = SmallPoly::from_i32(&s1_hat[i]);
623 smallpoly::small_ntt(&mut s1s[i]);
624 }
625 for i in 0..k {
626 s2s[i] = SmallPoly::from_i32(&s2_hat[i]);
627 smallpoly::small_ntt(&mut s2s[i]);
628 }
629 (s1s, s2s)
630 };
631
632 // A-hat = ExpandA(rho)
633 // Default: full matrix on stack (57 KB). Low-mem: recomputed on-the-fly.
634 #[cfg(not(feature = "low-mem"))]
635 let a_hat = sample::expand_a::<P>(&rho);
636
637 // mu = H(tr || M')
638 let mut mu = [0u8; 64];
639 {
640 let mut state = sha3::shake256();
641 state.absorb(&tr);
642 state.absorb(m_prime);
643 state.squeeze(&mut mu);
644 }
645
646 // rho'' = H(K || rnd || mu)
647 let mut rho_double_prime = [0u8; 64];
648 {
649 let mut state = sha3::shake256();
650 state.absorb(&k_seed);
651 state.absorb(rnd);
652 state.absorb(&mu);
653 state.squeeze(&mut rho_double_prime);
654 }
655
656 let mut kappa: u16 = 0;
657
658 loop {
659 // T1-A — refresh the persistent masked-secret-poly shares at
660 // the **start** of every rejection iteration, before any
661 // operation on them (Hermelink-Ning-Petri 2025/276 §4).
662 // `s1_hat_m`, `s2_hat_m`, `t0_hat_m` survive across all
663 // iterations (declared at line 530); without per-iteration
664 // refresh, higher-order DPA aggregating traces over multiple
665 // iterations sees correlated share pairs. Cost is one
666 // `MaskedPoly::refresh` per polynomial per iteration —
667 // identical to the previous end-of-cs/ct refresh placement;
668 // KAT output bytes are unchanged because the mask cancels in
669 // every `unmask()`.
670 #[cfg(feature = "sca-protected")]
671 {
672 for i in 0..l {
673 s1_hat_m[i].refresh(&mut sca_rng)?;
674 }
675 for i in 0..k {
676 s2_hat_m[i].refresh(&mut sca_rng)?;
677 }
678 for i in 0..k {
679 t0_hat_m[i].refresh(&mut sca_rng)?;
680 }
681 }
682
683 // y = ExpandMask(rho'', kappa)
684 //
685 // sca-masked-y: sample y directly as arithmetic shares from
686 // SHAKE256 (MaskedPoly::sample_expand_mask), keep it masked
687 // through NTT + mat_vec_mul + iNTT. Only unmask y and w at
688 // the end of the linear ops: w is about to be published via
689 // w1 in c_tilde, and y is recoverable from z = y + cs1 in
690 // the final signature anyway.
691 //
692 // Default: sample y in clear via expand_mask.
693 #[cfg(not(feature = "sca-masked-y"))]
694 let y = sample::expand_mask::<P>(&rho_double_prime, kappa);
695
696 #[cfg(feature = "sca-masked-y")]
697 let (y, w_precomputed) = {
698 // Full masking pipeline: y stays split into two arithmetic
699 // shares from sampling through NTT, A·y, and iNTT. Only
700 // once w reaches its "about-to-be-published" form do we
701 // unmask y and w together.
702 //
703 // Countermeasure references:
704 // ePrint 2025/276 — Hermelink–Ning–Petri, DPA on y
705 // ePrint 2025/582 — Rejected-signature timing leak
706
707 // 1. Sample y as arithmetic shares from SHAKE256. The
708 // unmasked coefficient value only transits through CPU
709 // registers — never written to RAM.
710 let mut y_m: [masked::MaskedPoly; MAX_L] = core::array::from_fn(|_| masked::MaskedPoly::zero());
711 for r in 0..l {
712 y_m[r] = masked::MaskedPoly::sample_expand_mask(
713 &rho_double_prime,
714 kappa + r as u16,
715 gamma1,
716 P::BITLEN_GAMMA1_MINUS1,
717 );
718 }
719
720 // 2. Masked NTT into y_hat_m — y_m is preserved for the
721 // later time-domain unmask (needed by z = y + cs1).
722 let mut y_hat_m: [masked::MaskedPoly; MAX_L] = core::array::from_fn(|_| masked::MaskedPoly::zero());
723 for r in 0..l {
724 y_hat_m[r].share0 = y_m[r].share0;
725 y_hat_m[r].share1 = y_m[r].share1;
726 masked::masked_ntt(&mut y_hat_m[r]);
727 }
728
729 // 3. Masked A · y_hat → w_m (NTT domain, masked).
730 // A is public; the matrix multiplication touches each
731 // share independently.
732 let mut w_m: [masked::MaskedPoly; MAX_K] = core::array::from_fn(|_| masked::MaskedPoly::zero());
733 #[cfg(not(feature = "low-mem"))]
734 masked::masked_mat_vec_mul(&a_hat, &y_hat_m, k, l, &mut w_m);
735 #[cfg(feature = "low-mem")]
736 masked::masked_mat_vec_mul_lazy(&rho, &y_hat_m, k, l, &mut w_m);
737
738 for r in 0..l {
739 y_hat_m[r].zeroize();
740 }
741
742 // 4. Masked iNTT on each share.
743 for i in 0..k {
744 masked::masked_ntt_inv(&mut w_m[i]);
745 }
746
747 // 5. Unmask w — w1 = HighBits(w) is public (it ends up in
748 // c_tilde). Output in [0, q-1]; `decompose` handles that.
749 let mut w_tmp = [[0i32; N]; MAX_K];
750 for i in 0..k {
751 w_tmp[i] = w_m[i].unmask();
752 w_m[i].zeroize();
753 }
754
755 // 6. Unmask y to centered (-gamma1, gamma1] time domain,
756 // matching the default `expand_mask` output range.
757 let mut y_out = [[0i32; N]; MAX_L];
758 for r in 0..l {
759 let um = y_m[r].unmask();
760 for n in 0..N {
761 let mut v = um[n];
762 if v > Q / 2 {
763 v -= Q;
764 }
765 y_out[r][n] = v;
766 }
767 y_m[r].zeroize();
768 }
769 (y_out, w_tmp)
770 };
771
772 // ============================================================
773 // Rejection-loop body.
774 //
775 // low-stack build: temporary polynomial vectors (w, w1, cs1,
776 // cs2, w_minus_cs2, r0, ct0, neg_ct0, w_cs2_ct0) are
777 // heap-allocated via Vec with scoped lifetimes and explicit
778 // drop() calls. Only ~23 KB of heap is live at peak instead
779 // of ~96 KB of stack.
780 //
781 // Default build: everything on the stack as fixed arrays.
782 // ============================================================
783
784 // w = NTT^{-1}(A-hat * NTT(y))
785 //
786 // Default path: compute y_hat = NTT(y), then w_tmp = iNTT(A·y_hat).
787 // sca-masked-y: w was already computed in the masked block
788 // above (w_precomputed) — the unmasked y was never in RAM.
789
790 // --- Compute w and w1, then derive c_tilde ---------------
791 //
792 // compressed-poly: after iNTT(w), pack w into 3-byte/coeff
793 // compressed form (−25% RAM), then derive w1 and w_minus_cs2
794 // from the compressed representation.
795 #[cfg(not(feature = "compressed-poly"))]
796 {
797 #[cfg(not(feature = "low-stack"))]
798 let mut _w_full = [[0i32; N]; MAX_K];
799 #[cfg(feature = "low-stack")]
800 let mut _w_full = poly_vec(k);
801 // (assigned below, used via w_ref)
802 }
803
804 // Compute w into a temporary full-poly buffer, then either
805 // keep it (default) or compress it (compressed-poly).
806 #[cfg(not(feature = "sca-masked-y"))]
807 let mut w_tmp = {
808 let mut y_hat = y;
809 ntt_vec(&mut y_hat, l);
810 let mut wt = [[0i32; N]; MAX_K];
811 #[cfg(not(feature = "low-mem"))]
812 mat_vec_mul(&a_hat, &y_hat, k, l, &mut wt);
813 #[cfg(feature = "low-mem")]
814 mat_vec_mul_lazy(&rho, &y_hat, k, l, &mut wt);
815 ntt_inv_vec(&mut wt, k);
816 wt
817 };
818 #[cfg(feature = "sca-masked-y")]
819 let mut w_tmp = w_precomputed;
820
821 // compressed-poly: pack w into 3-byte/coeff storage.
822 #[cfg(feature = "compressed-poly")]
823 let w_comp = {
824 let mut wc = compressed::CompressedVecK::new(k);
825 for i in 0..k {
826 // Reduce to [0, q-1] before packing.
827 for c in w_tmp[i].iter_mut() {
828 *c = mod_q(*c);
829 }
830 wc.pack(i, &w_tmp[i]);
831 }
832 wc
833 };
834
835 // w1 = HighBits(w) — works on the full-poly tmp (before we drop it).
836 #[cfg(not(feature = "low-stack"))]
837 let mut w1 = [[0i32; N]; MAX_K];
838 #[cfg(feature = "low-stack")]
839 let mut w1 = poly_vec(k);
840 decompose::high_bits_vec(&w_tmp, gamma2, &mut w1, k);
841
842 // In non-compressed mode, keep w_tmp as "w" for later vec_sub.
843 // In compressed mode, drop w_tmp (we'll read from w_comp).
844 #[cfg(not(feature = "compressed-poly"))]
845 let w = w_tmp;
846 #[cfg(feature = "compressed-poly")]
847 drop(w_tmp);
848
849 let w1_encoded = encode::w1_encode::<P>(&w1);
850 #[cfg(feature = "low-stack")]
851 drop(w1);
852
853 let mut c_tilde_buf = [0u8; 64];
854 {
855 let mut state = sha3::shake256();
856 state.absorb(&mu);
857 state.absorb(&w1_encoded);
858 state.squeeze(&mut c_tilde_buf[..c_tilde_len]);
859 }
860 let c_tilde = &c_tilde_buf[..c_tilde_len];
861
862 // --- challenge computation --------------------------------
863 //
864 // Default: c_hat = NTT(SampleInBall(c_tilde)), 2 KB stack.
865 // compressed-challenge: compress c into 68 bytes and use
866 // schoolbook multiplication in time domain, saving ~2 KB.
867 let c = sample::sample_in_ball::<P>(c_tilde);
868 #[cfg(not(feature = "compressed-challenge"))]
869 let c_hat = {
870 let mut ch = c;
871 for coeff in ch.iter_mut() {
872 *coeff = mod_q(*coeff);
873 }
874 ntt::ntt(&mut ch);
875 ch
876 };
877 #[cfg(feature = "compressed-challenge")]
878 let c_comp = {
879 let mut cc = [0u8; compressed::COMPRESSED_CHALLENGE_BYTES];
880 compressed::challenge_compress(&mut cc, &c, P::TAU);
881 cc
882 };
883 // small-secret: also convert c to SmallPoly NTT for basemul.
884 #[cfg(feature = "small-secret")]
885 let c_small_ntt = {
886 let mut cs = SmallPoly::from_i32(&c);
887 smallpoly::small_ntt(&mut cs);
888 cs
889 };
890
891 // ============================================================
892 // union-buffer path: single 1 KB workspace reused per poly.
893 // Processes L + K iterations sequentially, only z and h persist.
894 // ============================================================
895 #[cfg(feature = "union-buffer")]
896 {
897 let mut z = [[0i32; N]; MAX_L];
898 let mut tmp = [0i32; N];
899 let mut rejected = false;
900
901 // Phase 1: z[i] = y[i] + c*s1[i]
902 for l_idx in 0..l {
903 #[cfg(all(not(feature = "compressed-challenge"), not(feature = "small-secret")))]
904 {
905 tmp = ntt::pointwise_mul(&c_hat, &s1_hat[l_idx]);
906 ntt::ntt_inv(&mut tmp);
907 }
908 #[cfg(feature = "compressed-challenge")]
909 {
910 tmp = [0i32; N];
911 compressed::schoolbook_mul_add(&mut tmp, &c_comp, &s1_hat[l_idx], P::TAU);
912 }
913 #[cfg(all(feature = "small-secret", not(feature = "compressed-challenge")))]
914 {
915 tmp = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s1_small[l_idx]);
916 }
917 z[l_idx] = ntt::poly_add(&y[l_idx], &tmp);
918 }
919 if !check_norm_vec(&z, gamma1 - beta, l) {
920 kappa += l as u16;
921 continue;
922 }
923
924 // Phase 2: per k_idx — cs2, r0, ct0, hint
925 let mut h = [[0i32; N]; MAX_K];
926 let mut total_hints = 0usize;
927 let mut wbuf = [0i32; N];
928
929 for k_idx in 0..k {
930 if rejected {
931 break;
932 }
933 // cs2 → tmp
934 #[cfg(all(not(feature = "compressed-challenge"), not(feature = "small-secret")))]
935 {
936 tmp = ntt::pointwise_mul(&c_hat, &s2_hat[k_idx]);
937 ntt::ntt_inv(&mut tmp);
938 }
939 #[cfg(feature = "compressed-challenge")]
940 {
941 tmp = [0i32; N];
942 compressed::schoolbook_mul_add(&mut tmp, &c_comp, &s2_hat[k_idx], P::TAU);
943 }
944 #[cfg(all(feature = "small-secret", not(feature = "compressed-challenge")))]
945 {
946 tmp = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s2_small[k_idx]);
947 }
948
949 // wbuf = w[k_idx] - cs2
950 #[cfg(not(feature = "compressed-poly"))]
951 for j in 0..N {
952 wbuf[j] = w[k_idx][j] - tmp[j];
953 }
954 #[cfg(feature = "compressed-poly")]
955 w_comp.sub_into(k_idx, &tmp, &mut wbuf);
956
957 // r0 check in tmp
958 for j in 0..N {
959 tmp[j] = decompose::low_bits(wbuf[j], gamma2);
960 }
961 if !check_norm(&tmp, gamma2 - beta) {
962 rejected = true;
963 continue;
964 }
965
966 // ct0 → tmp
967 #[cfg(not(feature = "compressed-challenge"))]
968 {
969 tmp = ntt::pointwise_mul(&c_hat, &t0_hat[k_idx]);
970 ntt::ntt_inv(&mut tmp);
971 }
972 #[cfg(feature = "compressed-challenge")]
973 {
974 tmp = [0i32; N];
975 compressed::schoolbook_mul_add(&mut tmp, &c_comp, &t0_hat[k_idx], P::TAU);
976 }
977
978 if !check_norm(&tmp, gamma2) {
979 rejected = true;
980 continue;
981 }
982
983 // hint for this k_idx
984 for j in 0..N {
985 h[k_idx][j] = decompose::make_hint(mod_q(-tmp[j]), wbuf[j] + tmp[j], gamma2);
986 if h[k_idx][j] == 1 {
987 total_hints += 1;
988 }
989 }
990 }
991
992 if rejected || total_hints > omega {
993 kappa += l as u16;
994 continue;
995 }
996
997 // Center z and encode
998 for poly in z[..l].iter_mut() {
999 for c in poly.iter_mut() {
1000 *c = mod_q(*c);
1001 if *c > Q / 2 {
1002 *c -= Q;
1003 }
1004 }
1005 }
1006 let sig = encode::sig_encode::<P>(c_tilde, &z, &h);
1007 return Ok(sig);
1008 }
1009
1010 // ============================================================
1011 // Standard path (non-union-buffer)
1012 // ============================================================
1013 #[cfg(not(feature = "union-buffer"))]
1014 {
1015 // --- cs1 = ĉ · ŝ1, then z = y + cs1 --------------------
1016 #[cfg(not(feature = "low-stack"))]
1017 let mut cs1 = [[0i32; N]; MAX_L];
1018 #[cfg(feature = "low-stack")]
1019 let mut cs1 = poly_vec(l);
1020
1021 #[cfg(feature = "sca-protected")]
1022 for i in 0..l {
1023 let prod = masked::masked_pointwise_mul_public(&s1_hat_m[i], &c_hat);
1024 cs1[i] = prod.unmask();
1025 ntt::ntt_inv(&mut cs1[i]);
1026 // refresh of s1_hat_m happens at the head of the
1027 // next rejection iteration (T1-A).
1028 }
1029 #[cfg(all(
1030 not(feature = "sca-protected"),
1031 not(feature = "compressed-challenge"),
1032 not(feature = "small-secret")
1033 ))]
1034 for i in 0..l {
1035 cs1[i] = ntt::pointwise_mul(&c_hat, &s1_hat[i]);
1036 ntt::ntt_inv(&mut cs1[i]);
1037 }
1038 #[cfg(all(not(feature = "sca-protected"), feature = "compressed-challenge"))]
1039 for i in 0..l {
1040 cs1[i] = [0i32; N];
1041 compressed::schoolbook_mul_add(&mut cs1[i], &c_comp, &s1_hat[i], P::TAU);
1042 }
1043 #[cfg(all(
1044 not(feature = "sca-protected"),
1045 feature = "small-secret",
1046 not(feature = "compressed-challenge")
1047 ))]
1048 for i in 0..l {
1049 cs1[i] = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s1_small[i]);
1050 }
1051
1052 #[cfg(not(feature = "low-stack"))]
1053 let mut z = [[0i32; N]; MAX_L];
1054 #[cfg(feature = "low-stack")]
1055 let mut z = poly_vec(l);
1056 vec_add(&y, &cs1, &mut z, l);
1057 // cs1 no longer needed.
1058 #[cfg(feature = "low-stack")]
1059 drop(cs1);
1060
1061 // --- cs2, w_minus_cs2, r0 --------------------------------
1062 #[cfg(not(feature = "low-stack"))]
1063 let mut cs2 = [[0i32; N]; MAX_K];
1064 #[cfg(feature = "low-stack")]
1065 let mut cs2 = poly_vec(k);
1066
1067 #[cfg(feature = "sca-protected")]
1068 for i in 0..k {
1069 let prod = masked::masked_pointwise_mul_public(&s2_hat_m[i], &c_hat);
1070 cs2[i] = prod.unmask();
1071 ntt::ntt_inv(&mut cs2[i]);
1072 // refresh of s2_hat_m happens at the head of the
1073 // next rejection iteration (T1-A).
1074 }
1075 #[cfg(all(
1076 not(feature = "sca-protected"),
1077 not(feature = "compressed-challenge"),
1078 not(feature = "small-secret")
1079 ))]
1080 for i in 0..k {
1081 cs2[i] = ntt::pointwise_mul(&c_hat, &s2_hat[i]);
1082 ntt::ntt_inv(&mut cs2[i]);
1083 }
1084 #[cfg(all(not(feature = "sca-protected"), feature = "compressed-challenge"))]
1085 for i in 0..k {
1086 cs2[i] = [0i32; N];
1087 compressed::schoolbook_mul_add(&mut cs2[i], &c_comp, &s2_hat[i], P::TAU);
1088 }
1089 #[cfg(all(
1090 not(feature = "sca-protected"),
1091 feature = "small-secret",
1092 not(feature = "compressed-challenge")
1093 ))]
1094 for i in 0..k {
1095 cs2[i] = smallpoly::small_basemul_invntt_widen(&c_small_ntt, &s2_small[i]);
1096 }
1097
1098 #[cfg(not(feature = "low-stack"))]
1099 let mut w_minus_cs2 = [[0i32; N]; MAX_K];
1100 #[cfg(feature = "low-stack")]
1101 let mut w_minus_cs2 = poly_vec(k);
1102 #[cfg(not(feature = "compressed-poly"))]
1103 vec_sub(&w, &cs2, &mut w_minus_cs2, k);
1104 #[cfg(feature = "compressed-poly")]
1105 for i in 0..k {
1106 w_comp.sub_into(i, &cs2[i], &mut w_minus_cs2[i]);
1107 }
1108 // cs2 and w/w_comp no longer needed for the norm checks.
1109 #[cfg(feature = "low-stack")]
1110 drop(cs2);
1111 #[cfg(all(feature = "low-stack", not(feature = "compressed-poly")))]
1112 drop(w);
1113 #[cfg(feature = "compressed-poly")]
1114 drop(w_comp);
1115
1116 #[cfg(not(feature = "low-stack"))]
1117 let mut r0 = [[0i32; N]; MAX_K];
1118 #[cfg(feature = "low-stack")]
1119 let mut r0 = poly_vec(k);
1120 decompose::low_bits_vec(&w_minus_cs2, gamma2, &mut r0, k);
1121
1122 // Norm checks. Standard build: early-abort for performance.
1123 // sca-ct-rejection build: collect all flags and decide at end.
1124 #[cfg(not(feature = "sca-ct-rejection"))]
1125 {
1126 if !check_norm_vec(&z, gamma1 - beta, l) {
1127 kappa += l as u16;
1128 continue;
1129 }
1130 if !check_norm_vec(&r0, gamma2 - beta, k) {
1131 kappa += l as u16;
1132 continue;
1133 }
1134 }
1135 #[cfg(feature = "sca-ct-rejection")]
1136 let mut _reject_flag = {
1137 let z_ok = check_norm_vec(&z, gamma1 - beta, l);
1138 let r0_ok = check_norm_vec(&r0, gamma2 - beta, k);
1139 !(z_ok & r0_ok)
1140 };
1141 // r0 no longer needed.
1142 #[cfg(feature = "low-stack")]
1143 drop(r0);
1144
1145 // --- ct0, hint computation --------------------------------
1146 #[cfg(not(feature = "low-stack"))]
1147 let mut ct0 = [[0i32; N]; MAX_K];
1148 #[cfg(feature = "low-stack")]
1149 let mut ct0 = poly_vec(k);
1150
1151 #[cfg(feature = "sca-protected")]
1152 for i in 0..k {
1153 let prod = masked::masked_pointwise_mul_public(&t0_hat_m[i], &c_hat);
1154 ct0[i] = prod.unmask();
1155 ntt::ntt_inv(&mut ct0[i]);
1156 // refresh of t0_hat_m happens at the head of the
1157 // next rejection iteration (T1-A).
1158 }
1159 #[cfg(all(not(feature = "sca-protected"), not(feature = "compressed-challenge")))]
1160 for i in 0..k {
1161 ct0[i] = ntt::pointwise_mul(&c_hat, &t0_hat[i]);
1162 ntt::ntt_inv(&mut ct0[i]);
1163 }
1164 #[cfg(all(not(feature = "sca-protected"), feature = "compressed-challenge"))]
1165 for i in 0..k {
1166 ct0[i] = [0i32; N];
1167 compressed::schoolbook_mul_add(&mut ct0[i], &c_comp, &t0_hat[i], P::TAU);
1168 }
1169
1170 // Check ||ct0||_inf < gamma2
1171 #[cfg(not(feature = "sca-ct-rejection"))]
1172 {
1173 if !check_norm_vec(&ct0, gamma2, k) {
1174 kappa += l as u16;
1175 continue;
1176 }
1177 }
1178 #[cfg(feature = "sca-ct-rejection")]
1179 {
1180 _reject_flag |= !check_norm_vec(&ct0, gamma2, k);
1181 }
1182
1183 // h = MakeHint(-ct0, w_minus_cs2 + ct0)
1184 #[cfg(not(feature = "low-stack"))]
1185 let mut w_cs2_ct0 = [[0i32; N]; MAX_K];
1186 #[cfg(feature = "low-stack")]
1187 let mut w_cs2_ct0 = poly_vec(k);
1188 vec_add(&w_minus_cs2, &ct0, &mut w_cs2_ct0, k);
1189
1190 #[cfg(not(feature = "low-stack"))]
1191 let mut neg_ct0 = [[0i32; N]; MAX_K];
1192 #[cfg(feature = "low-stack")]
1193 let mut neg_ct0 = poly_vec(k);
1194 for i in 0..k {
1195 for j in 0..N {
1196 neg_ct0[i][j] = mod_q(-ct0[i][j]);
1197 }
1198 }
1199 // ct0 and w_minus_cs2 no longer needed.
1200 #[cfg(feature = "low-stack")]
1201 {
1202 drop(ct0);
1203 drop(w_minus_cs2);
1204 }
1205
1206 let (h, num_ones) = decompose::make_hint_vec(&neg_ct0, &w_cs2_ct0, gamma2, k);
1207 // neg_ct0 and w_cs2_ct0 no longer needed.
1208 #[cfg(feature = "low-stack")]
1209 {
1210 drop(neg_ct0);
1211 drop(w_cs2_ct0);
1212 }
1213
1214 #[cfg(not(feature = "sca-ct-rejection"))]
1215 {
1216 if num_ones > omega {
1217 kappa += l as u16;
1218 continue;
1219 }
1220 }
1221 #[cfg(feature = "sca-ct-rejection")]
1222 {
1223 _reject_flag |= num_ones > omega;
1224 if _reject_flag {
1225 kappa += l as u16;
1226 continue;
1227 }
1228 }
1229
1230 // Center z coefficients to [-gamma1+1, gamma1] before encoding
1231 for poly in z[..l].iter_mut() {
1232 for c in poly.iter_mut() {
1233 *c = mod_q(*c);
1234 if *c > Q / 2 {
1235 *c -= Q;
1236 }
1237 }
1238 }
1239
1240 // Encode signature
1241 let sig = encode::sig_encode::<P>(c_tilde, &z, &h);
1242 return Ok(sig);
1243 } // end #[cfg(not(feature = "union-buffer"))]
1244 }
1245}
1246
1247/// Verify a signature against a pre-formatted message.
1248///
1249/// Implements Algorithm 8 of FIPS 204 (ML-DSA.Verify_internal).
1250///
1251/// Recomputes the commitment w1' from the public key, signature components
1252/// (c_tilde, z, h), and the message hash mu. Verification succeeds when the
1253/// recomputed commitment hash matches the c_tilde embedded in the signature.
1254///
1255/// - `pk`: encoded public key (must be `P::PK_LEN` bytes).
1256/// - `m_prime`: pre-formatted message.
1257/// - `sig`: encoded signature (must be `P::SIG_LEN` bytes).
1258///
1259/// Returns `Ok(true)` if the signature is valid, `Ok(false)` otherwise.
1260///
1261/// # Errors
1262///
1263/// - [`MlDsaError::InvalidPublicKey`] if `pk` has the wrong length.
1264/// - [`MlDsaError::InvalidSignature`] if `sig` has the wrong length.
1265pub fn verify_internal<P: Params>(pk: &[u8], m_prime: &[u8], sig: &[u8]) -> Result<bool, MlDsaError> {
1266 let k = P::K;
1267 let l = P::L;
1268 let gamma1 = P::GAMMA1;
1269 let gamma2 = P::GAMMA2;
1270 let beta = P::BETA;
1271 let omega = P::OMEGA;
1272 let c_tilde_len = P::LAMBDA / 4;
1273
1274 if pk.len() != P::PK_LEN {
1275 return Err(MlDsaError::InvalidPublicKey);
1276 }
1277 if sig.len() != P::SIG_LEN {
1278 return Err(MlDsaError::InvalidSignature);
1279 }
1280
1281 // Decode public key
1282 let (rho, t1) = encode::pk_decode::<P>(pk);
1283
1284 // tr = H(pk) (64 bytes)
1285 let mut tr = [0u8; 64];
1286 sha3::shake256_digest(pk, &mut tr);
1287
1288 // Decode signature
1289 let (c_tilde, z, h) = match encode::sig_decode::<P>(sig) {
1290 Some(x) => x,
1291 None => return Ok(false),
1292 };
1293
1294 // Check ||z||_inf < gamma1 - beta
1295 if !check_norm_vec(&z, gamma1 - beta, l) {
1296 return Ok(false);
1297 }
1298
1299 // A-hat = ExpandA(rho)
1300 #[cfg(not(feature = "low-mem"))]
1301 let a_hat = sample::expand_a::<P>(&rho);
1302
1303 // mu = H(tr || M')
1304 let mut mu = [0u8; 64];
1305 {
1306 let mut state = sha3::shake256();
1307 state.absorb(&tr);
1308 state.absorb(m_prime);
1309 state.squeeze(&mut mu);
1310 }
1311
1312 // c = SampleInBall(c_tilde)
1313 let mut c = sample::sample_in_ball::<P>(&c_tilde);
1314 for coeff in c.iter_mut() {
1315 *coeff = mod_q(*coeff);
1316 }
1317 let mut c_hat = c;
1318 ntt::ntt(&mut c_hat);
1319
1320 // z_hat = NTT(z)
1321 let mut z_hat = z;
1322 ntt_vec(&mut z_hat, l);
1323
1324 // w'_approx = NTT^{-1}(A-hat * z_hat - c_hat * NTT(t1 * 2^d))
1325 // First compute NTT(t1 * 2^d)
1326 let mut t1_2d_hat = [[0i32; N]; MAX_K];
1327 for i in 0..k {
1328 for j in 0..N {
1329 t1_2d_hat[i][j] = mod_q(t1[i][j] * (1 << D));
1330 }
1331 ntt::ntt(&mut t1_2d_hat[i]);
1332 }
1333
1334 // A-hat * z_hat
1335 let mut az = [[0i32; N]; MAX_K];
1336 #[cfg(not(feature = "low-mem"))]
1337 mat_vec_mul(&a_hat, &z_hat, k, l, &mut az);
1338 #[cfg(feature = "low-mem")]
1339 mat_vec_mul_lazy(&rho, &z_hat, k, l, &mut az);
1340
1341 // c_hat * t1_2d_hat (component-wise)
1342 let mut ct1 = [[0i32; N]; MAX_K];
1343 for i in 0..k {
1344 ct1[i] = ntt::pointwise_mul(&c_hat, &t1_2d_hat[i]);
1345 }
1346
1347 // w'_approx = NTT^{-1}(az - ct1)
1348 let mut w_approx = [[0i32; N]; MAX_K];
1349 vec_sub(&az, &ct1, &mut w_approx, k);
1350 ntt_inv_vec(&mut w_approx, k);
1351
1352 // w1' = UseHint(h, w'_approx)
1353 let w1_prime = decompose::use_hint_vec(&h, &w_approx, gamma2, k);
1354
1355 // Recompute c_tilde' = H(mu || w1Encode(w1'))
1356 let w1_encoded = encode::w1_encode::<P>(&w1_prime);
1357 let mut c_tilde_prime = vec![0u8; c_tilde_len];
1358 {
1359 let mut state = sha3::shake256();
1360 state.absorb(&mu);
1361 state.absorb(&w1_encoded);
1362 state.squeeze(&mut c_tilde_prime);
1363 }
1364
1365 // Check c_tilde == c_tilde'
1366 // Also verify hint weight
1367 let mut hint_count = 0usize;
1368 for i in 0..k {
1369 for &c in h[i].iter() {
1370 hint_count += c as usize;
1371 }
1372 }
1373 if hint_count > omega {
1374 return Ok(false);
1375 }
1376
1377 Ok(c_tilde == c_tilde_prime)
1378}
1379
1380/// Generate an ML-DSA key pair.
1381///
1382/// Implements Algorithm 1 of FIPS 204 (ML-DSA.KeyGen). Draws 32 random
1383/// bytes from `rng` and delegates to [`keygen_internal`].
1384///
1385/// Returns `(pk, sk)` as byte vectors.
1386///
1387/// # Errors
1388///
1389/// Returns [`MlDsaError::RngFailure`] if the RNG cannot provide bytes.
1390pub fn keygen<P: Params>(rng: &mut dyn CryptoRng) -> Result<(Vec<u8>, Vec<u8>), MlDsaError> {
1391 let mut xi = [0u8; 32];
1392 rng.fill_bytes(&mut xi)?;
1393 let result = keygen_internal::<P>(&xi);
1394 Ok(result)
1395}
1396
1397/// Sign a message with an optional context string (hedged mode).
1398///
1399/// Implements Algorithm 2 of FIPS 204 (ML-DSA.Sign). Constructs the
1400/// pre-formatted message `M' = 0x00 || len(ctx) || ctx || msg`, draws 32
1401/// random bytes for hedged signing, and calls `sign_internal`.
1402///
1403/// - `sk`: secret key (must be `P::SK_LEN` bytes).
1404/// - `msg`: message to sign.
1405/// - `ctx`: optional context string (at most 255 bytes).
1406/// - `rng`: source of randomness for the hedged nonce.
1407///
1408/// # Errors
1409///
1410/// - [`MlDsaError::ContextTooLong`] if `ctx` exceeds 255 bytes.
1411/// - [`MlDsaError::InvalidSecretKey`] if `sk` has the wrong length.
1412/// - [`MlDsaError::RngFailure`] if the RNG cannot provide bytes.
1413pub fn sign<P: Params>(sk: &[u8], msg: &[u8], ctx: &[u8], rng: &mut dyn CryptoRng) -> Result<Vec<u8>, MlDsaError> {
1414 if ctx.len() > 255 {
1415 return Err(MlDsaError::ContextTooLong);
1416 }
1417 if sk.len() != P::SK_LEN {
1418 return Err(MlDsaError::InvalidSecretKey);
1419 }
1420
1421 // M' = 0x00 || len(ctx) || ctx || M
1422 let mut m_prime = Vec::with_capacity(1 + 1 + ctx.len() + msg.len());
1423 m_prime.push(0x00);
1424 m_prime.push(ctx.len() as u8);
1425 m_prime.extend_from_slice(ctx);
1426 m_prime.extend_from_slice(msg);
1427
1428 // Random bytes for hedged signing
1429 let mut rnd = [0u8; 32];
1430 rng.fill_bytes(&mut rnd)?;
1431
1432 sign_internal::<P>(sk, &m_prime, &rnd)
1433}
1434
1435/// Verify a signature on a message with an optional context string.
1436///
1437/// Implements Algorithm 3 of FIPS 204 (ML-DSA.Verify). Constructs the
1438/// pre-formatted message `M' = 0x00 || len(ctx) || ctx || msg` and
1439/// delegates to [`verify_internal`].
1440///
1441/// - `pk`: public key (must be `P::PK_LEN` bytes).
1442/// - `msg`: the signed message.
1443/// - `ctx`: the context string used at signing time (at most 255 bytes).
1444/// - `sig`: the signature (must be `P::SIG_LEN` bytes).
1445///
1446/// Returns `Ok(true)` if the signature is valid, `Ok(false)` otherwise.
1447///
1448/// # Errors
1449///
1450/// - [`MlDsaError::ContextTooLong`] if `ctx` exceeds 255 bytes.
1451/// - [`MlDsaError::InvalidPublicKey`] if `pk` has the wrong length.
1452/// - [`MlDsaError::InvalidSignature`] if `sig` has the wrong length.
1453pub fn verify<P: Params>(pk: &[u8], msg: &[u8], ctx: &[u8], sig: &[u8]) -> Result<bool, MlDsaError> {
1454 if ctx.len() > 255 {
1455 return Err(MlDsaError::ContextTooLong);
1456 }
1457 if pk.len() != P::PK_LEN {
1458 return Err(MlDsaError::InvalidPublicKey);
1459 }
1460 if sig.len() != P::SIG_LEN {
1461 return Err(MlDsaError::InvalidSignature);
1462 }
1463
1464 // M' = 0x00 || len(ctx) || ctx || M
1465 let mut m_prime = Vec::with_capacity(1 + 1 + ctx.len() + msg.len());
1466 m_prime.push(0x00);
1467 m_prime.push(ctx.len() as u8);
1468 m_prime.extend_from_slice(ctx);
1469 m_prime.extend_from_slice(msg);
1470
1471 verify_internal::<P>(pk, &m_prime, sig)
1472}