Skip to main content

quantica/ml_kem/
sample.rs

1/// Polynomial sampling algorithms for ML-KEM (FIPS 203 Section 4.2.2).
2///
3/// Two sampling methods are provided:
4///
5/// - [`sample_ntt`] -- Algorithm 7 (SampleNTT): rejection sampling from a SHAKE128
6///   XOF to produce a uniformly random polynomial in NTT domain. Operates on
7///   **public** data (the seed rho is public), so rejection branching does not
8///   leak secret information.
9///
10/// - [`sample_poly_cbd`] -- Algorithm 8 (SamplePolyCBD): centered binomial
11///   distribution sampling from PRF output. Used for secret and error
12///   polynomials. Fully constant-time (no secret-dependent branches).
13use super::params::{N, Q};
14use super::sha3;
15
16/// Sample a uniformly random NTT-domain polynomial (Algorithm 7: SampleNTT).
17///
18/// Uses SHAKE128 as an XOF (extendable output function) seeded with the
19/// 34-byte input `seed = rho || j || i`. Pairs of 12-bit candidates are
20/// extracted from each 3-byte block and accepted if less than q = 3329
21/// (rejection sampling).
22///
23/// Since the seed rho is public data, the variable-time rejection loop
24/// does not leak secret information through timing.
25///
26/// # Arguments
27///
28/// * `seed` - A 34-byte XOF seed: `rho (32 bytes) || column_index (1 byte) || row_index (1 byte)`.
29///
30/// # Returns
31///
32/// A 256-coefficient polynomial in NTT domain with coefficients in `[0, q-1]`.
33pub fn sample_ntt(seed: &[u8; 34]) -> [i16; N] {
34    let mut a_hat = [0i16; N];
35    let mut xof = sha3::Xof::new();
36    xof.absorb(seed);
37
38    let mut j = 0usize;
39    let mut buf = [0u8; 3];
40
41    while j < N {
42        xof.squeeze(&mut buf);
43        let d1 = (buf[0] as u16) | (((buf[1] as u16) & 0x0F) << 8);
44        let d2 = ((buf[1] as u16) >> 4) | ((buf[2] as u16) << 4);
45
46        if d1 < Q {
47            a_hat[j] = d1 as i16;
48            j += 1;
49        }
50        if d2 < Q && j < N {
51            a_hat[j] = d2 as i16;
52            j += 1;
53        }
54    }
55    a_hat
56}
57
58/// Sample a polynomial from the centered binomial distribution CBD_eta (Algorithm 8).
59///
60/// For each of the 256 coefficients, sums `eta` random bits for `x` and
61/// `eta` random bits for `y`, then computes `(x - y) mod q`. The result
62/// lies in `[-eta, eta]` before reduction, corresponding to the centered
63/// binomial distribution.
64///
65/// Fully constant-time: no branches depend on secret bit values. The
66/// branchless modular reduction adds q when the difference is negative
67/// using an arithmetic shift mask.
68///
69/// # Arguments
70///
71/// * `eta` - The CBD parameter (2 or 3 for ML-KEM).
72/// * `bytes` - Exactly `64 * eta` bytes of PRF output.
73///
74/// # Returns
75///
76/// A 256-coefficient polynomial with coefficients in `[0, q-1]`.
77///
78/// # Panics
79///
80/// Debug-asserts that `bytes.len() == 64 * eta`.
81pub fn sample_poly_cbd(eta: usize, bytes: &[u8]) -> [i16; N] {
82    debug_assert_eq!(bytes.len(), 64 * eta);
83    let mut f = [0i16; N];
84
85    for i in 0..N {
86        let mut x = 0i16;
87        let mut y = 0i16;
88        for j in 0..eta {
89            let bit_x = (i * 2 * eta + j) as usize;
90            let bit_y = (i * 2 * eta + eta + j) as usize;
91            x += ((bytes[bit_x >> 3] >> (bit_x & 7)) & 1) as i16;
92            y += ((bytes[bit_y >> 3] >> (bit_y & 7)) & 1) as i16;
93        }
94        // Branchless mod q: diff is in [-η, η], add q if negative
95        let diff = x - y;
96        // Constant-time: q & (diff >> 15) adds q iff diff < 0
97        f[i] = diff.wrapping_add(super::ntt::Q & (diff >> 15));
98    }
99    f
100}