quantica/slh_dsa/hypertree.rs
1//! Hypertree: a multi-layer tree-of-XMSS-trees structure (FIPS 205, Algorithms 12-13).
2//!
3//! The hypertree is a `d`-layer construction where each layer consists of XMSS trees
4//! of height `h'`. The leaves of each upper-layer XMSS tree certify the roots of
5//! lower-layer trees, forming a hierarchy with total height `h = d * h'`.
6//!
7//! This structure allows the scheme to support `2^h` FORS signing key pairs while
8//! keeping each individual XMSS tree small (only `2^h'` leaves). The hypertree
9//! signature consists of `d` XMSS signatures, one per layer, chained from the
10//! bottom (layer 0) to the top (layer `d - 1`).
11
12use super::address::Adrs;
13use super::params::Params;
14use super::xmss;
15use alloc::vec::Vec;
16
17/// Sign a message using the hypertree.
18///
19/// Implements Algorithm 12 of FIPS 205. Signs the `n`-byte message `m` (typically
20/// a FORS public key) at the position identified by `idx_tree` and `idx_leaf`.
21///
22/// The signing process starts at layer 0, signs `m` with the XMSS tree at
23/// `(layer=0, tree=idx_tree)` using leaf `idx_leaf`, then propagates the resulting
24/// root up through layers 1 to `d - 1`, each time signing the previous layer's root.
25///
26/// Returns `SIG_HT`, the concatenation of `d` XMSS signatures, totaling
27/// `d * (len + h') * n` bytes.
28pub fn ht_sign<P: Params>(m: &[u8], sk_seed: &[u8], pk_seed: &[u8], idx_tree: u64, idx_leaf: u32) -> Vec<u8> {
29 let xmss_sig_len = (P::LEN + P::H_PRIME) * P::N;
30 let mut sig_ht = alloc::vec![0u8; P::D * xmss_sig_len];
31 ht_sign_into::<P>(m, sk_seed, pk_seed, idx_tree, idx_leaf, &mut sig_ht);
32 sig_ht
33}
34
35/// Streaming variant of [`ht_sign`] — writes the `D * (LEN + H') * N`
36/// byte hypertree signature into the start of `out` (which must be
37/// at least that size).
38///
39/// Each of the `D` XMSS signatures is written directly into its slot
40/// inside `out`; the layer-`j` root is recovered from the slot that
41/// was just written (`xmss_pk_from_sig`) rather than from a
42/// temporary per-layer `Vec<u8>`.
43pub fn ht_sign_into<P: Params>(m: &[u8], sk_seed: &[u8], pk_seed: &[u8], idx_tree: u64, idx_leaf: u32, out: &mut [u8]) {
44 let xmss_sig_len = (P::LEN + P::H_PRIME) * P::N;
45 debug_assert!(out.len() >= P::D * xmss_sig_len);
46
47 let mut adrs = Adrs::new();
48 adrs.set_tree_address(idx_tree);
49
50 // Sign M at layer 0 directly into slot 0.
51 adrs.set_layer_address(0);
52 {
53 let slot = &mut out[..xmss_sig_len];
54 xmss::xmss_sign_into::<P>(m, sk_seed, idx_leaf, pk_seed, &mut adrs, slot);
55 }
56 // Recover the layer-0 root from the slot we just wrote.
57 let mut root = xmss::xmss_pk_from_sig::<P>(idx_leaf, &out[..xmss_sig_len], m, pk_seed, &mut adrs);
58
59 let mut current_idx_tree = idx_tree;
60 for j in 1..P::D {
61 let current_idx_leaf = (current_idx_tree & ((1u64 << P::H_PRIME) - 1)) as u32;
62 current_idx_tree >>= P::H_PRIME;
63
64 adrs.set_layer_address(j as u32);
65 adrs.set_tree_address(current_idx_tree);
66
67 let slot_start = j * xmss_sig_len;
68 let slot_end = slot_start + xmss_sig_len;
69 {
70 let slot = &mut out[slot_start..slot_end];
71 xmss::xmss_sign_into::<P>(&root, sk_seed, current_idx_leaf, pk_seed, &mut adrs, slot);
72 }
73
74 if j < P::D - 1 {
75 root = xmss::xmss_pk_from_sig::<P>(current_idx_leaf, &out[slot_start..slot_end], &root, pk_seed, &mut adrs);
76 }
77 }
78}
79
80/// Verify a hypertree signature.
81///
82/// Implements Algorithm 13 of FIPS 205. Recomputes the XMSS root at each of the `d`
83/// layers, starting from the signed message `m` at layer 0 and propagating upward.
84/// Returns `true` if the reconstructed top-layer root matches `pk_root`.
85///
86/// Each XMSS signature in `sig_ht` is `(len + h') * n` bytes; the full hypertree
87/// signature contains `d` such signatures.
88pub fn ht_verify<P: Params>(
89 m: &[u8],
90 sig_ht: &[u8],
91 pk_seed: &[u8],
92 idx_tree: u64,
93 idx_leaf: u32,
94 pk_root: &[u8],
95) -> bool {
96 let mut adrs = Adrs::new();
97 adrs.set_tree_address(idx_tree);
98
99 // Each XMSS signature is (len + h') * n bytes
100 let xmss_sig_len = (P::LEN + P::H_PRIME) * P::N;
101
102 // Verify layer 0
103 adrs.set_layer_address(0);
104 let sig_tmp = &sig_ht[..xmss_sig_len];
105 let mut node = xmss::xmss_pk_from_sig::<P>(idx_leaf, sig_tmp, m, pk_seed, &mut adrs);
106
107 // Verify remaining layers
108 let mut current_idx_tree = idx_tree;
109 for j in 1..P::D {
110 let current_idx_leaf = (current_idx_tree & ((1u64 << P::H_PRIME) - 1)) as u32;
111 current_idx_tree >>= P::H_PRIME;
112
113 adrs.set_layer_address(j as u32);
114 adrs.set_tree_address(current_idx_tree);
115
116 let sig_tmp = &sig_ht[j * xmss_sig_len..(j + 1) * xmss_sig_len];
117 node = xmss::xmss_pk_from_sig::<P>(current_idx_leaf, sig_tmp, &node, pk_seed, &mut adrs);
118 }
119
120 node == pk_root
121}