Skip to content

Commit e36105c

Browse files
committed
refactor: make bzlmod pass platform mapping to host repo creation
1 parent 47b115e commit e36105c

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

python/private/python.bzl

+10-3
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,9 @@ def _python_impl(module_ctx):
298298
_internal_bzlmod_toolchain_call = True,
299299
**kwargs
300300
)
301-
host_compatible = []
301+
host_platforms = []
302+
host_os_names = {}
303+
host_archs = {}
302304
for repo_name, (platform_name, platform_info) in register_result.impl_repos.items():
303305
toolchain_impls.append(struct(
304306
# str: The base name to use for the toolchain() target
@@ -317,12 +319,17 @@ def _python_impl(module_ctx):
317319
set_python_version_constraint = is_last,
318320
))
319321
if _is_compatible_with_host(module_ctx, platform_info):
320-
host_compatible.append(platform_name)
322+
host_key = str(len(host_platforms))
323+
host_platforms.append(platform_name)
324+
host_os_names[host_key] = platform_info.os_name
325+
host_archs[host_key] = platform_info.arch
321326

322327
host_toolchain(
323328
name = toolchain_info.name + "_host",
324329
# NOTE: Order matters. The first found to be compatible is (usually) used.
325-
platforms = host_compatible,
330+
platforms = host_platforms,
331+
os_names = host_os_names,
332+
archs = host_archs,
326333
python_version = full_python_version,
327334
)
328335

python/private/toolchains_repo.bzl

+22-1
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,16 @@ this repo causes an eager fetch of the toolchain for the host platform.
388388
attrs = {
389389
"platforms": attr.string_list(mandatory = True),
390390
"python_version": attr.string(mandatory = True),
391+
"os_names": attr.string_dict(
392+
doc = """
393+
If set, overrides the platform metadata. Keyed by index in `platforms`
394+
""",
395+
),
396+
"archs": attr.string_dict(
397+
doc = """
398+
If set, overrides the platform metadata. Keyed by index in `platforms`
399+
""",
400+
),
391401
"_rule_name": attr.string(default = "host_toolchain"),
392402
"_rules_python_workspace": attr.label(default = Label("//:WORKSPACE")),
393403
},
@@ -436,9 +446,20 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
436446
Returns:
437447
The host platform.
438448
"""
449+
if rctx.attr.os_names:
450+
platform_map = {}
451+
for i, platform_name in enumerate(platforms):
452+
key = str(i)
453+
platform_map[platform_name] = struct(
454+
os_name = rctx.attr.os_names[key],
455+
arch = rctx.attr.archs[key],
456+
)
457+
else:
458+
platform_map = PLATFORMS
459+
439460
candidates = []
440461
for platform in platforms:
441-
meta = PLATFORMS[platform]
462+
meta = platform_map[platform]
442463

443464
if meta.os_name == os_name and meta.arch == cpu_name:
444465
candidates.append(platform)

0 commit comments

Comments
 (0)