quantica/ml_kem/encode.rs
1/// Encoding, decoding, compression, and decompression algorithms.
2///
3/// Implements FIPS 203 Section 4.2.1, Algorithms 3-6:
4///
5/// - [`byte_encode`] / [`byte_decode`] -- bit-pack/unpack 256 integers into/from bytes
6/// - [`compress`] / [`decompress`] -- lossy compression between Z_q and Z_{2^d}
7/// - [`compress_poly`] / [`decompress_poly`] -- vectorized compression over full polynomials
8use super::params::{N, Q};
9
10/// Bit-pack 256 integers into `32*d` bytes (Algorithm 5: ByteEncode_d).
11///
12/// Each of the 256 coefficients in `f` is encoded using `d` bits.
13/// For `d < 12`, coefficients are treated modulo `2^d`.
14/// For `d = 12`, coefficients are treated modulo q = 3329.
15///
16/// # Arguments
17///
18/// * `d` - The bit-width per coefficient (1..=12).
19/// * `f` - Array of 256 unsigned 16-bit coefficients.
20/// * `out` - Output buffer of exactly `32 * d` bytes.
21///
22/// # Panics
23///
24/// Debug-asserts that `out.len() == 32 * d`.
25pub fn byte_encode(d: usize, f: &[u16; N], out: &mut [u8]) {
26 debug_assert_eq!(out.len(), 32 * d);
27
28 if d < 12 {
29 // Bit-packing: each coefficient is d bits
30 let mut bit_idx = 0usize;
31 for i in 0..N {
32 let mut val = f[i] as u32;
33 for _ in 0..d {
34 let byte_pos = bit_idx >> 3;
35 let bit_pos = bit_idx & 7;
36 if bit_pos == 0 && byte_pos < out.len() {
37 out[byte_pos] = 0;
38 }
39 if byte_pos < out.len() {
40 out[byte_pos] |= ((val & 1) as u8) << bit_pos;
41 }
42 val >>= 1;
43 bit_idx += 1;
44 }
45 }
46 } else {
47 // d = 12: same logic but coefficients are mod q
48 let mut bit_idx = 0usize;
49 for b in out.iter_mut() {
50 *b = 0;
51 }
52 for i in 0..N {
53 let mut val = f[i] as u32;
54 for _ in 0..12 {
55 let byte_pos = bit_idx >> 3;
56 let bit_pos = bit_idx & 7;
57 if byte_pos < out.len() {
58 out[byte_pos] |= ((val & 1) as u8) << bit_pos;
59 }
60 val >>= 1;
61 bit_idx += 1;
62 }
63 }
64 }
65}
66
67/// Unpack `32*d` bytes into 256 integers (Algorithm 6: ByteDecode_d).
68///
69/// The inverse of [`byte_encode`]. Each coefficient is extracted from `d`
70/// consecutive bits and reduced modulo `2^d` (for `d < 12`) or modulo
71/// q = 3329 (for `d = 12`).
72///
73/// # Arguments
74///
75/// * `d` - The bit-width per coefficient (1..=12).
76/// * `input` - Input buffer of exactly `32 * d` bytes.
77/// * `f` - Output array of 256 unsigned 16-bit coefficients.
78///
79/// # Panics
80///
81/// Debug-asserts that `input.len() == 32 * d`.
82pub fn byte_decode(d: usize, input: &[u8], f: &mut [u16; N]) {
83 debug_assert_eq!(input.len(), 32 * d);
84
85 let m = if d < 12 { 1u32 << d } else { Q as u32 };
86
87 let mut bit_idx = 0usize;
88 for i in 0..N {
89 let mut val = 0u32;
90 for j in 0..d {
91 let byte_pos = bit_idx >> 3;
92 let bit_pos = bit_idx & 7;
93 val |= (((input[byte_pos] >> bit_pos) & 1) as u32) << j;
94 bit_idx += 1;
95 }
96 f[i] = (val % m) as u16;
97 }
98}
99
100/// Lossy compression from Z_q to Z_{2^d} (FIPS 203 eq. 4.7).
101///
102/// Computes `x -> round(2^d / q * x) mod 2^d` using integer-only
103/// arithmetic: `floor((2^d * x + q/2) / q) mod 2^d`.
104///
105/// # Arguments
106///
107/// * `d` - Target bit-width (1..=12).
108/// * `x` - A coefficient in `[0, q-1]`.
109///
110/// # Returns
111///
112/// The compressed value in `[0, 2^d - 1]`.
113#[inline(always)]
114pub fn compress(d: u32, x: u16) -> u16 {
115 // ⌈(2^d / q) · x⌋ = ⌊(2^d · x + q/2) / q⌋ mod 2^d
116 let shifted = ((x as u64) << d) + (Q as u64 / 2);
117 ((shifted / Q as u64) & ((1u64 << d) - 1)) as u16
118}
119
120/// Decompression from Z_{2^d} back to Z_q (FIPS 203 eq. 4.8).
121///
122/// Computes `y -> round(q / 2^d * y)` using integer-only arithmetic:
123/// `floor((q * y + 2^{d-1}) / 2^d)`.
124///
125/// This is the approximate inverse of [`compress`]. The round-trip
126/// introduces a bounded quantization error.
127///
128/// # Arguments
129///
130/// * `d` - Source bit-width (1..=12).
131/// * `y` - A compressed value in `[0, 2^d - 1]`.
132///
133/// # Returns
134///
135/// The decompressed value in `[0, q-1]`.
136#[inline(always)]
137pub fn decompress(d: u32, y: u16) -> u16 {
138 // ⌈(q / 2^d) · y⌋ = ⌊(q · y + 2^(d-1)) / 2^d⌋
139 let val = (Q as u32 * y as u32 + (1u32 << (d - 1))) >> d;
140 val as u16
141}
142
143/// Compress all 256 coefficients of a polynomial from Z_q to Z_{2^d}.
144///
145/// Applies [`compress`] element-wise.
146///
147/// # Arguments
148///
149/// * `d` - Target bit-width.
150/// * `f` - Input polynomial with coefficients in `[0, q-1]`.
151/// * `out` - Output polynomial with coefficients in `[0, 2^d - 1]`.
152pub fn compress_poly(d: u32, f: &[u16; N], out: &mut [u16; N]) {
153 for i in 0..N {
154 out[i] = compress(d, f[i]);
155 }
156}
157
158/// Decompress all 256 coefficients of a polynomial from Z_{2^d} to Z_q.
159///
160/// Applies [`decompress`] element-wise.
161///
162/// # Arguments
163///
164/// * `d` - Source bit-width.
165/// * `f` - Input polynomial with coefficients in `[0, 2^d - 1]`.
166/// * `out` - Output polynomial with coefficients in `[0, q-1]`.
167pub fn decompress_poly(d: u32, f: &[u16; N], out: &mut [u16; N]) {
168 for i in 0..N {
169 out[i] = decompress(d, f[i]);
170 }
171}