1use super::ntt::mod_q;
11use super::params::{D, MAX_K, MAX_L, N, Params, Q};
12use alloc::vec::Vec;
13
14pub fn simple_bit_pack(w: &[i32; N], b: u32, out: &mut [u8]) {
27 let bits = 32 - b.leading_zeros() as usize; let mut bit_pos = 0usize;
29 for byte in out.iter_mut() {
31 *byte = 0;
32 }
33 for i in 0..N {
34 let val = w[i] as u32;
35 for bit in 0..bits {
36 if (val >> bit) & 1 == 1 {
37 out[bit_pos / 8] |= 1 << (bit_pos % 8);
38 }
39 bit_pos += 1;
40 }
41 }
42}
43
44pub fn simple_bit_unpack(data: &[u8], b: u32, w: &mut [i32; N]) {
53 let bits = 32 - b.leading_zeros() as usize;
54 let mut bit_pos = 0usize;
55 for i in 0..N {
56 let mut val = 0u32;
57 for bit in 0..bits {
58 if (data[bit_pos / 8] >> (bit_pos % 8)) & 1 == 1 {
59 val |= 1 << bit;
60 }
61 bit_pos += 1;
62 }
63 w[i] = val as i32;
64 }
65}
66
67pub fn bit_pack(w: &[i32; N], a: u32, b: u32, out: &mut [u8]) {
78 let range = a + b;
79 let bits = 32 - range.leading_zeros() as usize;
80 let mut bit_pos = 0usize;
81 for byte in out.iter_mut() {
82 *byte = 0;
83 }
84 for i in 0..N {
85 let val = (b as i32 - w[i]) as u32;
86 for bit in 0..bits {
87 if (val >> bit) & 1 == 1 {
88 out[bit_pos / 8] |= 1 << (bit_pos % 8);
89 }
90 bit_pos += 1;
91 }
92 }
93}
94
95pub fn bit_unpack(data: &[u8], a: u32, b: u32, w: &mut [i32; N]) {
104 let range = a + b;
105 let bits = 32 - range.leading_zeros() as usize;
106 let mut bit_pos = 0usize;
107 for i in 0..N {
108 let mut val = 0u32;
109 for bit in 0..bits {
110 if (data[bit_pos / 8] >> (bit_pos % 8)) & 1 == 1 {
111 val |= 1 << bit;
112 }
113 bit_pos += 1;
114 }
115 w[i] = b as i32 - val as i32;
116 }
117}
118
119pub fn hint_bit_pack<P: Params>(h: &[[i32; N]], out: &mut [u8]) {
130 let omega = P::OMEGA;
131 let k = P::K;
132 for byte in out.iter_mut() {
134 *byte = 0;
135 }
136 let mut idx = 0usize;
137 for i in 0..k {
138 for j in 0..N {
139 if h[i][j] != 0 {
140 out[idx] = j as u8;
141 idx += 1;
142 }
143 }
144 out[omega + i] = idx as u8;
145 }
146}
147
148pub fn hint_bit_unpack<P: Params>(data: &[u8]) -> Option<[[i32; N]; MAX_K]> {
157 let omega = P::OMEGA;
158 let k = P::K;
159 let mut h = [[0i32; N]; MAX_K];
160 let mut idx = 0usize;
161 for i in 0..k {
162 let upper = data[omega + i] as usize;
163 if upper < idx || upper > omega {
164 return None;
165 }
166 let first = idx;
167 while idx < upper {
168 if idx > first && data[idx] <= data[idx - 1] {
170 return None;
171 }
172 let j = data[idx] as usize;
173 if j >= N {
174 return None;
175 }
176 h[i][j] = 1;
177 idx += 1;
178 }
179 }
180 while idx < omega {
182 if data[idx] != 0 {
183 return None;
184 }
185 idx += 1;
186 }
187 Some(h)
188}
189
190pub fn pk_encode<P: Params>(rho: &[u8; 32], t1: &[[i32; N]]) -> Vec<u8> {
201 let k = P::K;
202 let coeff_bits = 10; let poly_bytes = N * coeff_bits / 8; let mut pk = vec![0u8; P::PK_LEN];
206 pk[..32].copy_from_slice(rho);
207 for i in 0..k {
208 let offset = 32 + i * poly_bytes;
209 simple_bit_pack(&t1[i], 1023, &mut pk[offset..offset + poly_bytes]);
210 }
211 pk
212}
213
214pub fn pk_decode<P: Params>(pk: &[u8]) -> ([u8; 32], [[i32; N]; MAX_K]) {
221 let k = P::K;
222 let poly_bytes = 320; let mut rho = [0u8; 32];
224 rho.copy_from_slice(&pk[..32]);
225 let mut t1 = [[0i32; N]; MAX_K];
226 for i in 0..k {
227 let offset = 32 + i * poly_bytes;
228 simple_bit_unpack(&pk[offset..offset + poly_bytes], 1023, &mut t1[i]);
229 }
230 (rho, t1)
231}
232
233pub fn sk_encode<P: Params>(
250 rho: &[u8; 32],
251 k_seed: &[u8; 32],
252 tr: &[u8; 64],
253 s1: &[[i32; N]],
254 s2: &[[i32; N]],
255 t0: &[[i32; N]],
256) -> Vec<u8> {
257 let eta = P::ETA as u32;
258 let l = P::L;
259 let k = P::K;
260 let eta_bits = P::BITLEN_2ETA;
261 let poly_eta_bytes = N * eta_bits / 8;
262 let d = D;
263 let poly_t0_bytes = N * d / 8; let mut sk = vec![0u8; P::SK_LEN];
266 let mut offset = 0;
267
268 sk[offset..offset + 32].copy_from_slice(rho);
269 offset += 32;
270 sk[offset..offset + 32].copy_from_slice(k_seed);
271 offset += 32;
272 sk[offset..offset + 64].copy_from_slice(tr);
273 offset += 64;
274
275 for i in 0..l {
276 bit_pack(&s1[i], eta, eta, &mut sk[offset..offset + poly_eta_bytes]);
277 offset += poly_eta_bytes;
278 }
279 for i in 0..k {
280 bit_pack(&s2[i], eta, eta, &mut sk[offset..offset + poly_eta_bytes]);
281 offset += poly_eta_bytes;
282 }
283 for i in 0..k {
284 bit_pack(&t0[i], 4095, 4096, &mut sk[offset..offset + poly_t0_bytes]);
287 offset += poly_t0_bytes;
288 }
289
290 sk
291}
292
293pub fn sk_decode<P: Params>(
299 sk: &[u8],
300) -> (
301 [u8; 32],
302 [u8; 32],
303 [u8; 64],
304 [[i32; N]; MAX_L],
305 [[i32; N]; MAX_K],
306 [[i32; N]; MAX_K],
307) {
308 let eta = P::ETA as u32;
309 let l = P::L;
310 let k = P::K;
311 let eta_bits = P::BITLEN_2ETA;
312 let poly_eta_bytes = N * eta_bits / 8;
313 let d = D;
314 let poly_t0_bytes = N * d / 8;
315
316 let mut offset = 0;
317 let mut rho = [0u8; 32];
318 rho.copy_from_slice(&sk[offset..offset + 32]);
319 offset += 32;
320 let mut k_seed = [0u8; 32];
321 k_seed.copy_from_slice(&sk[offset..offset + 32]);
322 offset += 32;
323 let mut tr = [0u8; 64];
324 tr.copy_from_slice(&sk[offset..offset + 64]);
325 offset += 64;
326
327 let mut s1 = [[0i32; N]; MAX_L];
328 for i in 0..l {
329 bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, &mut s1[i]);
330 offset += poly_eta_bytes;
331 }
332 let mut s2 = [[0i32; N]; MAX_K];
333 for i in 0..k {
334 bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, &mut s2[i]);
335 offset += poly_eta_bytes;
336 }
337 let mut t0 = [[0i32; N]; MAX_K];
338 for i in 0..k {
339 bit_unpack(&sk[offset..offset + poly_t0_bytes], 4095, 4096, &mut t0[i]);
340 offset += poly_t0_bytes;
341 }
342
343 (rho, k_seed, tr, s1, s2, t0)
344}
345
346pub fn sk_decode_seeds<P: Params>(sk: &[u8]) -> ([u8; 32], [u8; 32], [u8; 64]) {
349 let mut rho = [0u8; 32];
350 rho.copy_from_slice(&sk[..32]);
351 let mut k_seed = [0u8; 32];
352 k_seed.copy_from_slice(&sk[32..64]);
353 let mut tr = [0u8; 64];
354 tr.copy_from_slice(&sk[64..128]);
355 (rho, k_seed, tr)
356}
357
358pub fn sk_decode_s1<P: Params>(sk: &[u8], idx: usize, out: &mut [i32; N]) {
362 let eta = P::ETA as u32;
363 let eta_bits = P::BITLEN_2ETA;
364 let poly_eta_bytes = N * eta_bits / 8;
365 let base = 128; let offset = base + idx * poly_eta_bytes;
367 bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, out);
368}
369
370pub fn sk_decode_s2<P: Params>(sk: &[u8], idx: usize, out: &mut [i32; N]) {
374 let eta = P::ETA as u32;
375 let l = P::L;
376 let eta_bits = P::BITLEN_2ETA;
377 let poly_eta_bytes = N * eta_bits / 8;
378 let base = 128 + l * poly_eta_bytes;
379 let offset = base + idx * poly_eta_bytes;
380 bit_unpack(&sk[offset..offset + poly_eta_bytes], eta, eta, out);
381}
382
383pub fn sk_decode_t0<P: Params>(sk: &[u8], idx: usize, out: &mut [i32; N]) {
387 let l = P::L;
388 let k = P::K;
389 let eta_bits = P::BITLEN_2ETA;
390 let poly_eta_bytes = N * eta_bits / 8;
391 let d = D;
392 let poly_t0_bytes = N * d / 8;
393 let base = 128 + (l + k) * poly_eta_bytes;
394 let offset = base + idx * poly_t0_bytes;
395 bit_unpack(&sk[offset..offset + poly_t0_bytes], 4095, 4096, out);
396}
397
398pub fn sig_encode<P: Params>(c_tilde: &[u8], z: &[[i32; N]], h: &[[i32; N]]) -> Vec<u8> {
411 let l = P::L;
412 let gamma1 = P::GAMMA1 as u32;
413 let gamma1_bits = P::BITLEN_GAMMA1_MINUS1 + 1; let poly_z_bytes = N * gamma1_bits / 8;
415 let c_tilde_len = P::LAMBDA / 4;
416
417 let mut sig = vec![0u8; P::SIG_LEN];
418 let mut offset = 0;
419
420 sig[offset..offset + c_tilde_len].copy_from_slice(&c_tilde[..c_tilde_len]);
421 offset += c_tilde_len;
422
423 for i in 0..l {
424 bit_pack(&z[i], gamma1 - 1, gamma1, &mut sig[offset..offset + poly_z_bytes]);
426 offset += poly_z_bytes;
427 }
428
429 hint_bit_pack::<P>(h, &mut sig[offset..]);
430
431 sig
432}
433
434pub fn sig_decode<P: Params>(sig: &[u8]) -> Option<(Vec<u8>, [[i32; N]; MAX_L], [[i32; N]; MAX_K])> {
441 let l = P::L;
442 let gamma1 = P::GAMMA1 as u32;
443 let gamma1_bits = P::BITLEN_GAMMA1_MINUS1 + 1;
444 let poly_z_bytes = N * gamma1_bits / 8;
445 let c_tilde_len = P::LAMBDA / 4;
446
447 let mut offset = 0;
448 let c_tilde = sig[offset..offset + c_tilde_len].to_vec();
449 offset += c_tilde_len;
450
451 let mut z = [[0i32; N]; MAX_L];
452 for i in 0..l {
453 bit_unpack(&sig[offset..offset + poly_z_bytes], gamma1 - 1, gamma1, &mut z[i]);
454 offset += poly_z_bytes;
455 }
456
457 let h = hint_bit_unpack::<P>(&sig[offset..])?;
458
459 Some((c_tilde, z, h))
460}
461
462pub fn w1_encode<P: Params>(w1: &[[i32; N]]) -> Vec<u8> {
473 let k = P::K;
474 let gamma2 = P::GAMMA2;
475 let max_w1 = ((Q - 1) / (2 * gamma2) - 1) as u32;
477 let bits = 32 - max_w1.leading_zeros() as usize;
478 let poly_bytes = N * bits / 8;
479 let mut out = vec![0u8; k * poly_bytes];
480 for i in 0..k {
481 let offset = i * poly_bytes;
482 simple_bit_pack(&w1[i], max_w1, &mut out[offset..offset + poly_bytes]);
483 }
484 out
485}
486
487pub fn power2round(r: i32) -> (i32, i32) {
496 let rp = mod_q(r);
497 let two_d = 1i32 << D;
499 let half = two_d >> 1;
500 let mut r0v = rp % two_d; if r0v > half {
502 r0v -= two_d;
503 }
504 let r1 = (rp - r0v) / two_d;
505 (r1, r0v)
506}
507
508pub fn power2round_vec(t: &[[i32; N]], len: usize) -> ([[i32; N]; MAX_K], [[i32; N]; MAX_K]) {
513 let mut t1 = [[0i32; N]; MAX_K];
514 let mut t0 = [[0i32; N]; MAX_K];
515 for i in 0..len {
516 for j in 0..N {
517 let (r1, r0) = power2round(t[i][j]);
518 t1[i][j] = r1;
519 t0[i][j] = r0;
520 }
521 }
522 (t1, t0)
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn test_simple_bit_pack_unpack() {
531 let mut w = [0i32; N];
532 for i in 0..N {
533 w[i] = (i as i32 * 3) % 1024;
534 }
535 let mut buf = [0u8; 320]; simple_bit_pack(&w, 1023, &mut buf);
537 let mut w2 = [0i32; N];
538 simple_bit_unpack(&buf, 1023, &mut w2);
539 assert_eq!(w, w2);
540 }
541
542 #[test]
543 fn test_bit_pack_unpack() {
544 let mut w = [0i32; N];
545 for i in 0..N {
546 w[i] = (i as i32 % 5) - 2; }
548 let mut buf = [0u8; 96];
550 bit_pack(&w, 2, 2, &mut buf);
551 let mut w2 = [0i32; N];
552 bit_unpack(&buf, 2, 2, &mut w2);
553 assert_eq!(w, w2);
554 }
555
556 #[test]
557 fn test_power2round() {
558 let (r1, r0) = power2round(1234567);
559 assert_eq!(r1 * (1 << D) + r0, 1234567);
560 assert!(r0.abs() <= (1 << (D - 1)));
561 }
562}