quantica/ml_kem/kem.rs
1use super::MlKemError;
2use super::encode;
3use super::kpke;
4use super::ntt;
5use super::params::N;
6use super::params::Params;
7use super::rng::CryptoRng;
8use super::sha3;
9/// ML-KEM Key-Encapsulation Mechanism (FIPS 203 Sections 6-7).
10///
11/// Implements Algorithms 16-21 with the following side-channel countermeasures:
12///
13/// - **Constant-time**: no secret-dependent branches or memory access patterns
14/// - **Zeroization**: all secret intermediates erased via volatile writes
15/// - **Double decaps**: fault detection by running decapsulation twice (DFA protection)
16/// - **dk integrity**: `H(ek)` verification at decapsulation time (DFA protection on key storage)
17///
18/// The public API ([`keygen`], [`encaps`], [`decaps`]) includes input validation
19/// and full countermeasures. The internal variants ([`keygen_internal`],
20/// [`encaps_internal`], [`decaps_internal`]) are deterministic and intended
21/// for CAVP testing.
22use alloc::vec::Vec;
23
24/// Maximum module rank across all ML-KEM parameter sets.
25const MAX_K: usize = 4;
26/// Maximum encapsulation key size: 384*4 + 32 = 1568 bytes.
27const MAX_EK_LEN: usize = 384 * MAX_K + 32;
28/// Maximum ciphertext size: 32*(11*4 + 5) = 1568 bytes.
29const MAX_CT_LEN: usize = 1568;
30/// Maximum j_input size: 32 + max ciphertext = 1600 bytes.
31const MAX_J_INPUT_LEN: usize = 32 + MAX_CT_LEN;
32
33// =====================================================================
34// Internal (deterministic) functions — used for CAVP testing
35// =====================================================================
36
37/// Deterministic ML-KEM key generation (Algorithm 16).
38///
39/// Generates an encapsulation/decapsulation key pair from explicit seeds.
40/// The decapsulation key is structured as `dk_pke || ek || H(ek) || z`.
41///
42/// # Arguments
43///
44/// * `d` - 32-byte seed passed to [`kpke::keygen`] for K-PKE key generation.
45/// * `z` - 32-byte implicit rejection value embedded in the decapsulation key.
46///
47/// # Returns
48///
49/// A tuple `(ek, dk)` of byte vectors.
50pub fn keygen_internal<P: Params>(d: &[u8; 32], z: &[u8; 32]) -> (Vec<u8>, Vec<u8>) {
51 let k = P::K;
52 let ek_len = 384 * k + 32;
53 let dk_pke_len = 384 * k;
54
55 let mut ek_buf = [0u8; MAX_EK_LEN];
56 let mut dk_pke_buf = [0u8; 384 * MAX_K];
57 kpke::keygen::<P>(d, &mut ek_buf, &mut dk_pke_buf);
58
59 let ek = &ek_buf[..ek_len];
60 let h_ek = sha3::h(ek);
61
62 let mut dk = Vec::with_capacity(P::DK_LEN);
63 dk.extend_from_slice(&dk_pke_buf[..dk_pke_len]);
64 dk.extend_from_slice(ek);
65 dk.extend_from_slice(&h_ek);
66 dk.extend_from_slice(z);
67
68 (ek[..ek_len].to_vec(), dk)
69}
70
71/// SCA-protected ML-KEM key generation.
72///
73/// Uses [`kpke::keygen_sca`] which applies shuffled NTT on secret polynomials,
74/// then assembles the full decapsulation key. The deterministic seed derivation
75/// is identical to [`keygen_internal`].
76#[cfg(feature = "sca-protected")]
77fn keygen_internal_sca<P: Params>(
78 d: &[u8; 32],
79 z: &[u8; 32],
80 rng: &mut impl CryptoRng,
81) -> Result<(Vec<u8>, Vec<u8>), MlKemError> {
82 let k = P::K;
83 let ek_len = 384 * k + 32;
84 let dk_pke_len = 384 * k;
85
86 let mut ek_buf = [0u8; MAX_EK_LEN];
87 let mut dk_pke_buf = [0u8; 384 * MAX_K];
88 kpke::keygen_sca::<P>(d, &mut ek_buf, &mut dk_pke_buf, rng)?;
89
90 let ek = &ek_buf[..ek_len];
91 let h_ek = sha3::h(ek);
92
93 let mut dk = Vec::with_capacity(P::DK_LEN);
94 dk.extend_from_slice(&dk_pke_buf[..dk_pke_len]);
95 dk.extend_from_slice(ek);
96 dk.extend_from_slice(&h_ek);
97 dk.extend_from_slice(z);
98
99 Ok((ek[..ek_len].to_vec(), dk))
100}
101
102/// Deterministic ML-KEM encapsulation (Algorithm 17).
103///
104/// Computes the shared secret and ciphertext from an encapsulation key
105/// and an explicit 32-byte message `m`. The shared secret and encryption
106/// randomness are derived as `(K, r) = G(m || H(ek))`.
107///
108/// No input validation is performed on `ek`; callers should use [`encaps`]
109/// for production use.
110///
111/// # Arguments
112///
113/// * `ek` - The encapsulation (public) key.
114/// * `m` - 32-byte random message seed.
115///
116/// # Returns
117///
118/// A tuple `(shared_secret, ciphertext)`.
119pub fn encaps_internal<P: Params>(ek: &[u8], m: &[u8; 32]) -> ([u8; 32], Vec<u8>) {
120 let h_ek = sha3::h(ek);
121 let mut g_input = [0u8; 64];
122 g_input[..32].copy_from_slice(m);
123 g_input[32..64].copy_from_slice(&h_ek);
124 let (shared_key, r) = sha3::g(&g_input);
125 ntt::zeroize_bytes(&mut g_input);
126
127 let mut ct_buf = [0u8; MAX_CT_LEN];
128 let ct_len = kpke::encrypt::<P>(ek, m, &r, &mut ct_buf);
129 (shared_key, ct_buf[..ct_len].to_vec())
130}
131
132/// Deterministic ML-KEM decapsulation (Algorithm 18).
133///
134/// Implements the Fujisaki-Okamoto transform: decrypts the ciphertext to
135/// recover `m'`, re-encrypts to get `c'`, then uses a constant-time
136/// comparison to select either the real shared key `K'` (if `c == c'`)
137/// or an implicit rejection key `J(z || c)` (otherwise).
138///
139/// All operations are constant-time with no secret-dependent branches.
140/// Intermediates (`m'`, `g_input`, `j_input`) are zeroized after use.
141///
142/// # Arguments
143///
144/// * `dk` - The full decapsulation key (layout: `dk_pke || ek || H(ek) || z`).
145/// * `c` - The ciphertext.
146///
147/// # Returns
148///
149/// The 32-byte shared secret.
150pub fn decaps_internal<P: Params>(dk: &[u8], c: &[u8]) -> [u8; 32] {
151 let k = P::K;
152 let dk_pke = &dk[0..384 * k];
153 let ek_pke = &dk[384 * k..768 * k + 32];
154 let h = &dk[768 * k + 32..768 * k + 64];
155 let z = &dk[768 * k + 64..768 * k + 96];
156
157 let mut m_prime = kpke::decrypt::<P>(dk_pke, c);
158
159 let mut g_input = [0u8; 64];
160 g_input[..32].copy_from_slice(&m_prime);
161 g_input[32..64].copy_from_slice(h);
162 let (k_prime, r_prime) = sha3::g(&g_input);
163 ntt::zeroize_bytes(&mut g_input);
164
165 let j_input_len = 32 + c.len();
166 let mut j_input = [0u8; MAX_J_INPUT_LEN];
167 j_input[..32].copy_from_slice(z);
168 j_input[32..j_input_len].copy_from_slice(c);
169 let k_bar = sha3::j(&j_input[..j_input_len]);
170 ntt::zeroize_bytes(&mut j_input[..j_input_len]);
171
172 let mut c_prime = [0u8; MAX_CT_LEN];
173 let ct_len = kpke::encrypt::<P>(ek_pke, &m_prime, &r_prime, &mut c_prime);
174
175 let eq = ct_eq(c, &c_prime[..ct_len]);
176 let mut result = [0u8; 32];
177 ct_select(&mut result, &k_prime, &k_bar, eq);
178
179 ntt::zeroize_bytes(&mut m_prime);
180 result
181}
182
183/// SCA-protected ML-KEM decapsulation (internal).
184///
185/// Identical to [`decaps_internal`] but uses [`kpke::decrypt_sca`] which applies
186/// masked multiplication and shuffled NTT during decryption.
187#[cfg(feature = "sca-protected")]
188fn decaps_internal_sca<P: Params>(dk: &[u8], c: &[u8], rng: &mut impl CryptoRng) -> Result<[u8; 32], MlKemError> {
189 let k = P::K;
190 let dk_pke = &dk[0..384 * k];
191 let ek_pke = &dk[384 * k..768 * k + 32];
192 let h = &dk[768 * k + 32..768 * k + 64];
193 let z = &dk[768 * k + 64..768 * k + 96];
194
195 let mut m_prime = kpke::decrypt_sca::<P>(dk_pke, c, rng)?;
196
197 let mut g_input = [0u8; 64];
198 g_input[..32].copy_from_slice(&m_prime);
199 g_input[32..64].copy_from_slice(h);
200 let (k_prime, r_prime) = sha3::g(&g_input);
201 ntt::zeroize_bytes(&mut g_input);
202
203 let j_input_len = 32 + c.len();
204 let mut j_input = [0u8; MAX_J_INPUT_LEN];
205 j_input[..32].copy_from_slice(z);
206 j_input[32..j_input_len].copy_from_slice(c);
207 let k_bar = sha3::j(&j_input[..j_input_len]);
208 ntt::zeroize_bytes(&mut j_input[..j_input_len]);
209
210 let mut c_prime = [0u8; MAX_CT_LEN];
211 let ct_len = kpke::encrypt::<P>(ek_pke, &m_prime, &r_prime, &mut c_prime);
212
213 let eq = ct_eq(c, &c_prime[..ct_len]);
214 let mut result = [0u8; 32];
215 ct_select(&mut result, &k_prime, &k_bar, eq);
216
217 ntt::zeroize_bytes(&mut m_prime);
218 Ok(result)
219}
220
221// =====================================================================
222// Public API with full side-channel countermeasures
223// =====================================================================
224
225/// Generate an ML-KEM key pair (Algorithm 19).
226///
227/// Draws 32-byte seeds `d` and `z` from the RNG, delegates to
228/// [`keygen_internal`], then zeroizes both seeds. This is the
229/// recommended entry point for key generation.
230///
231/// # Arguments
232///
233/// * `rng` - A cryptographic random number generator implementing [`CryptoRng`].
234///
235/// # Returns
236///
237/// A tuple `(encapsulation_key, decapsulation_key)` as byte vectors.
238///
239/// # Errors
240///
241/// Returns [`MlKemError::RngFailure`] if the RNG fails to produce bytes.
242pub fn keygen<P: Params>(rng: &mut impl CryptoRng) -> Result<(Vec<u8>, Vec<u8>), MlKemError> {
243 let mut d = [0u8; 32];
244 let mut z = [0u8; 32];
245 rng.fill_bytes(&mut d)?;
246 rng.fill_bytes(&mut z)?;
247
248 #[cfg(feature = "sca-protected")]
249 let result = keygen_internal_sca::<P>(&d, &z, rng)?;
250
251 #[cfg(not(feature = "sca-protected"))]
252 let result = keygen_internal::<P>(&d, &z);
253
254 ntt::zeroize_bytes(&mut d);
255 ntt::zeroize_bytes(&mut z);
256 Ok(result)
257}
258
259/// Encapsulate a shared secret (Algorithm 20) with input validation.
260///
261/// Validates the encapsulation key (length check and constant-time modulus
262/// check per FIPS 203 Section 7.2), then draws a random 32-byte message
263/// and delegates to [`encaps_internal`].
264///
265/// The modulus check ensures each 12-bit coefficient in `ek` is less than
266/// q = 3329 by round-tripping through encode/decode.
267///
268/// # Arguments
269///
270/// * `ek` - The encapsulation (public) key, exactly [`Params::EK_LEN`] bytes.
271/// * `rng` - A cryptographic random number generator implementing [`CryptoRng`].
272///
273/// # Returns
274///
275/// A tuple `(shared_secret, ciphertext)`.
276///
277/// # Errors
278///
279/// * [`MlKemError::InvalidEncapsulationKey`] if the key has wrong length or
280/// fails the modulus check.
281/// * [`MlKemError::RngFailure`] if the RNG fails.
282pub fn encaps<P: Params>(ek: &[u8], rng: &mut impl CryptoRng) -> Result<([u8; 32], Vec<u8>), MlKemError> {
283 if ek.len() != P::EK_LEN {
284 return Err(MlKemError::InvalidEncapsulationKey);
285 }
286 // Modulus check (constant-time)
287 let k = P::K;
288 for i in 0..k {
289 let slice = &ek[384 * i..384 * (i + 1)];
290 let mut decoded = [0u16; N];
291 encode::byte_decode(12, slice, &mut decoded);
292 let mut reencoded = [0u8; 384];
293 encode::byte_encode(12, &decoded, &mut reencoded);
294 if !ct_eq(slice, &reencoded) {
295 return Err(MlKemError::InvalidEncapsulationKey);
296 }
297 }
298
299 let mut m = [0u8; 32];
300 rng.fill_bytes(&mut m)?;
301 let (shared_key, ct) = encaps_internal::<P>(ek, &m);
302 ntt::zeroize_bytes(&mut m);
303 Ok((shared_key, ct))
304}
305
306/// Algorithm 21: ML-KEM.Decaps(dk, c) with input validation and DFA protection.
307///
308/// # Side-channel / fault countermeasures
309///
310/// Three independent countermeasures are layered inside this function.
311/// Full threat-model context in `doc/sca/countermeasures/ml_kem.rst`,
312/// sections *DFA — double computation + CT fault fallback* and
313/// *DFA on `dk` — `H(ek)` integrity check*.
314///
315/// ## 1. `dk` integrity check — fault on `dk` in storage
316///
317/// The decapsulation key's FIPS 203 layout is
318/// `dk_pke ‖ ek ‖ H(ek) ‖ z`. A fault that alters `dk` in memory
319/// (for example a hot-carrier-induced bit flip in flash) would
320/// undermine the FO security argument because the attacker could
321/// coerce decapsulation into using a crafted `dk_pke`. We recompute
322/// `H(ek_in_dk)` and compare it with the stored `H(ek)` using
323/// `silentops::ct_eq`; a mismatch aborts with
324/// [`MlKemError::InvalidDecapsulationKey`] before `decaps_internal`
325/// runs.
326///
327/// ## 2. Double computation — fault on FO re-encryption
328///
329/// ML-KEM decapsulation is vulnerable to a classical DFA on the FO
330/// re-encryption step: if an attacker can make the re-encryption
331/// return a value close to the real ciphertext on one specific
332/// input, the implicit-rejection path is bypassed and the KEM acts
333/// as a decryption oracle (Boneh–DeMillo–Lipton, EUROCRYPT 1997).
334///
335/// We run `decaps_internal_sca` twice. A single-fault attacker can
336/// only affect one execution; if the two shared secrets differ,
337/// we conclude a fault happened and switch to the constant-time
338/// fallback described below. In the no-fault path both runs agree
339/// and we return either result.
340///
341/// ## 3. Constant-time fault fallback — leakage on the fault branch
342///
343/// Naive code would write
344/// `if !results_match { return k_fault }`, which introduces a
345/// conditional jump on a secret-derived bit (`results_match` is
346/// derived from `k1`, `k2` which both depend on the sk). An
347/// attacker able to inject a fault AND measure timing learns
348/// "fault was detected" from the branch timing alone.
349///
350/// To close this, we **always** compute `k_fault = SHA3(z ‖ 0xFF)`
351/// and select between `k1` and `k_fault` with
352/// [`silentops::ct_select_u8`] (via the local 32-byte `ct_select`
353/// wrapper). The branch is gone; timing is identical in both
354/// cases. `k_fault` is:
355/// * deterministic for a given `(dk, c)` so a repeated faulted
356/// call returns the same value (prevents oracle-by-repetition);
357/// * distinct from both the legitimate FO output and the implicit-
358/// rejection output (`J(z ‖ c)`), so the attacker cannot
359/// distinguish "fault detected" from either correct branch.
360///
361/// Recommended for embedded and high-security contexts where
362/// physical fault attacks are in the threat model.
363pub fn decaps<P: Params>(dk: &[u8], c: &[u8], _rng: &mut impl CryptoRng) -> Result<[u8; 32], MlKemError> {
364 let k = P::K;
365
366 // Length checks
367 if c.len() != P::CT_LEN {
368 return Err(MlKemError::InvalidCiphertext);
369 }
370 if dk.len() != P::DK_LEN {
371 return Err(MlKemError::InvalidDecapsulationKey);
372 }
373
374 // ---- DFA countermeasure 1: dk integrity check ----
375 // Verify H(ek) stored in dk matches recomputed value.
376 // Detects fault injection on dk in memory.
377 let ek_in_dk = &dk[384 * k..768 * k + 32];
378 let h_stored = &dk[768 * k + 32..768 * k + 64];
379 let h_computed = sha3::h(ek_in_dk);
380 if !ct_eq(&h_computed, h_stored) {
381 return Err(MlKemError::InvalidDecapsulationKey);
382 }
383
384 // ---- DFA countermeasure 2: double computation ----
385 // Run decaps_internal twice. A single-fault attack can only affect
386 // one execution. If results differ, a fault was detected.
387 #[cfg(feature = "sca-protected")]
388 let k1 = decaps_internal_sca::<P>(dk, c, _rng)?;
389 #[cfg(feature = "sca-protected")]
390 let k2 = decaps_internal_sca::<P>(dk, c, _rng)?;
391
392 #[cfg(not(feature = "sca-protected"))]
393 let k1 = decaps_internal::<P>(dk, c);
394 #[cfg(not(feature = "sca-protected"))]
395 let k2 = decaps_internal::<P>(dk, c);
396
397 // Always compute the fault-fallback key, then select between `k1`
398 // and `k_fault` in constant time. A `if !match { return k_fault }`
399 // branch would leak, through timing, whether a fault was injected
400 // — measurable under ctgrind, and exploitable when the attacker
401 // can both trigger faults and observe timing. Constant-time
402 // selection costs one extra SHA3 (negligible compared to the two
403 // full decaps calls already performed).
404 let results_match = ct_eq(&k1, &k2);
405 let z = &dk[768 * k + 64..768 * k + 96];
406 let mut fault_input = [0u8; 33];
407 fault_input[..32].copy_from_slice(z);
408 fault_input[32] = 0xFF; // domain separator: distinct from J(z||c)
409 let k_fault = sha3::h(&fault_input);
410 ntt::zeroize_bytes(&mut fault_input);
411
412 let mut out = [0u8; 32];
413 // out = if results_match { k1 } else { k_fault }
414 ct_select(&mut out, &k1, &k_fault, results_match);
415 Ok(out)
416}
417
418/// Decapsulate without double computation (single-pass variant).
419///
420/// Performs length validation and the `H(ek)` integrity check on the
421/// decapsulation key, then runs [`decaps_internal`] once. This is faster
422/// than [`decaps`] but does not detect single-fault injection attacks.
423///
424/// Suitable for software-only environments where physical fault attacks
425/// are not in the threat model.
426///
427/// # Arguments
428///
429/// * `dk` - The decapsulation (private) key, exactly [`Params::DK_LEN`] bytes.
430/// * `c` - The ciphertext, exactly [`Params::CT_LEN`] bytes.
431///
432/// # Returns
433///
434/// The 32-byte shared secret.
435///
436/// # Errors
437///
438/// * [`MlKemError::InvalidDecapsulationKey`] if `dk` has wrong length or
439/// fails the integrity check.
440/// * [`MlKemError::InvalidCiphertext`] if `c` has wrong length.
441pub fn decaps_single<P: Params>(dk: &[u8], c: &[u8]) -> Result<[u8; 32], MlKemError> {
442 let k = P::K;
443 if c.len() != P::CT_LEN {
444 return Err(MlKemError::InvalidCiphertext);
445 }
446 if dk.len() != P::DK_LEN {
447 return Err(MlKemError::InvalidDecapsulationKey);
448 }
449 let ek_in_dk = &dk[384 * k..768 * k + 32];
450 let h_stored = &dk[768 * k + 32..768 * k + 64];
451 let h_computed = sha3::h(ek_in_dk);
452 if !ct_eq(&h_computed, h_stored) {
453 return Err(MlKemError::InvalidDecapsulationKey);
454 }
455 Ok(decaps_internal::<P>(dk, c))
456}
457
458// =====================================================================
459// Constant-time primitives (delegated to silentops crate)
460// =====================================================================
461
462/// Constant-time equality. No early exit, no secret-dependent branch.
463/// Delegates to [`silentops::ct_eq`] and converts the u8 result to bool.
464#[inline(never)]
465fn ct_eq(a: &[u8], b: &[u8]) -> bool {
466 silentops::ct_eq(a, b) == 1
467}
468
469/// Constant-time select: out = condition ? a : b. Branchless via arithmetic mask.
470/// Delegates to [`silentops::ct_select_u8`] per byte.
471#[inline(never)]
472fn ct_select(out: &mut [u8; 32], a: &[u8; 32], b: &[u8; 32], condition: bool) {
473 let cond = condition as u8;
474 for i in 0..32 {
475 out[i] = silentops::ct_select_u8(a[i], b[i], cond);
476 }
477}