Skip to main content

quantica/slh_dsa/
wots.rs

1//! WOTS+ one-time signature scheme (FIPS 205, Algorithms 1, 4-8).
2//!
3//! WOTS+ (Winternitz One-Time Signature Plus) is the foundational one-time signature
4//! scheme used at the leaves of every XMSS tree in SLH-DSA. It signs a single `n`-byte
5//! message by:
6//!
7//! 1. Splitting the message into base-`w` digits and computing a checksum.
8//! 2. For each digit `d_i`, evaluating a hash chain `F^(d_i)(sk_i)` where `sk_i` is
9//!    derived from the secret seed via PRF.
10//! 3. Verification completes each chain to step `w - 1` and compresses all endpoints
11//!    into a single `n`-byte public key using `T_l`.
12//!
13//! WOTS+ is purely hash-based: its one-time security relies solely on the second-preimage
14//! resistance of the hash function `F`.
15
16use super::address::{Adrs, WOTS_PK, WOTS_PRF};
17use super::hash;
18use super::params::Params;
19use alloc::vec::Vec;
20
21/// Extract base-2^b digits from a byte string.
22///
23/// Implements Algorithm 4 of FIPS 205. Interprets the byte string `x` as a sequence
24/// of `b`-bit unsigned integers in big-endian bit order and returns `out_len` digits,
25/// each in the range `[0, 2^b)`.
26///
27/// For the standard WOTS+ parameters (`lg_w = 4`, `b = 4`), this extracts nibbles.
28/// It is also used by FORS to extract `a`-bit indices from the message digest.
29pub fn base_2b(x: &[u8], b: usize, out_len: usize) -> Vec<u32> {
30    let mut result = Vec::with_capacity(out_len);
31    let mut in_idx = 0usize;
32    let mut bits = 0u32;
33    let mut total = 0usize;
34    let mask = (1u32 << b) - 1;
35
36    for _ in 0..out_len {
37        while total < b {
38            bits = (bits << 8) | (x[in_idx] as u32);
39            in_idx += 1;
40            total += 8;
41        }
42        total -= b;
43        result.push((bits >> total) & mask);
44    }
45    result
46}
47
48/// Compute the WOTS+ checksum and return `len2` base-`w` digits.
49///
50/// Implements Algorithm 1 of FIPS 205. The checksum ensures that an attacker cannot
51/// forge a signature by only advancing hash chains forward: decreasing any message
52/// digit necessarily increases a checksum digit.
53fn gen_len2<P: Params>(msg_digits: &[u32]) -> Vec<u32> {
54    let w = P::W as u32;
55    // csum = sum of (w - 1 - digit) for each digit in msg
56    let mut csum: u32 = 0;
57    for &d in msg_digits.iter() {
58        csum += w - 1 - d;
59    }
60    // Shift csum left by (8 - ((len2 * lg_w) % 8)) % 8 bits
61    let shift = (8 - ((P::LEN2 * P::LG_W) % 8)) % 8;
62    csum <<= shift;
63
64    // Convert csum to bytes, then extract len2 base-w digits
65    // csum fits in ceil((len2 * lg_w + shift) / 8) bytes
66    let num_bytes = ((P::LEN2 * P::LG_W) + shift + 7) / 8;
67    let csum_bytes = to_byte(csum as u64, num_bytes);
68    base_2b(&csum_bytes, P::LG_W, P::LEN2)
69}
70
71/// Apply the WOTS+ chain function `s` times starting from step `i`.
72///
73/// Implements Algorithm 5 of FIPS 205. Computes `F^s(X)` by iterating the tweakable
74/// hash function `F` a total of `s` times, with hash addresses ranging from `i` to
75/// `i + s - 1`. This recursive implementation is provided for clarity; the internal
76/// signing and verification routines use `chain_iter` (private helper) for efficiency.
77///
78/// Returns an `n`-byte hash chain value.
79pub fn chain<P: Params>(x: &[u8], i: u32, s: u32, pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
80    if s == 0 {
81        return x.to_vec();
82    }
83    let mut tmp = chain::<P>(x, i, s - 1, pk_seed, adrs);
84    adrs.set_hash_address(i + s - 1);
85    tmp = hash::f_hash::<P>(pk_seed, adrs, &tmp);
86    tmp
87}
88
89/// Iterative (non-recursive) variant of the WOTS+ chain function.
90///
91/// Functionally equivalent to [`chain`] but avoids deep recursion by using a simple loop.
92/// This is the version used by all internal signing and verification routines.
93fn chain_iter<P: Params>(x: &[u8], i: u32, s: u32, pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
94    // Use two stack buffers to avoid heap allocation in the inner loop.
95    // MAX_N = 32 covers all SLH-DSA parameter sets.
96    let n = P::N;
97    let mut buf_a = [0u8; 32];
98    let mut buf_b = [0u8; 32];
99    buf_a[..n].copy_from_slice(&x[..n]);
100
101    for j in 0..s {
102        adrs.set_hash_address(i + j);
103        if j % 2 == 0 {
104            hash::f_hash_into::<P>(pk_seed, adrs, &buf_a[..n], &mut buf_b[..n]);
105        } else {
106            hash::f_hash_into::<P>(pk_seed, adrs, &buf_b[..n], &mut buf_a[..n]);
107        }
108    }
109    if s == 0 {
110        buf_a[..n].to_vec()
111    } else if s % 2 == 1 {
112        buf_b[..n].to_vec()
113    } else {
114        buf_a[..n].to_vec()
115    }
116}
117
118/// Generate a WOTS+ public key.
119///
120/// Implements Algorithm 6 of FIPS 205. Derives `len` secret chain values from `SK.seed`
121/// via PRF, evaluates each chain to its full length (`w - 1` steps), and compresses
122/// all `len` endpoints into a single `n`-byte public key using `T_len`.
123///
124/// The address `adrs` must have its key pair address set before calling.
125pub fn wots_pk_gen<P: Params>(sk_seed: &[u8], pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
126    let mut sk_adrs = adrs.clone();
127    sk_adrs.set_type_and_clear(WOTS_PRF);
128    sk_adrs.set_key_pair_address(adrs.get_key_pair_address());
129
130    let mut tmp = Vec::with_capacity(P::LEN * P::N);
131
132    for i in 0..P::LEN {
133        sk_adrs.set_chain_address(i as u32);
134        let sk = hash::prf::<P>(pk_seed, sk_seed, &sk_adrs);
135        adrs.set_chain_address(i as u32);
136        let pk_i = chain_iter::<P>(&sk, 0, (P::W - 1) as u32, pk_seed, adrs);
137        tmp.extend_from_slice(&pk_i);
138    }
139
140    // Compress: T_len(PK.seed, ADRS', tmp)
141    let mut wots_pk_adrs = adrs.clone();
142    wots_pk_adrs.set_type_and_clear(WOTS_PK);
143    wots_pk_adrs.set_key_pair_address(adrs.get_key_pair_address());
144
145    hash::t_l::<P>(pk_seed, &wots_pk_adrs, &tmp)
146}
147
148/// Sign an `n`-byte message using WOTS+.
149///
150/// Implements Algorithm 7 of FIPS 205. Converts the message `m` to base-`w` digits,
151/// appends the checksum digits, and for each digit `d_i` outputs the chain value
152/// `F^(d_i)(sk_i)`. The resulting signature is `len * n` bytes.
153///
154/// This is a one-time signature: using the same WOTS+ key pair to sign two different
155/// messages leaks enough information to allow forgery.
156pub fn wots_sign<P: Params>(m: &[u8], sk_seed: &[u8], pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
157    let mut sig = alloc::vec![0u8; P::LEN * P::N];
158    wots_sign_into::<P>(m, sk_seed, pk_seed, adrs, &mut sig);
159    sig
160}
161
162/// Streaming variant of [`wots_sign`] — writes the `LEN * N`-byte
163/// signature into the start of `out` (which must be at least that
164/// size) instead of returning a freshly-allocated `Vec<u8>`.
165///
166/// Used by the streaming sign path (`xmss_sign_into` → `ht_sign_into`
167/// → `slh_sign_into`) to avoid the transient per-layer heap buffers.
168pub fn wots_sign_into<P: Params>(m: &[u8], sk_seed: &[u8], pk_seed: &[u8], adrs: &mut Adrs, out: &mut [u8]) {
169    debug_assert!(out.len() >= P::LEN * P::N);
170
171    // Convert M to base-w, then append the checksum digits.
172    let msg_digits = base_2b(m, P::LG_W, P::LEN1);
173    let csum_digits = gen_len2::<P>(&msg_digits);
174    let mut digits = msg_digits;
175    digits.extend_from_slice(&csum_digits);
176
177    let mut sk_adrs = adrs.clone();
178    sk_adrs.set_type_and_clear(WOTS_PRF);
179    sk_adrs.set_key_pair_address(adrs.get_key_pair_address());
180
181    for i in 0..P::LEN {
182        sk_adrs.set_chain_address(i as u32);
183        let sk = hash::prf::<P>(pk_seed, sk_seed, &sk_adrs);
184        adrs.set_chain_address(i as u32);
185        let sig_i = chain_iter::<P>(&sk, 0, digits[i], pk_seed, adrs);
186        out[i * P::N..(i + 1) * P::N].copy_from_slice(&sig_i);
187    }
188}
189
190/// Compute a WOTS+ public key candidate from a signature.
191///
192/// Implements Algorithm 8 of FIPS 205. For each digit `d_i` of the message (including
193/// checksum), completes the hash chain from the signature value (at step `d_i`) to the
194/// full chain endpoint (step `w - 1`), then compresses all endpoints with `T_len`.
195///
196/// If the signature is valid, the returned value equals the original public key.
197pub fn wots_pk_from_sig<P: Params>(sig: &[u8], m: &[u8], pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
198    let msg_digits = base_2b(m, P::LG_W, P::LEN1);
199    let csum_digits = gen_len2::<P>(&msg_digits);
200    let mut digits = msg_digits;
201    digits.extend_from_slice(&csum_digits);
202
203    let mut tmp = Vec::with_capacity(P::LEN * P::N);
204
205    for i in 0..P::LEN {
206        adrs.set_chain_address(i as u32);
207        let sig_i = &sig[i * P::N..(i + 1) * P::N];
208        let w_minus_1 = (P::W - 1) as u32;
209        let pk_i = chain_iter::<P>(sig_i, digits[i], w_minus_1 - digits[i], pk_seed, adrs);
210        tmp.extend_from_slice(&pk_i);
211    }
212
213    let mut wots_pk_adrs = adrs.clone();
214    wots_pk_adrs.set_type_and_clear(WOTS_PK);
215    wots_pk_adrs.set_key_pair_address(adrs.get_key_pair_address());
216
217    hash::t_l::<P>(pk_seed, &wots_pk_adrs, &tmp)
218}
219
220/// Convert an integer to a big-endian byte vector of length `n`.
221///
222/// The integer `val` is encoded in exactly `n` bytes, zero-padded on the left.
223/// Used internally to serialize the checksum value before extracting its base-`w` digits.
224/// Convert an integer to a big-endian byte array.
225/// Returns a stack-allocated [u8; 8] (max needed is 4 bytes for checksum).
226pub fn to_byte(val: u64, n: usize) -> Vec<u8> {
227    let mut result = vec![0u8; n];
228    let mut v = val;
229    for i in (0..n).rev() {
230        result[i] = (v & 0xff) as u8;
231        v >>= 8;
232    }
233    result
234}
235
236/// Stack-allocated to_byte for small sizes (up to 8 bytes).
237pub fn to_byte_stack(val: u64, n: usize) -> [u8; 8] {
238    let mut result = [0u8; 8];
239    let mut v = val;
240    let start = 8 - n;
241    for i in (start..8).rev() {
242        result[i] = (v & 0xff) as u8;
243        v >>= 8;
244    }
245    result
246}