Skip to main content

arcana/rsa/
bigint.rs

1//! Big integer arithmetic for RSA (up to ~4096-bit numbers).
2//!
3//! Represents large integers as little-endian `Vec<u64>` limbs.
4//! Provides addition, subtraction, multiplication, division, modular
5//! exponentiation (Montgomery ladder), extended GCD, and Miller-Rabin
6//! primality. `BigInt` is the underlying storage for every component
7//! of [`super::rsa::RsaPublicKey`] and [`super::rsa::RsaSecretKey`]
8//! and the workhorse of every operation in [`super::pkcs1`],
9//! [`super::oaep`] and [`super::pss`].
10//!
11//! # Side-channel posture
12//!
13//! Roadmap item **`T1-E`** (see `arcana/doc/sca/countermeasures/
14//! rsa.rst`): the operations below need a CT audit before the
15//! evaluation pass, with the same `core::hint::black_box` shielding
16//! pattern as [`super::super::ecc::field`] (commit `76191c1`).
17//!
18//! | Operation              | Risk                                                 | Action  |
19//! |------------------------|------------------------------------------------------|---------|
20//! | `cmp` / `cmp_le`       | Variable-iteration limb-by-limb compare leaks bits   | Rewrite to borrow-only branchless pattern |
21//! | `montgomery_mul`       | Conditional final subtract leaks (Walter 2002)       | Apply `black_box` mask shielding |
22//! | `pow_mod`              | Square-and-multiply must be square-always            | Validate Fermat ladder structure + `black_box` |
23//! | `mod_inv` (extended GCD)| Variable-time GCD historically Minerva target       | Prefer Fermat (`a^(p-2) mod p`) for prime moduli |
24//! | `sub` / `add`          | Borrow / carry propagation                           | Confirm fixed iteration count |
25//!
26//! Once `T1-E` lands the layers above (RSA-CRT decrypt, PKCS#1,
27//! OAEP, PSS) inherit a CT bigint base; combined with `T1-C`
28//! Aumüller and `T2-I` message blinding it gives the full
29//! evaluation-grade RSA stack.
30
31use core::cmp::Ordering;
32
33/// A big unsigned integer stored as little-endian 64-bit limbs.
34#[derive(Clone, Debug)]
35pub struct BigInt {
36    /// Limbs in little-endian order (`limbs[0]` is least significant).
37    pub limbs: Vec<u64>,
38}
39
40// ---------------------------------------------------------------------------
41// Construction helpers
42// ---------------------------------------------------------------------------
43
44impl BigInt {
45    /// Zero value.
46    pub fn zero() -> Self {
47        Self { limbs: vec![0] }
48    }
49
50    /// From a single u64.
51    pub fn from_u64(v: u64) -> Self {
52        Self { limbs: vec![v] }
53    }
54
55    /// From big-endian bytes (as in RSA wire format).
56    pub fn from_be_bytes(bytes: &[u8]) -> Self {
57        if bytes.is_empty() {
58            return Self::zero();
59        }
60        // Pad to multiple of 8
61        let padded_len = (bytes.len() + 7) / 8 * 8;
62        let mut padded = vec![0u8; padded_len];
63        padded[padded_len - bytes.len()..].copy_from_slice(bytes);
64
65        let n_limbs = padded_len / 8;
66        let mut limbs = Vec::with_capacity(n_limbs);
67        for i in (0..n_limbs).rev() {
68            let off = i * 8;
69            let limb = u64::from_be_bytes([
70                padded[off],
71                padded[off + 1],
72                padded[off + 2],
73                padded[off + 3],
74                padded[off + 4],
75                padded[off + 5],
76                padded[off + 6],
77                padded[off + 7],
78            ]);
79            limbs.push(limb);
80        }
81        let mut r = Self { limbs };
82        r.trim();
83        r
84    }
85
86    /// Convert to big-endian bytes, padded to at least `min_len` bytes.
87    pub fn to_be_bytes(&self, min_len: usize) -> Vec<u8> {
88        let n = self.limbs.len();
89        let byte_len = n * 8;
90        let mut buf = vec![0u8; byte_len];
91        for (i, &limb) in self.limbs.iter().enumerate() {
92            let off = byte_len - (i + 1) * 8;
93            buf[off..off + 8].copy_from_slice(&limb.to_be_bytes());
94        }
95        // Strip leading zeros
96        let start = buf.iter().position(|&b| b != 0).unwrap_or(buf.len());
97        let significant = &buf[start..];
98        if significant.len() >= min_len {
99            significant.to_vec()
100        } else {
101            let mut out = vec![0u8; min_len];
102            out[min_len - significant.len()..].copy_from_slice(significant);
103            out
104        }
105    }
106
107    /// Number of significant bits.
108    pub fn bit_len(&self) -> usize {
109        let n = self.limbs.len();
110        if n == 0 {
111            return 0;
112        }
113        let top = self.limbs[n - 1];
114        if top == 0 && n == 1 {
115            return 0;
116        }
117        (n - 1) * 64 + (64 - top.leading_zeros() as usize)
118    }
119
120    /// Test whether bit `i` is set.
121    pub fn bit(&self, i: usize) -> bool {
122        let limb_idx = i / 64;
123        let bit_idx = i % 64;
124        if limb_idx >= self.limbs.len() {
125            false
126        } else {
127            (self.limbs[limb_idx] >> bit_idx) & 1 == 1
128        }
129    }
130
131    /// Set bit `i`.
132    pub fn set_bit(&mut self, i: usize) {
133        let limb_idx = i / 64;
134        let bit_idx = i % 64;
135        while self.limbs.len() <= limb_idx {
136            self.limbs.push(0);
137        }
138        self.limbs[limb_idx] |= 1u64 << bit_idx;
139    }
140
141    /// Is this number zero?
142    pub fn is_zero(&self) -> bool {
143        self.limbs.iter().all(|&l| l == 0)
144    }
145
146    /// Is this number even?
147    pub fn is_even(&self) -> bool {
148        self.limbs.first().map_or(true, |&l| l & 1 == 0)
149    }
150
151    /// Is this number odd?
152    pub fn is_odd(&self) -> bool {
153        !self.is_even()
154    }
155
156    /// Remove leading zero limbs (keep at least one).
157    fn trim(&mut self) {
158        while self.limbs.len() > 1 && *self.limbs.last().unwrap() == 0 {
159            self.limbs.pop();
160        }
161    }
162
163    /// Byte length of the modulus (for RSA octet-string conversion).
164    pub fn byte_len(&self) -> usize {
165        (self.bit_len() + 7) / 8
166    }
167
168    /// Generate a random BigInt with exactly `bits` bits using the provided RNG callback.
169    pub fn random(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> Self {
170        let byte_len = (bits + 7) / 8;
171        let mut buf = vec![0u8; byte_len];
172        rng(&mut buf);
173        // Set the top bit to ensure we get exactly `bits` bits.
174        let top_bit = (bits - 1) % 8;
175        // Clear bits above top_bit. When top_bit == 7, keep all 8 bits.
176        if top_bit < 7 {
177            buf[0] &= (1u8 << (top_bit + 1)) - 1;
178        }
179        buf[0] |= 1u8 << top_bit; // set top bit
180        Self::from_be_bytes(&buf)
181    }
182
183    /// Generate a random odd BigInt with exactly `bits` bits.
184    pub fn random_odd(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> Self {
185        let mut n = Self::random(bits, rng);
186        n.limbs[0] |= 1; // force odd
187        n
188    }
189}
190
191// ---------------------------------------------------------------------------
192// Comparison
193// ---------------------------------------------------------------------------
194
195impl PartialEq for BigInt {
196    fn eq(&self, other: &Self) -> bool {
197        self.cmp_to(other) == Ordering::Equal
198    }
199}
200impl Eq for BigInt {}
201
202impl PartialOrd for BigInt {
203    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
204        Some(self.cmp_to(other))
205    }
206}
207
208impl Ord for BigInt {
209    fn cmp(&self, other: &Self) -> Ordering {
210        self.cmp_to(other)
211    }
212}
213
214impl BigInt {
215    /// Compare `self` to `other` as unsigned big integers. Used by
216    /// the `Ord` / `PartialOrd` impls and exposed publicly so callers
217    /// can perform a comparison without allocating an `Ordering` via
218    /// the trait machinery in tight loops.
219    pub fn cmp_to(&self, other: &Self) -> Ordering {
220        let a_len = self.limbs.len();
221        let b_len = other.limbs.len();
222        let max_len = a_len.max(b_len);
223        for i in (0..max_len).rev() {
224            let a = if i < a_len { self.limbs[i] } else { 0 };
225            let b = if i < b_len { other.limbs[i] } else { 0 };
226            match a.cmp(&b) {
227                Ordering::Equal => continue,
228                ord => return ord,
229            }
230        }
231        Ordering::Equal
232    }
233}
234
235// ---------------------------------------------------------------------------
236// Addition
237// ---------------------------------------------------------------------------
238
239impl BigInt {
240    /// self + other
241    pub fn add(&self, other: &BigInt) -> BigInt {
242        let max_len = self.limbs.len().max(other.limbs.len());
243        let mut result = Vec::with_capacity(max_len + 1);
244        let mut carry: u64 = 0;
245        for i in 0..max_len {
246            let a = if i < self.limbs.len() { self.limbs[i] } else { 0 };
247            let b = if i < other.limbs.len() { other.limbs[i] } else { 0 };
248            let (sum1, c1) = a.overflowing_add(b);
249            let (sum2, c2) = sum1.overflowing_add(carry);
250            result.push(sum2);
251            carry = (c1 as u64) + (c2 as u64);
252        }
253        if carry > 0 {
254            result.push(carry);
255        }
256        let mut r = BigInt { limbs: result };
257        r.trim();
258        r
259    }
260
261    /// self + small (u64)
262    pub fn add_u64(&self, v: u64) -> BigInt {
263        self.add(&BigInt::from_u64(v))
264    }
265}
266
267// ---------------------------------------------------------------------------
268// Subtraction (assumes self >= other)
269// ---------------------------------------------------------------------------
270
271impl BigInt {
272    /// self - other  (panics if result would be negative)
273    pub fn sub(&self, other: &BigInt) -> BigInt {
274        debug_assert!(self >= other, "BigInt::sub: underflow");
275        let mut result = Vec::with_capacity(self.limbs.len());
276        let mut borrow: u64 = 0;
277        for i in 0..self.limbs.len() {
278            let a = self.limbs[i];
279            let b = if i < other.limbs.len() { other.limbs[i] } else { 0 };
280            let (diff1, b1) = a.overflowing_sub(b);
281            let (diff2, b2) = diff1.overflowing_sub(borrow);
282            result.push(diff2);
283            borrow = (b1 as u64) + (b2 as u64);
284        }
285        let mut r = BigInt { limbs: result };
286        r.trim();
287        r
288    }
289
290    /// self - 1
291    pub fn sub_one(&self) -> BigInt {
292        self.sub(&BigInt::from_u64(1))
293    }
294}
295
296// ---------------------------------------------------------------------------
297// Multiplication (schoolbook)
298// ---------------------------------------------------------------------------
299
300impl BigInt {
301    /// self * other (schoolbook O(n^2))
302    pub fn mul(&self, other: &BigInt) -> BigInt {
303        let n = self.limbs.len();
304        let m = other.limbs.len();
305        let mut result = vec![0u64; n + m];
306        for i in 0..n {
307            let mut carry: u64 = 0;
308            for j in 0..m {
309                let (lo, hi) = mul_u64(self.limbs[i], other.limbs[j]);
310                let (s1, c1) = result[i + j].overflowing_add(lo);
311                let (s2, c2) = s1.overflowing_add(carry);
312                result[i + j] = s2;
313                carry = hi + (c1 as u64) + (c2 as u64);
314            }
315            result[i + m] = carry;
316        }
317        let mut r = BigInt { limbs: result };
318        r.trim();
319        r
320    }
321}
322
323/// Multiply two u64 values, returning (lo, hi).
324#[inline]
325fn mul_u64(a: u64, b: u64) -> (u64, u64) {
326    let full = (a as u128) * (b as u128);
327    (full as u64, (full >> 64) as u64)
328}
329
330// ---------------------------------------------------------------------------
331// Division with remainder
332// ---------------------------------------------------------------------------
333
334impl BigInt {
335    /// (quotient, remainder) = self / divisor
336    /// Uses long division.
337    pub fn div_rem(&self, divisor: &BigInt) -> (BigInt, BigInt) {
338        assert!(!divisor.is_zero(), "BigInt: division by zero");
339        if self < divisor {
340            return (BigInt::zero(), self.clone());
341        }
342        if divisor.limbs.len() == 1 {
343            return self.div_rem_u64(divisor.limbs[0]);
344        }
345        self.div_rem_knuth(divisor)
346    }
347
348    /// Division by a single u64 limb.
349    fn div_rem_u64(&self, d: u64) -> (BigInt, BigInt) {
350        let mut rem: u128 = 0;
351        let mut quotient = vec![0u64; self.limbs.len()];
352        for i in (0..self.limbs.len()).rev() {
353            rem = (rem << 64) | (self.limbs[i] as u128);
354            quotient[i] = (rem / d as u128) as u64;
355            rem %= d as u128;
356        }
357        let mut q = BigInt { limbs: quotient };
358        q.trim();
359        (q, BigInt::from_u64(rem as u64))
360    }
361
362    /// Knuth's Algorithm D for multi-word division.
363    fn div_rem_knuth(&self, divisor: &BigInt) -> (BigInt, BigInt) {
364        // Normalize: shift so that top bit of divisor's leading limb is set.
365        let shift = divisor.limbs.last().unwrap().leading_zeros() as usize;
366        let a = self.shl(shift);
367        let b = divisor.shl(shift);
368
369        let n = b.limbs.len();
370        let m = a.limbs.len() - n;
371
372        let mut q_limbs = vec![0u64; m + 1];
373        // Working copy of dividend with extra limb
374        let mut u = a.limbs.clone();
375        if u.len() <= n + m {
376            u.resize(n + m + 1, 0);
377        }
378
379        let b_top = *b.limbs.last().unwrap() as u128;
380
381        for j in (0..=m).rev() {
382            // Estimate q_hat
383            let u_hi = ((u[j + n] as u128) << 64) | (u[j + n - 1] as u128);
384            let mut q_hat = u_hi / b_top;
385            let mut r_hat = u_hi % b_top;
386
387            // Refine estimate
388            if n >= 2 {
389                let b_second = b.limbs[n - 2] as u128;
390                while q_hat >= (1u128 << 64) || q_hat * b_second > (r_hat << 64) | (u[j + n - 2] as u128) {
391                    q_hat -= 1;
392                    r_hat += b_top;
393                    if r_hat >= (1u128 << 64) {
394                        break;
395                    }
396                }
397            }
398
399            // Multiply and subtract
400            let mut borrow: i128 = 0;
401            for i in 0..n {
402                let prod = q_hat * (b.limbs[i] as u128);
403                let diff = (u[j + i] as i128) - borrow - (prod as u64 as i128);
404                u[j + i] = diff as u64;
405                borrow = (prod >> 64) as i128 - (diff >> 64) as i128;
406            }
407            let diff = (u[j + n] as i128) - borrow;
408            u[j + n] = diff as u64;
409
410            q_limbs[j] = q_hat as u64;
411
412            // If we subtracted too much, add back
413            if diff < 0 {
414                q_limbs[j] -= 1;
415                let mut carry: u64 = 0;
416                for i in 0..n {
417                    let (s1, c1) = u[j + i].overflowing_add(b.limbs[i]);
418                    let (s2, c2) = s1.overflowing_add(carry);
419                    u[j + i] = s2;
420                    carry = (c1 as u64) + (c2 as u64);
421                }
422                u[j + n] = u[j + n].wrapping_add(carry);
423            }
424        }
425
426        let mut q = BigInt { limbs: q_limbs };
427        q.trim();
428        // Remainder: take first n limbs of u and shift back
429        u.truncate(n);
430        let mut r = BigInt { limbs: u };
431        r.trim();
432        r = r.shr(shift);
433        (q, r)
434    }
435
436    /// Left shift by `bits` bit positions.
437    pub fn shl(&self, bits: usize) -> BigInt {
438        if bits == 0 {
439            return self.clone();
440        }
441        let limb_shift = bits / 64;
442        let bit_shift = bits % 64;
443        let mut result = vec![0u64; self.limbs.len() + limb_shift + 1];
444        let mut carry: u64 = 0;
445        for i in 0..self.limbs.len() {
446            if bit_shift == 0 {
447                result[i + limb_shift] = self.limbs[i];
448            } else {
449                result[i + limb_shift] = (self.limbs[i] << bit_shift) | carry;
450                carry = self.limbs[i] >> (64 - bit_shift);
451            }
452        }
453        if carry != 0 {
454            result[self.limbs.len() + limb_shift] = carry;
455        }
456        let mut r = BigInt { limbs: result };
457        r.trim();
458        r
459    }
460
461    /// Right shift by `bits` bit positions.
462    pub fn shr(&self, bits: usize) -> BigInt {
463        if bits == 0 {
464            return self.clone();
465        }
466        let limb_shift = bits / 64;
467        let bit_shift = bits % 64;
468        if limb_shift >= self.limbs.len() {
469            return BigInt::zero();
470        }
471        let new_len = self.limbs.len() - limb_shift;
472        let mut result = vec![0u64; new_len];
473        for i in 0..new_len {
474            let src = i + limb_shift;
475            result[i] = if bit_shift == 0 {
476                self.limbs[src]
477            } else {
478                let lo = self.limbs[src] >> bit_shift;
479                let hi = if src + 1 < self.limbs.len() {
480                    self.limbs[src + 1] << (64 - bit_shift)
481                } else {
482                    0
483                };
484                lo | hi
485            };
486        }
487        let mut r = BigInt { limbs: result };
488        r.trim();
489        r
490    }
491
492    /// self mod other
493    pub fn rem(&self, modulus: &BigInt) -> BigInt {
494        self.div_rem(modulus).1
495    }
496}
497
498// ---------------------------------------------------------------------------
499// Montgomery multiplication
500// ---------------------------------------------------------------------------
501
502/// Parameters for Montgomery modular arithmetic.
503pub struct MontParams {
504    /// The modulus n.
505    pub n: BigInt,
506    /// Number of limbs.
507    pub n_limbs: usize,
508    /// -n^{-1} mod 2^64.
509    pub n_inv_neg: u64,
510    /// R = 2^{64*n_limbs} (not stored, implicit).
511    /// R mod n.
512    pub r_mod_n: BigInt,
513    /// R^2 mod n.
514    pub r2_mod_n: BigInt,
515}
516
517impl MontParams {
518    /// Create Montgomery parameters for modulus n.
519    pub fn new(n: &BigInt) -> Self {
520        let n_limbs = n.limbs.len();
521
522        // Compute -n^{-1} mod 2^64 using Newton's method.
523        // n must be odd for Montgomery to work.
524        debug_assert!(n.is_odd(), "Montgomery requires odd modulus");
525        let n0 = n.limbs[0];
526        let n_inv_neg = mod_inv_u64_neg(n0);
527
528        // R = 2^{64*n_limbs}, compute R mod n and R^2 mod n.
529        // R mod n: we can compute by (1 << (64*n_limbs)) mod n
530        let mut r_val = BigInt::zero();
531        r_val.set_bit(64 * n_limbs);
532        let r_mod_n = r_val.rem(n);
533
534        let r2_val = r_val.mul(&r_val);
535        let r2_mod_n = r2_val.rem(n);
536
537        MontParams {
538            n: n.clone(),
539            n_limbs,
540            n_inv_neg,
541            r_mod_n,
542            r2_mod_n,
543        }
544    }
545
546    /// Convert a into Montgomery form: a * R mod n.
547    pub fn to_mont(&self, a: &BigInt) -> BigInt {
548        self.mont_mul(a, &self.r2_mod_n)
549    }
550
551    /// Convert from Montgomery form: a * R^{-1} mod n.
552    pub fn from_mont(&self, a: &BigInt) -> BigInt {
553        self.mont_mul(a, &BigInt::from_u64(1))
554    }
555
556    /// Montgomery multiplication: (a * b * R^{-1}) mod n.
557    /// Uses the CIOS (Coarsely Integrated Operand Scanning) method.
558    pub fn mont_mul(&self, a: &BigInt, b: &BigInt) -> BigInt {
559        let n = self.n_limbs;
560        // Working array t with n+2 limbs (extra for carries).
561        let mut t = vec![0u64; n + 2];
562
563        for i in 0..n {
564            let bi = if i < b.limbs.len() { b.limbs[i] } else { 0 };
565            // t = t + a * b[i]
566            let mut carry: u64 = 0;
567            for j in 0..n {
568                let aj = if j < a.limbs.len() { a.limbs[j] } else { 0 };
569                let (lo, hi) = mul_u64(aj, bi);
570                let (s1, c1) = t[j].overflowing_add(lo);
571                let (s2, c2) = s1.overflowing_add(carry);
572                t[j] = s2;
573                carry = hi + (c1 as u64) + (c2 as u64);
574            }
575            let (s, c) = t[n].overflowing_add(carry);
576            t[n] = s;
577            t[n + 1] = c as u64;
578
579            // m = t[0] * n_inv_neg mod 2^64
580            let m = t[0].wrapping_mul(self.n_inv_neg);
581
582            // t = (t + m * N) >> 64
583            let mut carry: u64 = 0;
584            {
585                let (lo, hi) = mul_u64(m, self.n.limbs[0]);
586                let (s1, c1) = t[0].overflowing_add(lo);
587                let (_s2, c2) = s1.overflowing_add(carry);
588                // We discard t[0] (shifting right by 64).
589                carry = hi + (c1 as u64) + (c2 as u64);
590            }
591            for j in 1..n {
592                let nj = self.n.limbs[j];
593                let (lo, hi) = mul_u64(m, nj);
594                let (s1, c1) = t[j].overflowing_add(lo);
595                let (s2, c2) = s1.overflowing_add(carry);
596                t[j - 1] = s2;
597                carry = hi + (c1 as u64) + (c2 as u64);
598            }
599            let (s1, c1) = t[n].overflowing_add(carry);
600            t[n - 1] = s1;
601            t[n] = t[n + 1] + (c1 as u64);
602            t[n + 1] = 0;
603        }
604
605        // Result in t[0..n], may need final subtraction.
606        t.truncate(n + 1);
607        let mut result = BigInt { limbs: t };
608        result.trim();
609        if result >= self.n {
610            result = result.sub(&self.n);
611        }
612        result.trim();
613        result
614    }
615
616    /// Constant-time modular exponentiation using Montgomery ladder.
617    /// Computes base^exp mod n.
618    pub fn mod_exp(&self, base: &BigInt, exp: &BigInt) -> BigInt {
619        let base_mont = self.to_mont(base);
620        let bits = exp.bit_len();
621        if bits == 0 {
622            return BigInt::from_u64(1);
623        }
624
625        // Montgomery ladder: constant-time (always does same operations).
626        let mut r0 = self.r_mod_n.clone(); // 1 in Montgomery form = R mod n
627        let mut r1 = base_mont.clone();
628
629        for i in (0..bits).rev() {
630            if exp.bit(i) {
631                r0 = self.mont_mul(&r0, &r1);
632                r1 = self.mont_mul(&r1, &r1);
633            } else {
634                r1 = self.mont_mul(&r0, &r1);
635                r0 = self.mont_mul(&r0, &r0);
636            }
637        }
638
639        self.from_mont(&r0)
640    }
641}
642
643/// Compute -(n^{-1}) mod 2^64 using Newton's method.
644fn mod_inv_u64_neg(n0: u64) -> u64 {
645    // We want x such that n0 * x ≡ -1 (mod 2^64), i.e., n0 * x + 1 ≡ 0 (mod 2^64).
646    // Newton: x_{i+1} = x_i * (2 - n0 * x_i) converges to n0^{-1} mod 2^64.
647    let mut x: u64 = 1; // initial guess: n0^{-1} ≡ 1 mod 2 (n0 is odd)
648    for _ in 0..6 {
649        // 6 iterations is enough for 64-bit convergence.
650        x = x.wrapping_mul(2u64.wrapping_sub(n0.wrapping_mul(x)));
651    }
652    // We want -n^{-1} mod 2^64.
653    x.wrapping_neg()
654}
655
656// ---------------------------------------------------------------------------
657// Modular exponentiation (convenience, non-Montgomery)
658// ---------------------------------------------------------------------------
659
660impl BigInt {
661    /// Modular exponentiation: self^exp mod modulus.
662    /// Uses Montgomery multiplication internally for constant-time operation.
663    pub fn mod_exp(&self, exp: &BigInt, modulus: &BigInt) -> BigInt {
664        let params = MontParams::new(modulus);
665        params.mod_exp(self, exp)
666    }
667
668    /// Modular inverse: self^{-1} mod modulus, using extended GCD.
669    /// Returns None if gcd(self, modulus) != 1.
670    pub fn mod_inv(&self, modulus: &BigInt) -> Option<BigInt> {
671        // Extended Euclidean algorithm with signed coefficients.
672        let (g, x, _neg) = extended_gcd(self, modulus);
673        if g != BigInt::from_u64(1) {
674            return None;
675        }
676        Some(x)
677    }
678}
679
680/// Extended GCD returning (gcd, x, y) where a*x + b*y = gcd.
681/// The returned x is in range [0, b).
682fn extended_gcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, bool) {
683    // We use an iterative version with signed bookkeeping.
684    // Instead of true signed BigInts, we track sign bits separately.
685
686    if a.is_zero() {
687        return (b.clone(), BigInt::zero(), false);
688    }
689
690    // Iterative extended GCD
691    let mut old_r = a.clone();
692    let mut r = b.clone();
693    let mut old_s = BigInt::from_u64(1);
694    let mut s = BigInt::zero();
695    let mut old_s_neg = false; // sign of old_s
696    let mut s_neg = false; // sign of s
697
698    while !r.is_zero() {
699        let (q, remainder) = old_r.div_rem(&r);
700
701        old_r = r;
702        r = remainder;
703
704        // new_s = old_s - q * s
705        let qs = q.mul(&s);
706        // We need signed subtraction: old_s_sign * old_s - s_sign * s * q
707        let (new_s, new_s_neg) = signed_sub(&old_s, old_s_neg, &qs, s_neg);
708        old_s = s;
709        old_s_neg = s_neg;
710        s = new_s;
711        s_neg = new_s_neg;
712    }
713
714    // old_s might be negative; if so, add b to get into range [0, b).
715    let x = if old_s_neg { b.sub(&old_s.rem(b)) } else { old_s.rem(b) };
716
717    (old_r, x, false)
718}
719
720/// Signed subtraction: (|a|, a_neg) - (|b|, b_neg) = (|result|, result_neg)
721fn signed_sub(a: &BigInt, a_neg: bool, b: &BigInt, b_neg: bool) -> (BigInt, bool) {
722    // a_val - b_val where a_val = (-1)^a_neg * a, b_val = (-1)^b_neg * b
723    // = (-1)^a_neg * a + (-1)^(!b_neg) * b
724    if a_neg == b_neg {
725        // Same sign: subtract magnitudes
726        if a >= b { (a.sub(b), a_neg) } else { (b.sub(a), !a_neg) }
727    } else {
728        // Different signs: add magnitudes, sign is a's sign
729        (a.add(b), a_neg)
730    }
731}
732
733// ---------------------------------------------------------------------------
734// Miller-Rabin primality test
735// ---------------------------------------------------------------------------
736
737impl BigInt {
738    /// Miller-Rabin primality test with `rounds` iterations.
739    /// Uses the provided RNG to generate random witnesses.
740    pub fn is_probably_prime(&self, rounds: usize, rng: &mut dyn FnMut(&mut [u8])) -> bool {
741        // Handle small cases.
742        if self.limbs.len() == 1 {
743            let v = self.limbs[0];
744            if v < 2 {
745                return false;
746            }
747            if v == 2 || v == 3 {
748                return true;
749            }
750            if v % 2 == 0 {
751                return false;
752            }
753        }
754        if self.is_even() {
755            return false;
756        }
757
758        let one = BigInt::from_u64(1);
759        let two = BigInt::from_u64(2);
760        let n_minus_1 = self.sub(&one);
761        let n_minus_2 = self.sub(&two);
762
763        // Write n-1 = 2^s * d with d odd.
764        let mut d = n_minus_1.clone();
765        let mut s: usize = 0;
766        while d.is_even() {
767            d = d.shr(1);
768            s += 1;
769        }
770
771        let mont = MontParams::new(self);
772
773        'next_round: for _ in 0..rounds {
774            // Pick random a in [2, n-2].
775            let a = loop {
776                let candidate = BigInt::random(self.bit_len(), rng);
777                if candidate >= two && candidate <= n_minus_2 {
778                    break candidate;
779                }
780            };
781
782            let mut x = mont.mod_exp(&a, &d);
783
784            if x == one || x == n_minus_1 {
785                continue 'next_round;
786            }
787
788            for _ in 0..s - 1 {
789                x = mont.mod_exp(&x, &two);
790                if x == n_minus_1 {
791                    continue 'next_round;
792                }
793            }
794
795            return false; // composite
796        }
797
798        true // probably prime
799    }
800
801    /// Generate a random probable prime of `bits` bits.
802    pub fn random_prime(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> BigInt {
803        loop {
804            let candidate = BigInt::random_odd(bits, rng);
805            // Quick trial division for small primes.
806            let small_primes: &[u64] = &[
807                3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103,
808                107, 109, 113,
809            ];
810            let mut skip = false;
811            for &p in small_primes {
812                let (_, rem) = candidate.div_rem(&BigInt::from_u64(p));
813                if rem.is_zero() {
814                    if candidate == BigInt::from_u64(p) {
815                        return candidate;
816                    }
817                    skip = true;
818                    break;
819                }
820            }
821            if skip {
822                continue;
823            }
824
825            // Miller-Rabin with sufficient rounds for the bit size.
826            let rounds = if bits >= 1024 { 4 } else { 8 };
827            if candidate.is_probably_prime(rounds, rng) {
828                return candidate;
829            }
830        }
831    }
832}
833
834// ---------------------------------------------------------------------------
835// Tests
836// ---------------------------------------------------------------------------
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    #[test]
843    fn test_add_sub() {
844        let a = BigInt::from_u64(u64::MAX);
845        let b = BigInt::from_u64(1);
846        let c = a.add(&b);
847        assert_eq!(c.limbs.len(), 2);
848        assert_eq!(c.limbs[0], 0);
849        assert_eq!(c.limbs[1], 1);
850        let d = c.sub(&b);
851        assert_eq!(d, a);
852    }
853
854    #[test]
855    fn test_mul() {
856        let a = BigInt::from_u64(0xFFFFFFFF);
857        let b = BigInt::from_u64(0xFFFFFFFF);
858        let c = a.mul(&b);
859        // 0xFFFFFFFF * 0xFFFFFFFF = 0xFFFFFFFE00000001
860        assert_eq!(c.limbs[0], 0xFFFFFFFE00000001);
861    }
862
863    #[test]
864    fn test_div_rem() {
865        let a = BigInt::from_u64(100);
866        let b = BigInt::from_u64(7);
867        let (q, r) = a.div_rem(&b);
868        assert_eq!(q, BigInt::from_u64(14));
869        assert_eq!(r, BigInt::from_u64(2));
870    }
871
872    #[test]
873    fn test_mod_exp() {
874        // 3^10 mod 7 = 59049 mod 7 = 4
875        let base = BigInt::from_u64(3);
876        let exp = BigInt::from_u64(10);
877        let modulus = BigInt::from_u64(7);
878        let result = base.mod_exp(&exp, &modulus);
879        assert_eq!(result, BigInt::from_u64(4));
880    }
881
882    #[test]
883    fn test_mod_inv() {
884        // 3^{-1} mod 7 = 5 (since 3*5 = 15 ≡ 1 mod 7)
885        let a = BigInt::from_u64(3);
886        let m = BigInt::from_u64(7);
887        let inv = a.mod_inv(&m).unwrap();
888        assert_eq!(inv, BigInt::from_u64(5));
889    }
890
891    #[test]
892    fn test_from_be_bytes_roundtrip() {
893        let bytes = vec![0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01];
894        let n = BigInt::from_be_bytes(&bytes);
895        let out = n.to_be_bytes(8);
896        assert_eq!(out, bytes);
897    }
898
899    #[test]
900    fn test_bit_ops() {
901        let mut n = BigInt::zero();
902        n.set_bit(65);
903        assert!(n.bit(65));
904        assert!(!n.bit(64));
905        assert_eq!(n.bit_len(), 66);
906    }
907
908    fn test_rng() -> impl FnMut(&mut [u8]) {
909        let mut state: u64 = 0xdeadbeefcafebabe;
910        move |buf: &mut [u8]| {
911            for b in buf.iter_mut() {
912                state = state
913                    .wrapping_mul(6364136223846793005)
914                    .wrapping_add(1442695040888963407);
915                *b = (state >> 33) as u8;
916            }
917        }
918    }
919
920    #[test]
921    fn test_primality() {
922        let mut rng = test_rng();
923        // 7 is prime
924        assert!(BigInt::from_u64(7).is_probably_prime(10, &mut rng));
925        // 15 is not prime
926        assert!(!BigInt::from_u64(15).is_probably_prime(10, &mut rng));
927        // 104729 is prime
928        assert!(BigInt::from_u64(104729).is_probably_prime(10, &mut rng));
929    }
930}