Skip to content

Commit b409adb

Browse files
committed
feat: implement get-kernel-variants
1 parent 6504ae5 commit b409adb

4 files changed

Lines changed: 67 additions & 0 deletions

File tree

docs/source/api/kernels.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
[[autodoc]] kernels.has_kernel
1616

17+
### get_kernel_variants
18+
19+
[[autodoc]] kernels.get_kernel_variants
20+
1721
### get_loaded_kernels
1822

1923
[[autodoc]] kernels.get_loaded_kernels

docs/source/basic-usage.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,23 @@ is_available = has_kernel("kernels-community/activation", version=1)
4343
print(f"Kernel available: {is_available}")
4444
```
4545

46+
When no compatible kernel is found, [`~kernels.has_kernel`] does not say *why*.
47+
[`~kernels.get_kernel_variants`] returns the full resolution trace instead: one
48+
decision per build variant in the repository, with compatible variants listed
49+
first. Each decision is a `VariantAccepted` or a `VariantRejected`, and rejected
50+
variants carry a human-readable `reason`:
51+
52+
```python
53+
from kernels import get_kernel_variants, VariantAccepted
54+
55+
for decision in get_kernel_variants("kernels-community/activation", version=1):
56+
name = decision.variant.variant_str
57+
if isinstance(decision, VariantAccepted):
58+
print(f"{name}: compatible")
59+
else:
60+
print(f"{name}: rejected ({decision.reason})")
61+
```
62+
4663
## Inspecting Loaded Kernels
4764

4865
[`~kernels.get_loaded_kernels`] returns a snapshot of every kernel that has been loaded

kernels/src/kernels/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@
2828
LoadedKernel,
2929
RepoInfo,
3030
get_kernel,
31+
get_kernel_variants,
3132
get_loaded_kernels,
3233
get_local_kernel,
3334
get_locked_kernel,
3435
has_kernel,
3536
install_kernel,
3637
load_kernel,
3738
)
39+
from kernels.variants import (
40+
VariantAccepted,
41+
VariantRejected,
42+
)
3843

3944
_add_additional_dll_paths()
4045

@@ -54,7 +59,10 @@
5459
"Metadata",
5560
"Mode",
5661
"RepoInfo",
62+
"VariantAccepted",
63+
"VariantRejected",
5764
"get_kernel",
65+
"get_kernel_variants",
5866
"get_loaded_kernels",
5967
"get_local_kernel",
6068
"get_locked_kernel",

kernels/src/kernels/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
from kernels.lockfile import KernelLock, VariantLock
2626
from kernels.status import resolve_status
2727
from kernels.variants import (
28+
Decision,
2829
Variant,
2930
get_variants,
3031
get_variants_local,
3132
resolve_variant,
33+
resolve_variants,
3234
variants_trace_str,
3335
)
3436

@@ -552,6 +554,42 @@ def has_kernel(
552554
)
553555

554556

557+
def get_kernel_variants(
558+
repo_id: str,
559+
revision: str | None = None,
560+
version: int | None = None,
561+
backend: str | None = None,
562+
) -> list[Decision]:
563+
"""
564+
Resolve all build variants of a kernel against the current environment.
565+
566+
The decisions are sorted with compatible variants first, the most preferred
567+
variant leading.
568+
569+
Args:
570+
repo_id (`str`):
571+
The Hub repository containing the kernel.
572+
revision (`str`, *optional*):
573+
The specific revision (branch, tag, or commit) to inspect. Cannot be used together with `version`.
574+
version (`int`, *optional*):
575+
The kernel version to inspect. Cannot be used together with `revision`.
576+
Either `version` or `revision` must be specified.
577+
backend (`str`, *optional*):
578+
The backend to resolve variants for. Can only be `cpu` or the backend that Torch is compiled for.
579+
The backend will be detected automatically if not provided.
580+
581+
Returns:
582+
`list[Decision]`: One `VariantAccepted` or `VariantRejected` per build variant
583+
in the repository, compatible variants first.
584+
"""
585+
revision = select_revision_or_version(repo_id, revision=revision, version=version)
586+
587+
api = _get_hf_api()
588+
variants = get_variants(api, repo_id=repo_id, revision=revision)
589+
_, trace = resolve_variants(variants, backend)
590+
return trace
591+
592+
555593
def load_kernel(
556594
repo_id: str, *, lockfile: Path | None, backend: str | None = None, revision: str | None = None
557595
) -> ModuleType:

0 commit comments

Comments
 (0)