From 5e4527d7cc5b015f1d1e1796c37419dde9506c4b Mon Sep 17 00:00:00 2001 From: Keith Smiley Date: Thu, 30 May 2024 13:04:53 -0700 Subject: [PATCH] Allow overriding bazel executable It's possible that users don't have a `bazel` in their path depending on your project configuration. This allows passing `--bazel whatever` to override to a custom wrapper, or anything else. --- refresh.template.py | 47 ++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/refresh.template.py b/refresh.template.py index 3b337014..25fede7b 100644 --- a/refresh.template.py +++ b/refresh.template.py @@ -18,6 +18,7 @@ # Similarly, when upgrading, please search for that MIN_PY= tag. +import argparse import concurrent.futures import enum import functools # MIN_PY=3.9: Replace `functools.lru_cache(maxsize=None)` with `functools.cache`. @@ -96,7 +97,7 @@ def _print_header_finding_warning_once(): @functools.lru_cache(maxsize=None) -def _get_bazel_version(): +def _get_bazel_version(bazel: str): """Gets the Bazel version as a tuple of (major, minor, patch). The rolling release and the release candidate are treated as the LTS release. @@ -104,7 +105,7 @@ def _get_bazel_version(): If the version can't be determined, returns (0, 0, 0). """ bazel_version_process = subprocess.run( - ['bazel', 'version'], + [bazel, 'version'], # MIN_PY=3.7: Replace PIPEs with capture_output. stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -127,10 +128,10 @@ def _get_bazel_version(): @functools.lru_cache(maxsize=None) -def _get_bazel_cached_action_keys(): +def _get_bazel_cached_action_keys(bazel: str): """Gets the set of actionKeys cached in bazel-out.""" action_cache_process = subprocess.run( - ['bazel', 'dump', '--action_cache'], + [bazel, 'dump', '--action_cache'], # MIN_PY=3.7: Replace PIPEs with capture_output. stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -232,7 +233,7 @@ def _is_nvcc(path: str): return os.path.basename(path).startswith('nvcc') -def _get_headers_gcc(compile_action, source_path: str, action_key: str): +def _get_headers_gcc(bazel: str, compile_action, source_path: str, action_key: str): """Gets the headers used by a particular compile command that uses gcc arguments formatting (including clang.) Relatively slow. Requires running the C preprocessor if we can't hit Bazel's cache. @@ -240,7 +241,7 @@ def _get_headers_gcc(compile_action, source_path: str, action_key: str): # Flags reference here: https://clang.llvm.org/docs/ClangCommandLineReference.html # Check to see if Bazel has an (approximately) fresh cache of the included headers, and if so, use them to avoid a slow preprocessing step. - if action_key in _get_bazel_cached_action_keys(): # Safe because Bazel only holds one cached action key per path, and the key contains the path. + if action_key in _get_bazel_cached_action_keys(bazel): # Safe because Bazel only holds one cached action key per path, and the key contains the path. for i, arg in enumerate(compile_action.arguments): if arg.startswith('-MF'): if len(arg) > 3: # Either appended, like -MF @@ -517,7 +518,7 @@ def _file_is_in_main_workspace_and_not_external(file_str: str): return True -def _get_headers(compile_action, source_path: str): +def _get_headers(bazel: str, compile_action, source_path: str): """Gets the headers used by a particular compile command. Relatively slow. Requires running the C preprocessor. @@ -588,7 +589,7 @@ def _get_headers(compile_action, source_path: str): if compile_action.arguments[0].endswith('cl.exe'): # cl.exe and also clang-cl.exe headers, should_cache = _get_headers_msvc(compile_action, source_path) else: - headers, should_cache = _get_headers_gcc(compile_action, source_path, compile_action.actionKey) + headers, should_cache = _get_headers_gcc(bazel, compile_action, source_path, compile_action.actionKey) # Cache for future use if output_file and should_cache: @@ -610,7 +611,7 @@ def _get_headers(compile_action, source_path: str): _get_headers.has_logged = False -def _get_files(compile_action): +def _get_files(bazel: str, compile_action): """Gets the ({source files}, {header files}) clangd should be told the command applies to.""" # Getting the source file is a little trickier than it might seem. @@ -670,7 +671,7 @@ def _get_files(compile_action): if os.path.splitext(source_file)[1] in _get_files.assembly_source_extensions: return {source_file}, set() - header_files = _get_headers(compile_action, source_file) + header_files = _get_headers(bazel, compile_action, source_file) # Ambiguous .h headers need a language specified if they aren't C, or clangd sometimes makes mistakes # Delete this and unused extension variables when clangd >= 16 is released, since their underlying issues are resolved at HEAD @@ -1097,7 +1098,7 @@ def _nvcc_patch(compile_args: typing.List[str]) -> typing.List[str]: } -def _get_cpp_command_for_files(compile_action): +def _get_cpp_command_for_files(bazel: str, compile_action): """Reformat compile_action into a compile command clangd can understand. Undo Bazel-isms and figures out which files clangd should apply the command to. @@ -1113,7 +1114,7 @@ def _get_cpp_command_for_files(compile_action): # Android and Linux and grailbio LLVM toolchains: Fine as is; no special patching needed. compile_action.arguments = _all_platform_patch(compile_action.arguments) - source_files, header_files = _get_files(compile_action) + source_files, header_files = _get_files(bazel, compile_action) # Done after getting files since we may execute NVCC to get the files. compile_action.arguments = _nvcc_patch(compile_action.arguments) @@ -1121,7 +1122,7 @@ def _get_cpp_command_for_files(compile_action): return source_files, header_files, compile_action.arguments -def _convert_compile_commands(aquery_output): +def _convert_compile_commands(bazel: str, aquery_output): """Converts from Bazel's aquery format to de-Bazeled compile_commands.json entries. Input: jsonproto output from aquery, pre-filtered to (Objective-)C(++) and CUDA compile actions for a given build. @@ -1145,7 +1146,7 @@ def _convert_compile_commands(aquery_output): with concurrent.futures.ThreadPoolExecutor( max_workers=min(32, (os.cpu_count() or 1) + 4) # Backport. Default in MIN_PY=3.8. See "using very large resources implicitly on many-core machines" in https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor ) as threadpool: - outputs = threadpool.map(_get_cpp_command_for_files, aquery_output.actions) + outputs = threadpool.map(functools.partial(_get_cpp_command_for_files, bazel), aquery_output.actions) # Yield as compile_commands.json entries header_files_already_written = set() @@ -1171,12 +1172,12 @@ def _convert_compile_commands(aquery_output): } -def _get_commands(target: str, flags: str): +def _get_commands(bazel: str, target: str, flags: str, passthrough_flags: typing.List[str]): """Yields compile_commands.json entries for a given target and flags, gracefully tolerating errors.""" # Log clear completion messages log_info(f">>> Analyzing commands used in {target}") - additional_flags = shlex.split(flags) + sys.argv[1:] + additional_flags = shlex.split(flags) + passthrough_flags # Detect anything that looks like a build target in the flags, and issue a warning. # Note that positional arguments after -- are all interpreted as target patterns. (If it's at the end, then no worries.) @@ -1200,7 +1201,7 @@ def _get_commands(target: str, flags: str): # For efficiency, have bazel filter out external targets (and therefore actions) before they even get turned into actions or serialized and sent to us. Note: this is a different mechanism than is used for excluding just external headers. target_statment = f"filter('^(//|@//)',{target_statment})" aquery_args = [ - 'bazel', + bazel, 'aquery', # Aquery docs if you need em: https://docs.bazel.build/versions/master/aquery.html # Aquery output proto reference: https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/analysis_v2.proto @@ -1227,7 +1228,7 @@ def _get_commands(target: str, flags: str): '--features=-layering_check', ] - if _get_bazel_version() >= (6, 1, 0): + if _get_bazel_version(bazel) >= (6, 1, 0): aquery_args += ['--host_features=-compiler_param_file', '--host_features=-layering_check'] aquery_args += additional_flags @@ -1269,7 +1270,7 @@ def _get_commands(target: str, flags: str): Continuing gracefully...""") return - yield from _convert_compile_commands(parsed_aquery_output) + yield from _convert_compile_commands(bazel, parsed_aquery_output) # Log clear completion messages @@ -1393,12 +1394,18 @@ def _ensure_cwd_is_workspace_root(): # Although this can fail (OSError/FileNotFoundError/PermissionError/NotADirectoryError), there's no easy way to recover, so we'll happily crash. os.chdir(workspace_root) +def _build_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--bazel", default="bazel") + return parser def main(): _ensure_cwd_is_workspace_root() _ensure_gitignore_entries_exist() _ensure_external_workspaces_link_exists() + args, passthrough_flags = _build_parser().parse_known_args() + target_flag_pairs = [ # Begin: template filled by Bazel {target_flag_pairs} @@ -1407,7 +1414,7 @@ def main(): compile_command_entries = [] for (target, flags) in target_flag_pairs: - compile_command_entries.extend(_get_commands(target, flags)) + compile_command_entries.extend(_get_commands(args.bazel, target, flags, passthrough_flags)) if not compile_command_entries: log_error(""">>> Not (over)writing compile_commands.json, since no commands were extracted and an empty file is of no use.