Skip to main content

quantica/slh_dsa/
address.rs

1//! ADRS (Address) structure for SLH-DSA (FIPS 205, Section 4.2).
2//!
3//! The address structure is a 32-byte value organized as 8 big-endian `u32` words.
4//! It is used to domain-separate every hash call in SLH-DSA so that distinct positions
5//! in the tree hierarchy produce independent hash outputs.
6//!
7//! # Layout
8//!
9//! | Bytes     | Field             | Description                       |
10//! |-----------|-------------------|-----------------------------------|
11//! | `[0..4]`  | Layer address     | Hypertree layer (0 = bottom)      |
12//! | `[4..16]` | Tree address      | Index of the XMSS tree (3 words)  |
13//! | `[16..20]`| Type              | Address type (see constants below)|
14//! | `[20..32]`| Type-specific     | Meaning depends on the type field |
15
16/// Address type for WOTS+ hash chain evaluation.
17pub const WOTS_HASH: u32 = 0;
18/// Address type for WOTS+ public key compression.
19pub const WOTS_PK: u32 = 1;
20/// Address type for XMSS/hypertree internal Merkle tree nodes.
21pub const TREE: u32 = 2;
22/// Address type for FORS tree nodes.
23pub const FORS_TREE: u32 = 3;
24/// Address type for FORS root compression (combining k roots into one public key).
25pub const FORS_ROOTS: u32 = 4;
26/// Address type for WOTS+ secret key generation via PRF.
27pub const WOTS_PRF: u32 = 5;
28/// Address type for FORS secret key generation via PRF.
29pub const FORS_PRF: u32 = 6;
30
31/// A 32-byte address structure used to domain-separate hash function calls in SLH-DSA.
32///
33/// Every hash invocation in SLH-DSA includes an `Adrs` value that encodes the position
34/// within the signing hierarchy (layer, tree, leaf, chain step, etc.). This ensures that
35/// hash outputs at different positions are cryptographically independent.
36///
37/// The type-specific fields (bytes 20..32) have different interpretations depending on
38/// the address type:
39///
40/// - **WOTS_HASH / WOTS_PK / WOTS_PRF**: key pair address, chain address, hash address
41/// - **TREE**: padding (0), tree height, tree index
42/// - **FORS_TREE / FORS_ROOTS / FORS_PRF**: key pair address, tree height, tree index
43#[derive(Clone, Debug)]
44pub struct Adrs {
45    data: [u8; 32],
46}
47
48impl Adrs {
49    /// Create a new address with all fields initialized to zero.
50    pub fn new() -> Self {
51        Self { data: [0u8; 32] }
52    }
53
54    /// Get a reference to the raw 32-byte address value.
55    ///
56    /// This is passed directly to hash functions as part of their input.
57    pub fn as_bytes(&self) -> &[u8; 32] {
58        &self.data
59    }
60
61    // ---------- Word-level helpers ----------
62
63    fn get_word(&self, offset: usize) -> u32 {
64        u32::from_be_bytes([
65            self.data[offset],
66            self.data[offset + 1],
67            self.data[offset + 2],
68            self.data[offset + 3],
69        ])
70    }
71
72    fn set_word(&mut self, offset: usize, val: u32) {
73        let bytes = val.to_be_bytes();
74        self.data[offset..offset + 4].copy_from_slice(&bytes);
75    }
76
77    // ---------- Layer address (bytes 0..4) ----------
78
79    /// Get the hypertree layer address (0 = bottom layer).
80    pub fn get_layer_address(&self) -> u32 {
81        self.get_word(0)
82    }
83
84    /// Set the hypertree layer address.
85    ///
86    /// Layer 0 is the bottom of the hypertree (closest to the FORS trees);
87    /// layer `d - 1` is the top.
88    pub fn set_layer_address(&mut self, val: u32) {
89        self.set_word(0, val);
90    }
91
92    // ---------- Tree address (bytes 4..16, 12 bytes = 3 words) ----------
93
94    /// Set the tree address from a `u64` index.
95    ///
96    /// The tree address occupies bytes 4..16 (words 1, 2, 3). The `u64` value is stored
97    /// in the lower 8 bytes (words 2 and 3); the upper 4 bytes (word 1) are zeroed.
98    /// This identifies which XMSS tree within the current layer is being addressed.
99    pub fn set_tree_address(&mut self, val: u64) {
100        // Word 1 (bytes 4..8): upper 32 bits beyond u64 range, set to 0.
101        self.set_word(4, 0);
102        // Word 2 (bytes 8..12): high 32 bits of val.
103        self.set_word(8, (val >> 32) as u32);
104        // Word 3 (bytes 12..16): low 32 bits of val.
105        self.set_word(12, val as u32);
106    }
107
108    /// Get the tree address as a `u64` (lower 8 bytes of the 12-byte field).
109    pub fn get_tree_address(&self) -> u64 {
110        let hi = self.get_word(8) as u64;
111        let lo = self.get_word(12) as u64;
112        (hi << 32) | lo
113    }
114
115    // ---------- Type (bytes 16..20) ----------
116
117    /// Set the address type and zero all type-specific fields (bytes 20..32).
118    ///
119    /// This must be called when switching address types to ensure leftover values
120    /// from a previous type do not leak into the new context. Use one of the
121    /// address type constants: [`WOTS_HASH`], [`WOTS_PK`], [`TREE`], [`FORS_TREE`],
122    /// [`FORS_ROOTS`], [`WOTS_PRF`], or [`FORS_PRF`].
123    pub fn set_type_and_clear(&mut self, addr_type: u32) {
124        self.set_word(16, addr_type);
125        // Zero out type-specific fields (bytes 20..32).
126        for i in 20..32 {
127            self.data[i] = 0;
128        }
129    }
130
131    /// Get the current address type.
132    pub fn get_type(&self) -> u32 {
133        self.get_word(16)
134    }
135
136    // ---------- Type-specific fields (bytes 20..32) ----------
137    // For WOTS_HASH / WOTS_PK / WOTS_PRF:
138    //   word 5 (bytes 20..24): key pair address
139    //   word 6 (bytes 24..28): chain address
140    //   word 7 (bytes 28..32): hash address
141    // For TREE:
142    //   word 5 (bytes 20..24): padding (0)
143    //   word 6 (bytes 24..28): tree height
144    //   word 7 (bytes 28..32): tree index
145    // For FORS_TREE / FORS_ROOTS / FORS_PRF:
146    //   word 5 (bytes 20..24): key pair address
147    //   word 6 (bytes 24..28): tree height
148    //   word 7 (bytes 28..32): tree index
149
150    /// Get the key pair address (word 5, bytes 20..24).
151    ///
152    /// Used by WOTS+ and FORS address types to identify which leaf key pair is being
153    /// operated on within the current XMSS tree.
154    pub fn get_key_pair_address(&self) -> u32 {
155        self.get_word(20)
156    }
157
158    /// Set the key pair address (word 5, bytes 20..24).
159    pub fn set_key_pair_address(&mut self, val: u32) {
160        self.set_word(20, val);
161    }
162
163    /// Get the WOTS+ chain address (word 6, bytes 24..28).
164    ///
165    /// Identifies which of the `len` hash chains within a WOTS+ instance is being computed.
166    pub fn get_chain_address(&self) -> u32 {
167        self.get_word(24)
168    }
169
170    /// Set the WOTS+ chain address (word 6, bytes 24..28).
171    pub fn set_chain_address(&mut self, val: u32) {
172        self.set_word(24, val);
173    }
174
175    /// Get the WOTS+ hash address (word 7, bytes 28..32).
176    ///
177    /// Identifies the step within a WOTS+ hash chain (0 to `w - 2`).
178    pub fn get_hash_address(&self) -> u32 {
179        self.get_word(28)
180    }
181
182    /// Set the WOTS+ hash address (word 7, bytes 28..32).
183    pub fn set_hash_address(&mut self, val: u32) {
184        self.set_word(28, val);
185    }
186
187    /// Get the Merkle tree height (word 6, bytes 24..28).
188    ///
189    /// Used by TREE and FORS_TREE address types. Height 0 corresponds to leaves;
190    /// increasing heights move toward the root.
191    pub fn get_tree_height(&self) -> u32 {
192        self.get_word(24)
193    }
194
195    /// Set the Merkle tree height (word 6, bytes 24..28).
196    pub fn set_tree_height(&mut self, val: u32) {
197        self.set_word(24, val);
198    }
199
200    /// Get the Merkle tree node index (word 7, bytes 28..32).
201    ///
202    /// Identifies the node's horizontal position within its tree level.
203    pub fn get_tree_index(&self) -> u32 {
204        self.get_word(28)
205    }
206
207    /// Set the Merkle tree node index (word 7, bytes 28..32).
208    pub fn set_tree_index(&mut self, val: u32) {
209        self.set_word(28, val);
210    }
211}