1use super::bigint::BigInt;
30use super::rsa::{RsaPublicKey, RsaSecretKey, rsa_decrypt_raw, rsa_encrypt_raw};
31use crate::Hasher;
32use crate::hash::sha256::Sha256;
33
34const HASH_LEN: usize = 32; fn sha256(data: &[u8]) -> [u8; HASH_LEN] {
38 let digest = Sha256::hash(data);
39 let mut out = [0u8; HASH_LEN];
40 out.copy_from_slice(&digest);
41 out
42}
43
44fn mgf1_sha256(seed: &[u8], len: usize) -> Vec<u8> {
46 let mut output = Vec::with_capacity(len);
47 let mut counter: u32 = 0;
48 while output.len() < len {
49 let mut h = Sha256::new();
50 h.update(seed);
51 h.update(&counter.to_be_bytes());
52 let block = h.finalize();
53 let take = (len - output.len()).min(block.len());
54 output.extend_from_slice(&block[..take]);
55 counter += 1;
56 }
57 output.truncate(len);
58 output
59}
60
61fn xor_in_place(a: &mut [u8], b: &[u8]) {
63 for (x, y) in a.iter_mut().zip(b.iter()) {
64 *x ^= y;
65 }
66}
67
68pub fn oaep_encrypt(pk: &RsaPublicKey, msg: &[u8], label: &[u8], rng: &mut dyn FnMut(&mut [u8])) -> Vec<u8> {
72 let k = pk.modulus_byte_len();
73 let max_msg_len = k - 2 * HASH_LEN - 2;
74 assert!(
75 msg.len() <= max_msg_len,
76 "OAEP: message too long (max {} bytes, got {})",
77 max_msg_len,
78 msg.len()
79 );
80
81 let l_hash = sha256(label);
82
83 let db_len = k - HASH_LEN - 1;
85 let mut db = vec![0u8; db_len];
86 db[..HASH_LEN].copy_from_slice(&l_hash);
87 let ps_len = db_len - HASH_LEN - 1 - msg.len();
89 db[HASH_LEN + ps_len] = 0x01;
90 db[HASH_LEN + ps_len + 1..].copy_from_slice(msg);
91
92 let mut seed = [0u8; HASH_LEN];
94 rng(&mut seed);
95
96 let db_mask = mgf1_sha256(&seed, db_len);
98 xor_in_place(&mut db, &db_mask);
99
100 let seed_mask = mgf1_sha256(&db, HASH_LEN);
102 let mut masked_seed = seed;
103 xor_in_place(&mut masked_seed, &seed_mask);
104
105 let mut em = vec![0u8; k];
107 em[0] = 0x00;
108 em[1..1 + HASH_LEN].copy_from_slice(&masked_seed);
109 em[1 + HASH_LEN..].copy_from_slice(&db);
110
111 let m = BigInt::from_be_bytes(&em);
112 let c = rsa_encrypt_raw(pk, &m);
113 c.to_be_bytes(k)
114}
115
116pub fn oaep_decrypt(sk: &RsaSecretKey, ct: &[u8], label: &[u8]) -> Option<Vec<u8>> {
120 let k = sk.modulus_byte_len();
121 if ct.len() != k || k < 2 * HASH_LEN + 2 {
122 return None;
123 }
124
125 let c = BigInt::from_be_bytes(ct);
126 let m = rsa_decrypt_raw(sk, &c);
127 let em = m.to_be_bytes(k);
128
129 if em[0] != 0x00 {
131 return None;
132 }
133
134 let masked_seed = &em[1..1 + HASH_LEN];
135 let masked_db = &em[1 + HASH_LEN..];
136
137 let seed_mask = mgf1_sha256(masked_db, HASH_LEN);
139 let mut seed = [0u8; HASH_LEN];
140 seed.copy_from_slice(masked_seed);
141 xor_in_place(&mut seed, &seed_mask);
142
143 let db_len = k - HASH_LEN - 1;
145 let db_mask = mgf1_sha256(&seed, db_len);
146 let mut db = vec![0u8; db_len];
147 db.copy_from_slice(masked_db);
148 xor_in_place(&mut db, &db_mask);
149
150 let l_hash = sha256(label);
152 let mut valid = true;
153 for i in 0..HASH_LEN {
154 if db[i] != l_hash[i] {
155 valid = false;
156 }
157 }
158
159 let mut sep = None;
161 for i in HASH_LEN..db.len() {
162 if db[i] == 0x01 {
163 sep = Some(i);
164 break;
165 } else if db[i] != 0x00 {
166 valid = false;
167 break;
168 }
169 }
170
171 if !valid {
172 return None;
173 }
174 let sep = sep?;
175
176 Some(db[sep + 1..].to_vec())
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 fn test_rng() -> impl FnMut(&mut [u8]) {
184 let mut state: u64 = 0xdeadbeefcafebabe;
185 move |buf: &mut [u8]| {
186 for b in buf.iter_mut() {
187 state = state
188 .wrapping_mul(6364136223846793005)
189 .wrapping_add(1442695040888963407);
190 *b = (state >> 33) as u8;
191 }
192 }
193 }
194
195 #[test]
196 fn test_mgf1() {
197 let mask1 = mgf1_sha256(b"seed", 64);
199 let mask2 = mgf1_sha256(b"seed", 64);
200 assert_eq!(mask1.len(), 64);
201 assert_eq!(mask1, mask2);
202 let mask3 = mgf1_sha256(b"other", 64);
204 assert_ne!(mask1, mask3);
205 }
206
207 #[test]
208 fn test_oaep_encrypt_decrypt_roundtrip() {
209 let mut rng = test_rng();
210 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
211 let msg = b"Hello, OAEP!";
212 let ct = oaep_encrypt(&pk, msg, b"", &mut rng);
213 let pt = oaep_decrypt(&sk, &ct, b"").expect("OAEP decryption failed");
214 assert_eq!(&pt, msg);
215 }
216
217 #[test]
218 fn test_oaep_wrong_label() {
219 let mut rng = test_rng();
220 let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
221 let msg = b"test";
222 let ct = oaep_encrypt(&pk, msg, b"label_a", &mut rng);
223 let result = oaep_decrypt(&sk, &ct, b"label_b");
224 assert!(result.is_none(), "Decryption should fail with wrong label");
225 }
226}