Skip to main content

quantica/ml_kem/
encode.rs

1/// Encoding, decoding, compression, and decompression algorithms.
2///
3/// Implements FIPS 203 Section 4.2.1, Algorithms 3-6:
4///
5/// - [`byte_encode`] / [`byte_decode`] -- bit-pack/unpack 256 integers into/from bytes
6/// - [`compress`] / [`decompress`] -- lossy compression between Z_q and Z_{2^d}
7/// - [`compress_poly`] / [`decompress_poly`] -- vectorized compression over full polynomials
8use super::params::{N, Q};
9
10/// Bit-pack 256 integers into `32*d` bytes (Algorithm 5: ByteEncode_d).
11///
12/// Each of the 256 coefficients in `f` is encoded using `d` bits.
13/// For `d < 12`, coefficients are treated modulo `2^d`.
14/// For `d = 12`, coefficients are treated modulo q = 3329.
15///
16/// # Arguments
17///
18/// * `d` - The bit-width per coefficient (1..=12).
19/// * `f` - Array of 256 unsigned 16-bit coefficients.
20/// * `out` - Output buffer of exactly `32 * d` bytes.
21///
22/// # Panics
23///
24/// Debug-asserts that `out.len() == 32 * d`.
25pub fn byte_encode(d: usize, f: &[u16; N], out: &mut [u8]) {
26    debug_assert_eq!(out.len(), 32 * d);
27
28    if d < 12 {
29        // Bit-packing: each coefficient is d bits
30        let mut bit_idx = 0usize;
31        for i in 0..N {
32            let mut val = f[i] as u32;
33            for _ in 0..d {
34                let byte_pos = bit_idx >> 3;
35                let bit_pos = bit_idx & 7;
36                if bit_pos == 0 && byte_pos < out.len() {
37                    out[byte_pos] = 0;
38                }
39                if byte_pos < out.len() {
40                    out[byte_pos] |= ((val & 1) as u8) << bit_pos;
41                }
42                val >>= 1;
43                bit_idx += 1;
44            }
45        }
46    } else {
47        // d = 12: same logic but coefficients are mod q
48        let mut bit_idx = 0usize;
49        for b in out.iter_mut() {
50            *b = 0;
51        }
52        for i in 0..N {
53            let mut val = f[i] as u32;
54            for _ in 0..12 {
55                let byte_pos = bit_idx >> 3;
56                let bit_pos = bit_idx & 7;
57                if byte_pos < out.len() {
58                    out[byte_pos] |= ((val & 1) as u8) << bit_pos;
59                }
60                val >>= 1;
61                bit_idx += 1;
62            }
63        }
64    }
65}
66
67/// Unpack `32*d` bytes into 256 integers (Algorithm 6: ByteDecode_d).
68///
69/// The inverse of [`byte_encode`]. Each coefficient is extracted from `d`
70/// consecutive bits and reduced modulo `2^d` (for `d < 12`) or modulo
71/// q = 3329 (for `d = 12`).
72///
73/// # Arguments
74///
75/// * `d` - The bit-width per coefficient (1..=12).
76/// * `input` - Input buffer of exactly `32 * d` bytes.
77/// * `f` - Output array of 256 unsigned 16-bit coefficients.
78///
79/// # Panics
80///
81/// Debug-asserts that `input.len() == 32 * d`.
82pub fn byte_decode(d: usize, input: &[u8], f: &mut [u16; N]) {
83    debug_assert_eq!(input.len(), 32 * d);
84
85    let m = if d < 12 { 1u32 << d } else { Q as u32 };
86
87    let mut bit_idx = 0usize;
88    for i in 0..N {
89        let mut val = 0u32;
90        for j in 0..d {
91            let byte_pos = bit_idx >> 3;
92            let bit_pos = bit_idx & 7;
93            val |= (((input[byte_pos] >> bit_pos) & 1) as u32) << j;
94            bit_idx += 1;
95        }
96        f[i] = (val % m) as u16;
97    }
98}
99
100/// Lossy compression from Z_q to Z_{2^d} (FIPS 203 eq. 4.7).
101///
102/// Computes `x -> round(2^d / q * x) mod 2^d` using integer-only
103/// arithmetic: `floor((2^d * x + q/2) / q) mod 2^d`.
104///
105/// # Arguments
106///
107/// * `d` - Target bit-width (1..=12).
108/// * `x` - A coefficient in `[0, q-1]`.
109///
110/// # Returns
111///
112/// The compressed value in `[0, 2^d - 1]`.
113#[inline(always)]
114pub fn compress(d: u32, x: u16) -> u16 {
115    // ⌈(2^d / q) · x⌋ = ⌊(2^d · x + q/2) / q⌋ mod 2^d
116    let shifted = ((x as u64) << d) + (Q as u64 / 2);
117    ((shifted / Q as u64) & ((1u64 << d) - 1)) as u16
118}
119
120/// Decompression from Z_{2^d} back to Z_q (FIPS 203 eq. 4.8).
121///
122/// Computes `y -> round(q / 2^d * y)` using integer-only arithmetic:
123/// `floor((q * y + 2^{d-1}) / 2^d)`.
124///
125/// This is the approximate inverse of [`compress`]. The round-trip
126/// introduces a bounded quantization error.
127///
128/// # Arguments
129///
130/// * `d` - Source bit-width (1..=12).
131/// * `y` - A compressed value in `[0, 2^d - 1]`.
132///
133/// # Returns
134///
135/// The decompressed value in `[0, q-1]`.
136#[inline(always)]
137pub fn decompress(d: u32, y: u16) -> u16 {
138    // ⌈(q / 2^d) · y⌋ = ⌊(q · y + 2^(d-1)) / 2^d⌋
139    let val = (Q as u32 * y as u32 + (1u32 << (d - 1))) >> d;
140    val as u16
141}
142
143/// Compress all 256 coefficients of a polynomial from Z_q to Z_{2^d}.
144///
145/// Applies [`compress`] element-wise.
146///
147/// # Arguments
148///
149/// * `d` - Target bit-width.
150/// * `f` - Input polynomial with coefficients in `[0, q-1]`.
151/// * `out` - Output polynomial with coefficients in `[0, 2^d - 1]`.
152pub fn compress_poly(d: u32, f: &[u16; N], out: &mut [u16; N]) {
153    for i in 0..N {
154        out[i] = compress(d, f[i]);
155    }
156}
157
158/// Decompress all 256 coefficients of a polynomial from Z_{2^d} to Z_q.
159///
160/// Applies [`decompress`] element-wise.
161///
162/// # Arguments
163///
164/// * `d` - Source bit-width.
165/// * `f` - Input polynomial with coefficients in `[0, 2^d - 1]`.
166/// * `out` - Output polynomial with coefficients in `[0, q-1]`.
167pub fn decompress_poly(d: u32, f: &[u16; N], out: &mut [u16; N]) {
168    for i in 0..N {
169        out[i] = decompress(d, f[i]);
170    }
171}