1#[cfg(feature = "sca-protected")]
2use super::MlKemError;
3use super::encode;
4#[cfg(feature = "sca-protected")]
5use super::masked::{self, MaskedPoly};
6use super::ntt;
7use super::params::{N, Params};
26#[cfg(feature = "sca-protected")]
27use super::rng::CryptoRng;
28use super::sample;
29use super::sha3;
30#[cfg(feature = "sca-protected")]
31use super::shuffle;
32
33const MAX_K: usize = 4;
35const MAX_PRF_LEN: usize = 192;
37
38pub fn keygen<P: Params>(d: &[u8; 32], ek_out: &mut [u8], dk_out: &mut [u8]) -> (usize, usize) {
58 let k = P::K;
59 let ek_len = 384 * k + 32;
60 let dk_len = 384 * k;
61
62 let mut g_input = [0u8; 33];
64 g_input[..32].copy_from_slice(d);
65 g_input[32] = k as u8;
66 let (rho, sigma) = sha3::g(&g_input);
67 ntt::zeroize_bytes(&mut g_input);
68
69 let mut a_hat = [[0i16; N]; MAX_K * MAX_K];
71 for i in 0..k {
72 for j in 0..k {
73 let mut seed = [0u8; 34];
74 seed[..32].copy_from_slice(&rho);
75 seed[32] = j as u8;
76 seed[33] = i as u8;
77 a_hat[i * k + j] = sample::sample_ntt(&seed);
78 }
79 }
80
81 let mut n_counter = 0u8;
83 let mut s_hat = [[0i16; N]; MAX_K];
84 let mut prf_buf = [0u8; MAX_PRF_LEN];
85 for i in 0..k {
86 sha3::prf(P::ETA1, &sigma, n_counter, &mut prf_buf);
87 s_hat[i] = sample::sample_poly_cbd(P::ETA1, &prf_buf[..64 * P::ETA1]);
88 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA1]);
89 ntt::ntt(&mut s_hat[i]);
90 n_counter += 1;
91 }
92
93 let mut e_hat = [[0i16; N]; MAX_K];
94 for i in 0..k {
95 sha3::prf(P::ETA1, &sigma, n_counter, &mut prf_buf);
96 e_hat[i] = sample::sample_poly_cbd(P::ETA1, &prf_buf[..64 * P::ETA1]);
97 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA1]);
98 ntt::ntt(&mut e_hat[i]);
99 n_counter += 1;
100 }
101
102 let mut t_hat = [[0i16; N]; MAX_K];
106 for i in 0..k {
107 for j in 0..k {
108 let mut tmp = [0i16; N];
109 ntt::multiply_ntts(&a_hat[i * k + j], &s_hat[j], &mut tmp);
110 for l in 0..N {
111 t_hat[i][l] = t_hat[i][l] + tmp[l];
112 }
113 }
114 ntt::to_mont_poly(&mut t_hat[i]); for l in 0..N {
116 t_hat[i][l] = t_hat[i][l] + e_hat[i][l];
117 }
118 }
119
120 for i in 0..k {
122 let mut t_u16 = [0u16; N];
123 for l in 0..N {
124 t_u16[l] = ntt::barrett_reduce(t_hat[i][l]) as u16;
125 }
126 encode::byte_encode(12, &t_u16, &mut ek_out[384 * i..384 * (i + 1)]);
127 }
128 ek_out[384 * k..384 * k + 32].copy_from_slice(&rho);
129
130 for i in 0..k {
132 let mut s_u16 = [0u16; N];
133 for l in 0..N {
134 s_u16[l] = ntt::barrett_reduce(s_hat[i][l]) as u16;
135 }
136 encode::byte_encode(12, &s_u16, &mut dk_out[384 * i..384 * (i + 1)]);
137 }
138
139 for poly in s_hat[..k].iter_mut() {
141 ntt::zeroize_poly(poly);
142 }
143 for poly in e_hat[..k].iter_mut() {
144 ntt::zeroize_poly(poly);
145 }
146
147 (ek_len, dk_len)
148}
149
150#[cfg(feature = "sca-protected")]
169pub fn keygen_sca<P: Params>(
170 d: &[u8; 32],
171 ek_out: &mut [u8],
172 dk_out: &mut [u8],
173 rng: &mut impl CryptoRng,
174) -> Result<(usize, usize), MlKemError> {
175 let k = P::K;
176 let ek_len = 384 * k + 32;
177 let dk_len = 384 * k;
178
179 let mut g_input = [0u8; 33];
181 g_input[..32].copy_from_slice(d);
182 g_input[32] = k as u8;
183 let (rho, sigma) = sha3::g(&g_input);
184 ntt::zeroize_bytes(&mut g_input);
185
186 let mut a_hat = [[0i16; N]; MAX_K * MAX_K];
188 for i in 0..k {
189 for j in 0..k {
190 let mut seed = [0u8; 34];
191 seed[..32].copy_from_slice(&rho);
192 seed[32] = j as u8;
193 seed[33] = i as u8;
194 a_hat[i * k + j] = sample::sample_ntt(&seed);
195 }
196 }
197
198 let mut n_counter = 0u8;
200 let mut s_hat = [[0i16; N]; MAX_K];
201 let mut prf_buf = [0u8; MAX_PRF_LEN];
202 for i in 0..k {
203 sha3::prf(P::ETA1, &sigma, n_counter, &mut prf_buf);
204 s_hat[i] = sample::sample_poly_cbd(P::ETA1, &prf_buf[..64 * P::ETA1]);
205 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA1]);
206 shuffle::ntt_shuffled(&mut s_hat[i], rng)?;
207 n_counter += 1;
208 }
209
210 let mut e_hat = [[0i16; N]; MAX_K];
211 for i in 0..k {
212 sha3::prf(P::ETA1, &sigma, n_counter, &mut prf_buf);
213 e_hat[i] = sample::sample_poly_cbd(P::ETA1, &prf_buf[..64 * P::ETA1]);
214 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA1]);
215 shuffle::ntt_shuffled(&mut e_hat[i], rng)?;
216 n_counter += 1;
217 }
218
219 let mut t_hat = [[0i16; N]; MAX_K];
221 for i in 0..k {
222 for j in 0..k {
223 let mut tmp = [0i16; N];
224 ntt::multiply_ntts(&a_hat[i * k + j], &s_hat[j], &mut tmp);
225 for l in 0..N {
226 t_hat[i][l] = t_hat[i][l] + tmp[l];
227 }
228 }
229 ntt::to_mont_poly(&mut t_hat[i]);
230 for l in 0..N {
231 t_hat[i][l] = t_hat[i][l] + e_hat[i][l];
232 }
233 }
234
235 for i in 0..k {
237 let mut t_u16 = [0u16; N];
238 for l in 0..N {
239 t_u16[l] = ntt::barrett_reduce(t_hat[i][l]) as u16;
240 }
241 encode::byte_encode(12, &t_u16, &mut ek_out[384 * i..384 * (i + 1)]);
242 }
243 ek_out[384 * k..384 * k + 32].copy_from_slice(&rho);
244
245 for i in 0..k {
247 let mut s_u16 = [0u16; N];
248 for l in 0..N {
249 s_u16[l] = ntt::barrett_reduce(s_hat[i][l]) as u16;
250 }
251 encode::byte_encode(12, &s_u16, &mut dk_out[384 * i..384 * (i + 1)]);
252 }
253
254 for poly in s_hat[..k].iter_mut() {
256 ntt::zeroize_poly(poly);
257 }
258 for poly in e_hat[..k].iter_mut() {
259 ntt::zeroize_poly(poly);
260 }
261
262 Ok((ek_len, dk_len))
263}
264
265pub fn encrypt<P: Params>(ek_pke: &[u8], m: &[u8; 32], r: &[u8; 32], ct_out: &mut [u8]) -> usize {
286 let k = P::K;
287 let du = P::DU;
288 let dv = P::DV;
289
290 let mut t_hat = [[0i16; N]; MAX_K];
292 for i in 0..k {
293 let mut t_decoded = [0u16; N];
294 encode::byte_decode(12, &ek_pke[384 * i..384 * (i + 1)], &mut t_decoded);
295 for l in 0..N {
296 t_hat[i][l] = t_decoded[l] as i16;
297 }
298 }
299
300 let rho = &ek_pke[384 * k..384 * k + 32];
302 let mut a_hat = [[0i16; N]; MAX_K * MAX_K];
303 for i in 0..k {
304 for j in 0..k {
305 let mut seed = [0u8; 34];
306 seed[..32].copy_from_slice(rho);
307 seed[32] = j as u8;
308 seed[33] = i as u8;
309 a_hat[i * k + j] = sample::sample_ntt(&seed);
310 }
311 }
312
313 let mut n_counter = 0u8;
315 let mut y_hat = [[0i16; N]; MAX_K];
316 let mut prf_buf = [0u8; MAX_PRF_LEN];
317 for i in 0..k {
318 sha3::prf(P::ETA1, r, n_counter, &mut prf_buf);
319 y_hat[i] = sample::sample_poly_cbd(P::ETA1, &prf_buf[..64 * P::ETA1]);
320 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA1]);
321 ntt::ntt(&mut y_hat[i]);
322 n_counter += 1;
323 }
324
325 let mut e1 = [[0i16; N]; MAX_K];
326 for i in 0..k {
327 sha3::prf(P::ETA2, r, n_counter, &mut prf_buf);
328 e1[i] = sample::sample_poly_cbd(P::ETA2, &prf_buf[..64 * P::ETA2]);
329 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA2]);
330 n_counter += 1;
331 }
332
333 sha3::prf(P::ETA2, r, n_counter, &mut prf_buf);
334 let mut e2 = sample::sample_poly_cbd(P::ETA2, &prf_buf[..64 * P::ETA2]);
335 ntt::zeroize_bytes(&mut prf_buf[..64 * P::ETA2]);
336
337 let mut u = [[0i16; N]; MAX_K];
339 for i in 0..k {
340 let mut acc = [0i16; N];
341 for j in 0..k {
342 let mut tmp = [0i16; N];
343 ntt::multiply_ntts(&a_hat[j * k + i], &y_hat[j], &mut tmp);
344 for l in 0..N {
345 acc[l] = acc[l].wrapping_add(tmp[l]);
346 }
347 }
348 ntt::ntt_inv(&mut acc);
349 ntt::poly_add(&acc, &e1[i], &mut u[i]);
350 }
351
352 let mut mu = [0u16; N];
354 encode::byte_decode(1, m, &mut mu);
355
356 let mut v = [0i16; N];
358 {
359 let mut acc = [0i16; N];
360 for j in 0..k {
361 let mut tmp = [0i16; N];
362 ntt::multiply_ntts(&t_hat[j], &y_hat[j], &mut tmp);
363 for l in 0..N {
364 acc[l] = acc[l].wrapping_add(tmp[l]);
365 }
366 }
367 ntt::ntt_inv(&mut acc);
368 for l in 0..N {
369 let mu_dec = encode::decompress(1, mu[l]) as i16;
370 v[l] = acc[l].wrapping_add(e2[l]).wrapping_add(mu_dec);
371 }
372 }
373
374 let ct_len = 32 * (du * k + dv);
376
377 for i in 0..k {
378 let mut u_comp = [0u16; N];
379 for l in 0..N {
380 u_comp[l] = encode::compress(du as u32, ntt::barrett_reduce(u[i][l]) as u16);
381 }
382 encode::byte_encode(du, &u_comp, &mut ct_out[32 * du * i..32 * du * (i + 1)]);
383 }
384
385 let c2_off = 32 * du * k;
386 let mut v_comp = [0u16; N];
387 for l in 0..N {
388 v_comp[l] = encode::compress(dv as u32, ntt::barrett_reduce(v[l]) as u16);
389 }
390 encode::byte_encode(dv, &v_comp, &mut ct_out[c2_off..c2_off + 32 * dv]);
391
392 for poly in y_hat[..k].iter_mut() {
394 ntt::zeroize_poly(poly);
395 }
396 for poly in e1[..k].iter_mut() {
397 ntt::zeroize_poly(poly);
398 }
399 ntt::zeroize_poly(&mut e2);
400
401 ct_len
402}
403
404pub fn decrypt<P: Params>(dk_pke: &[u8], c: &[u8]) -> [u8; 32] {
422 let k = P::K;
423 let du = P::DU;
424 let dv = P::DV;
425
426 let mut u = [[0i16; N]; MAX_K];
428 for i in 0..k {
429 let mut u_comp = [0u16; N];
430 encode::byte_decode(du, &c[32 * du * i..32 * du * (i + 1)], &mut u_comp);
431 for l in 0..N {
432 u[i][l] = encode::decompress(du as u32, u_comp[l]) as i16;
433 }
434 }
435
436 let c2_off = 32 * du * k;
438 let mut v_comp = [0u16; N];
439 encode::byte_decode(dv, &c[c2_off..c2_off + 32 * dv], &mut v_comp);
440 let mut v = [0i16; N];
441 for l in 0..N {
442 v[l] = encode::decompress(dv as u32, v_comp[l]) as i16;
443 }
444
445 let mut s_hat = [[0i16; N]; MAX_K];
447 for i in 0..k {
448 let mut s_dec = [0u16; N];
449 encode::byte_decode(12, &dk_pke[384 * i..384 * (i + 1)], &mut s_dec);
450 for l in 0..N {
451 s_hat[i][l] = s_dec[l] as i16;
452 }
453 }
454
455 for poly in u[..k].iter_mut() {
457 ntt::ntt(poly);
458 }
459
460 let mut acc = [0i16; N];
462 for j in 0..k {
463 let mut tmp = [0i16; N];
464 ntt::multiply_ntts(&s_hat[j], &u[j], &mut tmp);
465 for l in 0..N {
466 acc[l] = acc[l].wrapping_add(tmp[l]);
467 }
468 }
469 ntt::reduce(&mut acc);
470 ntt::ntt_inv(&mut acc);
471
472 let mut w = [0i16; N];
474 ntt::poly_sub(&v, &acc, &mut w);
475
476 let mut w_comp = [0u16; N];
478 for l in 0..N {
479 w_comp[l] = encode::compress(1, ntt::barrett_reduce(w[l]) as u16);
480 }
481 let mut m = [0u8; 32];
482 encode::byte_encode(1, &w_comp, &mut m);
483
484 for poly in s_hat[..k].iter_mut() {
486 ntt::zeroize_poly(poly);
487 }
488 ntt::zeroize_poly(&mut acc);
489
490 m
491}
492
493#[cfg(feature = "sca-protected")]
517pub fn decrypt_sca<P: Params>(dk_pke: &[u8], c: &[u8], rng: &mut impl CryptoRng) -> Result<[u8; 32], MlKemError> {
518 let k = P::K;
519 let du = P::DU;
520 let dv = P::DV;
521
522 let mut u = [[0i16; N]; MAX_K];
524 for i in 0..k {
525 let mut u_comp = [0u16; N];
526 encode::byte_decode(du, &c[32 * du * i..32 * du * (i + 1)], &mut u_comp);
527 for l in 0..N {
528 u[i][l] = encode::decompress(du as u32, u_comp[l]) as i16;
529 }
530 }
531
532 let c2_off = 32 * du * k;
534 let mut v_comp = [0u16; N];
535 encode::byte_decode(dv, &c[c2_off..c2_off + 32 * dv], &mut v_comp);
536 let mut v = [0i16; N];
537 for l in 0..N {
538 v[l] = encode::decompress(dv as u32, v_comp[l]) as i16;
539 }
540
541 let mut s_hat = [[0i16; N]; MAX_K];
543 for i in 0..k {
544 let mut s_dec = [0u16; N];
545 encode::byte_decode(12, &dk_pke[384 * i..384 * (i + 1)], &mut s_dec);
546 for l in 0..N {
547 s_hat[i][l] = s_dec[l] as i16;
548 }
549 }
550
551 for poly in u[..k].iter_mut() {
555 shuffle::ntt_shuffled(poly, rng)?;
556 }
557
558 let mut acc_masked = MaskedPoly {
563 share0: [0i16; N],
564 share1: [0i16; N],
565 };
566 for j in 0..k {
567 let s_masked = MaskedPoly::mask(&s_hat[j], rng)?;
568 masked::masked_multiply_accumulate(&mut acc_masked, &s_masked, &u[j]);
569 }
570
571 let mut acc = acc_masked.unmask();
573 acc_masked.zeroize();
574 ntt::reduce(&mut acc);
575 ntt::ntt_inv(&mut acc);
576
577 let mut w = [0i16; N];
579 ntt::poly_sub(&v, &acc, &mut w);
580
581 let mut w_comp = [0u16; N];
583 for l in 0..N {
584 w_comp[l] = encode::compress(1, ntt::barrett_reduce(w[l]) as u16);
585 }
586 let mut m = [0u8; 32];
587 encode::byte_encode(1, &w_comp, &mut m);
588
589 for poly in s_hat[..k].iter_mut() {
591 ntt::zeroize_poly(poly);
592 }
593 ntt::zeroize_poly(&mut acc);
594
595 Ok(m)
596}