-
Notifications
You must be signed in to change notification settings - Fork 102
kernel-abi-check: add support for checking that a kernel does not use non-stable Torch ABI symbols #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
kernel-abi-check: add support for checking that a kernel does not use non-stable Torch ABI symbols #591
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,8 +6,8 @@ use eyre::{Context, Result}; | |
| use object::{File, Object}; | ||
|
|
||
| use kernel_abi_check::{ | ||
| check_macos, check_manylinux, check_python_abi, MacOSViolation, ManylinuxViolation, | ||
| PythonAbiViolation, Version, | ||
| check_macos, check_manylinux, check_python_abi, check_torch_stable_abi, MacOSViolation, | ||
| ManylinuxViolation, PythonAbiViolation, TorchStableAbiViolation, Version, | ||
| }; | ||
|
|
||
| /// CLI tool to check library versions | ||
|
|
@@ -28,6 +28,10 @@ struct Cli { | |
| /// Python ABI version. | ||
| #[arg(short, long, value_name = "VERSION", default_value = "3.9")] | ||
| python_abi: Version, | ||
|
|
||
| /// Torch stable ABI version. | ||
| #[arg(long, value_name = "VERSION")] | ||
| torch_stable_abi: Option<Version>, | ||
| } | ||
|
|
||
| fn main() -> Result<()> { | ||
|
|
@@ -84,9 +88,41 @@ fn main() -> Result<()> { | |
| eprintln!("✅ No compatibility issues found"); | ||
| } | ||
|
|
||
| if let Some(torch_stable_abi) = args.torch_stable_abi { | ||
| eprintln!("🔥 Checking for compatibility with Torch stable ABI version {torch_stable_abi}"); | ||
| let torch_stable_abi_violations = | ||
| check_torch_stable_abi(&torch_stable_abi, file.format(), file.symbols())?; | ||
| print_torch_stable_abi_violations(&torch_stable_abi_violations, &torch_stable_abi); | ||
|
|
||
| if !torch_stable_abi_violations.is_empty() { | ||
| return Err(eyre::eyre!("Torch stable ABI compatibility issues found")); | ||
| } else { | ||
| eprintln!("✅ No Torch stable ABI compatibility issues found"); | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn print_torch_stable_abi_violations( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ouf, this is superb! |
||
| violations: &BTreeSet<TorchStableAbiViolation>, | ||
| torch_abi: &Version, | ||
| ) { | ||
| if !violations.is_empty() { | ||
| eprintln!("\n⛔ Non-stable Torch ABI symbols found (incompatible with Torch stable ABI {torch_abi}):\n"); | ||
| for violation in violations { | ||
| match violation { | ||
| TorchStableAbiViolation::IncompatibleStableAbiSymbol { name, added } => { | ||
| eprintln!("{name}: {added}"); | ||
| } | ||
| TorchStableAbiViolation::NonStableAbiSymbol { name } => { | ||
| eprintln!("{name}"); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn print_manylinux_violations( | ||
| violations: &BTreeSet<ManylinuxViolation>, | ||
| manylinux_version: &str, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| use std::collections::{BTreeSet, HashMap}; | ||
| use std::str::FromStr; | ||
|
|
||
| use cpp_demangle::Symbol as CppSymbol; | ||
| use eyre::Result; | ||
| use object::{BinaryFormat, ObjectSymbol, Symbol}; | ||
| use once_cell::sync::Lazy; | ||
|
|
||
| use crate::version::Version; | ||
|
|
||
| // https://raw.githubusercontent.com/pytorch/pytorch/refs/heads/main/torch/csrc/stable/c/shim_function_versions.txt | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit): Could use the permalink.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't make it a permalink, because we typically want to use the URL to get the latest version and update our copy. |
||
| static SHIM_FUNCTION_VERSIONS_RAW: &str = include_str!("shim_function_versions.txt"); | ||
|
|
||
| /// Maps shim function names to the minimum Torch version that introduced them. | ||
| /// Functions absent from this map were available before 2.10.0. | ||
| pub static TORCH_SHIM_VERSIONS: Lazy<HashMap<String, Version>> = Lazy::new(|| { | ||
| let mut map = HashMap::new(); | ||
| for line in SHIM_FUNCTION_VERSIONS_RAW.lines() { | ||
| // Skip blank lines and comments. | ||
| let line = line.trim(); | ||
| if line.is_empty() || line.starts_with('#') { | ||
| continue; | ||
| } | ||
| if let Some((name, version_token)) = line.split_once(':') { | ||
| let name = name.trim().to_owned(); | ||
| // TORCH_VERSION_2_10_0 -> "2.10.0" | ||
| let version_str = version_token | ||
| .trim() | ||
| .strip_prefix("TORCH_VERSION_") | ||
| .expect("unexpected version token format") | ||
| .replace('_', "."); | ||
| let version = Version::from_str(&version_str) | ||
| .expect("invalid version in shim_function_versions.txt"); | ||
| map.insert(name, version); | ||
| } | ||
| } | ||
| map | ||
| }); | ||
|
|
||
| /// Torch stable ABI violation. | ||
| #[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)] | ||
| pub enum TorchStableAbiViolation { | ||
| /// Symbol is newer than the specified Torch Stable ABI version. | ||
| IncompatibleStableAbiSymbol { name: String, added: Version }, | ||
|
|
||
| /// Symbol is not part of ABI3. | ||
| NonStableAbiSymbol { name: String }, | ||
| } | ||
|
|
||
| /// Check for violations of the Python ABI policy. | ||
| pub fn check_torch_stable_abi<'a>( | ||
| torch_stable_abi: &Version, | ||
| binary_format: BinaryFormat, | ||
| symbols: impl IntoIterator<Item = Symbol<'a, 'a>>, | ||
| ) -> Result<BTreeSet<TorchStableAbiViolation>> { | ||
| let mut violations = BTreeSet::new(); | ||
|
|
||
| for symbol in symbols { | ||
| if !symbol.is_undefined() { | ||
| continue; | ||
| } | ||
|
|
||
| let mut symbol_name = symbol.name()?; | ||
| if matches!(binary_format, BinaryFormat::MachO) { | ||
| // Mach-O C symbol mangling adds an underscore. | ||
| symbol_name = symbol_name.strip_prefix("_").unwrap_or(symbol_name); | ||
| } | ||
|
|
||
| // If this is a C shim symbol, check if it is valid for this version. | ||
| if let Some(symbol_version) = TORCH_SHIM_VERSIONS.get(symbol_name) { | ||
| if symbol_version > torch_stable_abi { | ||
| violations.insert(TorchStableAbiViolation::IncompatibleStableAbiSymbol { | ||
| name: symbol_name.to_owned(), | ||
| added: symbol_version.clone(), | ||
| }); | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // Try to demangle the symbol as a C++ symbol. If that fails, it's probably an | ||
| // unrelated C symbol. | ||
| let cpp_symbol = match CppSymbol::new(symbol_name) { | ||
| Ok(cpp_symbol) => cpp_symbol, | ||
| Err(_) => { | ||
| continue; | ||
| } | ||
| }; | ||
| let demangled = cpp_symbol.demangle()?; | ||
|
|
||
| // Check if Torch symbols are from the stable ABI. | ||
| if demangled.starts_with("torch::stable::") { | ||
| // This branch fulfills to purposes: (1) avoid that stable ABI | ||
| // C++ symbols get reported by the filter below. (2) Once a | ||
| // versioned list of symbols is available, check versions. | ||
| } else if demangled.starts_with("c10::") | ||
| || demangled.starts_with("at::") | ||
| || demangled.starts_with("torch::") | ||
| { | ||
| violations.insert(TorchStableAbiViolation::NonStableAbiSymbol { name: demangled }); | ||
| } | ||
| } | ||
|
|
||
| Ok(violations) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| # Auto-generated file listing shim functions and their minimum required versions | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we ensure this is also the latest stable?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question for the manylinux and Python ABI3 files. We should probably have a CI job to check for this and maybe make a PR automatically one one of the files is outdated.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can look into it after the PR is merged: #593 |
||
| # Format: function_name: TORCH_VERSION_MAJOR_MINOR_PATCH | ||
| # | ||
| # This file is automatically updated by the stable_shim_usage_linter. | ||
| # If a function is not in this file, it was available before 2.10.0. | ||
| # DO NOT EDIT MANUALLY. | ||
|
|
||
| ParallelFunc: TORCH_VERSION_2_10_0 | ||
| StableListHandle: TORCH_VERSION_2_10_0 | ||
| StableListOpaque: TORCH_VERSION_2_10_0 | ||
| StringHandle: TORCH_VERSION_2_10_0 | ||
| StringOpaque: TORCH_VERSION_2_10_0 | ||
| aoti_torch_aten_full: TORCH_VERSION_2_10_0 | ||
| aoti_torch_aten_subtract_Tensor: TORCH_VERSION_2_10_0 | ||
| aoti_torch_cuda__scaled_grouped_mm: TORCH_VERSION_2_10_0 | ||
| torch_c10_cuda_check_msg: TORCH_VERSION_2_10_0 | ||
| torch_c10_cuda_free_error_msg: TORCH_VERSION_2_10_0 | ||
| torch_call_dispatcher: TORCH_VERSION_2_10_0 | ||
| torch_cuda_stream_synchronize: TORCH_VERSION_2_10_0 | ||
| torch_delete_list: TORCH_VERSION_2_10_0 | ||
| torch_delete_string: TORCH_VERSION_2_10_0 | ||
| torch_get_const_data_ptr: TORCH_VERSION_2_10_0 | ||
| torch_get_cuda_stream_from_pool: TORCH_VERSION_2_10_0 | ||
| torch_get_current_cuda_blas_handle: TORCH_VERSION_2_10_0 | ||
| torch_get_mutable_data_ptr: TORCH_VERSION_2_10_0 | ||
| torch_get_num_threads: TORCH_VERSION_2_10_0 | ||
| torch_get_thread_idx: TORCH_VERSION_2_10_0 | ||
| torch_library_impl: TORCH_VERSION_2_10_0 | ||
| torch_list_get_item: TORCH_VERSION_2_10_0 | ||
| torch_list_push_back: TORCH_VERSION_2_10_0 | ||
| torch_list_set_item: TORCH_VERSION_2_10_0 | ||
| torch_list_size: TORCH_VERSION_2_10_0 | ||
| torch_new_list_reserve_size: TORCH_VERSION_2_10_0 | ||
| torch_new_string_handle: TORCH_VERSION_2_10_0 | ||
| torch_parallel_for: TORCH_VERSION_2_10_0 | ||
| torch_parse_device_string: TORCH_VERSION_2_10_0 | ||
| torch_set_current_cuda_stream: TORCH_VERSION_2_10_0 | ||
| torch_set_requires_grad: TORCH_VERSION_2_10_0 | ||
| torch_string_c_str: TORCH_VERSION_2_10_0 | ||
| torch_string_length: TORCH_VERSION_2_10_0 | ||
| aoti_torch_cpu_nonzero_static: TORCH_VERSION_2_11_0 | ||
| aoti_torch_cuda__flash_attention_forward_quantized: TORCH_VERSION_2_11_0 | ||
| aoti_torch_cuda__scaled_dot_product_flash_attention_quantized: TORCH_VERSION_2_11_0 | ||
| aoti_torch_cuda_mm_dtype_out: TORCH_VERSION_2_11_0 | ||
| aoti_torch_cuda_nonzero_static: TORCH_VERSION_2_11_0 | ||
| aoti_torch_mps_nonzero_static: TORCH_VERSION_2_11_0 | ||
| aoti_torch_xpu_mm_dtype_out: TORCH_VERSION_2_11_0 | ||
| torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_11_0 | ||
| torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_11_0 | ||
| torch_from_blob: TORCH_VERSION_2_11_0 | ||
| aoti_torch_cpu__grouped_mm: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_rand_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_rand_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_randint_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_randint_like_low_dtype: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_randn_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_randn_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda__flash_attention_forward_no_dropout_inplace: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda__flash_attention_forward_no_dropout_inplace_v2: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda__flash_attention_forward_v2: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda__grouped_mm: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda_rand_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda_rand_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda_randint_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda_randint_like_low_dtype: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda_randn_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cuda_randn_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_12_0 | ||
| aoti_torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps__grouped_mm: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps_rand_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps_rand_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps_randint_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps_randint_like_low_dtype: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps_randn_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_mps_randn_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu__grouped_mm: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu_rand_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu_rand_like_generator: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu_randint_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu_randint_like_low_dtype: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu_randn_like: TORCH_VERSION_2_12_0 | ||
| aoti_torch_xpu_randn_like_generator: TORCH_VERSION_2_12_0 | ||
| torch_library_def_with_tags: TORCH_VERSION_2_12_0 | ||
| torch_tag_core: TORCH_VERSION_2_12_0 | ||
| torch_tag_cudagraph_unsafe: TORCH_VERSION_2_12_0 | ||
| torch_tag_data_dependent_output: TORCH_VERSION_2_12_0 | ||
| torch_tag_dynamic_output_shape: TORCH_VERSION_2_12_0 | ||
| torch_tag_flexible_layout: TORCH_VERSION_2_12_0 | ||
| torch_tag_generated: TORCH_VERSION_2_12_0 | ||
| torch_tag_inplace_view: TORCH_VERSION_2_12_0 | ||
| torch_tag_maybe_aliasing_or_mutating: TORCH_VERSION_2_12_0 | ||
| torch_tag_needs_contiguous_strides: TORCH_VERSION_2_12_0 | ||
| torch_tag_needs_exact_strides: TORCH_VERSION_2_12_0 | ||
| torch_tag_needs_fixed_stride_order: TORCH_VERSION_2_12_0 | ||
| torch_tag_nondeterministic_bitwise: TORCH_VERSION_2_12_0 | ||
| torch_tag_nondeterministic_seeded: TORCH_VERSION_2_12_0 | ||
| torch_tag_out_variant: TORCH_VERSION_2_12_0 | ||
| torch_tag_pointwise: TORCH_VERSION_2_12_0 | ||
| torch_tag_pt2_compliant_tag: TORCH_VERSION_2_12_0 | ||
| torch_tag_reduction: TORCH_VERSION_2_12_0 | ||
| torch_tag_view_copy: TORCH_VERSION_2_12_0 | ||
| aoti_torch_cpu_grid_sampler_3d: TORCH_VERSION_2_13_0 | ||
| aoti_torch_cpu_grid_sampler_3d_backward: TORCH_VERSION_2_13_0 | ||
| aoti_torch_cuda_grid_sampler_3d: TORCH_VERSION_2_13_0 | ||
| aoti_torch_cuda_grid_sampler_3d_backward: TORCH_VERSION_2_13_0 | ||
| aoti_torch_mps__scaled_dot_product_attention_math_for_mps_v2: TORCH_VERSION_2_13_0 | ||
| aoti_torch_mps_grid_sampler_3d: TORCH_VERSION_2_13_0 | ||
| aoti_torch_mps_grid_sampler_3d_backward: TORCH_VERSION_2_13_0 | ||
| torch_delete_stable_ivalue: TORCH_VERSION_2_13_0 | ||
| torch_exception_get_what: TORCH_VERSION_2_13_0 | ||
| torch_exception_get_what_without_backtrace: TORCH_VERSION_2_13_0 | ||
| torch_library_set_python_module: TORCH_VERSION_2_13_0 | ||
| torch_new_stable_ivalue: TORCH_VERSION_2_13_0 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already printed with an stop sign below.