oqs/
kem.rs

1//! KEM API
2//!
3//! See [`Kem`] for the main functionality.
4//! [`Algorithm`] lists the available algorithms.
5use alloc::vec::Vec;
6
7use core::ptr::NonNull;
8use core::str::FromStr;
9
10#[cfg(not(feature = "std"))]
11use cstr_core::CStr;
12#[cfg(feature = "std")]
13use std::ffi::CStr;
14
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17
18use crate::ffi::kem as ffi;
19use crate::newtype_buffer;
20use crate::*;
21
22newtype_buffer!(PublicKey, PublicKeyRef);
23newtype_buffer!(SecretKey, SecretKeyRef);
24newtype_buffer!(Ciphertext, CiphertextRef);
25newtype_buffer!(SharedSecret, SharedSecretRef);
26newtype_buffer!(KeypairSeed, KeypairSeedRef);
27
28macro_rules! implement_kems {
29    { $(($feat: literal) $kem: ident: $oqs_id: ident),* $(,)? } => (
30
31        /// Supported algorithms by OQS
32        ///
33        /// Note that this doesn't mean that they'll be available.
34        ///
35        /// Optional support for `serde` if that feature is enabled.
36        #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
37        #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
38        #[allow(missing_docs)]
39        pub enum Algorithm {
40            $(
41                $kem,
42            )*
43        }
44
45        fn algorithm_to_id(algorithm: Algorithm) -> *const libc::c_char {
46            let id: &[u8] = match algorithm {
47                $(
48                    Algorithm::$kem => &ffi::$oqs_id[..],
49                )*
50            };
51            id as *const _ as *const libc::c_char
52        }
53
54        impl FromStr for Algorithm {
55            type Err = crate::Error;
56
57            fn from_str(s: &str) -> Result<Self> {
58                $(
59                    if s == Algorithm::$kem.name() {
60                        return Ok(Algorithm::$kem);
61                    }
62                )*
63                Err(crate::Error::AlgorithmParsingError)
64            }
65        }
66
67        $(
68            #[cfg(test)]
69            #[allow(non_snake_case)]
70            mod $kem {
71                use super::*;
72
73                #[test]
74                #[cfg(feature = $feat)]
75                fn test_encaps_decaps() -> Result<()> {
76                    crate::init();
77
78                    let alg = Algorithm::$kem;
79                    let kem = Kem::new(alg)?;
80                    let (pk, sk) = kem.keypair()?;
81                    let (ct, ss1) = kem.encapsulate(&pk)?;
82                    let ss2 = kem.decapsulate(&sk, &ct)?;
83                    assert_eq!(ss1, ss2, "shared secret not equal!");
84                    Ok(())
85                }
86
87                #[test]
88                #[cfg(feature = $feat)]
89                fn test_encaps_decaps_derand() -> Result<()> {
90                    use crate::ffi::rand::OQS_randombytes;
91                    crate::init();
92
93                    let alg = Algorithm::$kem;
94                    let kem = Kem::new(alg)?;
95                    let mut seed = KeypairSeed {
96                        bytes: Vec::with_capacity(kem.length_keypair_seed()),
97                    };
98                    unsafe {
99                        // On some systems, getentropy fails if given a zero-length array
100                        if (kem.length_keypair_seed() > 0) {
101                            OQS_randombytes(seed.bytes.as_mut_ptr(), kem.length_keypair_seed());
102                        }
103                        seed.bytes.set_len(kem.length_keypair_seed());
104                    }
105                    let result = kem.keypair_derand(&seed);
106                    // expect Error::Error for KEMs with this API disabled
107                    if (kem.length_keypair_seed() == 0) {
108                        return result.map_or_else(|e| { match e { Error::Error => Ok(()), _ => Err(Error::Error) } }, |_| Err(Error::Error));
109                    }
110                    let (pk, sk) = result?;
111                    let (ct, ss1) = kem.encapsulate(&pk)?;
112                    let ss2 = kem.decapsulate(&sk, &ct)?;
113                    assert_eq!(ss1, ss2, "shared secret not equal!");
114                    Ok(())
115                }
116
117                #[test]
118                fn test_enabled() {
119                    crate::init();
120                    if cfg!(feature = $feat) {
121                        assert!(Algorithm::$kem.is_enabled());
122                    } else {
123                        assert!(!Algorithm::$kem.is_enabled())
124                    }
125                }
126
127                #[test]
128                fn test_name() {
129                    let algo = Algorithm::$kem;
130                    // Just make sure the name impl does not panic or crash.
131                    let name = algo.name();
132                    #[cfg(feature = "std")]
133                    assert_eq!(name, algo.to_string());
134                    // ... And actually contains something.
135                    assert!(!name.is_empty());
136                }
137
138                #[test]
139                fn test_get_algorithm_back() {
140                    let algorithm = Algorithm::$kem;
141                    if algorithm.is_enabled() {
142                        let kem = Kem::new(algorithm).unwrap();
143                        assert_eq!(algorithm, kem.algorithm());
144                    }
145                }
146
147                #[test]
148                fn test_version() {
149                    if let Ok(kem) = Kem::new(Algorithm::$kem) {
150                        // Just make sure the version can be called without panic
151                        let version = kem.version();
152                        // ... And actually contains something.
153                        assert!(!version.is_empty());
154                    }
155                }
156
157                #[test]
158                fn test_from_str() {
159                    let algorithm = Algorithm::$kem;
160                    let name = algorithm.name();
161                    let parsed = Algorithm::from_str(name).unwrap();
162                    assert_eq!(algorithm, parsed);
163                }
164            }
165        )*
166    )
167}
168
169implement_kems! {
170    ("bike") BikeL1: OQS_KEM_alg_bike_l1,
171    ("bike") BikeL3: OQS_KEM_alg_bike_l3,
172    ("bike") BikeL5: OQS_KEM_alg_bike_l5,
173    ("classic_mceliece") ClassicMcEliece348864: OQS_KEM_alg_classic_mceliece_348864,
174    ("classic_mceliece") ClassicMcEliece348864f: OQS_KEM_alg_classic_mceliece_348864f,
175    ("classic_mceliece") ClassicMcEliece460896: OQS_KEM_alg_classic_mceliece_460896,
176    ("classic_mceliece") ClassicMcEliece460896f: OQS_KEM_alg_classic_mceliece_460896f,
177    ("classic_mceliece") ClassicMcEliece6688128: OQS_KEM_alg_classic_mceliece_6688128,
178    ("classic_mceliece") ClassicMcEliece6688128f: OQS_KEM_alg_classic_mceliece_6688128f,
179    ("classic_mceliece") ClassicMcEliece6960119: OQS_KEM_alg_classic_mceliece_6960119,
180    ("classic_mceliece") ClassicMcEliece6960119f: OQS_KEM_alg_classic_mceliece_6960119f,
181    ("classic_mceliece") ClassicMcEliece8192128: OQS_KEM_alg_classic_mceliece_8192128,
182    ("classic_mceliece") ClassicMcEliece8192128f: OQS_KEM_alg_classic_mceliece_8192128f,
183    ("hqc") Hqc128: OQS_KEM_alg_hqc_128,
184    ("hqc") Hqc192: OQS_KEM_alg_hqc_192,
185    ("hqc") Hqc256: OQS_KEM_alg_hqc_256,
186    ("kyber") Kyber512: OQS_KEM_alg_kyber_512,
187    ("kyber") Kyber768: OQS_KEM_alg_kyber_768,
188    ("kyber") Kyber1024: OQS_KEM_alg_kyber_1024,
189    ("ml_kem") MlKem512: OQS_KEM_alg_ml_kem_512,
190    ("ml_kem") MlKem768: OQS_KEM_alg_ml_kem_768,
191    ("ml_kem") MlKem1024: OQS_KEM_alg_ml_kem_1024,
192    ("ntruprime") NtruPrimeSntrup761: OQS_KEM_alg_ntruprime_sntrup761,
193    ("frodokem") FrodoKem640Aes: OQS_KEM_alg_frodokem_640_aes,
194    ("frodokem") FrodoKem640Shake: OQS_KEM_alg_frodokem_640_shake,
195    ("frodokem") FrodoKem976Aes: OQS_KEM_alg_frodokem_976_aes,
196    ("frodokem") FrodoKem976Shake: OQS_KEM_alg_frodokem_976_shake,
197    ("frodokem") FrodoKem1344Aes: OQS_KEM_alg_frodokem_1344_aes,
198    ("frodokem") FrodoKem1344Shake: OQS_KEM_alg_frodokem_1344_shake,
199}
200
201impl Algorithm {
202    /// Returns true if this algorithm is enabled in the linked version
203    /// of liboqs
204    pub fn is_enabled(self) -> bool {
205        unsafe { ffi::OQS_KEM_alg_is_enabled(algorithm_to_id(self)) == 1 }
206    }
207
208    /// Provides a pointer to the id of the algorithm
209    ///
210    /// For use with the FFI api methods
211    pub fn to_id(self) -> *const libc::c_char {
212        algorithm_to_id(self)
213    }
214
215    /// Returns the algorithm's name as a static Rust string.
216    ///
217    /// This is the same as the `to_id`, but as a safe Rust string.
218    pub fn name(&self) -> &'static str {
219        // SAFETY: The id from ffi must be a proper null terminated C string
220        let id = unsafe { CStr::from_ptr(self.to_id()) };
221        id.to_str().expect("OQS algorithm names must be UTF-8")
222    }
223}
224
225#[cfg(feature = "std")]
226impl std::fmt::Display for Algorithm {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        self.name().fmt(f)
229    }
230}
231
232/// KEM algorithm
233///
234/// # Example
235/// ```rust
236/// # if !cfg!(feature = "ml_kem") { return; }
237/// use oqs;
238/// oqs::init();
239/// let kem = oqs::kem::Kem::new(oqs::kem::Algorithm::MlKem512).unwrap();
240/// let (pk, sk) = kem.keypair().unwrap();
241/// let (ct, ss) = kem.encapsulate(&pk).unwrap();
242/// let ss2 = kem.decapsulate(&sk, &ct).unwrap();
243/// assert_eq!(ss, ss2);
244/// ```
245pub struct Kem {
246    algorithm: Algorithm,
247    kem: NonNull<ffi::OQS_KEM>,
248}
249
250unsafe impl Sync for Kem {}
251unsafe impl Send for Kem {}
252
253impl Drop for Kem {
254    fn drop(&mut self) {
255        unsafe { ffi::OQS_KEM_free(self.kem.as_ptr()) };
256    }
257}
258
259impl core::convert::TryFrom<Algorithm> for Kem {
260    type Error = crate::Error;
261    fn try_from(alg: Algorithm) -> Result<Kem> {
262        Kem::new(alg)
263    }
264}
265
266impl Kem {
267    /// Construct a new algorithm
268    pub fn new(algorithm: Algorithm) -> Result<Self> {
269        let kem = unsafe { ffi::OQS_KEM_new(algorithm_to_id(algorithm)) };
270        NonNull::new(kem).map_or_else(
271            || Err(Error::AlgorithmDisabled),
272            |kem| Ok(Self { algorithm, kem }),
273        )
274    }
275
276    /// Get the algorithm used by this `Kem`
277    pub fn algorithm(&self) -> Algorithm {
278        self.algorithm
279    }
280
281    /// Get the version of the implementation
282    pub fn version(&self) -> &'static str {
283        let kem = unsafe { self.kem.as_ref() };
284        // SAFETY: The alg_version from ffi must be a proper null terminated C string
285        let cstr = unsafe { CStr::from_ptr(kem.alg_version) };
286        cstr.to_str()
287            .expect("Algorithm version strings must be UTF-8")
288    }
289
290    /// Get the claimed nist level
291    pub fn claimed_nist_level(&self) -> u8 {
292        let kem = unsafe { self.kem.as_ref() };
293        kem.claimed_nist_level
294    }
295
296    /// Is the algorithm ind_cca secure
297    pub fn is_ind_cca(&self) -> bool {
298        let kem = unsafe { self.kem.as_ref() };
299        kem.ind_cca
300    }
301
302    /// Get the length of the public key
303    pub fn length_public_key(&self) -> usize {
304        let kem = unsafe { self.kem.as_ref() };
305        kem.length_public_key
306    }
307
308    /// Get the length of the secret key
309    pub fn length_secret_key(&self) -> usize {
310        let kem = unsafe { self.kem.as_ref() };
311        kem.length_secret_key
312    }
313
314    /// Get the length of the ciphertext
315    pub fn length_ciphertext(&self) -> usize {
316        let kem = unsafe { self.kem.as_ref() };
317        kem.length_ciphertext
318    }
319
320    /// Get the length of a shared secret
321    pub fn length_shared_secret(&self) -> usize {
322        let kem = unsafe { self.kem.as_ref() };
323        kem.length_shared_secret
324    }
325
326    /// Get the length of a keypair seed
327    pub fn length_keypair_seed(&self) -> usize {
328        let kem = unsafe { self.kem.as_ref() };
329        kem.length_keypair_seed
330    }
331
332    /// Obtain a secret key objects from bytes
333    ///
334    /// Returns None if the secret key is not the correct length.
335    pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SecretKeyRef<'a>> {
336        if self.length_secret_key() != buf.len() {
337            None
338        } else {
339            Some(SecretKeyRef::new(buf))
340        }
341    }
342
343    /// Obtain a public key from bytes
344    ///
345    /// Returns None if the public key is not the correct length.
346    pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<PublicKeyRef<'a>> {
347        if self.length_public_key() != buf.len() {
348            None
349        } else {
350            Some(PublicKeyRef::new(buf))
351        }
352    }
353
354    /// Obtain a ciphertext from bytes
355    ///
356    /// Returns None if the ciphertext is not the correct length.
357    pub fn ciphertext_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<CiphertextRef<'a>> {
358        if self.length_ciphertext() != buf.len() {
359            None
360        } else {
361            Some(CiphertextRef::new(buf))
362        }
363    }
364
365    /// Obtain a secret key from bytes
366    ///
367    /// Returns None if the shared secret is not the correct length.
368    pub fn shared_secret_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SharedSecretRef<'a>> {
369        if self.length_shared_secret() != buf.len() {
370            None
371        } else {
372            Some(SharedSecretRef::new(buf))
373        }
374    }
375
376    /// Obtain a keypair seed from bytes
377    ///
378    /// Returns None if the shared secret is not the correct length.
379    pub fn keypair_seed_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<KeypairSeedRef<'a>> {
380        if self.length_keypair_seed() != buf.len() {
381            None
382        } else {
383            Some(KeypairSeedRef::new(buf))
384        }
385    }
386
387    /// Generate a new keypair
388    pub fn keypair(&self) -> Result<(PublicKey, SecretKey)> {
389        let kem = unsafe { self.kem.as_ref() };
390        let func = kem.keypair.unwrap();
391        let mut pk = PublicKey {
392            bytes: Vec::with_capacity(kem.length_public_key),
393        };
394        let mut sk = SecretKey {
395            bytes: Vec::with_capacity(kem.length_secret_key),
396        };
397        let status = unsafe { func(pk.bytes.as_mut_ptr(), sk.bytes.as_mut_ptr()) };
398        status_to_result(status)?;
399        // update the lengths of the vecs
400        // this is safe to do, as we have initialised them now.
401        unsafe {
402            pk.bytes.set_len(kem.length_public_key);
403            sk.bytes.set_len(kem.length_secret_key);
404        }
405        Ok((pk, sk))
406    }
407
408    /// Generate a new keypair from a seed
409    pub fn keypair_derand<'a, S: Into<KeypairSeedRef<'a>>>(
410        &self,
411        seed: S,
412    ) -> Result<(PublicKey, SecretKey)> {
413        let seed = seed.into();
414        if seed.bytes.len() != self.length_keypair_seed() {
415            return Err(Error::InvalidLength);
416        }
417        let kem = unsafe { self.kem.as_ref() };
418        let func = kem.keypair_derand.unwrap();
419        let mut pk = PublicKey {
420            bytes: Vec::with_capacity(kem.length_public_key),
421        };
422        let mut sk = SecretKey {
423            bytes: Vec::with_capacity(kem.length_secret_key),
424        };
425        let status = unsafe {
426            func(
427                pk.bytes.as_mut_ptr(),
428                sk.bytes.as_mut_ptr(),
429                seed.bytes.as_ptr(),
430            )
431        };
432        status_to_result(status)?;
433        // update the lengths of the vecs
434        // this is safe to do, as we have initialised them now.
435        unsafe {
436            pk.bytes.set_len(kem.length_public_key);
437            sk.bytes.set_len(kem.length_secret_key);
438        }
439        Ok((pk, sk))
440    }
441
442    /// Encapsulate to the provided public key
443    pub fn encapsulate<'a, P: Into<PublicKeyRef<'a>>>(
444        &self,
445        pk: P,
446    ) -> Result<(Ciphertext, SharedSecret)> {
447        let pk = pk.into();
448        if pk.bytes.len() != self.length_public_key() {
449            return Err(Error::InvalidLength);
450        }
451        let kem = unsafe { self.kem.as_ref() };
452        let func = kem.encaps.unwrap();
453        let mut ct = Ciphertext {
454            bytes: Vec::with_capacity(kem.length_ciphertext),
455        };
456        let mut ss = SharedSecret {
457            bytes: Vec::with_capacity(kem.length_shared_secret),
458        };
459        // call encapsulate
460        let status = unsafe {
461            func(
462                ct.bytes.as_mut_ptr(),
463                ss.bytes.as_mut_ptr(),
464                pk.bytes.as_ptr(),
465            )
466        };
467        status_to_result(status)?;
468        // update the lengths of the vecs
469        // this is safe to do, as we have initialised them now.
470        unsafe {
471            ct.bytes.set_len(kem.length_ciphertext);
472            ss.bytes.set_len(kem.length_shared_secret);
473        }
474        Ok((ct, ss))
475    }
476
477    /// Decapsulate the provided ciphertext
478    pub fn decapsulate<'a, 'b, S: Into<SecretKeyRef<'a>>, C: Into<CiphertextRef<'b>>>(
479        &self,
480        sk: S,
481        ct: C,
482    ) -> Result<SharedSecret> {
483        let kem = unsafe { self.kem.as_ref() };
484        let sk = sk.into();
485        let ct = ct.into();
486        if sk.bytes.len() != self.length_secret_key() || ct.bytes.len() != self.length_ciphertext()
487        {
488            return Err(Error::InvalidLength);
489        }
490        let mut ss = SharedSecret {
491            bytes: Vec::with_capacity(kem.length_shared_secret),
492        };
493        let func = kem.decaps.unwrap();
494        // Call decapsulate
495        let status = unsafe { func(ss.bytes.as_mut_ptr(), ct.bytes.as_ptr(), sk.bytes.as_ptr()) };
496        status_to_result(status)?;
497        // update the lengths of the vecs
498        // this is safe to do, as we have initialised them now.
499        unsafe { ss.bytes.set_len(kem.length_shared_secret) };
500        Ok(ss)
501    }
502}