Skip to main content

quantica/slh_dsa/
xmss.rs

1//! XMSS: eXtended Merkle Signature Scheme (FIPS 205, Algorithms 9-11).
2//!
3//! XMSS combines multiple WOTS+ one-time key pairs into a few-time signature scheme
4//! by organizing them as leaves of a binary Merkle tree of height `h'`. Each XMSS tree
5//! can sign `2^h'` messages (one per WOTS+ leaf).
6//!
7//! In the SLH-DSA hierarchy, XMSS trees form the building blocks of the hypertree:
8//! each layer of the hypertree consists of XMSS trees, where the leaves of upper-layer
9//! trees certify the roots of lower-layer trees.
10//!
11//! The Merkle tree uses the tweakable hash `H` ([`hash::hash_h`]) for internal nodes
12//! and WOTS+ public keys ([`wots::wots_pk_gen`]) as leaf values.
13
14use super::address::{Adrs, TREE, WOTS_HASH};
15use super::hash;
16use super::params::Params;
17use super::wots;
18use alloc::vec::Vec;
19
20/// Compute the root of an XMSS Merkle subtree.
21///
22/// Implements Algorithm 9 of FIPS 205. Computes the root node of a subtree of height
23/// `z` whose leftmost leaf is at index `i` within the current XMSS tree. When `z`
24/// equals `h'` and `i` equals 0, this returns the full XMSS tree root.
25///
26/// For `z = 0`, returns the WOTS+ public key at leaf `i`. For `z > 0`, builds the
27/// tree iteratively: first generating all `2^z` leaf public keys, then hashing pairs
28/// of sibling nodes up to the root using `H`.
29pub fn xmss_node<P: Params>(sk_seed: &[u8], pk_seed: &[u8], i: u32, z: u32, adrs: &mut Adrs) -> Vec<u8> {
30    if z == 0 {
31        // Leaf: compute WOTS+ public key for leaf i
32        adrs.set_type_and_clear(WOTS_HASH);
33        adrs.set_key_pair_address(i);
34        return wots::wots_pk_gen::<P>(sk_seed, pk_seed, adrs);
35    }
36
37    // Iterative Merkle tree computation.
38    // Compute all 2^z leaves, then hash up.
39    let num_leaves = 1u32 << z;
40    let base = i; // i is the leftmost leaf index at this level
41
42    // Compute leaf nodes
43    let mut nodes: Vec<Vec<u8>> = Vec::with_capacity(num_leaves as usize);
44    for j in 0..num_leaves {
45        adrs.set_type_and_clear(WOTS_HASH);
46        adrs.set_key_pair_address(base + j);
47        let leaf = wots::wots_pk_gen::<P>(sk_seed, pk_seed, adrs);
48        nodes.push(leaf);
49    }
50
51    // Hash up the tree
52    for height in 1..=z {
53        let mut new_nodes = Vec::with_capacity(nodes.len() / 2);
54        for j in 0..(nodes.len() / 2) {
55            adrs.set_type_and_clear(TREE);
56            adrs.set_tree_height(height);
57            adrs.set_tree_index(base / (1 << height) + j as u32);
58            let parent = hash::hash_h::<P>(pk_seed, adrs, &nodes[2 * j], &nodes[2 * j + 1]);
59            new_nodes.push(parent);
60        }
61        nodes = new_nodes;
62    }
63
64    nodes.into_iter().next().unwrap()
65}
66
67/// Create an XMSS signature for an `n`-byte message.
68///
69/// Implements Algorithm 10 of FIPS 205. Signs message `m` using the WOTS+ key pair
70/// at leaf index `idx` and produces an authentication path of `h'` sibling nodes
71/// that allows the verifier to recompute the tree root.
72///
73/// The returned signature is `(len + h') * n` bytes: a WOTS+ signature (`len * n`)
74/// followed by the Merkle authentication path (`h' * n`).
75pub fn xmss_sign<P: Params>(m: &[u8], sk_seed: &[u8], idx: u32, pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
76    let mut sig = alloc::vec![0u8; (P::LEN + P::H_PRIME) * P::N];
77    xmss_sign_into::<P>(m, sk_seed, idx, pk_seed, adrs, &mut sig);
78    sig
79}
80
81/// Streaming variant of [`xmss_sign`] — writes the `(LEN + H') * N`
82/// byte signature into the start of `out` (which must be at least
83/// that size) instead of returning a freshly-allocated `Vec<u8>`.
84///
85/// Layout of `out` (after the call): `WOTS sig (LEN * N) || auth path (H' * N)`.
86pub fn xmss_sign_into<P: Params>(m: &[u8], sk_seed: &[u8], idx: u32, pk_seed: &[u8], adrs: &mut Adrs, out: &mut [u8]) {
87    let wots_sig_len = P::LEN * P::N;
88    let total_len = wots_sig_len + P::H_PRIME * P::N;
89    debug_assert!(out.len() >= total_len);
90    let (sig_wots_slot, auth_slot) = out[..total_len].split_at_mut(wots_sig_len);
91
92    // Authentication path: one node per height.
93    for j in 0..P::H_PRIME {
94        let k = (idx >> j) ^ 1;
95        let node = xmss_node::<P>(sk_seed, pk_seed, k * (1 << j), j as u32, adrs);
96        auth_slot[j * P::N..(j + 1) * P::N].copy_from_slice(&node);
97    }
98
99    // WOTS+ signature written straight into the first block.
100    adrs.set_type_and_clear(WOTS_HASH);
101    adrs.set_key_pair_address(idx);
102    wots::wots_sign_into::<P>(m, sk_seed, pk_seed, adrs, sig_wots_slot);
103}
104
105/// Compute an XMSS public key (tree root) from an XMSS signature.
106///
107/// Implements Algorithm 11 of FIPS 205. Recovers the WOTS+ public key from the
108/// WOTS+ signature component, then walks up the authentication path to recompute
109/// the Merkle tree root. If the signature is valid, the returned root matches the
110/// original XMSS public key.
111///
112/// The `sig_xmss` input must be `(len + h') * n` bytes, and `idx` is the leaf index
113/// that was used during signing.
114pub fn xmss_pk_from_sig<P: Params>(idx: u32, sig_xmss: &[u8], m: &[u8], pk_seed: &[u8], adrs: &mut Adrs) -> Vec<u8> {
115    // sig_xmss = sig_wots (len*n bytes) || auth (h'*n bytes)
116    let wots_sig_len = P::LEN * P::N;
117    let sig_wots = &sig_xmss[..wots_sig_len];
118    let auth = &sig_xmss[wots_sig_len..];
119
120    // Compute WOTS+ public key candidate
121    adrs.set_type_and_clear(WOTS_HASH);
122    adrs.set_key_pair_address(idx);
123    let mut node = wots::wots_pk_from_sig::<P>(sig_wots, m, pk_seed, adrs);
124
125    // Climb the authentication path
126    adrs.set_type_and_clear(TREE);
127    adrs.set_tree_index(idx);
128
129    for k in 0..P::H_PRIME {
130        let auth_k = &auth[k * P::N..(k + 1) * P::N];
131        adrs.set_tree_height((k + 1) as u32);
132
133        if ((idx >> k) & 1) == 0 {
134            // idx/2^k is even: current node is left child
135            adrs.set_tree_index(idx >> (k + 1));
136            node = hash::hash_h::<P>(pk_seed, adrs, &node, auth_k);
137        } else {
138            // idx/2^k is odd: current node is right child
139            adrs.set_tree_index(idx >> (k + 1));
140            node = hash::hash_h::<P>(pk_seed, adrs, auth_k, &node);
141        }
142    }
143
144    node
145}