Skip to content

Commit cffdecb

Browse files
authored
Merge pull request #13 from open-quantum-safe/panic_less
Improve the API by not panicing as much
2 parents e7bd47d + a3d3a93 commit cffdecb

File tree

4 files changed

+134
-39
lines changed

4 files changed

+134
-39
lines changed

oqs/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "oqs"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
authors = ["Thom Wiggers <[email protected]>"]
55
edition = "2018"
66
description = "A Rusty interface to Open-Quantum-Safe's liboqs"

oqs/src/kem.rs

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,47 @@ impl Algorithm {
158158
}
159159
}
160160

161-
/// Contains a KEM algorithm
161+
/// KEM algorithm
162+
///
163+
/// # Example
164+
/// ```rust
165+
/// use oqs;
166+
/// oqs::init();
167+
/// let kem = oqs::kem::Kem::default();
168+
/// let (pk, sk) = kem.keypair().unwrap();
169+
/// let (ct, ss) = kem.encapsulate(&pk).unwrap();
170+
/// let ss2 = kem.decapsulate(&sk, &ct).unwrap();
171+
/// assert_eq!(ss, ss2);
172+
/// ```
162173
pub struct Kem {
163174
kem: NonNull<ffi::OQS_KEM>,
164175
}
165176

166177
unsafe impl Sync for Kem {}
178+
unsafe impl Send for Kem {}
167179

168180
impl Drop for Kem {
169181
fn drop(&mut self) {
170182
unsafe { ffi::OQS_KEM_free(self.kem.as_ptr()) };
171183
}
172184
}
173185

186+
impl core::default::Default for Kem {
187+
/// Get the default KEM algorithm in liboqs
188+
///
189+
/// Panics if the default algorithm is not enabled in liboqs.
190+
fn default() -> Self {
191+
Kem::new(Algorithm::Default).expect("Expected default algorithm to be enabled")
192+
}
193+
}
194+
195+
impl core::convert::TryFrom<Algorithm> for Kem {
196+
type Error = crate::Error;
197+
fn try_from(alg: Algorithm) -> Result<Kem> {
198+
Kem::new(alg)
199+
}
200+
}
201+
174202
impl Kem {
175203
/// Construct a new algorithm
176204
pub fn new(algorithm: Algorithm) -> Result<Self> {
@@ -229,31 +257,47 @@ impl Kem {
229257
}
230258

231259
/// Obtain a secret key objects from bytes
232-
pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> SecretKeyRef<'a> {
233-
let kem = unsafe { self.kem.as_ref() };
234-
assert_eq!(buf.len(), kem.length_secret_key);
235-
SecretKeyRef::new(buf)
260+
///
261+
/// Returns None if the secret key is not the correct length.
262+
pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SecretKeyRef<'a>> {
263+
if self.length_secret_key() != buf.len() {
264+
None
265+
} else {
266+
Some(SecretKeyRef::new(buf))
267+
}
236268
}
237269

238270
/// Obtain a public key from bytes
239-
pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> PublicKeyRef<'a> {
240-
let kem = unsafe { self.kem.as_ref() };
241-
assert_eq!(buf.len(), kem.length_public_key);
242-
PublicKeyRef::new(buf)
271+
///
272+
/// Returns None if the public key is not the correct length.
273+
pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<PublicKeyRef<'a>> {
274+
if self.length_public_key() != buf.len() {
275+
None
276+
} else {
277+
Some(PublicKeyRef::new(buf))
278+
}
243279
}
244280

245281
/// Obtain a ciphertext from bytes
246-
pub fn ciphertext_from_bytes<'a>(&self, buf: &'a [u8]) -> CiphertextRef<'a> {
247-
let kem = unsafe { self.kem.as_ref() };
248-
assert_eq!(buf.len(), kem.length_ciphertext);
249-
CiphertextRef::new(buf)
282+
///
283+
/// Returns None if the ciphertext is not the correct length.
284+
pub fn ciphertext_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<CiphertextRef<'a>> {
285+
if self.length_ciphertext() != buf.len() {
286+
None
287+
} else {
288+
Some(CiphertextRef::new(buf))
289+
}
250290
}
251291

252292
/// Obtain a secret key from bytes
253-
pub fn shared_secret_from_bytes<'a>(&self, buf: &'a [u8]) -> SharedSecretRef<'a> {
254-
let kem = unsafe { self.kem.as_ref() };
255-
assert_eq!(buf.len(), kem.length_shared_secret);
256-
SharedSecretRef::new(buf)
293+
///
294+
/// Returns None if the shared secret is not the correct length.
295+
pub fn shared_secret_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SharedSecretRef<'a>> {
296+
if self.length_shared_secret() != buf.len() {
297+
None
298+
} else {
299+
Some(SharedSecretRef::new(buf))
300+
}
257301
}
258302

259303
/// Generate a new keypair
@@ -267,12 +311,13 @@ impl Kem {
267311
bytes: Vec::with_capacity(kem.length_secret_key),
268312
};
269313
let status = unsafe { func(pk.bytes.as_mut_ptr(), sk.bytes.as_mut_ptr()) };
314+
status_to_result(status)?;
270315
// update the lengths of the vecs
316+
// this is safe to do, as we have initialised them now.
271317
unsafe {
272318
pk.bytes.set_len(kem.length_public_key);
273319
sk.bytes.set_len(kem.length_secret_key);
274320
}
275-
status_to_result(status)?;
276321
Ok((pk, sk))
277322
}
278323

@@ -282,15 +327,18 @@ impl Kem {
282327
pk: P,
283328
) -> Result<(Ciphertext, SharedSecret)> {
284329
let pk = pk.into();
330+
if pk.bytes.len() != self.length_public_key() {
331+
return Err(Error::InvalidLength);
332+
}
285333
let kem = unsafe { self.kem.as_ref() };
286-
debug_assert_eq!(pk.len(), kem.length_public_key);
287334
let func = kem.encaps.unwrap();
288335
let mut ct = Ciphertext {
289336
bytes: Vec::with_capacity(kem.length_ciphertext),
290337
};
291338
let mut ss = SharedSecret {
292339
bytes: Vec::with_capacity(kem.length_shared_secret),
293340
};
341+
// call encapsulate
294342
let status = unsafe {
295343
func(
296344
ct.bytes.as_mut_ptr(),
@@ -299,6 +347,8 @@ impl Kem {
299347
)
300348
};
301349
status_to_result(status)?;
350+
// update the lengths of the vecs
351+
// this is safe to do, as we have initialised them now.
302352
unsafe {
303353
ct.bytes.set_len(kem.length_ciphertext);
304354
ss.bytes.set_len(kem.length_shared_secret);
@@ -315,14 +365,19 @@ impl Kem {
315365
let kem = unsafe { self.kem.as_ref() };
316366
let sk = sk.into();
317367
let ct = ct.into();
318-
debug_assert_eq!(sk.len(), kem.length_secret_key);
319-
debug_assert_eq!(ct.len(), kem.length_ciphertext);
368+
if sk.bytes.len() != self.length_secret_key() || ct.bytes.len() != self.length_ciphertext()
369+
{
370+
return Err(Error::InvalidLength);
371+
}
320372
let mut ss = SharedSecret {
321373
bytes: Vec::with_capacity(kem.length_shared_secret),
322374
};
323375
let func = kem.decaps.unwrap();
376+
// Call decapsulate
324377
let status = unsafe { func(ss.bytes.as_mut_ptr(), ct.bytes.as_ptr(), sk.bytes.as_ptr()) };
325378
status_to_result(status)?;
379+
// update the lengths of the vecs
380+
// this is safe to do, as we have initialised them now.
326381
unsafe { ss.bytes.set_len(kem.length_shared_secret) };
327382
Ok(ss)
328383
}

oqs/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ pub enum Error {
8585
Error,
8686
/// Error occurred in OpenSSL functions external to liboqs
8787
ErrorExternalOpenSSL,
88+
/// Invalid length of a public object
89+
InvalidLength,
8890
}
8991
#[cfg(not(feature = "no_std"))]
9092
impl std::error::Error for Error {}

oqs/src/sig.rs

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,47 @@ impl Algorithm {
163163
}
164164
}
165165

166-
/// Algorithm wrapper
166+
/// Signature scheme
167+
///
168+
/// # Example
169+
/// ```rust
170+
/// use oqs;
171+
/// oqs::init();
172+
/// let scheme = oqs::sig::Sig::default();
173+
/// let message = [0u8; 100];
174+
/// let (pk, sk) = scheme.keypair().unwrap();
175+
/// let signature = scheme.sign(&message, &sk).unwrap();
176+
/// assert!(scheme.verify(&message, &signature, &pk).is_ok());
177+
/// ```
167178
pub struct Sig {
168179
sig: NonNull<ffi::OQS_SIG>,
169180
}
170181

182+
unsafe impl Sync for Sig {}
183+
unsafe impl Send for Sig {}
184+
171185
impl Drop for Sig {
172186
fn drop(&mut self) {
173187
unsafe { ffi::OQS_SIG_free(self.sig.as_ptr()) };
174188
}
175189
}
176190

191+
impl core::convert::TryFrom<Algorithm> for Sig {
192+
type Error = crate::Error;
193+
fn try_from(alg: Algorithm) -> Result<Sig> {
194+
Sig::new(alg)
195+
}
196+
}
197+
198+
impl core::default::Default for Sig {
199+
/// Get the default Signature scheme
200+
///
201+
/// Panics if the default algorithm is not enabled in liboqs.
202+
fn default() -> Self {
203+
Sig::new(Algorithm::default()).expect("Expected default algorithm to be enabled")
204+
}
205+
}
206+
177207
impl Sig {
178208
/// Construct a new algorithm
179209
///
@@ -228,24 +258,30 @@ impl Sig {
228258
}
229259

230260
/// Construct a secret key object from bytes
231-
pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> SecretKeyRef<'a> {
232-
let sig = unsafe { self.sig.as_ref() };
233-
assert_eq!(buf.len(), sig.length_secret_key);
234-
SecretKeyRef::new(buf)
261+
pub fn secret_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SecretKeyRef<'a>> {
262+
if buf.len() != self.length_secret_key() {
263+
None
264+
} else {
265+
Some(SecretKeyRef::new(buf))
266+
}
235267
}
236268

237269
/// Construct a public key object from bytes
238-
pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> PublicKeyRef<'a> {
239-
let sig = unsafe { self.sig.as_ref() };
240-
assert_eq!(buf.len(), sig.length_public_key);
241-
PublicKeyRef::new(buf)
270+
pub fn public_key_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<PublicKeyRef<'a>> {
271+
if buf.len() != self.length_public_key() {
272+
None
273+
} else {
274+
Some(PublicKeyRef::new(buf))
275+
}
242276
}
243277

244278
/// Construct a signature object from bytes
245-
pub fn signature_from_bytes<'a>(&self, buf: &'a [u8]) -> SignatureRef<'a> {
246-
let sig = unsafe { self.sig.as_ref() };
247-
assert!(buf.len() <= sig.length_signature);
248-
SignatureRef::new(buf)
279+
pub fn signature_from_bytes<'a>(&self, buf: &'a [u8]) -> Option<SignatureRef<'a>> {
280+
if buf.len() != self.length_signature() {
281+
None
282+
} else {
283+
Some(SignatureRef::new(buf))
284+
}
249285
}
250286

251287
/// Generate a new keypair
@@ -291,6 +327,7 @@ impl Sig {
291327
)
292328
};
293329
status_to_result(status)?;
330+
// This is safe to do as it's initialised now.
294331
unsafe {
295332
sig.bytes.set_len(sig_len);
296333
}
@@ -306,11 +343,12 @@ impl Sig {
306343
) -> Result<()> {
307344
let signature = signature.into();
308345
let pk = pk.into();
346+
if signature.bytes.len() > self.length_signature()
347+
|| pk.bytes.len() != self.length_public_key()
348+
{
349+
return Err(Error::InvalidLength);
350+
}
309351
let sig = unsafe { self.sig.as_ref() };
310-
assert!(
311-
signature.len() <= sig.length_signature,
312-
"Signature is too long?"
313-
);
314352
let func = sig.verify.unwrap();
315353
let status = unsafe {
316354
func(

0 commit comments

Comments
 (0)