Skip to main content

quantica/ml_dsa/
sample.rs

1//! Sampling algorithms for ML-DSA (FIPS 204, Algorithms 29-34).
2//!
3//! Provides functions to sample polynomials and polynomial vectors from
4//! various distributions using SHAKE-based extendable output functions.
5//! All sampling routines use rejection sampling to ensure uniform output.
6//! All returned data lives on the stack (no heap allocations).
7
8use super::params::{MAX_K, MAX_L, N, Params, Q};
9use super::sha3;
10
11/// CoeffFromThreeBytes (Algorithm 14): rejection-sample a coefficient from 3 bytes.
12/// Returns Some(coeff) if valid, None otherwise.
13#[inline]
14fn coeff_from_three_bytes(b0: u8, b1: u8, b2: u8) -> Option<i32> {
15    let b2_masked = b2 & 0x7F; // top bit is discarded
16    let z = (b0 as i32) | ((b1 as i32) << 8) | ((b2_masked as i32) << 16);
17    if z < Q { Some(z) } else { None }
18}
19
20/// CoeffFromHalfByte (Algorithm 15): map a half-byte to a coefficient in [-eta, eta].
21/// Returns Some(coeff) if valid, None otherwise.
22#[inline]
23fn coeff_from_half_byte(b: u8, eta: usize) -> Option<i32> {
24    let b = b as i32;
25    if eta == 2 {
26        if b < 15 { Some(2 - (b % 5)) } else { None }
27    } else {
28        // eta == 4
29        if b < 9 { Some(4 - b) } else { None }
30    }
31}
32
33/// Sample a sparse challenge polynomial with exactly tau non-zero entries.
34///
35/// Implements Algorithm 29 of FIPS 204 (SampleInBall). The output polynomial
36/// `c` has exactly `P::TAU` coefficients equal to +/-1 (the rest are 0).
37/// Signs are determined by squeezing 8 bytes of sign bits from SHAKE256,
38/// and positions are chosen via rejection sampling to ensure uniformity.
39///
40/// - `c_tilde`: commitment hash seed (lambda/4 bytes).
41///
42/// Returns a polynomial with coefficients in {-1, 0, 1}.
43pub fn sample_in_ball<P: Params>(c_tilde: &[u8]) -> [i32; N] {
44    let tau = P::TAU;
45    let mut c = [0i32; N];
46
47    let mut state = sha3::shake256();
48    state.absorb(c_tilde);
49
50    // Get 8 bytes for sign bits
51    let mut sign_bytes = [0u8; 8];
52    state.squeeze(&mut sign_bytes);
53    let signs = u64::from_le_bytes(sign_bytes);
54
55    let mut sign_idx = 0usize;
56    for i in (N - tau)..N {
57        // Sample j uniformly from [0, i]
58        let mut j;
59        loop {
60            let mut buf = [0u8; 1];
61            state.squeeze(&mut buf);
62            j = buf[0] as usize;
63            if j <= i {
64                break;
65            }
66        }
67        c[i] = c[j];
68        let sign_bit = (signs >> sign_idx) & 1;
69        c[j] = if sign_bit == 1 { -1 } else { 1 };
70        sign_idx += 1;
71    }
72
73    c
74}
75
76/// Generate an NTT-domain polynomial via rejection sampling.
77///
78/// Implements Algorithm 30 of FIPS 204 (RejNTTPoly). Samples coefficients
79/// uniformly from [0, q) by reading 3 bytes at a time from a SHAKE128
80/// stream seeded with `rho || j1 || j2`. Candidates >= q are rejected.
81///
82/// - `rho`: 32-byte public seed.
83/// - `j1`: column index of the matrix entry (s index).
84/// - `j2`: row index of the matrix entry (r index).
85///
86/// Returns a polynomial in NTT domain with coefficients in [0, q-1].
87pub fn rej_ntt_poly(rho: &[u8; 32], j1: u8, j2: u8) -> [i32; N] {
88    let mut a = [0i32; N];
89    let mut state = sha3::shake128();
90    state.absorb(rho);
91    state.absorb(&[j1, j2]);
92
93    let mut idx = 0;
94    while idx < N {
95        let mut buf = [0u8; 3];
96        state.squeeze(&mut buf);
97        if let Some(coeff) = coeff_from_three_bytes(buf[0], buf[1], buf[2]) {
98            a[idx] = coeff;
99            idx += 1;
100        }
101    }
102    a
103}
104
105/// Generate a polynomial with small coefficients via rejection sampling.
106///
107/// Implements Algorithm 31 of FIPS 204 (RejBoundedPoly). Samples
108/// coefficients in [-eta, eta] by reading half-bytes from a SHAKE256
109/// stream seeded with `rho_prime || nonce`. Invalid half-byte values are
110/// rejected.
111///
112/// - `rho_prime`: 64-byte secret seed.
113/// - `nonce`: 16-bit counter distinguishing different polynomials.
114/// - `eta`: coefficient bound (2 or 4 depending on parameter set).
115///
116/// Returns a polynomial with coefficients in [-eta, eta].
117pub fn rej_bounded_poly(rho_prime: &[u8; 64], nonce: u16, eta: usize) -> [i32; N] {
118    let mut a = [0i32; N];
119    let mut state = sha3::shake256();
120    state.absorb(rho_prime);
121    state.absorb(&nonce.to_le_bytes());
122
123    let mut idx = 0;
124    while idx < N {
125        let mut buf = [0u8; 1];
126        state.squeeze(&mut buf);
127        let z0 = coeff_from_half_byte(buf[0] & 0x0F, eta);
128        let z1 = coeff_from_half_byte(buf[0] >> 4, eta);
129        if let Some(c) = z0 {
130            if idx < N {
131                a[idx] = c;
132                idx += 1;
133            }
134        }
135        if let Some(c) = z1 {
136            if idx < N {
137                a[idx] = c;
138                idx += 1;
139            }
140        }
141    }
142    a
143}
144
145/// Expand the public matrix A in NTT domain from a seed.
146///
147/// Implements Algorithm 32 of FIPS 204 (ExpandA). Generates a k-by-l matrix
148/// of NTT-domain polynomials by calling [`rej_ntt_poly`] for each entry
149/// with indices `(s, r)`.
150///
151/// - `rho`: 32-byte public seed (part of the public key).
152///
153/// Returns the matrix A-hat as a `[[[i32; N]; MAX_L]; MAX_K]` with valid
154/// entries in rows 0..k and columns 0..l.
155pub fn expand_a<P: Params>(rho: &[u8; 32]) -> [[[i32; N]; MAX_L]; MAX_K] {
156    let k = P::K;
157    let l = P::L;
158    let mut a_hat = [[[0i32; N]; MAX_L]; MAX_K];
159    for r in 0..k {
160        for s in 0..l {
161            a_hat[r][s] = rej_ntt_poly(rho, s as u8, r as u8);
162        }
163    }
164    a_hat
165}
166
167/// Expand the secret vectors s1 and s2 from a seed.
168///
169/// Implements Algorithm 33 of FIPS 204 (ExpandS). Generates the secret
170/// vectors by calling [`rej_bounded_poly`] with incrementing nonces:
171/// nonces 0..l for s1, and l..(l+k) for s2.
172///
173/// - `rho_prime`: 64-byte secret seed derived during key generation.
174///
175/// Returns `(s1, s2)` where s1 has MAX_L polynomials (valid 0..l) and s2
176/// has MAX_K polynomials (valid 0..k), each with coefficients in [-eta, eta].
177pub fn expand_s<P: Params>(rho_prime: &[u8; 64]) -> ([[i32; N]; MAX_L], [[i32; N]; MAX_K]) {
178    let k = P::K;
179    let l = P::L;
180    let eta = P::ETA;
181    let mut s1 = [[0i32; N]; MAX_L];
182    for s in 0..l {
183        s1[s] = rej_bounded_poly(rho_prime, s as u16, eta);
184    }
185    let mut s2 = [[0i32; N]; MAX_K];
186    for s in 0..k {
187        s2[s] = rej_bounded_poly(rho_prime, (l + s) as u16, eta);
188    }
189    (s1, s2)
190}
191
192/// Expand the masking vector y from a seed and counter.
193///
194/// Implements Algorithm 34 of FIPS 204 (ExpandMask). Generates l polynomials
195/// with coefficients in [-(gamma1-1), gamma1] by squeezing SHAKE256 output
196/// and unpacking via `bit_unpack`. Each polynomial uses a distinct nonce
197/// derived from `kappa`.
198///
199/// - `rho_double_prime`: 64-byte seed derived from the secret key and randomness.
200/// - `kappa`: counter incremented by l on each rejection loop iteration.
201///
202/// Returns a fixed array of MAX_L polynomials (valid entries 0..l).
203pub fn expand_mask<P: Params>(rho_double_prime: &[u8; 64], kappa: u16) -> [[i32; N]; MAX_L] {
204    let l = P::L;
205    let gamma1 = P::GAMMA1 as u32;
206    let c = P::BITLEN_GAMMA1_MINUS1 + 1;
207    let poly_bytes = 32 * c;
208    let mut y = [[0i32; N]; MAX_L];
209
210    for r in 0..l {
211        let nonce = kappa + r as u16;
212        let mut state = sha3::shake256();
213        state.absorb(rho_double_prime);
214        state.absorb(&nonce.to_le_bytes());
215        // Max poly_bytes = 32 * 20 = 640
216        let mut buf = [0u8; 640];
217        state.squeeze(&mut buf[..poly_bytes]);
218
219        super::encode::bit_unpack(&buf[..poly_bytes], gamma1 - 1, gamma1, &mut y[r]);
220    }
221    y
222}