1use super::bigint::BigInt;
54use super::rsa::{RsaPublicKey, RsaSecretKey, rsa_decrypt_raw, rsa_encrypt_raw};
55use crate::Hasher;
56
57fn mgf1<H: Hasher>(seed: &[u8], len: usize) -> Vec<u8> {
69 let h_len = H::OUTPUT_LEN;
70 let mut out = Vec::with_capacity(len);
71 let mut counter: u32 = 0;
72 while out.len() < len {
73 let mut hasher = H::new();
74 hasher.update(seed);
75 hasher.update(&counter.to_be_bytes());
76 let block = hasher.finalize();
77 let take = (len - out.len()).min(h_len);
78 out.extend_from_slice(&block[..take]);
79 counter += 1;
80 }
81 out.truncate(len);
82 out
83}
84
85fn emsa_pss_encode<H: Hasher>(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Option<Vec<u8>> {
98 let h_len = H::OUTPUT_LEN;
99 let s_len = salt.len();
100 let em_len = (em_bits + 7) / 8;
101
102 if m_hash.len() != h_len {
104 return None;
105 }
106 if em_len < h_len + s_len + 2 {
107 return None;
108 }
109
110 let mut m_prime = Vec::with_capacity(8 + h_len + s_len);
112 m_prime.extend_from_slice(&[0u8; 8]);
113 m_prime.extend_from_slice(m_hash);
114 m_prime.extend_from_slice(salt);
115 let h = H::hash(&m_prime);
116
117 let db_len = em_len - h_len - 1;
120 let mut db = vec![0u8; db_len];
121 let ps_len = db_len - s_len - 1;
122 db[ps_len] = 0x01;
123 db[ps_len + 1..].copy_from_slice(salt);
124
125 let db_mask = mgf1::<H>(&h, db_len);
127 for i in 0..db_len {
128 db[i] ^= db_mask[i];
129 }
130
131 let clear_bits = 8 * em_len - em_bits;
135 if clear_bits > 0 {
136 db[0] &= 0xff_u8 >> clear_bits;
137 }
138
139 let mut em = Vec::with_capacity(em_len);
141 em.extend_from_slice(&db);
142 em.extend_from_slice(&h);
143 em.push(0xbc);
144
145 Some(em)
146}
147
148fn emsa_pss_verify<H: Hasher>(m_hash: &[u8], em: &[u8], em_bits: usize, s_len: usize) -> bool {
153 let h_len = H::OUTPUT_LEN;
154 let em_len = (em_bits + 7) / 8;
155
156 if m_hash.len() != h_len {
158 return false;
159 }
160 if em.len() != em_len {
161 return false;
162 }
163 if em_len < h_len + s_len + 2 {
164 return false;
165 }
166
167 if em[em_len - 1] != 0xbc {
169 return false;
170 }
171
172 let db_len = em_len - h_len - 1;
174 let masked_db = &em[..db_len];
175 let h = &em[db_len..db_len + h_len];
176
177 let clear_bits = 8 * em_len - em_bits;
179 if clear_bits > 0 && (masked_db[0] >> (8 - clear_bits)) != 0 {
180 return false;
181 }
182
183 let db_mask = mgf1::<H>(h, db_len);
186 let mut db = vec![0u8; db_len];
187 for i in 0..db_len {
188 db[i] = masked_db[i] ^ db_mask[i];
189 }
190 if clear_bits > 0 {
191 db[0] &= 0xff_u8 >> clear_bits;
192 }
193
194 let ps_len = db_len - s_len - 1;
196 for byte in &db[..ps_len] {
197 if *byte != 0 {
198 return false;
199 }
200 }
201 if db[ps_len] != 0x01 {
202 return false;
203 }
204
205 let salt = &db[ps_len + 1..];
207
208 let mut m_prime = Vec::with_capacity(8 + h_len + s_len);
210 m_prime.extend_from_slice(&[0u8; 8]);
211 m_prime.extend_from_slice(m_hash);
212 m_prime.extend_from_slice(salt);
213 let h_prime = H::hash(&m_prime);
214
215 let mut diff = 0u8;
218 for (a, b) in h.iter().zip(h_prime.iter()) {
219 diff |= a ^ b;
220 }
221 diff == 0
222}
223
224pub fn pss_sign_with_salt<H: Hasher>(sk: &RsaSecretKey, m_hash: &[u8], salt: &[u8]) -> Option<Vec<u8>> {
239 let k = sk.modulus_byte_len();
240 let mod_bits = sk.n.bit_len();
241 let em_bits = mod_bits - 1;
242 let em = emsa_pss_encode::<H>(m_hash, em_bits, salt)?;
243
244 let m = BigInt::from_be_bytes(&em);
246 let s = rsa_decrypt_raw(sk, &m);
247 Some(s.to_be_bytes(k))
248}
249
250pub fn pss_sign<H: Hasher>(
262 sk: &RsaSecretKey,
263 m_hash: &[u8],
264 s_len: usize,
265 rng: &mut dyn FnMut(&mut [u8]),
266) -> Option<Vec<u8>> {
267 let mut salt = vec![0u8; s_len];
268 if s_len > 0 {
269 rng(&mut salt);
270 }
271 pss_sign_with_salt::<H>(sk, m_hash, &salt)
272}
273
274pub fn pss_sign_msg<H: Hasher>(
276 sk: &RsaSecretKey,
277 msg: &[u8],
278 s_len: usize,
279 rng: &mut dyn FnMut(&mut [u8]),
280) -> Option<Vec<u8>> {
281 let digest = H::hash(msg);
282 pss_sign::<H>(sk, &digest, s_len, rng)
283}
284
285pub fn pss_verify<H: Hasher>(pk: &RsaPublicKey, m_hash: &[u8], s_len: usize, sig: &[u8]) -> bool {
293 let k = pk.modulus_byte_len();
294 if sig.len() != k {
295 return false;
296 }
297 let mod_bits = pk.n.bit_len();
298 if mod_bits == 0 {
299 return false;
300 }
301 let em_bits = mod_bits - 1;
302 let em_len = (em_bits + 7) / 8;
303
304 let s = BigInt::from_be_bytes(sig);
306 let m = rsa_encrypt_raw(pk, &s);
307 let em = m.to_be_bytes(em_len);
308 if em.len() != em_len {
309 return false;
313 }
314
315 emsa_pss_verify::<H>(m_hash, &em, em_bits, s_len)
316}
317
318pub fn pss_verify_msg<H: Hasher>(pk: &RsaPublicKey, msg: &[u8], s_len: usize, sig: &[u8]) -> bool {
320 let digest = H::hash(msg);
321 pss_verify::<H>(pk, &digest, s_len, sig)
322}
323
324#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::hash::sha256::Sha256;
332 use crate::hash::sha384::Sha384;
333 use crate::hash::sha512::Sha512;
334
335 fn test_rng() -> impl FnMut(&mut [u8]) {
337 let mut state: u64 = 0xdeadbeefcafebabe;
338 move |buf: &mut [u8]| {
339 for b in buf.iter_mut() {
340 state = state
341 .wrapping_mul(6364136223846793005)
342 .wrapping_add(1442695040888963407);
343 *b = (state >> 33) as u8;
344 }
345 }
346 }
347
348 #[test]
357 fn test_emsa_pss_encode_verify_roundtrip_sha256() {
358 let m_hash = Sha256::hash(b"hello PSS");
359 let salt = [0x5a; 32];
360 let em_bits = 2047; let em = emsa_pss_encode::<Sha256>(&m_hash, em_bits, &salt).expect("encode");
362 assert_eq!(em.len(), (em_bits + 7) / 8);
363 assert!(emsa_pss_verify::<Sha256>(&m_hash, &em, em_bits, salt.len()));
364 }
365
366 #[test]
367 fn test_emsa_pss_encode_verify_roundtrip_sha384() {
368 let m_hash = Sha384::hash(b"hello PSS sha384");
369 let salt = [0x17; 48];
370 let em_bits = 3071; let em = emsa_pss_encode::<Sha384>(&m_hash, em_bits, &salt).unwrap();
372 assert_eq!(em.len(), 384);
373 assert!(emsa_pss_verify::<Sha384>(&m_hash, &em, em_bits, salt.len()));
374 }
375
376 #[test]
377 fn test_emsa_pss_encode_verify_roundtrip_sha512() {
378 let m_hash = Sha512::hash(b"hello PSS sha512");
379 let salt = [0x00; 64];
380 let em_bits = 4095; let em = emsa_pss_encode::<Sha512>(&m_hash, em_bits, &salt).unwrap();
382 assert_eq!(em.len(), 512);
383 assert!(emsa_pss_verify::<Sha512>(&m_hash, &em, em_bits, salt.len()));
384 }
385
386 #[test]
389 fn test_emsa_pss_verify_wrong_salt_length_rejects() {
390 let m_hash = Sha256::hash(b"msg");
391 let salt = [0xAB; 32];
392 let em = emsa_pss_encode::<Sha256>(&m_hash, 2047, &salt).unwrap();
393 assert!(!emsa_pss_verify::<Sha256>(&m_hash, &em, 2047, 16));
395 }
396
397 #[test]
400 fn test_emsa_pss_verify_tampered_rejects() {
401 let m_hash = Sha256::hash(b"msg");
402 let salt = [0x12; 32];
403 let mut em = emsa_pss_encode::<Sha256>(&m_hash, 2047, &salt).unwrap();
404 em[0] ^= 0x01;
405 assert!(!emsa_pss_verify::<Sha256>(&m_hash, &em, 2047, salt.len()));
406 }
407
408 #[test]
410 fn test_emsa_pss_verify_missing_bc_byte_rejects() {
411 let m_hash = Sha256::hash(b"msg");
412 let salt = [0x12; 32];
413 let mut em = emsa_pss_encode::<Sha256>(&m_hash, 2047, &salt).unwrap();
414 let last = em.len() - 1;
415 em[last] = 0xbd;
416 assert!(!emsa_pss_verify::<Sha256>(&m_hash, &em, 2047, salt.len()));
417 }
418
419 #[test]
425 fn test_pss_sign_verify_roundtrip_sha256() {
426 let mut rng = test_rng();
427 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
428 let msg = b"PSS end-to-end with SHA-256";
429 let sig = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).expect("sign");
430 assert_eq!(sig.len(), pk.modulus_byte_len());
431 assert!(pss_verify_msg::<Sha256>(&pk, msg, 32, &sig));
432 }
433
434 #[test]
439 fn test_pss_deterministic_signatures_agree() {
440 let mut rng = test_rng();
441 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
442 let m_hash = Sha256::hash(b"determinism");
443 let sig1 = pss_sign_with_salt::<Sha256>(&sk, &m_hash, &[]).expect("sign 1");
444 let sig2 = pss_sign_with_salt::<Sha256>(&sk, &m_hash, &[]).expect("sign 2");
445 assert_eq!(sig1, sig2);
446 assert!(pss_verify::<Sha256>(&pk, &m_hash, 0, &sig1));
447 }
448
449 #[test]
452 fn test_pss_random_signatures_differ() {
453 let mut rng = test_rng();
454 let (_pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
455 let msg = b"randomisation";
456 let sig1 = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
457 let sig2 = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
458 assert_ne!(sig1, sig2);
459 }
460
461 #[test]
463 fn test_pss_verify_rejects_wrong_message() {
464 let mut rng = test_rng();
465 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
466 let sig = pss_sign_msg::<Sha256>(&sk, b"original", 32, &mut rng).unwrap();
467 assert!(!pss_verify_msg::<Sha256>(&pk, b"tampered", 32, &sig));
468 }
469
470 #[test]
472 fn test_pss_verify_rejects_tampered_signature() {
473 let mut rng = test_rng();
474 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
475 let msg = b"msg";
476 let mut sig = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
477 sig[0] ^= 0x01;
478 assert!(!pss_verify_msg::<Sha256>(&pk, msg, 32, &sig));
479 }
480
481 #[test]
483 fn test_pss_verify_rejects_wrong_key() {
484 let mut rng = test_rng();
485 let (_pk_a, sk_a) = super::super::rsa::rsa_keygen(1024, &mut rng);
486 let (pk_b, _sk_b) = super::super::rsa::rsa_keygen(1024, &mut rng);
487 let sig = pss_sign_msg::<Sha256>(&sk_a, b"msg", 32, &mut rng).unwrap();
488 assert!(!pss_verify_msg::<Sha256>(&pk_b, b"msg", 32, &sig));
489 }
490
491 #[test]
494 fn test_pss_verify_rejects_wrong_salt_length() {
495 let mut rng = test_rng();
496 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
497 let sig = pss_sign_msg::<Sha256>(&sk, b"msg", 32, &mut rng).unwrap();
498 assert!(!pss_verify_msg::<Sha256>(&pk, b"msg", 16, &sig));
499 }
500
501 #[test]
504 fn test_pss_sign_rejects_too_small_modulus_for_salt() {
505 let mut rng = test_rng();
506 let (_pk, sk) = super::super::rsa::rsa_keygen(512, &mut rng);
509 let m_hash = Sha256::hash(b"msg");
510 let result = pss_sign::<Sha256>(&sk, &m_hash, 32, &mut rng);
511 assert!(result.is_none());
512 }
513
514 #[test]
516 fn test_pss_sign_accepts_short_salt_on_small_modulus() {
517 let mut rng = test_rng();
518 let (pk, sk) = super::super::rsa::rsa_keygen(512, &mut rng);
519 let msg = b"small-modulus PSS";
521 let sig = pss_sign_msg::<Sha256>(&sk, msg, 16, &mut rng).expect("sign");
522 assert!(pss_verify_msg::<Sha256>(&pk, msg, 16, &sig));
523 }
524
525 #[test]
528 fn test_pss_hash_mismatch_rejected() {
529 let mut rng = test_rng();
530 let (pk, sk) = super::super::rsa::rsa_keygen(2048, &mut rng);
531 let msg = b"hash flexibility matters";
532 let sig = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
533 assert!(!pss_verify_msg::<Sha384>(&pk, msg, 32, &sig));
535 }
536
537 #[test]
539 fn test_pss_sha384_3072_roundtrip() {
540 let mut rng = test_rng();
541 let (pk, sk) = super::super::rsa::rsa_keygen(3072, &mut rng);
542 let msg = b"PSS SHA-384 / RSA-3072";
543 let sig = pss_sign_msg::<Sha384>(&sk, msg, 48, &mut rng).unwrap();
544 assert!(pss_verify_msg::<Sha384>(&pk, msg, 48, &sig));
545 }
546}