1use alloc::vec::Vec;
6
7use core::ptr::NonNull;
8use core::str::FromStr;
9
10#[cfg(not(feature = "std"))]
11use cstr_core::CStr;
12#[cfg(feature = "std")]
13use std::ffi::CStr;
14
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17
18use crate::ffi::kem as ffi;
19use crate::newtype_buffer;
20use crate::*;
21
22newtype_buffer!(PublicKey, PublicKeyRef);
23newtype_buffer!(SecretKey, SecretKeyRef);
24newtype_buffer!(Ciphertext, CiphertextRef);
25newtype_buffer!(SharedSecret, SharedSecretRef);
26newtype_buffer!(KeypairSeed, KeypairSeedRef);
27
28macro_rules! implement_kems {
29 { $(($feat: literal) $kem: ident: $oqs_id: ident),* $(,)? } => (
30
31 #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
37 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
38 #[allow(missing_docs)]
39 pub enum Algorithm {
40 $(
41 $kem,
42 )*
43 }
44
45 fn algorithm_to_id(algorithm: Algorithm) -> *const libc::c_char {
46 let id: &[u8] = match algorithm {
47 $(
48 Algorithm::$kem => &ffi::$oqs_id[..],
49 )*
50 };
51 id as *const _ as *const libc::c_char
52 }
53
54 impl FromStr for Algorithm {
55 type Err = crate::Error;
56
57 fn from_str(s: &str) -> Result<Self> {
58 $(
59 if s == Algorithm::$kem.name() {
60 return Ok(Algorithm::$kem);
61 }
62 )*
63 Err(crate::Error::AlgorithmParsingError)
64 }
65 }
66
67 $(
68 #[cfg(test)]
69 #[allow(non_snake_case)]
70 mod $kem {
71 use super::*;
72
73 #[test]
74 #[cfg(feature = $feat)]
75 fn test_encaps_decaps() -> Result<()> {
76 crate::init();
77
78 let alg = Algorithm::$kem;
79 let kem = Kem::new(alg)?;
80 let (pk, sk) = kem.keypair()?;
81 let (ct, ss1) = kem.encapsulate(&pk)?;
82 let ss2 = kem.decapsulate(&sk, &ct)?;
83 assert_eq!(ss1, ss2, "shared secret not equal!");
84 Ok(())
85 }
86
87 #[test]
88 #[cfg(feature = $feat)]
89 fn test_encaps_decaps_derand() -> Result<()> {
90 use crate::ffi::rand::OQS_randombytes;
91 crate::init();
92
93 let alg = Algorithm::$kem;
94 let kem = Kem::new(alg)?;
95 let mut seed = KeypairSeed {
96 bytes: Vec::with_capacity(kem.length_keypair_seed()),
97 };
98 unsafe {
99 if (kem.length_keypair_seed() > 0) {
101 OQS_randombytes(seed.bytes.as_mut_ptr(), kem.length_keypair_seed());
102 }
103 seed.bytes.set_len(kem.length_keypair_seed());
104 }
105 let result = kem.keypair_derand(&seed);
106 if (kem.length_keypair_seed() == 0) {
108 return result.map_or_else(|e| { match e { Error::Error => Ok(()), _ => Err(Error::Error) } }, |_| Err(Error::Error));
109 }
110 let (pk, sk) = result?;
111 let (ct, ss1) = kem.encapsulate(&pk)?;
112 let ss2 = kem.decapsulate(&sk, &ct)?;
113 assert_eq!(ss1, ss2, "shared secret not equal!");
114 Ok(())
115 }
116
117 #[test]
118 fn test_enabled() {
119 crate::init();
120 if cfg!(feature = $feat) {
121 assert!(Algorithm::$kem.is_enabled());
122 } else {
123 assert!(!Algorithm::$kem.is_enabled())
124 }
125 }
126
127 #[test]
128 fn test_name() {
129 let algo = Algorithm::$kem;
130 let name = algo.name();
132 #[cfg(feature = "std")]
133 assert_eq!(name, algo.to_string());
134 assert!(!name.is_empty());
136 }
137
138 #[test]
139 fn test_get_algorithm_back() {
140 let algorithm = Algorithm::$kem;
141 if algorithm.is_enabled() {
142 let kem = Kem::new(algorithm).unwrap();
143 assert_eq!(algorithm, kem.algorithm());
144 }
145 }
146
147 #[test]
148 fn test_version() {
149 if let Ok(kem) = Kem::new(Algorithm::$kem) {
150 let version = kem.version();
152 assert!(!version.is_empty());
154 }
155 }
156
157 #[test]
158 fn test_from_str() {
159 let algorithm = Algorithm::$kem;
160 let name = algorithm.name();
161 let parsed = Algorithm::from_str(name).unwrap();
162 assert_eq!(algorithm, parsed);
163 }
164 }
165 )*
166 )
167}
168
169implement_kems! {
170 ("bike") BikeL1: OQS_KEM_alg_bike_l1,
171 ("bike") BikeL3: OQS_KEM_alg_bike_l3,
172 ("bike") BikeL5: OQS_KEM_alg_bike_l5,
173 ("classic_mceliece") ClassicMcEliece348864: OQS_KEM_alg_classic_mceliece_348864,
174 ("classic_mceliece") ClassicMcEliece348864f: OQS_KEM_alg_classic_mceliece_348864f,
175 ("classic_mceliece") ClassicMcEliece460896: OQS_KEM_alg_classic_mceliece_460896,
176 ("classic_mceliece") ClassicMcEliece460896f: OQS_KEM_alg_classic_mceliece_460896f,
177 ("classic_mceliece") ClassicMcEliece6688128: OQS_KEM_alg_classic_mceliece_6688128,
178 ("classic_mceliece") ClassicMcEliece6688128f: OQS_KEM_alg_classic_mceliece_6688128f,
179 ("classic_mceliece") ClassicMcEliece6960119: OQS_KEM_alg_classic_mceliece_6960119,
180 ("classic_mceliece") ClassicMcEliece6960119f: OQS_KEM_alg_classic_mceliece_6960119f,
181 ("classic_mceliece") ClassicMcEliece8192128: OQS_KEM_alg_classic_mceliece_8192128,
182 ("classic_mceliece") ClassicMcEliece8192128f: OQS_KEM_alg_classic_mceliece_8192128f,
183 ("hqc") Hqc128: OQS_KEM_alg_hqc_128,
184 ("hqc") Hqc192: OQS_KEM_alg_hqc_192,
185 ("hqc") Hqc256: OQS_KEM_alg_hqc_256,
186 ("kyber") Kyber512: OQS_KEM_alg_kyber_512,
187 ("kyber") Kyber768: OQS_KEM_alg_kyber_768,
188 ("kyber") Kyber1024: OQS_KEM_alg_kyber_1024,
189 ("ml_kem") MlKem512: OQS_KEM_alg_ml_kem_512,
190 ("ml_kem") MlKem768: OQS_KEM_alg_ml_kem_768,
191 ("ml_kem") MlKem1024: OQS_KEM_alg_ml_kem_1024,
192 ("ntruprime") NtruPrimeSntrup761: OQS_KEM_alg_ntruprime_sntrup761,
193 ("frodokem") FrodoKem640Aes: OQS_KEM_alg_frodokem_640_aes,
194 ("frodokem") FrodoKem640Shake: OQS_KEM_alg_frodokem_640_shake,
195 ("frodokem") FrodoKem976Aes: OQS_KEM_alg_frodokem_976_aes,
196 ("frodokem") FrodoKem976Shake: OQS_KEM_alg_frodokem_976_shake,
197 ("frodokem") FrodoKem1344Aes: OQS_KEM_alg_frodokem_1344_aes,
198 ("frodokem") FrodoKem1344Shake: OQS_KEM_alg_frodokem_1344_shake,
199}
200
201impl Algorithm {
202 pub fn is_enabled(self) -> bool {
205 unsafe { ffi::OQS_KEM_alg_is_enabled(algorithm_to_id(self)) == 1 }
206 }
207
208 pub fn to_id(self) -> *const libc::c_char {
212 algorithm_to_id(self)
213 }
214
215 pub fn name(&self) -> &'static str {
219 let id = unsafe { CStr::from_ptr(self.to_id()) };
221 id.to_str().expect("OQS algorithm names must be UTF-8")
222 }
223}
224
225#[cfg(feature = "std")]
226impl std::fmt::Display for Algorithm {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 self.name().fmt(f)
229 }
230}
231
232pub struct Kem {
246 algorithm: Algorithm,
247 kem: NonNull<ffi::OQS_KEM>,
248}
249
250unsafe impl Sync for Kem {}
251unsafe impl Send for Kem {}
252
253impl Drop for Kem {
254 fn drop(&mut self) {
255 unsafe { ffi::OQS_KEM_free(self.kem.as_ptr()) };
256 }
257}
258
259impl core::convert::TryFrom<Algorithm> for Kem {
260 type Error = crate::Error;
261 fn try_from(alg: Algorithm) -> Result<Kem> {
262 Kem::new(alg)
263 }
264}
265
266impl Kem {
267 pub fn new(algorithm: Algorithm) -> Result<Self> {
269 let kem = unsafe { ffi::OQS_KEM_new(algorithm_to_id(algorithm)) };
270 NonNull::new(kem).map_or_else(
271 || Err(Error::AlgorithmDisabled),
272 |kem| Ok(Self { algorithm, kem }),
273 )
274 }
275
276 pub fn algorithm(&self) -> Algorithm {
278 self.algorithm
279 }
280
281 pub fn version(&self) -> &'static str {
283 let kem = unsafe { self.kem.as_ref() };
284 let cstr = unsafe { CStr::from_ptr(kem.alg_version) };
286 cstr.to_str()
287 .expect("Algorithm version strings must be UTF-8")
288 }
289
290 pub fn claimed_nist_level(&self) -> u8 {
292 let kem = unsafe { self.kem.as_ref() };
293 kem.claimed_nist_level
294 }
295
296 pub fn is_ind_cca(&self) -> bool {
298 let kem = unsafe { self.kem.as_ref() };
299 kem.ind_cca
300 }
301
302 pub fn length_public_key(&self) -> usize {
304 let kem = unsafe { self.kem.as_ref() };
305 kem.length_public_key
306 }
307
308 pub fn length_secret_key(&self) -> usize {
310 let kem = unsafe { self.kem.as_ref() };
311 kem.length_secret_key
312 }
313
314 pub fn length_ciphertext(&self) -> usize {
316 let kem = unsafe { self.kem.as_ref() };
317 kem.length_ciphertext
318 }
319
320 pub fn length_shared_secret(&self) -> usize {
322 let kem = unsafe { self.kem.as_ref() };
323 kem.length_shared_secret
324 }
325
326 pub fn length_keypair_seed(&self) -> usize {
328 let kem = unsafe { self.kem.as_ref() };
329 kem.length_keypair_seed
330 }
331
332 pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SecretKeyRef<'a>> {
336 if self.length_secret_key() != buf.len() {
337 None
338 } else {
339 Some(SecretKeyRef::new(buf))
340 }
341 }
342
343 pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<PublicKeyRef<'a>> {
347 if self.length_public_key() != buf.len() {
348 None
349 } else {
350 Some(PublicKeyRef::new(buf))
351 }
352 }
353
354 pub fn ciphertext_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<CiphertextRef<'a>> {
358 if self.length_ciphertext() != buf.len() {
359 None
360 } else {
361 Some(CiphertextRef::new(buf))
362 }
363 }
364
365 pub fn shared_secret_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SharedSecretRef<'a>> {
369 if self.length_shared_secret() != buf.len() {
370 None
371 } else {
372 Some(SharedSecretRef::new(buf))
373 }
374 }
375
376 pub fn keypair_seed_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<KeypairSeedRef<'a>> {
380 if self.length_keypair_seed() != buf.len() {
381 None
382 } else {
383 Some(KeypairSeedRef::new(buf))
384 }
385 }
386
387 pub fn keypair(&self) -> Result<(PublicKey, SecretKey)> {
389 let kem = unsafe { self.kem.as_ref() };
390 let func = kem.keypair.unwrap();
391 let mut pk = PublicKey {
392 bytes: Vec::with_capacity(kem.length_public_key),
393 };
394 let mut sk = SecretKey {
395 bytes: Vec::with_capacity(kem.length_secret_key),
396 };
397 let status = unsafe { func(pk.bytes.as_mut_ptr(), sk.bytes.as_mut_ptr()) };
398 status_to_result(status)?;
399 unsafe {
402 pk.bytes.set_len(kem.length_public_key);
403 sk.bytes.set_len(kem.length_secret_key);
404 }
405 Ok((pk, sk))
406 }
407
408 pub fn keypair_derand<'a, S: Into<KeypairSeedRef<'a>>>(
410 &self,
411 seed: S,
412 ) -> Result<(PublicKey, SecretKey)> {
413 let seed = seed.into();
414 if seed.bytes.len() != self.length_keypair_seed() {
415 return Err(Error::InvalidLength);
416 }
417 let kem = unsafe { self.kem.as_ref() };
418 let func = kem.keypair_derand.unwrap();
419 let mut pk = PublicKey {
420 bytes: Vec::with_capacity(kem.length_public_key),
421 };
422 let mut sk = SecretKey {
423 bytes: Vec::with_capacity(kem.length_secret_key),
424 };
425 let status = unsafe {
426 func(
427 pk.bytes.as_mut_ptr(),
428 sk.bytes.as_mut_ptr(),
429 seed.bytes.as_ptr(),
430 )
431 };
432 status_to_result(status)?;
433 unsafe {
436 pk.bytes.set_len(kem.length_public_key);
437 sk.bytes.set_len(kem.length_secret_key);
438 }
439 Ok((pk, sk))
440 }
441
442 pub fn encapsulate<'a, P: Into<PublicKeyRef<'a>>>(
444 &self,
445 pk: P,
446 ) -> Result<(Ciphertext, SharedSecret)> {
447 let pk = pk.into();
448 if pk.bytes.len() != self.length_public_key() {
449 return Err(Error::InvalidLength);
450 }
451 let kem = unsafe { self.kem.as_ref() };
452 let func = kem.encaps.unwrap();
453 let mut ct = Ciphertext {
454 bytes: Vec::with_capacity(kem.length_ciphertext),
455 };
456 let mut ss = SharedSecret {
457 bytes: Vec::with_capacity(kem.length_shared_secret),
458 };
459 let status = unsafe {
461 func(
462 ct.bytes.as_mut_ptr(),
463 ss.bytes.as_mut_ptr(),
464 pk.bytes.as_ptr(),
465 )
466 };
467 status_to_result(status)?;
468 unsafe {
471 ct.bytes.set_len(kem.length_ciphertext);
472 ss.bytes.set_len(kem.length_shared_secret);
473 }
474 Ok((ct, ss))
475 }
476
477 pub fn decapsulate<'a, 'b, S: Into<SecretKeyRef<'a>>, C: Into<CiphertextRef<'b>>>(
479 &self,
480 sk: S,
481 ct: C,
482 ) -> Result<SharedSecret> {
483 let kem = unsafe { self.kem.as_ref() };
484 let sk = sk.into();
485 let ct = ct.into();
486 if sk.bytes.len() != self.length_secret_key() || ct.bytes.len() != self.length_ciphertext()
487 {
488 return Err(Error::InvalidLength);
489 }
490 let mut ss = SharedSecret {
491 bytes: Vec::with_capacity(kem.length_shared_secret),
492 };
493 let func = kem.decaps.unwrap();
494 let status = unsafe { func(ss.bytes.as_mut_ptr(), ct.bytes.as_ptr(), sk.bytes.as_ptr()) };
496 status_to_result(status)?;
497 unsafe { ss.bytes.set_len(kem.length_shared_secret) };
500 Ok(ss)
501 }
502}