1use super::params::N;
10
11pub const Q: i16 = 3329;
14const Q32: i32 = 3329;
15const QINV: i32 = -3327; const R_SQ_MOD_Q: i32 = 1353; const ZETAS: [i16; 128] = [
20 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476,
21 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
22 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
23 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
24 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220,
25 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
26 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
27];
28
29const GAMMAS: [i16; 128] = [
31 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,
32 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540,
33 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
34 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684,
35 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237,
36 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029,
37 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
38];
39
40const fn mont_reduce_const(a: i32) -> i16 {
43 let t = (a as i16).wrapping_mul(QINV as i16);
44 ((a - t as i32 * Q32) >> 16) as i16
45}
46
47const fn to_mont_const(a: i16) -> i16 {
48 mont_reduce_const(a as i32 * R_SQ_MOD_Q)
49}
50
51const fn compute_table_mont(table: &[i16; 128]) -> [i16; 128] {
52 let mut r = [0i16; 128];
53 let mut i = 0;
54 while i < 128 {
55 r[i] = to_mont_const(table[i]);
56 i += 1;
57 }
58 r
59}
60
61const ZETAS_MONT: [i16; 128] = compute_table_mont(&ZETAS);
62const GAMMAS_MONT: [i16; 128] = compute_table_mont(&GAMMAS);
63
64const F_SCALE: i16 = 1441;
67
68#[inline(always)]
72fn montgomery_reduce(a: i32) -> i16 {
73 let t = (a as i16).wrapping_mul(QINV as i16);
74 ((a - t as i32 * Q32) >> 16) as i16
75}
76
77#[inline(always)]
79fn mont_mul(a: i16, b: i16) -> i16 {
80 montgomery_reduce(a as i32 * b as i32)
81}
82
83#[inline(always)]
85pub fn barrett_reduce(a: i16) -> i16 {
86 let t = ((20159i32 * a as i32 + (1 << 25)) >> 26) as i16;
87 let mut r = a - t.wrapping_mul(Q);
88 r += (r >> 15) & Q;
89 r
90}
91
92pub fn ntt(f: &mut [i16; N]) {
99 let mut k = 1usize;
100 let mut len = 128;
101 while len >= 2 {
102 let mut start = 0;
103 while start < N {
104 let zeta = ZETAS_MONT[k];
105 k += 1;
106 for j in start..start + len {
107 let t = mont_mul(zeta, f[j + len]);
108 f[j + len] = f[j] - t;
109 f[j] = f[j] + t;
110 }
111 start += 2 * len;
112 }
113 len >>= 1;
114 }
115 for c in f.iter_mut() {
117 *c = barrett_reduce(*c);
118 }
119}
120
121pub fn ntt_inv(f: &mut [i16; N]) {
130 let mut w = [0i32; N];
134 for i in 0..N {
135 w[i] = f[i] as i32;
136 }
137
138 let mut k = 127usize;
139 let mut len = 2;
140 while len <= 128 {
141 let mut start = 0;
142 while start < N {
143 let neg_zeta = ZETAS_MONT[k].wrapping_neg();
144 k = k.wrapping_sub(1);
145 for j in start..start + len {
146 let t = w[j];
147 let u = w[j + len];
148 w[j] = t + u;
149 w[j + len] = montgomery_reduce(neg_zeta as i32 * ((t - u) as i32)) as i32;
151 }
152 start += 2 * len;
153 }
154 len <<= 1;
155 }
156 for i in 0..N {
157 f[i] = montgomery_reduce(F_SCALE as i32 * w[i]) as i16;
159 }
160}
161
162pub fn multiply_ntts(f: &[i16; N], g: &[i16; N], h: &mut [i16; N]) {
173 for i in 0..128 {
174 let gamma = GAMMAS_MONT[i];
175 let a0 = f[2 * i];
176 let a1 = f[2 * i + 1];
177 let b0 = g[2 * i];
178 let b1 = g[2 * i + 1];
179 let t = mont_mul(mont_mul(a1, b1), gamma);
181 h[2 * i] = mont_mul(a0, b0) + t;
182 h[2 * i + 1] = mont_mul(a0, b1) + mont_mul(a1, b0);
184 }
185}
186
187pub fn to_mont_poly(f: &mut [i16; N]) {
192 for c in f.iter_mut() {
193 *c = montgomery_reduce(*c as i32 * R_SQ_MOD_Q);
194 }
195}
196
197pub fn poly_add(a: &[i16; N], b: &[i16; N], c: &mut [i16; N]) {
200 for i in 0..N {
201 c[i] = a[i] + b[i];
202 }
203}
204
205pub fn poly_sub(a: &[i16; N], b: &[i16; N], c: &mut [i16; N]) {
206 for i in 0..N {
207 c[i] = a[i] - b[i];
208 }
209}
210
211pub fn reduce(f: &mut [i16; N]) {
213 for c in f.iter_mut() {
214 *c = barrett_reduce(*c);
215 }
216}
217
218#[inline(never)]
221pub fn zeroize_poly(f: &mut [i16; N]) {
222 for c in f.iter_mut() {
223 unsafe { core::ptr::write_volatile(c, 0) };
224 }
225 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
226}
227
228#[inline(never)]
229pub fn zeroize_bytes(b: &mut [u8]) {
230 for byte in b.iter_mut() {
231 unsafe { core::ptr::write_volatile(byte, 0) };
232 }
233 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
234}
235
236#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_ntt_basemul_intt_simple() {
244 let mut f = [0i16; N];
246 f[0] = 42;
247 let mut g = [0i16; N];
248 g[0] = 1;
249 g[1] = 1;
250
251 ntt(&mut f);
252 ntt(&mut g);
253
254 let mut h = [0i16; N];
255 multiply_ntts(&f, &g, &mut h);
256 ntt_inv(&mut h);
257 reduce(&mut h);
258
259 assert_eq!(h[0], 42, "h[0]={} expected 42", h[0]);
260 assert_eq!(h[1], 42, "h[1]={} expected 42", h[1]);
261 for i in 2..N {
262 assert_eq!(h[i], 0, "h[{}]={} expected 0", i, h[i]);
263 }
264 }
265
266 #[test]
267 fn test_ntt_basemul_intt_identity() {
268 let mut one = [0i16; N];
270 one[0] = 1;
271 let mut b = [0i16; N];
272 for i in 0..N {
273 b[i] = (i as i16 * 7 + 13) % Q;
274 }
275 let orig = b;
276
277 ntt(&mut one);
278 ntt(&mut b);
279 let mut c = [0i16; N];
280 multiply_ntts(&one, &b, &mut c);
281 ntt_inv(&mut c);
282 reduce(&mut c);
283
284 for i in 0..N {
285 assert_eq!(c[i], orig[i], "mismatch at {}: got {} expected {}", i, c[i], orig[i]);
286 }
287 }
288
289 #[test]
290 fn test_to_mont_poly_keygen_pattern() {
291 let mut a = [0i16; N];
293 a[0] = 100;
294 let mut s = [0i16; N];
295 s[0] = 1;
296 let mut e = [0i16; N];
297 e[0] = 5;
298
299 ntt(&mut a);
300 ntt(&mut s);
301 ntt(&mut e);
302
303 let mut t = [0i16; N];
304 multiply_ntts(&a, &s, &mut t);
305 to_mont_poly(&mut t);
307 for i in 0..N {
309 t[i] = t[i] + e[i];
310 }
311
312 reduce(&mut t);
316 }
319
320 #[test]
321 fn test_montgomery_reduce_basic() {
322 let a: i16 = 1729;
323 let in_mont = to_mont_const(a);
324 let back = montgomery_reduce(in_mont as i32);
325 let back_reduced = barrett_reduce(back);
326 assert_eq!(back_reduced, a);
327 }
328}