Skip to main content

quantica/ml_kem/
kem.rs

1use super::MlKemError;
2use super::encode;
3use super::kpke;
4use super::ntt;
5use super::params::N;
6use super::params::Params;
7use super::rng::CryptoRng;
8use super::sha3;
9/// ML-KEM Key-Encapsulation Mechanism (FIPS 203 Sections 6-7).
10///
11/// Implements Algorithms 16-21 with the following side-channel countermeasures:
12///
13/// - **Constant-time**: no secret-dependent branches or memory access patterns
14/// - **Zeroization**: all secret intermediates erased via volatile writes
15/// - **Double decaps**: fault detection by running decapsulation twice (DFA protection)
16/// - **dk integrity**: `H(ek)` verification at decapsulation time (DFA protection on key storage)
17///
18/// The public API ([`keygen`], [`encaps`], [`decaps`]) includes input validation
19/// and full countermeasures. The internal variants ([`keygen_internal`],
20/// [`encaps_internal`], [`decaps_internal`]) are deterministic and intended
21/// for CAVP testing.
22use alloc::vec::Vec;
23
24/// Maximum module rank across all ML-KEM parameter sets.
25const MAX_K: usize = 4;
26/// Maximum encapsulation key size: 384*4 + 32 = 1568 bytes.
27const MAX_EK_LEN: usize = 384 * MAX_K + 32;
28/// Maximum ciphertext size: 32*(11*4 + 5) = 1568 bytes.
29const MAX_CT_LEN: usize = 1568;
30/// Maximum j_input size: 32 + max ciphertext = 1600 bytes.
31const MAX_J_INPUT_LEN: usize = 32 + MAX_CT_LEN;
32
33// =====================================================================
34// Internal (deterministic) functions — used for CAVP testing
35// =====================================================================
36
37/// Deterministic ML-KEM key generation (Algorithm 16).
38///
39/// Generates an encapsulation/decapsulation key pair from explicit seeds.
40/// The decapsulation key is structured as `dk_pke || ek || H(ek) || z`.
41///
42/// # Arguments
43///
44/// * `d` - 32-byte seed passed to [`kpke::keygen`] for K-PKE key generation.
45/// * `z` - 32-byte implicit rejection value embedded in the decapsulation key.
46///
47/// # Returns
48///
49/// A tuple `(ek, dk)` of byte vectors.
50pub fn keygen_internal<P: Params>(d: &[u8; 32], z: &[u8; 32]) -> (Vec<u8>, Vec<u8>) {
51    let k = P::K;
52    let ek_len = 384 * k + 32;
53    let dk_pke_len = 384 * k;
54
55    let mut ek_buf = [0u8; MAX_EK_LEN];
56    let mut dk_pke_buf = [0u8; 384 * MAX_K];
57    kpke::keygen::<P>(d, &mut ek_buf, &mut dk_pke_buf);
58
59    let ek = &ek_buf[..ek_len];
60    let h_ek = sha3::h(ek);
61
62    let mut dk = Vec::with_capacity(P::DK_LEN);
63    dk.extend_from_slice(&dk_pke_buf[..dk_pke_len]);
64    dk.extend_from_slice(ek);
65    dk.extend_from_slice(&h_ek);
66    dk.extend_from_slice(z);
67
68    (ek[..ek_len].to_vec(), dk)
69}
70
71/// SCA-protected ML-KEM key generation.
72///
73/// Uses [`kpke::keygen_sca`] which applies shuffled NTT on secret polynomials,
74/// then assembles the full decapsulation key. The deterministic seed derivation
75/// is identical to [`keygen_internal`].
76#[cfg(feature = "sca-protected")]
77fn keygen_internal_sca<P: Params>(
78    d: &[u8; 32],
79    z: &[u8; 32],
80    rng: &mut impl CryptoRng,
81) -> Result<(Vec<u8>, Vec<u8>), MlKemError> {
82    let k = P::K;
83    let ek_len = 384 * k + 32;
84    let dk_pke_len = 384 * k;
85
86    let mut ek_buf = [0u8; MAX_EK_LEN];
87    let mut dk_pke_buf = [0u8; 384 * MAX_K];
88    kpke::keygen_sca::<P>(d, &mut ek_buf, &mut dk_pke_buf, rng)?;
89
90    let ek = &ek_buf[..ek_len];
91    let h_ek = sha3::h(ek);
92
93    let mut dk = Vec::with_capacity(P::DK_LEN);
94    dk.extend_from_slice(&dk_pke_buf[..dk_pke_len]);
95    dk.extend_from_slice(ek);
96    dk.extend_from_slice(&h_ek);
97    dk.extend_from_slice(z);
98
99    Ok((ek[..ek_len].to_vec(), dk))
100}
101
102/// Deterministic ML-KEM encapsulation (Algorithm 17).
103///
104/// Computes the shared secret and ciphertext from an encapsulation key
105/// and an explicit 32-byte message `m`. The shared secret and encryption
106/// randomness are derived as `(K, r) = G(m || H(ek))`.
107///
108/// No input validation is performed on `ek`; callers should use [`encaps`]
109/// for production use.
110///
111/// # Arguments
112///
113/// * `ek` - The encapsulation (public) key.
114/// * `m` - 32-byte random message seed.
115///
116/// # Returns
117///
118/// A tuple `(shared_secret, ciphertext)`.
119pub fn encaps_internal<P: Params>(ek: &[u8], m: &[u8; 32]) -> ([u8; 32], Vec<u8>) {
120    let h_ek = sha3::h(ek);
121    let mut g_input = [0u8; 64];
122    g_input[..32].copy_from_slice(m);
123    g_input[32..64].copy_from_slice(&h_ek);
124    let (shared_key, r) = sha3::g(&g_input);
125    ntt::zeroize_bytes(&mut g_input);
126
127    let mut ct_buf = [0u8; MAX_CT_LEN];
128    let ct_len = kpke::encrypt::<P>(ek, m, &r, &mut ct_buf);
129    (shared_key, ct_buf[..ct_len].to_vec())
130}
131
132/// Deterministic ML-KEM decapsulation (Algorithm 18).
133///
134/// Implements the Fujisaki-Okamoto transform: decrypts the ciphertext to
135/// recover `m'`, re-encrypts to get `c'`, then uses a constant-time
136/// comparison to select either the real shared key `K'` (if `c == c'`)
137/// or an implicit rejection key `J(z || c)` (otherwise).
138///
139/// All operations are constant-time with no secret-dependent branches.
140/// Intermediates (`m'`, `g_input`, `j_input`) are zeroized after use.
141///
142/// # Arguments
143///
144/// * `dk` - The full decapsulation key (layout: `dk_pke || ek || H(ek) || z`).
145/// * `c`  - The ciphertext.
146///
147/// # Returns
148///
149/// The 32-byte shared secret.
150pub fn decaps_internal<P: Params>(dk: &[u8], c: &[u8]) -> [u8; 32] {
151    let k = P::K;
152    let dk_pke = &dk[0..384 * k];
153    let ek_pke = &dk[384 * k..768 * k + 32];
154    let h = &dk[768 * k + 32..768 * k + 64];
155    let z = &dk[768 * k + 64..768 * k + 96];
156
157    let mut m_prime = kpke::decrypt::<P>(dk_pke, c);
158
159    let mut g_input = [0u8; 64];
160    g_input[..32].copy_from_slice(&m_prime);
161    g_input[32..64].copy_from_slice(h);
162    let (k_prime, r_prime) = sha3::g(&g_input);
163    ntt::zeroize_bytes(&mut g_input);
164
165    let j_input_len = 32 + c.len();
166    let mut j_input = [0u8; MAX_J_INPUT_LEN];
167    j_input[..32].copy_from_slice(z);
168    j_input[32..j_input_len].copy_from_slice(c);
169    let k_bar = sha3::j(&j_input[..j_input_len]);
170    ntt::zeroize_bytes(&mut j_input[..j_input_len]);
171
172    let mut c_prime = [0u8; MAX_CT_LEN];
173    let ct_len = kpke::encrypt::<P>(ek_pke, &m_prime, &r_prime, &mut c_prime);
174
175    let eq = ct_eq(c, &c_prime[..ct_len]);
176    let mut result = [0u8; 32];
177    ct_select(&mut result, &k_prime, &k_bar, eq);
178
179    ntt::zeroize_bytes(&mut m_prime);
180    result
181}
182
183/// SCA-protected ML-KEM decapsulation (internal).
184///
185/// Identical to [`decaps_internal`] but uses [`kpke::decrypt_sca`] which applies
186/// masked multiplication and shuffled NTT during decryption.
187#[cfg(feature = "sca-protected")]
188fn decaps_internal_sca<P: Params>(dk: &[u8], c: &[u8], rng: &mut impl CryptoRng) -> Result<[u8; 32], MlKemError> {
189    let k = P::K;
190    let dk_pke = &dk[0..384 * k];
191    let ek_pke = &dk[384 * k..768 * k + 32];
192    let h = &dk[768 * k + 32..768 * k + 64];
193    let z = &dk[768 * k + 64..768 * k + 96];
194
195    let mut m_prime = kpke::decrypt_sca::<P>(dk_pke, c, rng)?;
196
197    let mut g_input = [0u8; 64];
198    g_input[..32].copy_from_slice(&m_prime);
199    g_input[32..64].copy_from_slice(h);
200    let (k_prime, r_prime) = sha3::g(&g_input);
201    ntt::zeroize_bytes(&mut g_input);
202
203    let j_input_len = 32 + c.len();
204    let mut j_input = [0u8; MAX_J_INPUT_LEN];
205    j_input[..32].copy_from_slice(z);
206    j_input[32..j_input_len].copy_from_slice(c);
207    let k_bar = sha3::j(&j_input[..j_input_len]);
208    ntt::zeroize_bytes(&mut j_input[..j_input_len]);
209
210    let mut c_prime = [0u8; MAX_CT_LEN];
211    let ct_len = kpke::encrypt::<P>(ek_pke, &m_prime, &r_prime, &mut c_prime);
212
213    let eq = ct_eq(c, &c_prime[..ct_len]);
214    let mut result = [0u8; 32];
215    ct_select(&mut result, &k_prime, &k_bar, eq);
216
217    ntt::zeroize_bytes(&mut m_prime);
218    Ok(result)
219}
220
221// =====================================================================
222// Public API with full side-channel countermeasures
223// =====================================================================
224
225/// Generate an ML-KEM key pair (Algorithm 19).
226///
227/// Draws 32-byte seeds `d` and `z` from the RNG, delegates to
228/// [`keygen_internal`], then zeroizes both seeds. This is the
229/// recommended entry point for key generation.
230///
231/// # Arguments
232///
233/// * `rng` - A cryptographic random number generator implementing [`CryptoRng`].
234///
235/// # Returns
236///
237/// A tuple `(encapsulation_key, decapsulation_key)` as byte vectors.
238///
239/// # Errors
240///
241/// Returns [`MlKemError::RngFailure`] if the RNG fails to produce bytes.
242pub fn keygen<P: Params>(rng: &mut impl CryptoRng) -> Result<(Vec<u8>, Vec<u8>), MlKemError> {
243    let mut d = [0u8; 32];
244    let mut z = [0u8; 32];
245    rng.fill_bytes(&mut d)?;
246    rng.fill_bytes(&mut z)?;
247
248    #[cfg(feature = "sca-protected")]
249    let result = keygen_internal_sca::<P>(&d, &z, rng)?;
250
251    #[cfg(not(feature = "sca-protected"))]
252    let result = keygen_internal::<P>(&d, &z);
253
254    ntt::zeroize_bytes(&mut d);
255    ntt::zeroize_bytes(&mut z);
256    Ok(result)
257}
258
259/// Encapsulate a shared secret (Algorithm 20) with input validation.
260///
261/// Validates the encapsulation key (length check and constant-time modulus
262/// check per FIPS 203 Section 7.2), then draws a random 32-byte message
263/// and delegates to [`encaps_internal`].
264///
265/// The modulus check ensures each 12-bit coefficient in `ek` is less than
266/// q = 3329 by round-tripping through encode/decode.
267///
268/// # Arguments
269///
270/// * `ek` - The encapsulation (public) key, exactly [`Params::EK_LEN`] bytes.
271/// * `rng` - A cryptographic random number generator implementing [`CryptoRng`].
272///
273/// # Returns
274///
275/// A tuple `(shared_secret, ciphertext)`.
276///
277/// # Errors
278///
279/// * [`MlKemError::InvalidEncapsulationKey`] if the key has wrong length or
280///   fails the modulus check.
281/// * [`MlKemError::RngFailure`] if the RNG fails.
282pub fn encaps<P: Params>(ek: &[u8], rng: &mut impl CryptoRng) -> Result<([u8; 32], Vec<u8>), MlKemError> {
283    if ek.len() != P::EK_LEN {
284        return Err(MlKemError::InvalidEncapsulationKey);
285    }
286    // Modulus check (constant-time)
287    let k = P::K;
288    for i in 0..k {
289        let slice = &ek[384 * i..384 * (i + 1)];
290        let mut decoded = [0u16; N];
291        encode::byte_decode(12, slice, &mut decoded);
292        let mut reencoded = [0u8; 384];
293        encode::byte_encode(12, &decoded, &mut reencoded);
294        if !ct_eq(slice, &reencoded) {
295            return Err(MlKemError::InvalidEncapsulationKey);
296        }
297    }
298
299    let mut m = [0u8; 32];
300    rng.fill_bytes(&mut m)?;
301    let (shared_key, ct) = encaps_internal::<P>(ek, &m);
302    ntt::zeroize_bytes(&mut m);
303    Ok((shared_key, ct))
304}
305
306/// Algorithm 21: ML-KEM.Decaps(dk, c) with input validation and DFA protection.
307///
308/// # Side-channel / fault countermeasures
309///
310/// Three independent countermeasures are layered inside this function.
311/// Full threat-model context in `doc/sca/countermeasures/ml_kem.rst`,
312/// sections *DFA — double computation + CT fault fallback* and
313/// *DFA on `dk` — `H(ek)` integrity check*.
314///
315/// ## 1. `dk` integrity check — fault on `dk` in storage
316///
317/// The decapsulation key's FIPS 203 layout is
318/// `dk_pke ‖ ek ‖ H(ek) ‖ z`. A fault that alters `dk` in memory
319/// (for example a hot-carrier-induced bit flip in flash) would
320/// undermine the FO security argument because the attacker could
321/// coerce decapsulation into using a crafted `dk_pke`. We recompute
322/// `H(ek_in_dk)` and compare it with the stored `H(ek)` using
323/// `silentops::ct_eq`; a mismatch aborts with
324/// [`MlKemError::InvalidDecapsulationKey`] before `decaps_internal`
325/// runs.
326///
327/// ## 2. Double computation — fault on FO re-encryption
328///
329/// ML-KEM decapsulation is vulnerable to a classical DFA on the FO
330/// re-encryption step: if an attacker can make the re-encryption
331/// return a value close to the real ciphertext on one specific
332/// input, the implicit-rejection path is bypassed and the KEM acts
333/// as a decryption oracle (Boneh–DeMillo–Lipton, EUROCRYPT 1997).
334///
335/// We run `decaps_internal_sca` twice. A single-fault attacker can
336/// only affect one execution; if the two shared secrets differ,
337/// we conclude a fault happened and switch to the constant-time
338/// fallback described below. In the no-fault path both runs agree
339/// and we return either result.
340///
341/// ## 3. Constant-time fault fallback — leakage on the fault branch
342///
343/// Naive code would write
344/// `if !results_match { return k_fault }`, which introduces a
345/// conditional jump on a secret-derived bit (`results_match` is
346/// derived from `k1`, `k2` which both depend on the sk). An
347/// attacker able to inject a fault AND measure timing learns
348/// "fault was detected" from the branch timing alone.
349///
350/// To close this, we **always** compute `k_fault = SHA3(z ‖ 0xFF)`
351/// and select between `k1` and `k_fault` with
352/// [`silentops::ct_select_u8`] (via the local 32-byte `ct_select`
353/// wrapper). The branch is gone; timing is identical in both
354/// cases. `k_fault` is:
355/// * deterministic for a given `(dk, c)` so a repeated faulted
356///   call returns the same value (prevents oracle-by-repetition);
357/// * distinct from both the legitimate FO output and the implicit-
358///   rejection output (`J(z ‖ c)`), so the attacker cannot
359///   distinguish "fault detected" from either correct branch.
360///
361/// Recommended for embedded and high-security contexts where
362/// physical fault attacks are in the threat model.
363pub fn decaps<P: Params>(dk: &[u8], c: &[u8], _rng: &mut impl CryptoRng) -> Result<[u8; 32], MlKemError> {
364    let k = P::K;
365
366    // Length checks
367    if c.len() != P::CT_LEN {
368        return Err(MlKemError::InvalidCiphertext);
369    }
370    if dk.len() != P::DK_LEN {
371        return Err(MlKemError::InvalidDecapsulationKey);
372    }
373
374    // ---- DFA countermeasure 1: dk integrity check ----
375    // Verify H(ek) stored in dk matches recomputed value.
376    // Detects fault injection on dk in memory.
377    let ek_in_dk = &dk[384 * k..768 * k + 32];
378    let h_stored = &dk[768 * k + 32..768 * k + 64];
379    let h_computed = sha3::h(ek_in_dk);
380    if !ct_eq(&h_computed, h_stored) {
381        return Err(MlKemError::InvalidDecapsulationKey);
382    }
383
384    // ---- DFA countermeasure 2: double computation ----
385    // Run decaps_internal twice. A single-fault attack can only affect
386    // one execution. If results differ, a fault was detected.
387    #[cfg(feature = "sca-protected")]
388    let k1 = decaps_internal_sca::<P>(dk, c, _rng)?;
389    #[cfg(feature = "sca-protected")]
390    let k2 = decaps_internal_sca::<P>(dk, c, _rng)?;
391
392    #[cfg(not(feature = "sca-protected"))]
393    let k1 = decaps_internal::<P>(dk, c);
394    #[cfg(not(feature = "sca-protected"))]
395    let k2 = decaps_internal::<P>(dk, c);
396
397    // Always compute the fault-fallback key, then select between `k1`
398    // and `k_fault` in constant time. A `if !match { return k_fault }`
399    // branch would leak, through timing, whether a fault was injected
400    // — measurable under ctgrind, and exploitable when the attacker
401    // can both trigger faults and observe timing. Constant-time
402    // selection costs one extra SHA3 (negligible compared to the two
403    // full decaps calls already performed).
404    let results_match = ct_eq(&k1, &k2);
405    let z = &dk[768 * k + 64..768 * k + 96];
406    let mut fault_input = [0u8; 33];
407    fault_input[..32].copy_from_slice(z);
408    fault_input[32] = 0xFF; // domain separator: distinct from J(z||c)
409    let k_fault = sha3::h(&fault_input);
410    ntt::zeroize_bytes(&mut fault_input);
411
412    let mut out = [0u8; 32];
413    // out = if results_match { k1 } else { k_fault }
414    ct_select(&mut out, &k1, &k_fault, results_match);
415    Ok(out)
416}
417
418/// Decapsulate without double computation (single-pass variant).
419///
420/// Performs length validation and the `H(ek)` integrity check on the
421/// decapsulation key, then runs [`decaps_internal`] once. This is faster
422/// than [`decaps`] but does not detect single-fault injection attacks.
423///
424/// Suitable for software-only environments where physical fault attacks
425/// are not in the threat model.
426///
427/// # Arguments
428///
429/// * `dk` - The decapsulation (private) key, exactly [`Params::DK_LEN`] bytes.
430/// * `c`  - The ciphertext, exactly [`Params::CT_LEN`] bytes.
431///
432/// # Returns
433///
434/// The 32-byte shared secret.
435///
436/// # Errors
437///
438/// * [`MlKemError::InvalidDecapsulationKey`] if `dk` has wrong length or
439///   fails the integrity check.
440/// * [`MlKemError::InvalidCiphertext`] if `c` has wrong length.
441pub fn decaps_single<P: Params>(dk: &[u8], c: &[u8]) -> Result<[u8; 32], MlKemError> {
442    let k = P::K;
443    if c.len() != P::CT_LEN {
444        return Err(MlKemError::InvalidCiphertext);
445    }
446    if dk.len() != P::DK_LEN {
447        return Err(MlKemError::InvalidDecapsulationKey);
448    }
449    let ek_in_dk = &dk[384 * k..768 * k + 32];
450    let h_stored = &dk[768 * k + 32..768 * k + 64];
451    let h_computed = sha3::h(ek_in_dk);
452    if !ct_eq(&h_computed, h_stored) {
453        return Err(MlKemError::InvalidDecapsulationKey);
454    }
455    Ok(decaps_internal::<P>(dk, c))
456}
457
458// =====================================================================
459// Constant-time primitives (delegated to silentops crate)
460// =====================================================================
461
462/// Constant-time equality. No early exit, no secret-dependent branch.
463/// Delegates to [`silentops::ct_eq`] and converts the u8 result to bool.
464#[inline(never)]
465fn ct_eq(a: &[u8], b: &[u8]) -> bool {
466    silentops::ct_eq(a, b) == 1
467}
468
469/// Constant-time select: out = condition ? a : b. Branchless via arithmetic mask.
470/// Delegates to [`silentops::ct_select_u8`] per byte.
471#[inline(never)]
472fn ct_select(out: &mut [u8; 32], a: &[u8; 32], b: &[u8; 32], condition: bool) {
473    let cond = condition as u8;
474    for i in 0..32 {
475        out[i] = silentops::ct_select_u8(a[i], b[i], cond);
476    }
477}