Skip to content

Commit f613821

Browse files
committed
fixes #12837 - make AES-GCM tag length configurable
Valid AES-GCM tag lengths as per NIST 800-38D, default is 16 (which is also backwards compatible.). Signed-off-by: David Lamparter <[email protected]>
1 parent 297a506 commit f613821

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

src/rust/src/backend/aead.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,12 @@ struct AesGcm {
600600
#[pyo3::pymethods]
601601
impl AesGcm {
602602
#[new]
603-
fn new(py: pyo3::Python<'_>, key: pyo3::Py<pyo3::PyAny>) -> CryptographyResult<AesGcm> {
603+
#[pyo3(signature = (key, tag_length=None))]
604+
fn new(
605+
py: pyo3::Python<'_>,
606+
key: pyo3::Py<pyo3::PyAny>,
607+
tag_length: Option<usize>,
608+
) -> CryptographyResult<AesGcm> {
604609
let key_buf = key.extract::<CffiBuf<'_>>(py)?;
605610
let cipher = match key_buf.as_bytes().len() {
606611
16 => openssl::cipher::Cipher::aes_128_gcm(),
@@ -614,6 +619,12 @@ impl AesGcm {
614619
))
615620
}
616621
};
622+
let tag_length = tag_length.unwrap_or(16);
623+
if ![4, 8, 12, 13, 14, 15, 16].contains(&tag_length) {
624+
return Err(CryptographyError::from(
625+
pyo3::exceptions::PyValueError::new_err("Invalid tag_length"),
626+
));
627+
}
617628

618629
cfg_if::cfg_if! {
619630
if #[cfg(any(
@@ -624,11 +635,11 @@ impl AesGcm {
624635
not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER),
625636
))] {
626637
Ok(AesGcm {
627-
ctx: EvpCipherAead::new(cipher, key_buf.as_bytes(), 16, false)?,
638+
ctx: EvpCipherAead::new(cipher, key_buf.as_bytes(), tag_length, false)?,
628639
})
629640
} else {
630641
Ok(AesGcm {
631-
ctx: LazyEvpCipherAead::new(cipher, key, 16, false, false),
642+
ctx: LazyEvpCipherAead::new(cipher, key, tag_length, false, false),
632643
})
633644

634645
}

tests/hazmat/primitives/test_aead.py

+19
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,13 @@ def test_bad_generate_key(self, backend):
487487
with pytest.raises(ValueError):
488488
AESGCM.generate_key(129)
489489

490+
def test_bad_tag_length(self, backend):
491+
with pytest.raises(TypeError):
492+
AESGCM(b"X" * 32, object()) # type:ignore[arg-type]
493+
494+
with pytest.raises(ValueError):
495+
AESGCM(b"X" * 32, 17)
496+
490497
def test_associated_data_none_equal_to_empty_bytestring(self, backend):
491498
key = AESGCM.generate_key(128)
492499
aesgcm = AESGCM(key)
@@ -498,6 +505,18 @@ def test_associated_data_none_equal_to_empty_bytestring(self, backend):
498505
pt2 = aesgcm.decrypt(nonce, ct2, b"")
499506
assert pt1 == pt2
500507

508+
@pytest.mark.parametrize("length", [4, 15])
509+
def test_short_tags(self, length, backend):
510+
key = AESGCM.generate_key(128)
511+
aesgcm_ref = AESGCM(key)
512+
aesgcm_cut = AESGCM(key, length)
513+
nonce = os.urandom(12)
514+
ct_ref = aesgcm_ref.encrypt(nonce, b"some_data", b"some_aad")
515+
ct_cut = aesgcm_cut.encrypt(nonce, b"some_data", b"some_aad")
516+
assert ct_cut == ct_ref[: -16 + length]
517+
pt_cut = aesgcm_cut.decrypt(nonce, ct_cut, b"some_aad")
518+
assert pt_cut == b"some_data"
519+
501520
def test_buffer_protocol(self, backend):
502521
key = AESGCM.generate_key(128)
503522
aesgcm = AESGCM(key)

0 commit comments

Comments
 (0)