Skip to main content

quantica/ml_dsa/
encode.rs

1//! Encoding and decoding algorithms for ML-DSA (FIPS 204, Algorithms 9-28).
2//!
3//! Provides bit-level packing and unpacking of polynomials and polynomial
4//! vectors into compact byte representations, as well as the Power2Round
5//! decomposition used during key generation.
6//!
7//! All vector functions accept slices (`&[[i32; N]]`) instead of `&Vec<...>`
8//! to avoid requiring heap-allocated containers.
9
10use super::ntt::mod_q;
11use super::params::{D, MAX_K, MAX_L, N, Params, Q};
12use alloc::vec::Vec;
13
14// ============================================================
15// Bit/byte conversion utilities (Algorithms 9-13)
16// ============================================================
17
18/// Pack a polynomial whose coefficients lie in [0, b].
19///
20/// Implements Algorithm 16 of FIPS 204 (SimpleBitPack). Each coefficient
21/// is stored using `bitlen(b)` bits in little-endian bit order.
22///
23/// - `w`: input polynomial with coefficients in [0, b].
24/// - `b`: upper bound on coefficient values.
25/// - `out`: output buffer (must have length >= N * bitlen(b) / 8).
26pub fn simple_bit_pack(w: &[i32; N], b: u32, out: &mut [u8]) {
27    let bits = 32 - b.leading_zeros() as usize; // bitlen(b)
28    let mut bit_pos = 0usize;
29    // Clear output
30    for byte in out.iter_mut() {
31        *byte = 0;
32    }
33    for i in 0..N {
34        let val = w[i] as u32;
35        for bit in 0..bits {
36            if (val >> bit) & 1 == 1 {
37                out[bit_pos / 8] |= 1 << (bit_pos % 8);
38            }
39            bit_pos += 1;
40        }
41    }
42}
43
44/// Unpack a polynomial whose coefficients lie in [0, b].
45///
46/// Implements Algorithm 18 of FIPS 204 (SimpleBitUnpack). Inverse of
47/// [`simple_bit_pack`].
48///
49/// - `data`: packed byte data.
50/// - `b`: upper bound on coefficient values.
51/// - `w`: output polynomial (filled with decoded coefficients).
52pub fn simple_bit_unpack(data: &[u8], b: u32, w: &mut [i32; N]) {
53    let bits = 32 - b.leading_zeros() as usize;
54    let mut bit_pos = 0usize;
55    for i in 0..N {
56        let mut val = 0u32;
57        for bit in 0..bits {
58            if (data[bit_pos / 8] >> (bit_pos % 8)) & 1 == 1 {
59                val |= 1 << bit;
60            }
61            bit_pos += 1;
62        }
63        w[i] = val as i32;
64    }
65}
66
67/// Pack a polynomial whose coefficients lie in [-a, b].
68///
69/// Implements Algorithm 17 of FIPS 204 (BitPack). Stores each coefficient
70/// as `(b - coeff)`, mapping the range [-a, b] to [0, a+b], then packs
71/// using `bitlen(a+b)` bits per coefficient.
72///
73/// - `w`: input polynomial with coefficients in [-a, b].
74/// - `a`: magnitude of the negative bound.
75/// - `b`: positive bound.
76/// - `out`: output buffer (must have length >= N * bitlen(a+b) / 8).
77pub fn bit_pack(w: &[i32; N], a: u32, b: u32, out: &mut [u8]) {
78    let range = a + b;
79    let bits = 32 - range.leading_zeros() as usize;
80    let mut bit_pos = 0usize;
81    for byte in out.iter_mut() {
82        *byte = 0;
83    }
84    for i in 0..N {
85        let val = (b as i32 - w[i]) as u32;
86        for bit in 0..bits {
87            if (val >> bit) & 1 == 1 {
88                out[bit_pos / 8] |= 1 << (bit_pos % 8);
89            }
90            bit_pos += 1;
91        }
92    }
93}
94
95/// Unpack a polynomial whose coefficients lie in [-a, b].
96///
97/// Implements Algorithm 19 of FIPS 204 (BitUnpack). Inverse of [`bit_pack`].
98///
99/// - `data`: packed byte data.
100/// - `a`: magnitude of the negative bound.
101/// - `b`: positive bound.
102/// - `w`: output polynomial (filled with decoded coefficients in [-a, b]).
103pub fn bit_unpack(data: &[u8], a: u32, b: u32, w: &mut [i32; N]) {
104    let range = a + b;
105    let bits = 32 - range.leading_zeros() as usize;
106    let mut bit_pos = 0usize;
107    for i in 0..N {
108        let mut val = 0u32;
109        for bit in 0..bits {
110            if (data[bit_pos / 8] >> (bit_pos % 8)) & 1 == 1 {
111                val |= 1 << bit;
112            }
113            bit_pos += 1;
114        }
115        w[i] = b as i32 - val as i32;
116    }
117}
118
119/// Pack a hint vector into bytes.
120///
121/// Implements Algorithm 20 of FIPS 204 (HintBitPack). The hint vector `h`
122/// consists of k binary polynomials with at most `omega` total non-zero
123/// entries. The output uses `omega + k` bytes: the first `omega` bytes
124/// store the indices of non-zero coefficients, and the last k bytes store
125/// cumulative index counts.
126///
127/// - `h`: hint vector (k polynomials with entries in {0, 1}).
128/// - `out`: output buffer (must have length omega + k).
129pub fn hint_bit_pack<P: Params>(h: &[[i32; N]], out: &mut [u8]) {
130    let omega = P::OMEGA;
131    let k = P::K;
132    // out has length omega + k
133    for byte in out.iter_mut() {
134        *byte = 0;
135    }
136    let mut idx = 0usize;
137    for i in 0..k {
138        for j in 0..N {
139            if h[i][j] != 0 {
140                out[idx] = j as u8;
141                idx += 1;
142            }
143        }
144        out[omega + i] = idx as u8;
145    }
146}
147
148/// Unpack a hint vector from bytes.
149///
150/// Implements Algorithm 21 of FIPS 204 (HintBitUnpack). Inverse of
151/// [`hint_bit_pack`]. Performs validity checks on the encoding: indices
152/// must be strictly increasing within each polynomial, and cumulative
153/// counts must be non-decreasing and within bounds.
154///
155/// Returns `None` if the encoding is malformed.
156pub fn hint_bit_unpack<P: Params>(data: &[u8]) -> Option<[[i32; N]; MAX_K]> {
157    let omega = P::OMEGA;
158    let k = P::K;
159    let mut h = [[0i32; N]; MAX_K];
160    let mut idx = 0usize;
161    for i in 0..k {
162        let upper = data[omega + i] as usize;
163        if upper < idx || upper > omega {
164            return None;
165        }
166        let first = idx;
167        while idx < upper {
168            // Check ordering (indices must be strictly increasing within each polynomial)
169            if idx > first && data[idx] <= data[idx - 1] {
170                return None;
171            }
172            let j = data[idx] as usize;
173            if j >= N {
174                return None;
175            }
176            h[i][j] = 1;
177            idx += 1;
178        }
179    }
180    // Check remaining bytes are zero
181    while idx < omega {
182        if data[idx] != 0 {
183            return None;
184        }
185        idx += 1;
186    }
187    Some(h)
188}
189
190/// Encode a public key as bytes.
191///
192/// Implements Algorithm 22 of FIPS 204 (pkEncode). The public key consists
193/// of the 32-byte seed `rho` followed by the k polynomials of `t1`, each
194/// packed with 10 bits per coefficient (since t1 values lie in [0, 1023]).
195///
196/// - `rho`: 32-byte public seed for matrix A generation.
197/// - `t1`: vector of k polynomials (the high bits of t).
198///
199/// Returns a byte vector of length `P::PK_LEN`.
200pub fn pk_encode<P: Params>(rho: &[u8; 32], t1: &[[i32; N]]) -> Vec<u8> {
201    let k = P::K;
202    // t1 coefficients are in [0, 2^{bitlen(q-1)-d} - 1] = [0, 2^10 - 1] = [0, 1023]
203    let coeff_bits = 10; // bitlen(q-1) - d = 23 - 13
204    let poly_bytes = N * coeff_bits / 8; // 256 * 10 / 8 = 320
205    let mut pk = vec![0u8; P::PK_LEN];
206    pk[..32].copy_from_slice(rho);
207    for i in 0..k {
208        let offset = 32 + i * poly_bytes;
209        simple_bit_pack(&t1[i], 1023, &mut pk[offset..offset + poly_bytes]);
210    }
211    pk
212}
213
214/// Decode a public key from bytes.
215///
216/// Implements Algorithm 23 of FIPS 204 (pkDecode). Inverse of [`pk_encode`].
217///
218/// Returns `(rho, t1)` where `rho` is the 32-byte seed and `t1` is the
219/// fixed array of k polynomials with coefficients in [0, 1023].
220pub fn pk_decode<P: Params>(pk: &[u8]) -> ([u8; 32], [[i32; N]; MAX_K]) {
221    let k = P::K;
222    let poly_bytes = 320; // 256 * 10 / 8
223    let mut rho = [0u8; 32];
224    rho.copy_from_slice(&pk[..32]);
225    let mut t1 = [[0i32; N]; MAX_K];
226    for i in 0..k {
227        let offset = 32 + i * poly_bytes;
228        simple_bit_unpack(&pk[offset..offset + poly_bytes], 1023, &mut t1[i]);
229    }
230    (rho, t1)
231}
232
233/// Encode a secret key as bytes.
234///
235/// Implements Algorithm 24 of FIPS 204 (skEncode). The secret key is the
236/// concatenation of `rho` (32 bytes), `K` (32 bytes), `tr` (64 bytes),
237/// the l polynomials of `s1` and k polynomials of `s2` (each packed with
238/// `bitlen(2*eta)` bits per coefficient), and the k polynomials of `t0`
239/// (each packed with 13 bits per coefficient).
240///
241/// - `rho`: 32-byte public seed.
242/// - `k_seed`: 32-byte secret seed K.
243/// - `tr`: 64-byte hash of the public key.
244/// - `s1`: secret vector of l polynomials with coefficients in [-eta, eta].
245/// - `s2`: secret vector of k polynomials with coefficients in [-eta, eta].
246/// - `t0`: low-order bits vector (k polynomials from Power2Round).
247///
248/// Returns a byte vector of length `P::SK_LEN`.
249pub fn sk_encode<P: Params>(
250    rho: &[u8; 32],
251    k_seed: &[u8; 32],
252    tr: &[u8; 64],
253    s1: &[[i32; N]],
254    s2: &[[i32; N]],
255    t0: &[[i32; N]],
256) -> Vec<u8> {
257    let eta = P::ETA as u32;
258    let l = P::L;
259    let k = P::K;
260    let eta_bits = P::BITLEN_2ETA;
261    let poly_eta_bytes = N * eta_bits / 8;
262    let d = D;
263    let poly_t0_bytes = N * d / 8; // 256 * 13 / 8 = 416
264
265    let mut sk = vec![0u8; P::SK_LEN];
266    let mut offset = 0;
267
268    sk[offset..offset + 32].copy_from_slice(rho);
269    offset += 32;
270    sk[offset..offset + 32].copy_from_slice(k_seed);
271    offset += 32;
272    sk[offset..offset + 64].copy_from_slice(tr);
273    offset += 64;
274
275    for i in 0..l {
276        bit_pack(&s1[i], eta, eta, &mut sk[offset..offset + poly_eta_bytes]);
277        offset += poly_eta_bytes;
278    }
279    for i in 0..k {
280        bit_pack(&s2[i], eta, eta, &mut sk[offset..offset + poly_eta_bytes]);
281        offset += poly_eta_bytes;
282    }
283    for i in 0..k {
284        // t0 coefficients are in [-(2^{d-1}-1), 2^{d-1}] = [-4095, 4096]
285        // We use bit_pack with a=2^{d-1}-1=4095 and b=2^{d-1}=4096
286        bit_pack(&t0[i], 4095, 4096, &mut sk[offset..offset + poly_t0_bytes]);
287        offset += poly_t0_bytes;
288    }
289
290    sk
291}
292
293/// Decode a secret key from bytes.
294///
295/// Implements Algorithm 25 of FIPS 204 (skDecode). Inverse of [`sk_encode`].
296///
297/// Returns `(rho, K, tr, s1, s2, t0)` using fixed arrays.
298pub fn sk_decode<P: Params>(
299    sk: &[u8],
300) -> (
301    [u8; 32],
302    [u8; 32],
303    [u8; 64],
304    [[i32; N]; MAX_L],
305    [[i32; N]; MAX_K],
306    [[i32; N]; MAX_K],
307) {
308    let eta = P::ETA as u32;
309    let l = P::L;
310    let k = P::K;
311    let eta_bits = P::BITLEN_2ETA;
312    let poly_eta_bytes = N * eta_bits / 8;
313    let d = D;
314    let poly_t0_bytes = N * d / 8;
315
316    let mut offset = 0;
317    let mut rho = [0u8; 32];
318    rho.copy_from_slice(&sk[offset..offset + 32]);
319    offset += 32;
320    let mut k_seed = [0u8; 32];
321    k_seed.copy_from_slice(&sk[offset..offset + 32]);
322    offset += 32;
323    let mut tr = [0u8; 64];
324    tr.copy_from_slice(&sk[offset..offset + 64]);
325    offset += 64;
326
327    let mut s1 = [[0i32; N]; MAX_L];
328    for i in 0..l {
329        bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, &mut s1[i]);
330        offset += poly_eta_bytes;
331    }
332    let mut s2 = [[0i32; N]; MAX_K];
333    for i in 0..k {
334        bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, &mut s2[i]);
335        offset += poly_eta_bytes;
336    }
337    let mut t0 = [[0i32; N]; MAX_K];
338    for i in 0..k {
339        bit_unpack(&sk[offset..offset + poly_t0_bytes], 4095, 4096, &mut t0[i]);
340        offset += poly_t0_bytes;
341    }
342
343    (rho, k_seed, tr, s1, s2, t0)
344}
345
346/// Decode only the seeds (rho, K, tr) from a secret key, without
347/// unpacking any polynomial. Returns 128 bytes of stack.
348pub fn sk_decode_seeds<P: Params>(sk: &[u8]) -> ([u8; 32], [u8; 32], [u8; 64]) {
349    let mut rho = [0u8; 32];
350    rho.copy_from_slice(&sk[..32]);
351    let mut k_seed = [0u8; 32];
352    k_seed.copy_from_slice(&sk[32..64]);
353    let mut tr = [0u8; 64];
354    tr.copy_from_slice(&sk[64..128]);
355    (rho, k_seed, tr)
356}
357
358/// Decode a single polynomial of s1 from the packed secret key.
359///
360/// `idx` must be in `0..P::L`.
361pub fn sk_decode_s1<P: Params>(sk: &[u8], idx: usize, out: &mut [i32; N]) {
362    let eta = P::ETA as u32;
363    let eta_bits = P::BITLEN_2ETA;
364    let poly_eta_bytes = N * eta_bits / 8;
365    let base = 128; // rho(32) + K(32) + tr(64)
366    let offset = base + idx * poly_eta_bytes;
367    bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, out);
368}
369
370/// Decode a single polynomial of s2 from the packed secret key.
371///
372/// `idx` must be in `0..P::K`.
373pub fn sk_decode_s2<P: Params>(sk: &[u8], idx: usize, out: &mut [i32; N]) {
374    let eta = P::ETA as u32;
375    let l = P::L;
376    let eta_bits = P::BITLEN_2ETA;
377    let poly_eta_bytes = N * eta_bits / 8;
378    let base = 128 + l * poly_eta_bytes;
379    let offset = base + idx * poly_eta_bytes;
380    bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, out);
381}
382
383/// Decode a single polynomial of t0 from the packed secret key.
384///
385/// `idx` must be in `0..P::K`.
386pub fn sk_decode_t0<P: Params>(sk: &[u8], idx: usize, out: &mut [i32; N]) {
387    let l = P::L;
388    let k = P::K;
389    let eta_bits = P::BITLEN_2ETA;
390    let poly_eta_bytes = N * eta_bits / 8;
391    let d = D;
392    let poly_t0_bytes = N * d / 8;
393    let base = 128 + (l + k) * poly_eta_bytes;
394    let offset = base + idx * poly_t0_bytes;
395    bit_unpack(&sk[offset..offset + poly_t0_bytes], 4095, 4096, out);
396}
397
398/// Encode a signature as bytes.
399///
400/// Implements Algorithm 26 of FIPS 204 (sigEncode). The signature is the
401/// concatenation of `c_tilde` (lambda/4 bytes), the l polynomials of `z`
402/// (each packed with `1 + bitlen(gamma1-1)` bits per coefficient), and the
403/// hint vector `h` (packed via [`hint_bit_pack`]).
404///
405/// - `c_tilde`: commitment hash (lambda/4 bytes).
406/// - `z`: response vector (l polynomials with coefficients in [-(gamma1-1), gamma1]).
407/// - `h`: hint vector (k binary polynomials).
408///
409/// Returns a byte vector of length `P::SIG_LEN`.
410pub fn sig_encode<P: Params>(c_tilde: &[u8], z: &[[i32; N]], h: &[[i32; N]]) -> Vec<u8> {
411    let l = P::L;
412    let gamma1 = P::GAMMA1 as u32;
413    let gamma1_bits = P::BITLEN_GAMMA1_MINUS1 + 1; // bitlen(gamma1-1)+1 = 1+bitlen(gamma1-1)
414    let poly_z_bytes = N * gamma1_bits / 8;
415    let c_tilde_len = P::LAMBDA / 4;
416
417    let mut sig = vec![0u8; P::SIG_LEN];
418    let mut offset = 0;
419
420    sig[offset..offset + c_tilde_len].copy_from_slice(&c_tilde[..c_tilde_len]);
421    offset += c_tilde_len;
422
423    for i in 0..l {
424        // z coefficients are in [-(gamma1-1), gamma1], so a=gamma1-1, b=gamma1
425        bit_pack(&z[i], gamma1 - 1, gamma1, &mut sig[offset..offset + poly_z_bytes]);
426        offset += poly_z_bytes;
427    }
428
429    hint_bit_pack::<P>(h, &mut sig[offset..]);
430
431    sig
432}
433
434/// Decode a signature from bytes.
435///
436/// Implements Algorithm 27 of FIPS 204 (sigDecode). Inverse of [`sig_encode`].
437///
438/// Returns `Some((c_tilde, z, h))` on success, or `None` if the hint
439/// encoding is malformed. Uses fixed arrays instead of Vec.
440pub fn sig_decode<P: Params>(sig: &[u8]) -> Option<(Vec<u8>, [[i32; N]; MAX_L], [[i32; N]; MAX_K])> {
441    let l = P::L;
442    let gamma1 = P::GAMMA1 as u32;
443    let gamma1_bits = P::BITLEN_GAMMA1_MINUS1 + 1;
444    let poly_z_bytes = N * gamma1_bits / 8;
445    let c_tilde_len = P::LAMBDA / 4;
446
447    let mut offset = 0;
448    let c_tilde = sig[offset..offset + c_tilde_len].to_vec();
449    offset += c_tilde_len;
450
451    let mut z = [[0i32; N]; MAX_L];
452    for i in 0..l {
453        bit_unpack(&sig[offset..offset + poly_z_bytes], gamma1 - 1, gamma1, &mut z[i]);
454        offset += poly_z_bytes;
455    }
456
457    let h = hint_bit_unpack::<P>(&sig[offset..])?;
458
459    Some((c_tilde, z, h))
460}
461
462/// Encode the high-order bits vector w1 as bytes.
463///
464/// Implements Algorithm 28 of FIPS 204 (w1Encode). The w1 coefficients lie
465/// in [0, (q-1)/(2*gamma2) - 1] and are packed using [`simple_bit_pack`].
466/// The encoded output is hashed together with the message digest to form
467/// the commitment hash c_tilde during signing and verification.
468///
469/// - `w1`: vector of k polynomials (high bits from Decompose).
470///
471/// Returns the packed byte representation of w1.
472pub fn w1_encode<P: Params>(w1: &[[i32; N]]) -> Vec<u8> {
473    let k = P::K;
474    let gamma2 = P::GAMMA2;
475    // w1 coefficients are in [0, (q-1)/(2*gamma2)]
476    let max_w1 = ((Q - 1) / (2 * gamma2) - 1) as u32;
477    let bits = 32 - max_w1.leading_zeros() as usize;
478    let poly_bytes = N * bits / 8;
479    let mut out = vec![0u8; k * poly_bytes];
480    for i in 0..k {
481        let offset = i * poly_bytes;
482        simple_bit_pack(&w1[i], max_w1, &mut out[offset..offset + poly_bytes]);
483    }
484    out
485}
486
487/// Decompose a coefficient into high and low parts using a power-of-2 divisor.
488///
489/// Implements Algorithm 35 of FIPS 204 (Power2Round). Splits `r` into
490/// `(r1, r0)` such that `r = r1 * 2^d + r0` with `r0` in the centered
491/// range `[-(2^{d-1} - 1), 2^{d-1}]`.
492///
493/// Used during key generation to compress the public vector t into t1
494/// (stored in the public key) and t0 (stored in the secret key).
495pub fn power2round(r: i32) -> (i32, i32) {
496    let rp = mod_q(r);
497    // r0 = r mod+ 2^d (centered modular reduction)
498    let two_d = 1i32 << D;
499    let half = two_d >> 1;
500    let mut r0v = rp % two_d; // [0, 2^d - 1]
501    if r0v > half {
502        r0v -= two_d;
503    }
504    let r1 = (rp - r0v) / two_d;
505    (r1, r0v)
506}
507
508/// Apply [`power2round`] to every coefficient of a polynomial vector.
509///
510/// Returns `(t1, t0)` as fixed arrays where each element satisfies
511/// `t[i][j] = t1[i][j] * 2^d + t0[i][j]`.
512pub fn power2round_vec(t: &[[i32; N]], len: usize) -> ([[i32; N]; MAX_K], [[i32; N]; MAX_K]) {
513    let mut t1 = [[0i32; N]; MAX_K];
514    let mut t0 = [[0i32; N]; MAX_K];
515    for i in 0..len {
516        for j in 0..N {
517            let (r1, r0) = power2round(t[i][j]);
518            t1[i][j] = r1;
519            t0[i][j] = r0;
520        }
521    }
522    (t1, t0)
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn test_simple_bit_pack_unpack() {
531        let mut w = [0i32; N];
532        for i in 0..N {
533            w[i] = (i as i32 * 3) % 1024;
534        }
535        let mut buf = [0u8; 320]; // 256 * 10 / 8
536        simple_bit_pack(&w, 1023, &mut buf);
537        let mut w2 = [0i32; N];
538        simple_bit_unpack(&buf, 1023, &mut w2);
539        assert_eq!(w, w2);
540    }
541
542    #[test]
543    fn test_bit_pack_unpack() {
544        let mut w = [0i32; N];
545        for i in 0..N {
546            w[i] = (i as i32 % 5) - 2; // values in [-2, 2]
547        }
548        // a=2, b=2, a+b=4, bitlen(4)=3, 256*3/8=96
549        let mut buf = [0u8; 96];
550        bit_pack(&w, 2, 2, &mut buf);
551        let mut w2 = [0i32; N];
552        bit_unpack(&buf, 2, 2, &mut w2);
553        assert_eq!(w, w2);
554    }
555
556    #[test]
557    fn test_power2round() {
558        let (r1, r0) = power2round(1234567);
559        assert_eq!(r1 * (1 << D) + r0, 1234567);
560        assert!(r0.abs() <= (1 << (D - 1)));
561    }
562}