Skip to content

refactor: explicitly define host platform ordering #2890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions python/private/python.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ load(":full_version.bzl", "full_version")
load(":python_register_toolchains.bzl", "python_register_toolchains")
load(":pythons_hub.bzl", "hub_repo")
load(":repo_utils.bzl", "repo_utils")
load(":toolchains_repo.bzl", "multi_toolchain_aliases")
load(":toolchains_repo.bzl", "host_toolchain", "multi_toolchain_aliases", "sorted_host_platforms")
load(":util.bzl", "IS_BAZEL_6_4_OR_HIGHER")
load(":version.bzl", "version")

Expand Down Expand Up @@ -267,11 +267,18 @@ def parse_modules(*, module_ctx, _fail = fail):
def _python_impl(module_ctx):
py = parse_modules(module_ctx = module_ctx)

# dict[str version, list[str] platforms]; where version is full
# python version string ("3.4.5"), and platforms are keys from
# the PLATFORMS global.
loaded_platforms = {}
for toolchain_info in py.toolchains:
# list of structs; see inline struct call within the loop below.
toolchain_impls = []

# list[str] of the base names of toolchain repos
base_toolchain_repo_names = []

# Create the underlying python_repository repos that contain the
# python runtimes and their toolchain implementation definitions.
for i, toolchain_info in enumerate(py.toolchains):
is_last = (i + 1) == len(py.toolchains)
base_toolchain_repo_names.append(toolchain_info.name)

# Ensure that we pass the full version here.
full_python_version = full_version(
version = toolchain_info.python_version,
Expand All @@ -286,12 +293,48 @@ def _python_impl(module_ctx):
kwargs.update(py.config.kwargs.get(toolchain_info.python_version, {}))
kwargs.update(py.config.kwargs.get(full_python_version, {}))
kwargs.update(py.config.default)
toolchain_registered_platforms = python_register_toolchains(
register_result = python_register_toolchains(
name = toolchain_info.name,
_internal_bzlmod_toolchain_call = True,
**kwargs
)
loaded_platforms[full_python_version] = toolchain_registered_platforms

host_platforms = {}
for repo_name, (platform_name, platform_info) in register_result.impl_repos.items():
toolchain_impls.append(struct(
# str: The base name to use for the toolchain() target
name = repo_name,
# str: The repo name the toolchain() target points to.
impl_repo_name = repo_name,
# str: platform key in the passed-in platforms dict
platform_name = platform_name,
# struct: platform_info() struct
platform = platform_info,
# str: Major.Minor.Micro python version
full_python_version = full_python_version,
# bool: whether to implicitly add the python version constraint
# to the toolchain's target_settings.
# The last toolchain is the default; it can't have version constraints
set_python_version_constraint = is_last,
))
if _is_compatible_with_host(module_ctx, platform_info):
host_platforms[platform_name] = platform_info

host_platforms = sorted_host_platforms(host_platforms)
host_toolchain(
name = toolchain_info.name + "_host",
# NOTE: Order matters. The first found to be compatible is (usually) used.
platforms = host_platforms.keys(),
os_names = {
str(i): platform_info.os_name
for i, platform_info in enumerate(host_platforms.values())
},
archs = {
str(i): platform_info.arch
for i, platform_info in enumerate(host_platforms.values())
},
python_version = full_python_version,
)

# List of the base names ("python_3_10") for the toolchain repos
base_toolchain_repo_names = []
Expand Down Expand Up @@ -329,31 +372,23 @@ def _python_impl(module_ctx):

# Split the toolchain info into separate objects so they can be passed onto
# the repository rule.
for i, t in enumerate(py.toolchains):
is_last = (i + 1) == len(py.toolchains)
base_name = t.name
base_toolchain_repo_names.append(base_name)
fv = full_version(version = t.python_version, minor_mapping = py.config.minor_mapping)
platforms = loaded_platforms[fv]
for platform_name, platform_info in platforms.items():
key = str(len(toolchain_names))

full_name = "{}_{}".format(base_name, platform_name)
toolchain_names.append(full_name)
toolchain_repo_names[key] = full_name
toolchain_tcw_map[key] = platform_info.compatible_with

# The target_settings attribute may not be present for users
# patching python/versions.bzl.
toolchain_ts_map[key] = getattr(platform_info, "target_settings", [])
toolchain_platform_keys[key] = platform_name
toolchain_python_versions[key] = fv

# The last toolchain is the default; it can't have version constraints
# Despite the implication of the arg name, the values are strs, not bools
toolchain_set_python_version_constraints[key] = (
"True" if not is_last else "False"
)
for entry in toolchain_impls:
key = str(len(toolchain_names))

toolchain_names.append(entry.name)
toolchain_repo_names[key] = entry.impl_repo_name
toolchain_tcw_map[key] = entry.platform.compatible_with

# The target_settings attribute may not be present for users
# patching python/versions.bzl.
toolchain_ts_map[key] = getattr(entry.platform, "target_settings", [])
toolchain_platform_keys[key] = entry.platform_name
toolchain_python_versions[key] = entry.full_python_version

# Repo rules can't accept dict[str, bool], so encode them as a string value.
toolchain_set_python_version_constraints[key] = (
"True" if entry.set_python_version_constraint else "False"
)

hub_repo(
name = "pythons_hub",
Expand Down Expand Up @@ -391,6 +426,11 @@ def _python_impl(module_ctx):
else:
return None

def _is_compatible_with_host(mctx, platform_info):
os_name = repo_utils.get_platforms_os_name(mctx)
cpu_name = repo_utils.get_platforms_cpu_name(mctx)
return platform_info.os_name == os_name and platform_info.arch == cpu_name

def _one_or_the_same(first, second, *, onerror = None):
if not first:
return second
Expand Down
40 changes: 23 additions & 17 deletions python/private/python_register_toolchains.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,17 @@ def python_register_toolchains(
))
register_coverage_tool = False

loaded_platforms = {}
for platform in platforms.keys():
# list[str] of the platform names that were used
loaded_platforms = []

# dict[str repo name, tuple[str, platform_info]]
impl_repos = {}
for platform, platform_info in platforms.items():
sha256 = tool_versions[python_version]["sha256"].get(platform, None)
if not sha256:
continue

loaded_platforms[platform] = platforms[platform]
loaded_platforms.append(platform)
(release_filename, urls, strip_prefix, patches, patch_strip) = get_release_info(platform, python_version, base_url, tool_versions)

# allow passing in a tool version
Expand All @@ -137,11 +141,10 @@ def python_register_toolchains(
)],
)

impl_repo_name = "{}_{}".format(name, platform)
impl_repos[impl_repo_name] = (platform, platform_info)
python_repository(
name = "{name}_{platform}".format(
name = name,
platform = platform,
),
name = impl_repo_name,
sha256 = sha256,
patches = patches,
patch_strip = patch_strip,
Expand All @@ -167,28 +170,31 @@ def python_register_toolchains(
platform = platform,
))

host_toolchain(
name = name + "_host",
platforms = loaded_platforms.keys(),
python_version = python_version,
)

toolchain_aliases(
name = name,
python_version = python_version,
user_repository_name = name,
platforms = loaded_platforms.keys(),
platforms = loaded_platforms,
)

# in bzlmod we write out our own toolchain repos
# in bzlmod we write out our own toolchain repos and host repos
if bzlmod_toolchain_call:
return loaded_platforms
return struct(
# dict[str name, tuple[str platform_name, platform_info]]
impl_repos = impl_repos,
)

host_toolchain(
name = name + "_host",
platforms = loaded_platforms,
python_version = python_version,
)

toolchains_repo(
name = toolchain_repo_name,
python_version = python_version,
set_python_version_constraint = set_python_version_constraint,
user_repository_name = name,
platforms = loaded_platforms.keys(),
platforms = loaded_platforms,
)
return None
67 changes: 66 additions & 1 deletion python/private/toolchains_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ platform-specific repositories.

load(
"//python:versions.bzl",
"FREETHREADED",
"MUSL",
"PLATFORMS",
"WINDOWS_NAME",
)
Expand Down Expand Up @@ -375,6 +377,8 @@ def _host_toolchain_impl(rctx):
if not rctx.delete(python_tester):
fail("Failed to delete the python tester")

# NOTE: The term "toolchain" is a misnomer for this rule. This doesn't define
# a repo with toolchains or toolchain implementations.
host_toolchain = repository_rule(
_host_toolchain_impl,
doc = """\
Expand All @@ -384,6 +388,16 @@ toolchain_aliases repo because referencing the `python` interpreter target from
this repo causes an eager fetch of the toolchain for the host platform.
""",
attrs = {
"archs": attr.string_dict(
doc = """
If set, overrides the platform metadata. Keyed by index in `platforms`
""",
),
"os_names": attr.string_dict(
doc = """
If set, overrides the platform metadata. Keyed by index in `platforms`
""",
),
"platforms": attr.string_list(mandatory = True),
"python_version": attr.string(mandatory = True),
"_rule_name": attr.string(default = "host_toolchain"),
Expand Down Expand Up @@ -421,6 +435,46 @@ multi_toolchain_aliases = repository_rule(
},
)

def sorted_host_platforms(platform_map):
"""Sort the keys in the platform map to give correct precedence.

The order of keys in the platform mapping matters for the host toolchain
selection. When multiple runtimes are compatible with the host, we take the
first that is compatible (usually; there's also the
`RULES_PYTHON_REPO_TOOLCHAIN_*` environment variables). The historical
behavior carefully constructed the ordering of platform keys such that
the ordering was:
* Regular platforms
* The "-freethreaded" suffix
* The "-musl" suffix

Here, we formalize that so it isn't subtly encoded in the ordering of keys
in a dict that autoformatters like to clobber and whose only documentation
is an innocous looking formatter disable directive.

Args:
platform_map: a mapping of platforms and their metadata.

Returns:
dict; the same values, but with the keys inserted in the desired
order so that iteration happens in the desired order.
"""

def platform_keyer(name):
# Ascending sort: lower is higher precedence
pref = 0
if name.endswith("-" + FREETHREADED):
pref = 1
elif name.endswith("-" + MUSL):
pref = 2
return (pref, name)

sorted_platform_keys = sorted(platform_map.keys(), key = platform_keyer)
return {
key: platform_map[key]
for key in sorted_platform_keys
}

def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platforms):
"""Gets the host platform.

Expand All @@ -434,9 +488,20 @@ def _get_host_platform(*, rctx, logger, python_version, os_name, cpu_name, platf
Returns:
The host platform.
"""
if rctx.attr.os_names:
platform_map = {}
for i, platform_name in enumerate(platforms):
key = str(i)
platform_map[platform_name] = struct(
os_name = rctx.attr.os_names[key],
arch = rctx.attr.archs[key],
)
else:
platform_map = sorted_host_platforms(PLATFORMS)

candidates = []
for platform in platforms:
meta = PLATFORMS[platform]
meta = platform_map[platform]

if meta.os_name == os_name and meta.arch == cpu_name:
candidates.append(platform)
Expand Down
Loading