Skip to content

Commit 6511c75

Browse files
committed
Warning only when (major, minor) versions differ
Signed-off-by: Vlad Gheorghiu <[email protected]>
1 parent 3be2109 commit 6511c75

File tree

2 files changed

+67
-25
lines changed

2 files changed

+67
-25
lines changed

CHANGES.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Pre-release
22

33
- Added type checking and automatic linting/formatting, https://github.com/open-quantum-safe/liboqs-python/pull/97
4+
- Added a utility function for de-structuring version strings in `oqs.py`
5+
- `version(version_str: str) -> tuple[str, str, str]:` - Returns a tuple
6+
containing the
7+
(major, minor, patch) versions
8+
- A warning is issued only if the liboqs-python version's major and minor
9+
numbers differ from those of liboqs, ignoring the patch version
410

511
# Version 0.12.0 - January 15, 2025
612

oqs/oqs.py

+61-25
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
logger.setLevel(logging.INFO)
3535
logger.addHandler(logging.StreamHandler(stdout))
3636

37+
# Expected return value from native OQS functions
38+
OQS_SUCCESS: Final[int] = 0
39+
OQS_ERROR: Final[int] = -1
40+
3741

3842
def oqs_python_version() -> Union[str, None]:
3943
"""liboqs-python version string."""
@@ -50,12 +54,14 @@ def oqs_python_version() -> Union[str, None]:
5054
OQS_VERSION = oqs_python_version()
5155

5256

53-
def _countdown(seconds: int) -> None:
54-
while seconds > 0:
55-
logger.info("Installing in %s seconds...", seconds)
56-
stdout.flush()
57-
seconds -= 1
58-
time.sleep(1)
57+
def version(version_str: str) -> tuple[str, str, str]:
58+
parts = version_str.split(".")
59+
60+
major = parts[0] if len(parts) > 0 else ""
61+
minor = parts[1] if len(parts) > 1 else ""
62+
patch = parts[2] if len(parts) > 2 else ""
63+
64+
return major, minor, patch
5965

6066

6167
def _load_shared_obj(
@@ -100,6 +106,14 @@ def _load_shared_obj(
100106
raise RuntimeError(msg)
101107

102108

109+
def _countdown(seconds: int) -> None:
110+
while seconds > 0:
111+
logger.info("Installing in %s seconds...", seconds)
112+
stdout.flush()
113+
seconds -= 1
114+
time.sleep(1)
115+
116+
103117
def _install_liboqs(
104118
target_directory: Path,
105119
oqs_version_to_install: Union[str, None] = None,
@@ -188,7 +202,9 @@ def _load_liboqs() -> ct.CDLL:
188202
assert liboqs # noqa: S101
189203
except RuntimeError:
190204
# We don't have liboqs, so we try to install it automatically
191-
_install_liboqs(target_directory=oqs_install_dir, oqs_version_to_install=OQS_VERSION)
205+
_install_liboqs(
206+
target_directory=oqs_install_dir, oqs_version_to_install=OQS_VERSION
207+
)
192208
# Try loading it again
193209
try:
194210
liboqs = _load_shared_obj(
@@ -206,11 +222,6 @@ def _load_liboqs() -> ct.CDLL:
206222
_liboqs = _load_liboqs()
207223

208224

209-
# Expected return value from native OQS functions
210-
OQS_SUCCESS: Final[int] = 0
211-
OQS_ERROR: Final[int] = -1
212-
213-
214225
def native() -> ct.CDLL:
215226
"""Handle to native liboqs handler."""
216227
return _liboqs
@@ -226,13 +237,24 @@ def oqs_version() -> str:
226237
return ct.c_char_p(native().OQS_version()).value.decode("UTF-8") # type: ignore[union-attr]
227238

228239

229-
# Warn the user if the liboqs version differs from liboqs-python version
230-
if oqs_version() != oqs_python_version():
231-
warnings.warn(
232-
f"liboqs version {oqs_version()} differs from liboqs-python version "
233-
f"{oqs_python_version()}",
234-
stacklevel=2,
240+
oqs_ver = oqs_version()
241+
oqs_ver_major, oqs_ver_minor, oqs_ver_patch = version(oqs_ver)
242+
243+
244+
oqs_python_ver = oqs_python_version()
245+
if oqs_python_ver:
246+
oqs_python_ver_major, oqs_python_ver_minor, oqs_python_ver_patch = version(
247+
oqs_python_ver
235248
)
249+
# Warn the user if the liboqs version differs from liboqs-python version
250+
if not (
251+
oqs_ver_major == oqs_python_ver_major and oqs_ver_minor == oqs_python_ver_minor
252+
):
253+
warnings.warn(
254+
f"liboqs version (major, minor) {oqs_version()} differs from liboqs-python version "
255+
f"{oqs_python_version()}",
256+
stacklevel=2,
257+
)
236258

237259

238260
class MechanismNotSupportedError(Exception):
@@ -281,7 +303,9 @@ class KeyEncapsulation(ct.Structure):
281303
("decaps_cb", ct.c_void_p),
282304
]
283305

284-
def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> None:
306+
def __init__(
307+
self, alg_name: str, secret_key: Union[int, bytes, None] = None
308+
) -> None:
285309
"""
286310
Create new KeyEncapsulation with the given algorithm.
287311
@@ -435,9 +459,15 @@ def is_kem_enabled(alg_name: str) -> bool:
435459
return native().OQS_KEM_alg_is_enabled(ct.create_string_buffer(alg_name.encode()))
436460

437461

438-
_KEM_alg_ids = [native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count())]
439-
_supported_KEMs: tuple[str, ...] = tuple([i.decode() for i in _KEM_alg_ids]) # noqa: N816
440-
_enabled_KEMs: tuple[str, ...] = tuple([i for i in _supported_KEMs if is_kem_enabled(i)]) # noqa: N816
462+
_KEM_alg_ids = [
463+
native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count())
464+
]
465+
_supported_KEMs: tuple[str, ...] = tuple(
466+
[i.decode() for i in _KEM_alg_ids]
467+
) # noqa: N816
468+
_enabled_KEMs: tuple[str, ...] = tuple(
469+
[i for i in _supported_KEMs if is_kem_enabled(i)]
470+
) # noqa: N816
441471

442472

443473
def get_enabled_kem_mechanisms() -> tuple[str, ...]:
@@ -478,7 +508,9 @@ class Signature(ct.Structure):
478508
("verify_cb", ct.c_void_p),
479509
]
480510

481-
def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> None:
511+
def __init__(
512+
self, alg_name: str, secret_key: Union[int, bytes, None] = None
513+
) -> None:
482514
"""
483515
Create new Signature with the given algorithm.
484516
@@ -723,9 +755,13 @@ def is_sig_enabled(alg_name: str) -> bool:
723755
return native().OQS_SIG_alg_is_enabled(ct.create_string_buffer(alg_name.encode()))
724756

725757

726-
_sig_alg_ids = [native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count())]
758+
_sig_alg_ids = [
759+
native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count())
760+
]
727761
_supported_sigs: tuple[str, ...] = tuple([i.decode() for i in _sig_alg_ids])
728-
_enabled_sigs: tuple[str, ...] = tuple([i for i in _supported_sigs if is_sig_enabled(i)])
762+
_enabled_sigs: tuple[str, ...] = tuple(
763+
[i for i in _supported_sigs if is_sig_enabled(i)]
764+
)
729765

730766

731767
def get_enabled_sig_mechanisms() -> tuple[str, ...]:

0 commit comments

Comments
 (0)