Skip to content

Commit 0f0e6ab

Browse files
committed
feat: allow toolchain without host compatible variant
1 parent b31a9a4 commit 0f0e6ab

File tree

2 files changed

+200
-28
lines changed

2 files changed

+200
-28
lines changed

python/private/python.bzl

+98-16
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ load(":full_version.bzl", "full_version")
2121
load(":python_register_toolchains.bzl", "python_register_toolchains")
2222
load(":pythons_hub.bzl", "hub_repo")
2323
load(":repo_utils.bzl", "repo_utils")
24-
load(":toolchains_repo.bzl", "host_toolchain", "multi_toolchain_aliases", "sorted_host_platforms")
24+
load(
25+
":toolchains_repo.bzl",
26+
"host_toolchain",
27+
"multi_toolchain_aliases",
28+
"sorted_host_platform_names",
29+
"sorted_host_platforms",
30+
)
2531
load(":util.bzl", "IS_BAZEL_6_4_OR_HIGHER")
2632
load(":version.bzl", "version")
2733

@@ -267,6 +273,8 @@ def parse_modules(*, module_ctx, _fail = fail):
267273
def _python_impl(module_ctx):
268274
py = parse_modules(module_ctx = module_ctx)
269275

276+
all_host_compatible_impls = []
277+
270278
# list of structs; see inline struct call within the loop below.
271279
toolchain_impls = []
272280

@@ -319,22 +327,96 @@ def _python_impl(module_ctx):
319327
))
320328
if _is_compatible_with_host(module_ctx, platform_info):
321329
host_platforms[platform_name] = platform_info
330+
host_compat_entry = struct(
331+
full_python_version = full_python_version,
332+
platform = platform_info,
333+
platform_name = platform_name,
334+
impl_repo_name = repo_name,
335+
)
336+
all_host_compatible_impls.setdefault(full_python_version, []).append(
337+
host_compat_entry,
338+
)
339+
all_host_compatible_impls.setdefault(
340+
toolchain_info.python_version,
341+
[],
342+
).append(host_compat_entry)
343+
344+
host_repo_name = toolchain_info.name + "_host"
345+
if not host_platforms:
346+
needed_host_repos[toolchain_info.python_version] = struct(
347+
compatible_version = toolchain_info.python_version,
348+
full_python_version = full_python_version,
349+
)
350+
else:
351+
host_platforms = sorted_host_platforms(host_platforms)
352+
host_toolchain(
353+
name = host_repo_name,
354+
# NOTE: Order matters. The first found to be compatible is (usually) used.
355+
platforms = host_platforms.keys(),
356+
os_names = {
357+
str(i): platform_info.os_name
358+
for i, platform_info in enumerate(host_platforms.values())
359+
},
360+
archs = {
361+
str(i): platform_info.arch
362+
for i, platform_info in enumerate(host_platforms.values())
363+
},
364+
python_version = full_python_version,
365+
)
322366

323-
host_platforms = sorted_host_platforms(host_platforms)
324-
host_toolchain(
325-
name = toolchain_info.name + "_host",
326-
# NOTE: Order matters. The first found to be compatible is (usually) used.
327-
platforms = host_platforms.keys(),
328-
os_names = {
329-
str(i): platform_info.os_name
330-
for i, platform_info in enumerate(host_platforms.values())
331-
},
332-
archs = {
333-
str(i): platform_info.arch
334-
for i, platform_info in enumerate(host_platforms.values())
335-
},
336-
python_version = full_python_version,
337-
)
367+
"""
368+
We want to define e.g. python_3_13_host backed by e.g. python_3_13 (full
369+
version 3.13.3)
370+
but, we didn't see a host-compatible option when we initial did so.
371+
So search from 3.13.3 down to 3.13.0, looking for compatible options.
372+
"""
373+
374+
def vt(s):
375+
return tuple([int(x) for x in s.split(".")])
376+
377+
if needed_host_repos:
378+
for key, entries in all_host_compatible_impls.items():
379+
all_host_compatible_impls[key] = sorted(
380+
entries,
381+
reverse = True,
382+
key = lambda e: vt(e.full_python_version),
383+
)
384+
385+
for host_repo_name, info in needed_host_repos.items():
386+
choices = []
387+
for entry in all_host_compatible_impls[info.compatible_version]:
388+
# todo: numeric version comparison
389+
if vt(entry.full_python_version) <= vt(info.full_python_version):
390+
choices.append(entry)
391+
if choices:
392+
platforms_keys = [
393+
# We have to prepend the offset because the same platform
394+
# name might occur accross different versions
395+
"{}_{}".format(i, entry.platform_name)
396+
for i, entry in enumerate(choices)
397+
]
398+
platform_keys = sorted_host_platforms_names(platform_keys)
399+
400+
# AH no, this won't quite work.
401+
# Multiple versions will have the same platform_name string.
402+
# Thus when it builds the platform_map, some will get clobbered
403+
# Maybe just throw an offset prefix onto it for uniqueness?
404+
host_toolchain(
405+
name = host_repo_name,
406+
platforms = platform_keys,
407+
impl_repo_suffixes = {
408+
# Internally, it prepends its own name arg, so we must remove
409+
# the common prefix
410+
str(i): entry.impl_repo_name.removeprefix(host_repo_name)
411+
for i, entry in enumerate(choices)
412+
},
413+
os_names = {str(i): entry.os_name for entry in enumerate(choices)},
414+
archs = {str(i): entry.arch for entry in enumerate(choices)},
415+
python_versions = {str(i): entry.full_python_version for entry in enumerate(choices)},
416+
)
417+
else:
418+
# todo: figure out what to do. Define nothing, if we can.
419+
fail("No host-compatible found")
338420

339421
# List of the base names ("python_3_10") for the toolchain repos
340422
base_toolchain_repo_names = []

python/private/toolchains_repo.bzl

+102-12
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def _host_toolchain_impl(rctx):
313313
rctx.file("BUILD.bazel", _HOST_TOOLCHAIN_BUILD_CONTENT)
314314

315315
os_name = repo_utils.get_platforms_os_name(rctx)
316-
host_platform = _get_host_platform(
316+
host_platform = _get_host_impl_repo_suffix(
317317
rctx = rctx,
318318
logger = repo_utils.logger(rctx),
319319
python_version = rctx.attr.python_version,
@@ -386,7 +386,27 @@ Creates a repository with a shorter name meant to be used in the repository_ctx,
386386
which needs to have `symlinks` for the interpreter. This is separate from the
387387
toolchain_aliases repo because referencing the `python` interpreter target from
388388
this repo causes an eager fetch of the toolchain for the host platform.
389-
""",
389+
390+
This repo has three different ways in which is it called:
391+
392+
1. Workspace. The `platforms` attribute is set, which are keys into the
393+
PLATFORMS global. It assumes `name` + <matching platform name> is a
394+
valid repo name which it can use as the backing repo.
395+
396+
2. Bzlmod, created along side when python_register_toolchains is called
397+
and expected to use one of repos created as part of that
398+
python_register_toolchains call.
399+
Because the bzlmod extension decides the platform mapping, it is given
400+
the `platform`, `os_names`, and `archs` attributes to figure out which
401+
to use.
402+
403+
3. Bzlmod, created when the initial python_register_toolchains didn't
404+
have a host-compatible runtime, so a different host-compatible
405+
implementation was used.
406+
This is like the normal bzlmod creation, except the python_versions
407+
and suffixes may vary between choices, so the `impl_repo_suffixes` and
408+
`python_version` attributes are specified.
409+
""",
390410
attrs = {
391411
"archs": attr.string_dict(
392412
doc = """
@@ -398,8 +418,43 @@ If set, overrides the platform metadata. Keyed by index in `platforms`
398418
If set, overrides the platform metadata. Keyed by index in `platforms`
399419
""",
400420
),
401-
"platforms": attr.string_list(mandatory = True),
402-
"python_version": attr.string(mandatory = True),
421+
"platforms": attr.string_list(
422+
mandatory = True,
423+
doc = """
424+
Platform names and backing repo-suffix.
425+
426+
NOTE: The order of this list matters. The first platform that is compatible
427+
with the host will be selected; this can be customized by using the
428+
`RULES_PYTHON_REPO_TOOLCHAIN_*` env vars.
429+
430+
When os_names aren't set, they act the key into the PLATFORMS
431+
dict to determine if a platform is compatible with the host. When
432+
os_names is set, then it is a (mostly) arbitrary platform name string
433+
(and platform metadata comes from the os_names/archs args).
434+
435+
The string is used as a suffix to create the name of the repo that
436+
should be pointed to. i.e. `name` + <selected platform string> should
437+
result in a valid repo (e.g. created by python_register_toolchains()).
438+
Under bzlmod, this also means the same extension must create the
439+
repo named `name+suffix` and the host_toolchain repo.
440+
""",
441+
),
442+
"python_version": attr.string(
443+
mandatory = True,
444+
doc = "Full python version, Major.Minor.Micro",
445+
),
446+
"python_versions": attr.string_dict(
447+
doc = """
448+
If set, the Python version for the corresponding selected platform.
449+
Keyed by index in `platforms`. Values Major.Minor.Patch
450+
""",
451+
),
452+
"impl_repo_suffixes": attr.string_dict(
453+
doc = """
454+
If set, the suffix to append to `name` to identify the backing repo that is used.
455+
Keyed by index in `platforms`.
456+
""",
457+
),
403458
"_rule_name": attr.string(default = "host_toolchain"),
404459
"_rules_python_workspace": attr.label(default = Label("//:WORKSPACE")),
405460
},
@@ -438,8 +493,8 @@ multi_toolchain_aliases = repository_rule(
438493
def sanitize_platform_name(platform):
439494
return platform.replace("-", "_")
440495

441-
def sorted_host_platforms(platform_map):
442-
"""Sort the keys in the platform map to give correct precedence.
496+
def sorted_host_platform_names(platform_names):
497+
"""Sort platform names to give correct precedence.
443498
444499
The order of keys in the platform mapping matters for the host toolchain
445500
selection. When multiple runtimes are compatible with the host, we take the
@@ -455,6 +510,29 @@ def sorted_host_platforms(platform_map):
455510
in a dict that autoformatters like to clobber and whose only documentation
456511
is an innocous looking formatter disable directive.
457512
513+
Args:
514+
platform_names: a list of platform names
515+
516+
Returns:
517+
list[str] the same values, but in the desired order.
518+
"""
519+
520+
def platform_keyer(name):
521+
# Ascending sort: lower is higher precedence
522+
pref = 0
523+
if name.endswith("-" + FREETHREADED):
524+
pref = 1
525+
elif name.endswith("-" + MUSL):
526+
pref = 2
527+
return (pref, name)
528+
529+
return sorted(platform_map.keys(), key = platform_keyer)
530+
531+
def sorted_host_platforms(platform_map):
532+
"""Sort the keys in the platform map to give correct precedence.
533+
534+
See sorted_host_platform_names for explanation.
535+
458536
Args:
459537
platform_map: a mapping of platforms and their metadata.
460538
@@ -472,13 +550,12 @@ def sorted_host_platforms(platform_map):
472550
pref = 2
473551
return (pref, name)
474552

475-
sorted_platform_keys = sorted(platform_map.keys(), key = platform_keyer)
476553
return {
477554
key: platform_map[key]
478-
for key in sorted_platform_keys
555+
for key in sorted_host_platform_names(platform_map.keys())
479556
}
480557

481-
def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platforms):
558+
def _get_host_impl_repo_suffix(*, rctx, logger, python_version, os_name, cpu_name, platforms):
482559
"""Gets the host platform.
483560
484561
Args:
@@ -498,6 +575,8 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
498575
platform_map[platform_name] = struct(
499576
os_name = rctx.attr.os_names[key],
500577
arch = rctx.attr.archs[key],
578+
python_version = rctx.attr.python_versions.get(key),
579+
impl_repo_suffix = rctx.attr.impl_repo_suffixes.get(key),
501580
)
502581
else:
503582
platform_map = sorted_host_platforms(PLATFORMS)
@@ -507,11 +586,13 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
507586
meta = platform_map[platform]
508587

509588
if meta.os_name == os_name and meta.arch == cpu_name:
510-
candidates.append(platform)
589+
candidates.append((platform, meta))
511590

512591
if len(candidates) == 1:
513-
return candidates[0]
592+
platform_name, meta = candidates[0]
593+
return getattr(meta, "impl_repo_suffix", platform_name)
514594

595+
# todo: have this handle multiple python versions
515596
if candidates:
516597
env_var = "RULES_PYTHON_REPO_TOOLCHAIN_{}_{}_{}".format(
517598
python_version.replace(".", "_"),
@@ -525,12 +606,21 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
525606
candidates,
526607
))
527608
elif preference not in candidates:
609+
# todo: need to map names like 3_13_0_linux_x86_64 back to
610+
# the input values. Ah, er, wait
611+
# Is this working?
612+
# The return value is appended to this repo's name.
613+
# This repo's name is e.g. python_3_13.
614+
# the net result would be
615+
# python_3_10_3_13_0_linux_x86_64
616+
# which isn't a valid name
528617
return logger.fail("Please choose a preferred interpreter out of the following platforms: {}".format(candidates))
529618
else:
530619
candidates = [preference]
531620

532621
if candidates:
533-
return candidates[0]
622+
platform_name, meta = candidates[0]
623+
return getattr(meta, "impl_repo_suffix", platform_name)
534624

535625
return logger.fail("Could not find a compatible 'host' python for '{os_name}', '{cpu_name}' from the loaded platforms: {platforms}".format(
536626
os_name = os_name,

0 commit comments

Comments
 (0)