1use alloc::vec::Vec;
6
7use core::ptr::{null, NonNull};
8use core::str::FromStr;
9
10#[cfg(not(feature = "std"))]
11use cstr_core::CStr;
12#[cfg(feature = "std")]
13use std::ffi::CStr;
14
15use crate::ffi::sig as ffi;
16use crate::newtype_buffer;
17use crate::*;
18
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22newtype_buffer!(PublicKey, PublicKeyRef);
23newtype_buffer!(SecretKey, SecretKeyRef);
24newtype_buffer!(Signature, SignatureRef);
25
26pub type Message = [u8];
28pub type CtxStr = [u8];
30
31macro_rules! implement_sigs {
32 { $(($feat: literal) $sig: ident: $oqs_id: ident),* $(,)? } => (
33 #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
39 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40 #[allow(missing_docs)]
41 pub enum Algorithm {
42 $(
43 $sig,
44 )*
45 }
46
47 fn algorithm_to_id(algorithm: Algorithm) -> *const libc::c_char {
48 let id: &[u8] = match algorithm {
49 $(
50 Algorithm::$sig => &ffi::$oqs_id[..],
51 )*
52 };
53 id as *const _ as *const libc::c_char
54 }
55
56 impl FromStr for Algorithm {
57 type Err = crate::Error;
58
59 fn from_str(s: &str) -> Result<Self> {
60 $(
61 if s == Algorithm::$sig.name() {
62 return Ok(Algorithm::$sig);
63 }
64 )*
65 Err(crate::Error::AlgorithmParsingError)
66 }
67 }
68
69 $(
70 #[cfg(test)]
71 #[allow(non_snake_case)]
72 mod $sig {
73 use super::*;
74
75 #[test]
76 #[cfg(feature = $feat)]
77 fn test_signing() -> Result<()> {
78 crate::init();
79 let message = [0u8; 100];
80 let sig = Sig::new(Algorithm::$sig)?;
81 let (pk, sk) = sig.keypair()?;
82 let signature = sig.sign(&message, &sk)?;
83 sig.verify(&message, &signature, &pk)
84 }
85
86 #[test]
87 #[cfg(feature = $feat)]
88 fn test_signing_with_empty_context_string() -> Result<()> {
89 crate::init();
90 let message = [0u8; 100];
91 let ctx_str: [u8; 0] = [];
92 let sig = Sig::new(Algorithm::$sig)?;
93 let (pk, sk) = sig.keypair()?;
94 let signature = sig.sign_with_ctx_str(&message, &ctx_str, &sk)?;
95 sig.verify_with_ctx_str(&message, &signature, &ctx_str, &pk)
96 }
97
98 #[test]
99 #[cfg(feature = $feat)]
100 fn test_signing_with_nonempty_context_string() -> Result<()> {
101 crate::init();
102 let message = [0u8; 100];
103 let ctx_str = [0u8; 100];
104 let sig = Sig::new(Algorithm::$sig)?;
105 let (pk, sk) = sig.keypair()?;
106 if sig.has_ctx_str_support() {
107 let signature = sig.sign_with_ctx_str(&message, &ctx_str, &sk)?;
108 sig.verify_with_ctx_str(&message, &signature, &ctx_str, &pk)
109 } else {
110 let sig_result = sig.sign_with_ctx_str(&message, &ctx_str, &sk);
111 let sig_result: Result<()> = match sig_result {
113 Err(Error::Error) => Ok(()),
114 Ok(_) => Err(Error::Error),
115 Err(e) => Err(e)
116 };
117 if sig_result.is_ok() {
118 let signature = sig.sign(&message, &sk)?;
120 match sig.verify_with_ctx_str(&message, &signature, &ctx_str, &pk) {
122 Err(Error::Error) => Ok(()),
123 Ok(_) => Err(Error::Error),
124 Err(e) => Err(e)
125
126 }
127 } else {
128 sig_result
129 }
130 }
131 }
132
133 #[test]
134 fn test_enabled() {
135 crate::init();
136 if cfg!(feature = $feat) {
137 assert!(Algorithm::$sig.is_enabled());
138 } else {
139 assert!(!Algorithm::$sig.is_enabled())
140 }
141 }
142
143 #[test]
144 fn test_name() {
145 let algo = Algorithm::$sig;
146 let name = algo.name();
148
149 #[cfg(feature = "std")]
150 assert_eq!(name, algo.to_string());
151
152 assert!(!name.is_empty());
154 }
155
156 #[test]
157 fn test_get_algorithm_back() {
158 let algorithm = Algorithm::$sig;
159 if algorithm.is_enabled() {
160 let sig = Sig::new(algorithm).unwrap();
161 assert_eq!(algorithm, sig.algorithm());
162 }
163 }
164
165 #[test]
166 fn test_version() {
167 if let Ok(sig) = Sig::new(Algorithm::$sig) {
168 let version = sig.version();
170 assert!(!version.is_empty());
172 }
173 }
174
175 #[test]
176 fn test_from_str() {
177 let algorithm = Algorithm::$sig;
178 let name = algorithm.name();
179 let parsed = Algorithm::from_str(name).unwrap();
180 assert_eq!(algorithm, parsed);}
181 }
182 )*
183 )
184}
185
186implement_sigs! {
187 ("cross") CrossRsdp128Balanced: OQS_SIG_alg_cross_rsdp_128_balanced,
188 ("cross") CrossRsdp128Fast: OQS_SIG_alg_cross_rsdp_128_fast,
189 ("cross") CrossRsdp128Small: OQS_SIG_alg_cross_rsdp_128_small,
190 ("cross") CrossRsdp192Balanced: OQS_SIG_alg_cross_rsdp_192_balanced,
191 ("cross") CrossRsdp192Fast: OQS_SIG_alg_cross_rsdp_192_fast,
192 ("cross") CrossRsdp192Small: OQS_SIG_alg_cross_rsdp_192_small,
193 ("cross") CrossRsdp256Balanced: OQS_SIG_alg_cross_rsdp_256_balanced,
194 ("cross") CrossRsdp256Fast: OQS_SIG_alg_cross_rsdp_256_fast,
195 ("cross") CrossRsdp256Small: OQS_SIG_alg_cross_rsdp_256_small,
196 ("cross") CrossRsdpg128Balanced: OQS_SIG_alg_cross_rsdpg_128_balanced,
197 ("cross") CrossRsdpg128Fast: OQS_SIG_alg_cross_rsdpg_128_fast,
198 ("cross") CrossRsdpg128Small: OQS_SIG_alg_cross_rsdpg_128_small,
199 ("cross") CrossRsdpg192Balanced: OQS_SIG_alg_cross_rsdpg_192_balanced,
200 ("cross") CrossRsdpg192Fast: OQS_SIG_alg_cross_rsdpg_192_fast,
201 ("cross") CrossRsdpg192Small: OQS_SIG_alg_cross_rsdpg_192_small,
202 ("cross") CrossRsdpg256Balanced: OQS_SIG_alg_cross_rsdpg_256_balanced,
203 ("cross") CrossRsdpg256Fast: OQS_SIG_alg_cross_rsdpg_256_fast,
204 ("cross") CrossRsdpg256Small: OQS_SIG_alg_cross_rsdpg_256_small,
205 ("falcon") Falcon512: OQS_SIG_alg_falcon_512,
206 ("falcon") Falcon1024: OQS_SIG_alg_falcon_1024,
207 ("mayo") Mayo1: OQS_SIG_alg_mayo_1,
208 ("mayo") Mayo2: OQS_SIG_alg_mayo_2,
209 ("mayo") Mayo3: OQS_SIG_alg_mayo_3,
210 ("mayo") Mayo5: OQS_SIG_alg_mayo_5,
211 ("ml_dsa") MlDsa44: OQS_SIG_alg_ml_dsa_44,
212 ("ml_dsa") MlDsa65: OQS_SIG_alg_ml_dsa_65,
213 ("ml_dsa") MlDsa87: OQS_SIG_alg_ml_dsa_87,
214 ("sphincs") SphincsSha2128fSimple: OQS_SIG_alg_sphincs_sha2_128f_simple,
215 ("sphincs") SphincsSha2128sSimple: OQS_SIG_alg_sphincs_sha2_128s_simple,
216 ("sphincs") SphincsSha2192fSimple: OQS_SIG_alg_sphincs_sha2_192f_simple,
217 ("sphincs") SphincsSha2192sSimple: OQS_SIG_alg_sphincs_sha2_192s_simple,
218 ("sphincs") SphincsSha2256fSimple: OQS_SIG_alg_sphincs_sha2_256f_simple,
219 ("sphincs") SphincsSha2256sSimple: OQS_SIG_alg_sphincs_sha2_256s_simple,
220 ("sphincs") SphincsShake128fSimple: OQS_SIG_alg_sphincs_shake_128f_simple,
221 ("sphincs") SphincsShake128sSimple: OQS_SIG_alg_sphincs_shake_128s_simple,
222 ("sphincs") SphincsShake192fSimple: OQS_SIG_alg_sphincs_shake_192f_simple,
223 ("sphincs") SphincsShake192sSimple: OQS_SIG_alg_sphincs_shake_192s_simple,
224 ("sphincs") SphincsShake256fSimple: OQS_SIG_alg_sphincs_shake_256f_simple,
225 ("sphincs") SphincsShake256sSimple: OQS_SIG_alg_sphincs_shake_256s_simple,
226 ("uov") UovOvIs: OQS_SIG_alg_uov_ov_Is,
227 ("uov") UovOvIp: OQS_SIG_alg_uov_ov_Ip,
228 ("uov") UovOvIII: OQS_SIG_alg_uov_ov_III,
229 ("uov") UovOvV: OQS_SIG_alg_uov_ov_V,
230 ("uov") UovOvIsPkc: OQS_SIG_alg_uov_ov_Is_pkc,
231 ("uov") UovOvIpPkc: OQS_SIG_alg_uov_ov_Ip_pkc,
232 ("uov") UovOvIIIPkc: OQS_SIG_alg_uov_ov_III_pkc,
233 ("uov") UovOvVPkc: OQS_SIG_alg_uov_ov_V_pkc,
234 ("uov") UovOvIsPkcSkc: OQS_SIG_alg_uov_ov_Is_pkc_skc,
235 ("uov") UovOvIpPkcSkc: OQS_SIG_alg_uov_ov_Ip_pkc_skc,
236 ("uov") UovOvIIIPkcSkc: OQS_SIG_alg_uov_ov_III_pkc_skc,
237 ("uov") UovOvVPkcSkc: OQS_SIG_alg_uov_ov_V_pkc_skc,
238}
239
240impl Algorithm {
241 pub fn is_enabled(self) -> bool {
244 unsafe { ffi::OQS_SIG_alg_is_enabled(algorithm_to_id(self)) == 1 }
245 }
246
247 pub fn to_id(self) -> *const libc::c_char {
251 algorithm_to_id(self)
252 }
253
254 pub fn name(&self) -> &'static str {
258 let id = unsafe { CStr::from_ptr(self.to_id()) };
260 id.to_str().expect("OQS algorithm names must be UTF-8")
261 }
262}
263
264pub struct Sig {
278 algorithm: Algorithm,
279 sig: NonNull<ffi::OQS_SIG>,
280}
281
282unsafe impl Sync for Sig {}
283unsafe impl Send for Sig {}
284
285impl Drop for Sig {
286 fn drop(&mut self) {
287 unsafe { ffi::OQS_SIG_free(self.sig.as_ptr()) };
288 }
289}
290
291#[cfg(feature = "std")]
292impl std::fmt::Display for Algorithm {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 self.name().fmt(f)
295 }
296}
297
298impl core::convert::TryFrom<Algorithm> for Sig {
299 type Error = crate::Error;
300 fn try_from(alg: Algorithm) -> Result<Sig> {
301 Sig::new(alg)
302 }
303}
304
305impl Sig {
306 pub fn new(algorithm: Algorithm) -> Result<Self> {
310 let sig = unsafe { ffi::OQS_SIG_new(algorithm_to_id(algorithm)) };
311 NonNull::new(sig).map_or_else(
312 || Err(Error::AlgorithmDisabled),
313 |sig| Ok(Self { algorithm, sig }),
314 )
315 }
316
317 pub fn algorithm(&self) -> Algorithm {
319 self.algorithm
320 }
321
322 pub fn version(&self) -> &'static str {
324 let sig = unsafe { self.sig.as_ref() };
325 let cstr = unsafe { CStr::from_ptr(sig.alg_version) };
327 cstr.to_str()
328 .expect("Algorithm version strings must be UTF-8")
329 }
330
331 pub fn claimed_nist_level(&self) -> u8 {
333 let sig = unsafe { self.sig.as_ref() };
334 sig.claimed_nist_level
335 }
336
337 pub fn is_euf_cma(&self) -> bool {
339 let sig = unsafe { self.sig.as_ref() };
340 sig.euf_cma
341 }
342
343 pub fn has_ctx_str_support(&self) -> bool {
345 let sig = unsafe { self.sig.as_ref() };
346 sig.sig_with_ctx_support
347 }
348
349 pub fn length_public_key(&self) -> usize {
351 let sig = unsafe { self.sig.as_ref() };
352 sig.length_public_key
353 }
354
355 pub fn length_secret_key(&self) -> usize {
357 let sig = unsafe { self.sig.as_ref() };
358 sig.length_secret_key
359 }
360
361 pub fn length_signature(&self) -> usize {
363 let sig = unsafe { self.sig.as_ref() };
364 sig.length_signature
365 }
366
367 pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SecretKeyRef<'a>> {
369 if buf.len() != self.length_secret_key() {
370 None
371 } else {
372 Some(SecretKeyRef::new(buf))
373 }
374 }
375
376 pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<PublicKeyRef<'a>> {
378 if buf.len() != self.length_public_key() {
379 None
380 } else {
381 Some(PublicKeyRef::new(buf))
382 }
383 }
384
385 pub fn signature_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SignatureRef<'a>> {
387 if buf.len() > self.length_signature() {
388 None
389 } else {
390 Some(SignatureRef::new(buf))
391 }
392 }
393
394 pub fn keypair(&self) -> Result<(PublicKey, SecretKey)> {
396 let sig = unsafe { self.sig.as_ref() };
397 let func = sig.keypair.unwrap();
398 let mut pk = PublicKey {
399 bytes: Vec::with_capacity(sig.length_public_key),
400 };
401 let mut sk = SecretKey {
402 bytes: Vec::with_capacity(sig.length_secret_key),
403 };
404 let status = unsafe { func(pk.bytes.as_mut_ptr(), sk.bytes.as_mut_ptr()) };
405 unsafe {
407 pk.bytes.set_len(sig.length_public_key);
408 sk.bytes.set_len(sig.length_secret_key);
409 }
410 status_to_result(status)?;
411 Ok((pk, sk))
412 }
413
414 pub fn sign<'a, S: Into<SecretKeyRef<'a>>>(
416 &self,
417 message: &Message,
418 sk: S,
419 ) -> Result<Signature> {
420 let sk = sk.into();
421 let sig = unsafe { self.sig.as_ref() };
422 let func = sig.sign.unwrap();
423 let mut sig = Signature {
424 bytes: Vec::with_capacity(sig.length_signature),
425 };
426 let mut sig_len = 0;
427 let status = unsafe {
428 func(
429 sig.bytes.as_mut_ptr(),
430 &mut sig_len,
431 message.as_ptr(),
432 message.len(),
433 sk.bytes.as_ptr(),
434 )
435 };
436 status_to_result(status)?;
437 unsafe {
439 sig.bytes.set_len(sig_len);
440 }
441 Ok(sig)
442 }
443
444 pub fn sign_with_ctx_str<'a, S: Into<SecretKeyRef<'a>>>(
446 &self,
447 message: &Message,
448 ctx_str: &CtxStr,
449 sk: S,
450 ) -> Result<Signature> {
451 let sk = sk.into();
452 let sig = unsafe { self.sig.as_ref() };
453 let func = sig.sign_with_ctx_str.unwrap();
454 let mut sig = Signature {
455 bytes: Vec::with_capacity(sig.length_signature),
456 };
457 let mut sig_len = 0;
458 let ctx_str_ptr = if !ctx_str.is_empty() {
462 ctx_str.as_ptr()
463 } else {
464 null()
465 };
466 let status = unsafe {
467 func(
468 sig.bytes.as_mut_ptr(),
469 &mut sig_len,
470 message.as_ptr(),
471 message.len(),
472 ctx_str_ptr,
473 ctx_str.len(),
474 sk.bytes.as_ptr(),
475 )
476 };
477 status_to_result(status)?;
478 unsafe {
480 sig.bytes.set_len(sig_len);
481 }
482 Ok(sig)
483 }
484
485 pub fn verify<'a, 'b>(
487 &self,
488 message: &Message,
489 signature: impl Into<SignatureRef<'a>>,
490 pk: impl Into<PublicKeyRef<'b>>,
491 ) -> Result<()> {
492 let signature = signature.into();
493 let pk = pk.into();
494 if signature.bytes.len() > self.length_signature()
495 || pk.bytes.len() != self.length_public_key()
496 {
497 return Err(Error::InvalidLength);
498 }
499 let sig = unsafe { self.sig.as_ref() };
500 let func = sig.verify.unwrap();
501 let status = unsafe {
502 func(
503 message.as_ptr(),
504 message.len(),
505 signature.bytes.as_ptr(),
506 signature.len(),
507 pk.bytes.as_ptr(),
508 )
509 };
510 status_to_result(status)
511 }
512
513 pub fn verify_with_ctx_str<'a, 'b>(
515 &self,
516 message: &Message,
517 signature: impl Into<SignatureRef<'a>>,
518 ctx_str: &CtxStr,
519 pk: impl Into<PublicKeyRef<'b>>,
520 ) -> Result<()> {
521 let signature = signature.into();
522 let pk = pk.into();
523 if signature.bytes.len() > self.length_signature()
524 || pk.bytes.len() != self.length_public_key()
525 {
526 return Err(Error::InvalidLength);
527 }
528 let sig = unsafe { self.sig.as_ref() };
529 let func = sig.verify_with_ctx_str.unwrap();
530 let ctx_str_ptr = if !ctx_str.is_empty() {
534 ctx_str.as_ptr()
535 } else {
536 null()
537 };
538 let status = unsafe {
539 func(
540 message.as_ptr(),
541 message.len(),
542 signature.bytes.as_ptr(),
543 signature.len(),
544 ctx_str_ptr,
545 ctx_str.len(),
546 pk.bytes.as_ptr(),
547 )
548 };
549 status_to_result(status)
550 }
551}