|
25 | 25 | from kernels.lockfile import KernelLock, VariantLock |
26 | 26 | from kernels.status import resolve_status |
27 | 27 | from kernels.variants import ( |
| 28 | + Decision, |
28 | 29 | Variant, |
29 | 30 | get_variants, |
30 | 31 | get_variants_local, |
31 | 32 | resolve_variant, |
| 33 | + resolve_variants, |
32 | 34 | variants_trace_str, |
33 | 35 | ) |
34 | 36 |
|
@@ -552,6 +554,42 @@ def has_kernel( |
552 | 554 | ) |
553 | 555 |
|
554 | 556 |
|
| 557 | +def get_kernel_variants( |
| 558 | + repo_id: str, |
| 559 | + revision: str | None = None, |
| 560 | + version: int | None = None, |
| 561 | + backend: str | None = None, |
| 562 | +) -> list[Decision]: |
| 563 | + """ |
| 564 | + Resolve all build variants of a kernel against the current environment. |
| 565 | +
|
| 566 | + The decisions are sorted with compatible variants first, the most preferred |
| 567 | + variant leading. |
| 568 | +
|
| 569 | + Args: |
| 570 | + repo_id (`str`): |
| 571 | + The Hub repository containing the kernel. |
| 572 | + revision (`str`, *optional*): |
| 573 | + The specific revision (branch, tag, or commit) to inspect. Cannot be used together with `version`. |
| 574 | + version (`int`, *optional*): |
| 575 | + The kernel version to inspect. Cannot be used together with `revision`. |
| 576 | + Either `version` or `revision` must be specified. |
| 577 | + backend (`str`, *optional*): |
| 578 | + The backend to resolve variants for. Can only be `cpu` or the backend that Torch is compiled for. |
| 579 | + The backend will be detected automatically if not provided. |
| 580 | +
|
| 581 | + Returns: |
| 582 | + `list[Decision]`: One `VariantAccepted` or `VariantRejected` per build variant |
| 583 | + in the repository, compatible variants first. |
| 584 | + """ |
| 585 | + revision = select_revision_or_version(repo_id, revision=revision, version=version) |
| 586 | + |
| 587 | + api = _get_hf_api() |
| 588 | + variants = get_variants(api, repo_id=repo_id, revision=revision) |
| 589 | + _, trace = resolve_variants(variants, backend) |
| 590 | + return trace |
| 591 | + |
| 592 | + |
555 | 593 | def load_kernel( |
556 | 594 | repo_id: str, *, lockfile: Path | None, backend: str | None = None, revision: str | None = None |
557 | 595 | ) -> ModuleType: |
|
0 commit comments