Skip to main content

quantica/ml_kem/
mod.rs

1//! # ML-KEM — Module-Lattice-Based Key-Encapsulation Mechanism
2//!
3//! A pure-Rust implementation of **FIPS 203** (ML-KEM) providing three
4//! parameter sets: [`MlKem512`], [`MlKem768`], and [`MlKem1024`].
5//!
6//! Uses only the Rust standard library. No external dependencies.
7//!
8//! ## Quick start
9//!
10//! ```no_run
11//! use quantica::ml_kem::*;
12//!
13//! let mut rng = OsRng;
14//!
15//! // Key generation
16//! let (ek, dk) = MlKem::<MlKem768>::keygen(&mut rng).unwrap();
17//!
18//! // Encapsulation (sender)
19//! let (shared_secret_s, ciphertext) = MlKem::<MlKem768>::encaps(&ek, &mut rng).unwrap();
20//!
21//! // Decapsulation (receiver)
22//! let shared_secret_r = MlKem::<MlKem768>::decaps(&dk, &ciphertext, &mut rng).unwrap();
23//!
24//! assert_eq!(shared_secret_s, shared_secret_r);
25//! ```
26//!
27//! ## Side-channel countermeasures
28//!
29//! This implementation includes multiple layers of protection against
30//! physical side-channel attacks:
31//!
32//! - **Constant-time**: no secret-dependent branches or memory accesses
33//! - **Zeroization**: all secret intermediates erased via volatile writes
34//! - **First-order masking**: secret polynomials split into additive shares (DPA/template)
35//! - **Double decaps**: fault detection on FO comparison (DFA)
36//! - **dk integrity**: H(ek) verification at decaps time (DFA)
37//! - **NTT shuffling**: randomized butterfly order (SPA)
38//!
39//! ## Module overview
40//!
41//! | Module       | Description |
42//! |--------------|-------------|
43//! | [`params`]   | Parameter sets and the [`Params`] trait |
44//! | [`kem`]      | Top-level ML-KEM algorithms (Algorithms 16-21) |
45//! | [`kpke`]     | K-PKE component scheme (Algorithms 13-15) |
46//! | [`ntt`]      | Number-Theoretic Transform and polynomial arithmetic |
47//! | [`encode`]   | Encoding, decoding, compression, decompression |
48//! | [`sample`]   | Polynomial sampling (NTT domain and CBD) |
49//! | [`sha3`]     | SHA-3 / SHAKE primitives (FIPS 202) |
50//! | [`rng`]      | Cryptographic RNG trait and OS-backed implementation |
51//! | [`masked`]   | First-order arithmetic masking for polynomials |
52//! | [`shuffle`]  | Fisher-Yates shuffle for NTT butterfly randomization |
53
54/// Byte encoding/decoding and compression/decompression (Algorithms 3-6).
55pub mod encode;
56/// ML-KEM key encapsulation: keygen, encaps, decaps (Algorithms 16-21).
57pub mod kem;
58/// K-PKE component scheme: key generation, encryption, decryption (Algorithms 13-15).
59pub mod kpke;
60/// First-order arithmetic masking for DPA/template attack protection.
61pub mod masked;
62/// Number-Theoretic Transform and modular polynomial arithmetic.
63pub mod ntt;
64/// ML-KEM parameter sets and the [`Params`] trait.
65pub mod params;
66/// Cryptographic random number generation trait and OS-backed implementation.
67pub mod rng;
68/// Polynomial sampling algorithms: [`sample::sample_ntt`] and [`sample::sample_poly_cbd`].
69pub mod sample;
70/// SHA-3 and SHAKE hash function primitives (FIPS 202).
71pub mod sha3;
72/// Fisher-Yates shuffle for NTT butterfly index randomization (SPA protection).
73pub mod shuffle;
74
75pub use params::{MlKem512, MlKem768, MlKem1024, Params};
76pub use rng::CryptoRng;
77#[cfg(feature = "std")]
78pub use rng::OsRng;
79
80use crate::secret::{SecretArray, SecretBytes};
81use alloc::vec::Vec;
82use core::marker::PhantomData;
83
84// =====================================================================
85// Typed key / ciphertext / shared-secret wrappers
86// =====================================================================
87
88/// ML-KEM **encapsulation key** (the public half of a key pair).
89///
90/// Type-tagged with the parameter set `P` so the type system can
91/// catch mismatches between security levels at compile time. The
92/// underlying bytes are a plain `Vec<u8>` because the encapsulation
93/// key is public material — no zeroization is performed on drop.
94pub struct EncapsulationKey<P: Params> {
95    bytes: Vec<u8>,
96    _marker: PhantomData<P>,
97}
98
99impl<P: Params> EncapsulationKey<P> {
100    /// Wrap a raw byte vector. Length is validated against
101    /// [`Params::EK_LEN`] for the parameter set `P`.
102    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlKemError> {
103        if bytes.len() != P::EK_LEN {
104            return Err(MlKemError::InvalidEncapsulationKey);
105        }
106        Ok(Self {
107            bytes: bytes.to_vec(),
108            _marker: PhantomData,
109        })
110    }
111
112    /// Borrow the encoded encapsulation key as a byte slice.
113    pub fn as_bytes(&self) -> &[u8] {
114        &self.bytes
115    }
116
117    /// Length in bytes (always [`Params::EK_LEN`]).
118    pub fn len(&self) -> usize {
119        self.bytes.len()
120    }
121}
122
123impl<P: Params> AsRef<[u8]> for EncapsulationKey<P> {
124    fn as_ref(&self) -> &[u8] {
125        &self.bytes
126    }
127}
128
129impl<P: Params> core::ops::Deref for EncapsulationKey<P> {
130    type Target = [u8];
131    fn deref(&self) -> &[u8] {
132        &self.bytes
133    }
134}
135
136impl<P: Params> Clone for EncapsulationKey<P> {
137    fn clone(&self) -> Self {
138        Self {
139            bytes: self.bytes.clone(),
140            _marker: PhantomData,
141        }
142    }
143}
144
145/// ML-KEM **decapsulation key** (the private half of a key pair).
146///
147/// Backed by a [`SecretBytes`] container that wipes its memory on
148/// [`Drop`] using `silentops::ct_zeroize`. Type-tagged with `P` to
149/// prevent accidental cross-parameter-set use.
150pub struct DecapsulationKey<P: Params> {
151    bytes: SecretBytes,
152    _marker: PhantomData<P>,
153}
154
155impl<P: Params> DecapsulationKey<P> {
156    /// Wrap a raw byte slice into a zeroizing decapsulation key.
157    /// Length is validated against [`Params::DK_LEN`].
158    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlKemError> {
159        if bytes.len() != P::DK_LEN {
160            return Err(MlKemError::InvalidDecapsulationKey);
161        }
162        Ok(Self {
163            bytes: SecretBytes::from_slice(bytes),
164            _marker: PhantomData,
165        })
166    }
167
168    /// Borrow the encoded decapsulation key as a byte slice.
169    pub fn as_bytes(&self) -> &[u8] {
170        self.bytes.as_bytes()
171    }
172
173    /// Length in bytes (always [`Params::DK_LEN`]).
174    pub fn len(&self) -> usize {
175        self.bytes.len()
176    }
177}
178
179impl<P: Params> AsRef<[u8]> for DecapsulationKey<P> {
180    fn as_ref(&self) -> &[u8] {
181        self.bytes.as_bytes()
182    }
183}
184
185impl<P: Params> core::ops::Deref for DecapsulationKey<P> {
186    type Target = [u8];
187    fn deref(&self) -> &[u8] {
188        self.bytes.as_bytes()
189    }
190}
191
192/// ML-KEM **ciphertext** wrapping the encapsulated shared secret.
193///
194/// Type-tagged with `P`. Ciphertexts are not secret (they travel on
195/// the wire), so no zeroization is performed.
196pub struct Ciphertext<P: Params> {
197    bytes: Vec<u8>,
198    _marker: PhantomData<P>,
199}
200
201impl<P: Params> Ciphertext<P> {
202    /// Wrap a raw byte slice. Length is validated against
203    /// [`Params::CT_LEN`].
204    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlKemError> {
205        if bytes.len() != P::CT_LEN {
206            return Err(MlKemError::InvalidCiphertext);
207        }
208        Ok(Self {
209            bytes: bytes.to_vec(),
210            _marker: PhantomData,
211        })
212    }
213
214    /// Borrow the encoded ciphertext as a byte slice.
215    pub fn as_bytes(&self) -> &[u8] {
216        &self.bytes
217    }
218
219    /// Length in bytes (always [`Params::CT_LEN`]).
220    pub fn len(&self) -> usize {
221        self.bytes.len()
222    }
223}
224
225impl<P: Params> AsRef<[u8]> for Ciphertext<P> {
226    fn as_ref(&self) -> &[u8] {
227        &self.bytes
228    }
229}
230
231impl<P: Params> core::ops::Deref for Ciphertext<P> {
232    type Target = [u8];
233    fn deref(&self) -> &[u8] {
234        &self.bytes
235    }
236}
237
238impl<P: Params> Clone for Ciphertext<P> {
239    fn clone(&self) -> Self {
240        Self {
241            bytes: self.bytes.clone(),
242            _marker: PhantomData,
243        }
244    }
245}
246
247/// 32-byte ML-KEM shared secret.
248///
249/// Backed by a [`SecretArray<32>`] which wipes itself on [`Drop`].
250/// Equality is constant-time via `silentops::ct_eq`.
251pub type SharedSecret = SecretArray<32>;
252
253/// Errors returned by ML-KEM operations.
254///
255/// All error variants are designed to avoid leaking secret information;
256/// timing is independent of the specific failure path.
257#[derive(Debug, Clone, Copy, PartialEq, Eq)]
258pub enum MlKemError {
259    /// The cryptographic random number generator failed to produce bytes.
260    ///
261    /// This typically indicates a system-level failure (e.g., `/dev/urandom`
262    /// is unavailable).
263    RngFailure,
264    /// The provided encapsulation (public) key has an invalid length or
265    /// fails the modulus check required by FIPS 203 Section 7.2.
266    InvalidEncapsulationKey,
267    /// The provided decapsulation (private) key has an invalid length or
268    /// fails the `H(ek)` integrity check embedded in the key.
269    ///
270    /// This may indicate storage corruption or fault injection.
271    InvalidDecapsulationKey,
272    /// The provided ciphertext has an invalid length for the parameter set.
273    InvalidCiphertext,
274}
275
276impl core::fmt::Display for MlKemError {
277    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
278        match self {
279            Self::RngFailure => write!(f, "Random bit generation failed"),
280            Self::InvalidEncapsulationKey => write!(f, "Invalid encapsulation key"),
281            Self::InvalidDecapsulationKey => write!(f, "Invalid decapsulation key"),
282            Self::InvalidCiphertext => write!(f, "Invalid ciphertext"),
283        }
284    }
285}
286
287#[cfg(feature = "std")]
288impl std::error::Error for MlKemError {}
289
290/// Main ML-KEM interface, generic over a [`Params`] parameter set.
291///
292/// This is the primary API for ML-KEM key encapsulation. It wraps the
293/// lower-level functions in [`kem`] and provides a convenient, type-safe
294/// interface parameterized by security level.
295///
296/// # Type parameters
297///
298/// * `P` - One of [`MlKem512`], [`MlKem768`], or [`MlKem1024`], selecting
299///   the security category (1, 3, or 5 respectively).
300///
301/// # Example
302///
303/// ```no_run
304/// use quantica::ml_kem::*;
305///
306/// let mut rng = OsRng;
307/// let (ek, dk) = MlKem::<MlKem768>::keygen(&mut rng).unwrap();
308/// let (ss, ct) = MlKem::<MlKem768>::encaps(&ek, &mut rng).unwrap();
309/// let ss2 = MlKem::<MlKem768>::decaps(&dk, &ct, &mut rng).unwrap();
310/// assert_eq!(ss, ss2);
311/// ```
312pub struct MlKem<P: Params>(core::marker::PhantomData<P>);
313
314impl<P: Params> MlKem<P> {
315    /// Generate an ML-KEM key pair.
316    ///
317    /// Produces an encapsulation key (public) and a decapsulation key (private)
318    /// using randomness from the provided RNG. Implements Algorithm 19 of FIPS 203.
319    ///
320    /// The decapsulation key includes an integrity hash `H(ek)` that is verified
321    /// during decapsulation to detect storage corruption (DFA protection).
322    ///
323    /// # Arguments
324    ///
325    /// * `rng` - A cryptographic random number generator implementing [`CryptoRng`].
326    ///
327    /// # Returns
328    ///
329    /// A tuple `(encapsulation_key, decapsulation_key)` of typed
330    /// wrappers. The decapsulation key auto-zeroizes on `Drop`.
331    ///
332    /// # Errors
333    ///
334    /// Returns [`MlKemError::RngFailure`] if the RNG fails to produce bytes.
335    ///
336    /// # Example
337    ///
338    /// ```no_run
339    /// use quantica::ml_kem::*;
340    /// let mut rng = OsRng;
341    /// let (ek, dk) = MlKem::<MlKem768>::keygen(&mut rng).unwrap();
342    /// assert_eq!(ek.len(), MlKem768::EK_LEN);
343    /// assert_eq!(dk.len(), MlKem768::DK_LEN);
344    /// ```
345    pub fn keygen(rng: &mut impl CryptoRng) -> Result<(EncapsulationKey<P>, DecapsulationKey<P>), MlKemError> {
346        let (ek_v, dk_v) = kem::keygen::<P>(rng)?;
347        Ok((
348            EncapsulationKey {
349                bytes: ek_v,
350                _marker: PhantomData,
351            },
352            DecapsulationKey {
353                bytes: SecretBytes::from_vec(dk_v),
354                _marker: PhantomData,
355            },
356        ))
357    }
358
359    /// Encapsulate a shared secret against an encapsulation key.
360    ///
361    /// Given a public encapsulation key, generates a fresh 32-byte shared
362    /// secret and the corresponding ciphertext. Implements Algorithm 20 of
363    /// FIPS 203 with input validation (length and modulus checks).
364    ///
365    /// # Arguments
366    ///
367    /// * `ek` - The encapsulation (public) key, exactly [`Params::EK_LEN`] bytes.
368    /// * `rng` - A cryptographic random number generator implementing [`CryptoRng`].
369    ///
370    /// # Returns
371    ///
372    /// A tuple `(shared_secret, ciphertext)` where the shared secret is 32 bytes.
373    ///
374    /// # Errors
375    ///
376    /// * [`MlKemError::InvalidEncapsulationKey`] if `ek` has wrong length or fails
377    ///   the modulus check.
378    /// * [`MlKemError::RngFailure`] if the RNG fails.
379    ///
380    /// # Example
381    ///
382    /// ```no_run
383    /// use quantica::ml_kem::*;
384    /// let mut rng = OsRng;
385    /// let (ek, _dk) = MlKem::<MlKem768>::keygen(&mut rng).unwrap();
386    /// let (shared_secret, ciphertext) = MlKem::<MlKem768>::encaps(&ek, &mut rng).unwrap();
387    /// ```
388    pub fn encaps(
389        ek: &EncapsulationKey<P>,
390        rng: &mut impl CryptoRng,
391    ) -> Result<(SharedSecret, Ciphertext<P>), MlKemError> {
392        let (ss, ct) = kem::encaps::<P>(ek.as_bytes(), rng)?;
393        Ok((
394            SharedSecret::new(ss),
395            Ciphertext {
396                bytes: ct,
397                _marker: PhantomData,
398            },
399        ))
400    }
401
402    /// Decapsulate a ciphertext with full DFA protection.
403    ///
404    /// Recovers the 32-byte shared secret from a ciphertext using the
405    /// decapsulation (private) key. Implements Algorithm 21 of FIPS 203
406    /// with two additional DFA countermeasures:
407    ///
408    /// 1. **dk integrity check** -- verifies `H(ek)` stored in `dk` to detect
409    ///    fault injection on key material in memory.
410    /// 2. **Double computation** -- runs the internal decapsulation twice and
411    ///    compares results. A single-fault attack can only corrupt one execution,
412    ///    so divergent results indicate fault injection.
413    ///
414    /// Recommended for embedded and high-security contexts where physical
415    /// fault attacks are in the threat model.
416    ///
417    /// # Arguments
418    ///
419    /// * `dk` - The decapsulation (private) key, exactly [`Params::DK_LEN`] bytes.
420    /// * `ct` - The ciphertext, exactly [`Params::CT_LEN`] bytes.
421    /// * `rng` - A cryptographic random number generator (reserved for future use).
422    ///
423    /// # Returns
424    ///
425    /// The 32-byte shared secret.
426    ///
427    /// # Errors
428    ///
429    /// * [`MlKemError::InvalidDecapsulationKey`] if `dk` has wrong length or
430    ///   fails the integrity check.
431    /// * [`MlKemError::InvalidCiphertext`] if `ct` has wrong length.
432    ///
433    /// # Example
434    ///
435    /// ```no_run
436    /// use quantica::ml_kem::*;
437    /// let mut rng = OsRng;
438    /// let (ek, dk) = MlKem::<MlKem768>::keygen(&mut rng).unwrap();
439    /// let (ss, ct) = MlKem::<MlKem768>::encaps(&ek, &mut rng).unwrap();
440    /// let ss2 = MlKem::<MlKem768>::decaps(&dk, &ct, &mut rng).unwrap();
441    /// assert_eq!(ss, ss2);
442    /// ```
443    pub fn decaps(
444        dk: &DecapsulationKey<P>,
445        ct: &Ciphertext<P>,
446        rng: &mut impl CryptoRng,
447    ) -> Result<SharedSecret, MlKemError> {
448        kem::decaps::<P>(dk.as_bytes(), ct.as_bytes(), rng).map(SharedSecret::new)
449    }
450
451    /// Decapsulate a ciphertext without double computation (faster variant).
452    ///
453    /// Same as [`MlKem::decaps`] but omits the double-computation DFA
454    /// countermeasure, making it roughly twice as fast. The `H(ek)` integrity
455    /// check on the decapsulation key is still performed.
456    ///
457    /// Use this for software-only contexts where physical fault injection
458    /// is not part of the threat model.
459    ///
460    /// # Arguments
461    ///
462    /// * `dk` - The decapsulation (private) key, exactly [`Params::DK_LEN`] bytes.
463    /// * `ct` - The ciphertext, exactly [`Params::CT_LEN`] bytes.
464    ///
465    /// # Returns
466    ///
467    /// The 32-byte shared secret.
468    ///
469    /// # Errors
470    ///
471    /// * [`MlKemError::InvalidDecapsulationKey`] if `dk` has wrong length or
472    ///   fails the integrity check.
473    /// * [`MlKemError::InvalidCiphertext`] if `ct` has wrong length.
474    pub fn decaps_fast(dk: &DecapsulationKey<P>, ct: &Ciphertext<P>) -> Result<SharedSecret, MlKemError> {
475        kem::decaps_single::<P>(dk.as_bytes(), ct.as_bytes()).map(SharedSecret::new)
476    }
477
478    // --- Deterministic internal functions (for testing / CAVP) ---
479
480    /// Deterministic key generation for testing and CAVP validation.
481    ///
482    /// Implements Algorithm 16 of FIPS 203 directly, using caller-supplied
483    /// seeds `d` and `z` instead of drawing them from an RNG.
484    ///
485    /// # Arguments
486    ///
487    /// * `d` - 32-byte seed for K-PKE key generation.
488    /// * `z` - 32-byte implicit rejection value stored in the decapsulation key.
489    ///
490    /// # Returns
491    ///
492    /// A tuple `(encapsulation_key, decapsulation_key)` as byte vectors.
493    pub fn keygen_internal(d: &[u8; 32], z: &[u8; 32]) -> (Vec<u8>, Vec<u8>) {
494        kem::keygen_internal::<P>(d, z)
495    }
496
497    /// Deterministic encapsulation for testing and CAVP validation.
498    ///
499    /// Implements Algorithm 17 of FIPS 203 directly, using a caller-supplied
500    /// message `m` instead of drawing it from an RNG. No input validation
501    /// is performed on `ek`.
502    ///
503    /// # Arguments
504    ///
505    /// * `ek` - The encapsulation (public) key.
506    /// * `m` - 32-byte random message seed.
507    ///
508    /// # Returns
509    ///
510    /// A tuple `(shared_secret, ciphertext)`.
511    pub fn encaps_internal(ek: &[u8], m: &[u8; 32]) -> ([u8; 32], Vec<u8>) {
512        kem::encaps_internal::<P>(ek, m)
513    }
514
515    /// Deterministic decapsulation for testing and CAVP validation.
516    ///
517    /// Implements Algorithm 18 of FIPS 203 directly, with no input
518    /// validation or DFA countermeasures. All comparisons are still
519    /// constant-time.
520    ///
521    /// # Arguments
522    ///
523    /// * `dk` - The decapsulation (private) key.
524    /// * `ct` - The ciphertext.
525    ///
526    /// # Returns
527    ///
528    /// The 32-byte shared secret.
529    pub fn decaps_internal(dk: &[u8], ct: &[u8]) -> [u8; 32] {
530        kem::decaps_internal::<P>(dk, ct)
531    }
532}