@@ -158,19 +158,47 @@ impl Algorithm {
158158 }
159159}
160160
161- /// Contains a KEM algorithm
161+ /// KEM algorithm
162+ ///
163+ /// # Example
164+ /// ```rust
165+ /// use oqs;
166+ /// oqs::init();
167+ /// let kem = oqs::kem::Kem::default();
168+ /// let (pk, sk) = kem.keypair().unwrap();
169+ /// let (ct, ss) = kem.encapsulate(&pk).unwrap();
170+ /// let ss2 = kem.decapsulate(&sk, &ct).unwrap();
171+ /// assert_eq!(ss, ss2);
172+ /// ```
162173pub struct Kem {
163174 kem : NonNull < ffi:: OQS_KEM > ,
164175}
165176
166177unsafe impl Sync for Kem { }
178+ unsafe impl Send for Kem { }
167179
168180impl Drop for Kem {
169181 fn drop ( & mut self ) {
170182 unsafe { ffi:: OQS_KEM_free ( self . kem . as_ptr ( ) ) } ;
171183 }
172184}
173185
186+ impl core:: default:: Default for Kem {
187+ /// Get the default KEM algorithm in liboqs
188+ ///
189+ /// Panics if the default algorithm is not enabled in liboqs.
190+ fn default ( ) -> Self {
191+ Kem :: new ( Algorithm :: Default ) . expect ( "Expected default algorithm to be enabled" )
192+ }
193+ }
194+
195+ impl core:: convert:: TryFrom < Algorithm > for Kem {
196+ type Error = crate :: Error ;
197+ fn try_from ( alg : Algorithm ) -> Result < Kem > {
198+ Kem :: new ( alg)
199+ }
200+ }
201+
174202impl Kem {
175203 /// Construct a new algorithm
176204 pub fn new ( algorithm : Algorithm ) -> Result < Self > {
@@ -229,31 +257,47 @@ impl Kem {
229257 }
230258
231259 /// Obtain a secret key objects from bytes
232- pub fn secret_key_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> SecretKeyRef < ' a > {
233- let kem = unsafe { self . kem . as_ref ( ) } ;
234- assert_eq ! ( buf. len( ) , kem. length_secret_key) ;
235- SecretKeyRef :: new ( buf)
260+ ///
261+ /// Returns None if the secret key is not the correct length.
262+ pub fn secret_key_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> Option < SecretKeyRef < ' a > > {
263+ if self . length_secret_key ( ) != buf. len ( ) {
264+ None
265+ } else {
266+ Some ( SecretKeyRef :: new ( buf) )
267+ }
236268 }
237269
238270 /// Obtain a public key from bytes
239- pub fn public_key_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> PublicKeyRef < ' a > {
240- let kem = unsafe { self . kem . as_ref ( ) } ;
241- assert_eq ! ( buf. len( ) , kem. length_public_key) ;
242- PublicKeyRef :: new ( buf)
271+ ///
272+ /// Returns None if the public key is not the correct length.
273+ pub fn public_key_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> Option < PublicKeyRef < ' a > > {
274+ if self . length_public_key ( ) != buf. len ( ) {
275+ None
276+ } else {
277+ Some ( PublicKeyRef :: new ( buf) )
278+ }
243279 }
244280
245281 /// Obtain a ciphertext from bytes
246- pub fn ciphertext_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> CiphertextRef < ' a > {
247- let kem = unsafe { self . kem . as_ref ( ) } ;
248- assert_eq ! ( buf. len( ) , kem. length_ciphertext) ;
249- CiphertextRef :: new ( buf)
282+ ///
283+ /// Returns None if the ciphertext is not the correct length.
284+ pub fn ciphertext_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> Option < CiphertextRef < ' a > > {
285+ if self . length_ciphertext ( ) != buf. len ( ) {
286+ None
287+ } else {
288+ Some ( CiphertextRef :: new ( buf) )
289+ }
250290 }
251291
252292 /// Obtain a secret key from bytes
253- pub fn shared_secret_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> SharedSecretRef < ' a > {
254- let kem = unsafe { self . kem . as_ref ( ) } ;
255- assert_eq ! ( buf. len( ) , kem. length_shared_secret) ;
256- SharedSecretRef :: new ( buf)
293+ ///
294+ /// Returns None if the shared secret is not the correct length.
295+ pub fn shared_secret_from_bytes < ' a > ( & self , buf : & ' a [ u8 ] ) -> Option < SharedSecretRef < ' a > > {
296+ if self . length_shared_secret ( ) != buf. len ( ) {
297+ None
298+ } else {
299+ Some ( SharedSecretRef :: new ( buf) )
300+ }
257301 }
258302
259303 /// Generate a new keypair
@@ -267,12 +311,13 @@ impl Kem {
267311 bytes : Vec :: with_capacity ( kem. length_secret_key ) ,
268312 } ;
269313 let status = unsafe { func ( pk. bytes . as_mut_ptr ( ) , sk. bytes . as_mut_ptr ( ) ) } ;
314+ status_to_result ( status) ?;
270315 // update the lengths of the vecs
316+ // this is safe to do, as we have initialised them now.
271317 unsafe {
272318 pk. bytes . set_len ( kem. length_public_key ) ;
273319 sk. bytes . set_len ( kem. length_secret_key ) ;
274320 }
275- status_to_result ( status) ?;
276321 Ok ( ( pk, sk) )
277322 }
278323
@@ -282,15 +327,18 @@ impl Kem {
282327 pk : P ,
283328 ) -> Result < ( Ciphertext , SharedSecret ) > {
284329 let pk = pk. into ( ) ;
330+ if pk. bytes . len ( ) != self . length_public_key ( ) {
331+ return Err ( Error :: InvalidLength ) ;
332+ }
285333 let kem = unsafe { self . kem . as_ref ( ) } ;
286- debug_assert_eq ! ( pk. len( ) , kem. length_public_key) ;
287334 let func = kem. encaps . unwrap ( ) ;
288335 let mut ct = Ciphertext {
289336 bytes : Vec :: with_capacity ( kem. length_ciphertext ) ,
290337 } ;
291338 let mut ss = SharedSecret {
292339 bytes : Vec :: with_capacity ( kem. length_shared_secret ) ,
293340 } ;
341+ // call encapsulate
294342 let status = unsafe {
295343 func (
296344 ct. bytes . as_mut_ptr ( ) ,
@@ -299,6 +347,8 @@ impl Kem {
299347 )
300348 } ;
301349 status_to_result ( status) ?;
350+ // update the lengths of the vecs
351+ // this is safe to do, as we have initialised them now.
302352 unsafe {
303353 ct. bytes . set_len ( kem. length_ciphertext ) ;
304354 ss. bytes . set_len ( kem. length_shared_secret ) ;
@@ -315,14 +365,19 @@ impl Kem {
315365 let kem = unsafe { self . kem . as_ref ( ) } ;
316366 let sk = sk. into ( ) ;
317367 let ct = ct. into ( ) ;
318- debug_assert_eq ! ( sk. len( ) , kem. length_secret_key) ;
319- debug_assert_eq ! ( ct. len( ) , kem. length_ciphertext) ;
368+ if sk. bytes . len ( ) != self . length_secret_key ( ) || ct. bytes . len ( ) != self . length_ciphertext ( )
369+ {
370+ return Err ( Error :: InvalidLength ) ;
371+ }
320372 let mut ss = SharedSecret {
321373 bytes : Vec :: with_capacity ( kem. length_shared_secret ) ,
322374 } ;
323375 let func = kem. decaps . unwrap ( ) ;
376+ // Call decapsulate
324377 let status = unsafe { func ( ss. bytes . as_mut_ptr ( ) , ct. bytes . as_ptr ( ) , sk. bytes . as_ptr ( ) ) } ;
325378 status_to_result ( status) ?;
379+ // update the lengths of the vecs
380+ // this is safe to do, as we have initialised them now.
326381 unsafe { ss. bytes . set_len ( kem. length_shared_secret ) } ;
327382 Ok ( ss)
328383 }
0 commit comments