Skip to content

Commit 6504ae5

Browse files
authored
kernel-abi-check: add support for checking that a kernel does not use non-stable Torch ABI symbols (#591)
* kernel-abi-check: support checking Torch stable ABI * nix-builder: hook up stable Torch stable ABI check
1 parent 9cfa642 commit 6504ae5

11 files changed

Lines changed: 289 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kernel-abi-check/kernel-abi-check/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ repository = "https://github.com/huggingface/kernel-builder"
1212
[dependencies]
1313
clap = { version = "4", features = ["derive"] }
1414
color-eyre = "0.6"
15+
cpp_demangle = "0.5"
1516
eyre = "0.6"
1617
itertools = "0.14.0"
1718
object = "0.36.7"

kernel-abi-check/kernel-abi-check/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ pub use manylinux::{check_manylinux, ManylinuxViolation};
99
mod python_abi;
1010
pub use python_abi::{check_python_abi, PythonAbiViolation};
1111

12+
mod torch_stable_abi;
13+
pub use torch_stable_abi::{check_torch_stable_abi, TorchStableAbiViolation};
14+
1215
mod version;
1316
pub use version::Version;

kernel-abi-check/kernel-abi-check/src/main.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use eyre::{Context, Result};
66
use object::{File, Object};
77

88
use kernel_abi_check::{
9-
check_macos, check_manylinux, check_python_abi, MacOSViolation, ManylinuxViolation,
10-
PythonAbiViolation, Version,
9+
check_macos, check_manylinux, check_python_abi, check_torch_stable_abi, MacOSViolation,
10+
ManylinuxViolation, PythonAbiViolation, TorchStableAbiViolation, Version,
1111
};
1212

1313
/// CLI tool to check library versions
@@ -28,6 +28,10 @@ struct Cli {
2828
/// Python ABI version.
2929
#[arg(short, long, value_name = "VERSION", default_value = "3.9")]
3030
python_abi: Version,
31+
32+
/// Torch stable ABI version.
33+
#[arg(long, value_name = "VERSION")]
34+
torch_stable_abi: Option<Version>,
3135
}
3236

3337
fn main() -> Result<()> {
@@ -84,9 +88,41 @@ fn main() -> Result<()> {
8488
eprintln!("✅ No compatibility issues found");
8589
}
8690

91+
if let Some(torch_stable_abi) = args.torch_stable_abi {
92+
eprintln!("🔥 Checking for compatibility with Torch stable ABI version {torch_stable_abi}");
93+
let torch_stable_abi_violations =
94+
check_torch_stable_abi(&torch_stable_abi, file.format(), file.symbols())?;
95+
print_torch_stable_abi_violations(&torch_stable_abi_violations, &torch_stable_abi);
96+
97+
if !torch_stable_abi_violations.is_empty() {
98+
return Err(eyre::eyre!("Torch stable ABI compatibility issues found"));
99+
} else {
100+
eprintln!("✅ No Torch stable ABI compatibility issues found");
101+
}
102+
}
103+
87104
Ok(())
88105
}
89106

107+
fn print_torch_stable_abi_violations(
108+
violations: &BTreeSet<TorchStableAbiViolation>,
109+
torch_abi: &Version,
110+
) {
111+
if !violations.is_empty() {
112+
eprintln!("\n⛔ Non-stable Torch ABI symbols found (incompatible with Torch stable ABI {torch_abi}):\n");
113+
for violation in violations {
114+
match violation {
115+
TorchStableAbiViolation::IncompatibleStableAbiSymbol { name, added } => {
116+
eprintln!("{name}: {added}");
117+
}
118+
TorchStableAbiViolation::NonStableAbiSymbol { name } => {
119+
eprintln!("{name}");
120+
}
121+
}
122+
}
123+
}
124+
}
125+
90126
fn print_manylinux_violations(
91127
violations: &BTreeSet<ManylinuxViolation>,
92128
manylinux_version: &str,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use std::collections::{BTreeSet, HashMap};
2+
use std::str::FromStr;
3+
4+
use cpp_demangle::Symbol as CppSymbol;
5+
use eyre::Result;
6+
use object::{BinaryFormat, ObjectSymbol, Symbol};
7+
use once_cell::sync::Lazy;
8+
9+
use crate::version::Version;
10+
11+
// https://raw.githubusercontent.com/pytorch/pytorch/refs/heads/main/torch/csrc/stable/c/shim_function_versions.txt
12+
static SHIM_FUNCTION_VERSIONS_RAW: &str = include_str!("shim_function_versions.txt");
13+
14+
/// Maps shim function names to the minimum Torch version that introduced them.
15+
/// Functions absent from this map were available before 2.10.0.
16+
pub static TORCH_SHIM_VERSIONS: Lazy<HashMap<String, Version>> = Lazy::new(|| {
17+
let mut map = HashMap::new();
18+
for line in SHIM_FUNCTION_VERSIONS_RAW.lines() {
19+
// Skip blank lines and comments.
20+
let line = line.trim();
21+
if line.is_empty() || line.starts_with('#') {
22+
continue;
23+
}
24+
if let Some((name, version_token)) = line.split_once(':') {
25+
let name = name.trim().to_owned();
26+
// TORCH_VERSION_2_10_0 -> "2.10.0"
27+
let version_str = version_token
28+
.trim()
29+
.strip_prefix("TORCH_VERSION_")
30+
.expect("unexpected version token format")
31+
.replace('_', ".");
32+
let version = Version::from_str(&version_str)
33+
.expect("invalid version in shim_function_versions.txt");
34+
map.insert(name, version);
35+
}
36+
}
37+
map
38+
});
39+
40+
/// Torch stable ABI violation.
41+
#[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)]
42+
pub enum TorchStableAbiViolation {
43+
/// Symbol is newer than the specified Torch Stable ABI version.
44+
IncompatibleStableAbiSymbol { name: String, added: Version },
45+
46+
/// Symbol is not part of ABI3.
47+
NonStableAbiSymbol { name: String },
48+
}
49+
50+
/// Check for violations of the Python ABI policy.
51+
pub fn check_torch_stable_abi<'a>(
52+
torch_stable_abi: &Version,
53+
binary_format: BinaryFormat,
54+
symbols: impl IntoIterator<Item = Symbol<'a, 'a>>,
55+
) -> Result<BTreeSet<TorchStableAbiViolation>> {
56+
let mut violations = BTreeSet::new();
57+
58+
for symbol in symbols {
59+
if !symbol.is_undefined() {
60+
continue;
61+
}
62+
63+
let mut symbol_name = symbol.name()?;
64+
if matches!(binary_format, BinaryFormat::MachO) {
65+
// Mach-O C symbol mangling adds an underscore.
66+
symbol_name = symbol_name.strip_prefix("_").unwrap_or(symbol_name);
67+
}
68+
69+
// If this is a C shim symbol, check if it is valid for this version.
70+
if let Some(symbol_version) = TORCH_SHIM_VERSIONS.get(symbol_name) {
71+
if symbol_version > torch_stable_abi {
72+
violations.insert(TorchStableAbiViolation::IncompatibleStableAbiSymbol {
73+
name: symbol_name.to_owned(),
74+
added: symbol_version.clone(),
75+
});
76+
}
77+
continue;
78+
}
79+
80+
// Try to demangle the symbol as a C++ symbol. If that fails, it's probably an
81+
// unrelated C symbol.
82+
let cpp_symbol = match CppSymbol::new(symbol_name) {
83+
Ok(cpp_symbol) => cpp_symbol,
84+
Err(_) => {
85+
continue;
86+
}
87+
};
88+
let demangled = cpp_symbol.demangle()?;
89+
90+
// Check if Torch symbols are from the stable ABI.
91+
if demangled.starts_with("torch::stable::") {
92+
// This branch fulfills to purposes: (1) avoid that stable ABI
93+
// C++ symbols get reported by the filter below. (2) Once a
94+
// versioned list of symbols is available, check versions.
95+
} else if demangled.starts_with("c10::")
96+
|| demangled.starts_with("at::")
97+
|| demangled.starts_with("torch::")
98+
{
99+
violations.insert(TorchStableAbiViolation::NonStableAbiSymbol { name: demangled });
100+
}
101+
}
102+
103+
Ok(violations)
104+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Auto-generated file listing shim functions and their minimum required versions
2+
# Format: function_name: TORCH_VERSION_MAJOR_MINOR_PATCH
3+
#
4+
# This file is automatically updated by the stable_shim_usage_linter.
5+
# If a function is not in this file, it was available before 2.10.0.
6+
# DO NOT EDIT MANUALLY.
7+
8+
ParallelFunc: TORCH_VERSION_2_10_0
9+
StableListHandle: TORCH_VERSION_2_10_0
10+
StableListOpaque: TORCH_VERSION_2_10_0
11+
StringHandle: TORCH_VERSION_2_10_0
12+
StringOpaque: TORCH_VERSION_2_10_0
13+
aoti_torch_aten_full: TORCH_VERSION_2_10_0
14+
aoti_torch_aten_subtract_Tensor: TORCH_VERSION_2_10_0
15+
aoti_torch_cuda__scaled_grouped_mm: TORCH_VERSION_2_10_0
16+
torch_c10_cuda_check_msg: TORCH_VERSION_2_10_0
17+
torch_c10_cuda_free_error_msg: TORCH_VERSION_2_10_0
18+
torch_call_dispatcher: TORCH_VERSION_2_10_0
19+
torch_cuda_stream_synchronize: TORCH_VERSION_2_10_0
20+
torch_delete_list: TORCH_VERSION_2_10_0
21+
torch_delete_string: TORCH_VERSION_2_10_0
22+
torch_get_const_data_ptr: TORCH_VERSION_2_10_0
23+
torch_get_cuda_stream_from_pool: TORCH_VERSION_2_10_0
24+
torch_get_current_cuda_blas_handle: TORCH_VERSION_2_10_0
25+
torch_get_mutable_data_ptr: TORCH_VERSION_2_10_0
26+
torch_get_num_threads: TORCH_VERSION_2_10_0
27+
torch_get_thread_idx: TORCH_VERSION_2_10_0
28+
torch_library_impl: TORCH_VERSION_2_10_0
29+
torch_list_get_item: TORCH_VERSION_2_10_0
30+
torch_list_push_back: TORCH_VERSION_2_10_0
31+
torch_list_set_item: TORCH_VERSION_2_10_0
32+
torch_list_size: TORCH_VERSION_2_10_0
33+
torch_new_list_reserve_size: TORCH_VERSION_2_10_0
34+
torch_new_string_handle: TORCH_VERSION_2_10_0
35+
torch_parallel_for: TORCH_VERSION_2_10_0
36+
torch_parse_device_string: TORCH_VERSION_2_10_0
37+
torch_set_current_cuda_stream: TORCH_VERSION_2_10_0
38+
torch_set_requires_grad: TORCH_VERSION_2_10_0
39+
torch_string_c_str: TORCH_VERSION_2_10_0
40+
torch_string_length: TORCH_VERSION_2_10_0
41+
aoti_torch_cpu_nonzero_static: TORCH_VERSION_2_11_0
42+
aoti_torch_cuda__flash_attention_forward_quantized: TORCH_VERSION_2_11_0
43+
aoti_torch_cuda__scaled_dot_product_flash_attention_quantized: TORCH_VERSION_2_11_0
44+
aoti_torch_cuda_mm_dtype_out: TORCH_VERSION_2_11_0
45+
aoti_torch_cuda_nonzero_static: TORCH_VERSION_2_11_0
46+
aoti_torch_mps_nonzero_static: TORCH_VERSION_2_11_0
47+
aoti_torch_xpu_mm_dtype_out: TORCH_VERSION_2_11_0
48+
torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_11_0
49+
torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_11_0
50+
torch_from_blob: TORCH_VERSION_2_11_0
51+
aoti_torch_cpu__grouped_mm: TORCH_VERSION_2_12_0
52+
aoti_torch_cpu_rand_like: TORCH_VERSION_2_12_0
53+
aoti_torch_cpu_rand_like_generator: TORCH_VERSION_2_12_0
54+
aoti_torch_cpu_randint_like: TORCH_VERSION_2_12_0
55+
aoti_torch_cpu_randint_like_low_dtype: TORCH_VERSION_2_12_0
56+
aoti_torch_cpu_randn_like: TORCH_VERSION_2_12_0
57+
aoti_torch_cpu_randn_like_generator: TORCH_VERSION_2_12_0
58+
aoti_torch_cuda__flash_attention_forward_no_dropout_inplace: TORCH_VERSION_2_12_0
59+
aoti_torch_cuda__flash_attention_forward_no_dropout_inplace_v2: TORCH_VERSION_2_12_0
60+
aoti_torch_cuda__flash_attention_forward_v2: TORCH_VERSION_2_12_0
61+
aoti_torch_cuda__grouped_mm: TORCH_VERSION_2_12_0
62+
aoti_torch_cuda_rand_like: TORCH_VERSION_2_12_0
63+
aoti_torch_cuda_rand_like_generator: TORCH_VERSION_2_12_0
64+
aoti_torch_cuda_randint_like: TORCH_VERSION_2_12_0
65+
aoti_torch_cuda_randint_like_low_dtype: TORCH_VERSION_2_12_0
66+
aoti_torch_cuda_randn_like: TORCH_VERSION_2_12_0
67+
aoti_torch_cuda_randn_like_generator: TORCH_VERSION_2_12_0
68+
aoti_torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_12_0
69+
aoti_torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_12_0
70+
aoti_torch_mps__grouped_mm: TORCH_VERSION_2_12_0
71+
aoti_torch_mps_rand_like: TORCH_VERSION_2_12_0
72+
aoti_torch_mps_rand_like_generator: TORCH_VERSION_2_12_0
73+
aoti_torch_mps_randint_like: TORCH_VERSION_2_12_0
74+
aoti_torch_mps_randint_like_low_dtype: TORCH_VERSION_2_12_0
75+
aoti_torch_mps_randn_like: TORCH_VERSION_2_12_0
76+
aoti_torch_mps_randn_like_generator: TORCH_VERSION_2_12_0
77+
aoti_torch_xpu__grouped_mm: TORCH_VERSION_2_12_0
78+
aoti_torch_xpu_rand_like: TORCH_VERSION_2_12_0
79+
aoti_torch_xpu_rand_like_generator: TORCH_VERSION_2_12_0
80+
aoti_torch_xpu_randint_like: TORCH_VERSION_2_12_0
81+
aoti_torch_xpu_randint_like_low_dtype: TORCH_VERSION_2_12_0
82+
aoti_torch_xpu_randn_like: TORCH_VERSION_2_12_0
83+
aoti_torch_xpu_randn_like_generator: TORCH_VERSION_2_12_0
84+
torch_library_def_with_tags: TORCH_VERSION_2_12_0
85+
torch_tag_core: TORCH_VERSION_2_12_0
86+
torch_tag_cudagraph_unsafe: TORCH_VERSION_2_12_0
87+
torch_tag_data_dependent_output: TORCH_VERSION_2_12_0
88+
torch_tag_dynamic_output_shape: TORCH_VERSION_2_12_0
89+
torch_tag_flexible_layout: TORCH_VERSION_2_12_0
90+
torch_tag_generated: TORCH_VERSION_2_12_0
91+
torch_tag_inplace_view: TORCH_VERSION_2_12_0
92+
torch_tag_maybe_aliasing_or_mutating: TORCH_VERSION_2_12_0
93+
torch_tag_needs_contiguous_strides: TORCH_VERSION_2_12_0
94+
torch_tag_needs_exact_strides: TORCH_VERSION_2_12_0
95+
torch_tag_needs_fixed_stride_order: TORCH_VERSION_2_12_0
96+
torch_tag_nondeterministic_bitwise: TORCH_VERSION_2_12_0
97+
torch_tag_nondeterministic_seeded: TORCH_VERSION_2_12_0
98+
torch_tag_out_variant: TORCH_VERSION_2_12_0
99+
torch_tag_pointwise: TORCH_VERSION_2_12_0
100+
torch_tag_pt2_compliant_tag: TORCH_VERSION_2_12_0
101+
torch_tag_reduction: TORCH_VERSION_2_12_0
102+
torch_tag_view_copy: TORCH_VERSION_2_12_0
103+
aoti_torch_cpu_grid_sampler_3d: TORCH_VERSION_2_13_0
104+
aoti_torch_cpu_grid_sampler_3d_backward: TORCH_VERSION_2_13_0
105+
aoti_torch_cuda_grid_sampler_3d: TORCH_VERSION_2_13_0
106+
aoti_torch_cuda_grid_sampler_3d_backward: TORCH_VERSION_2_13_0
107+
aoti_torch_mps__scaled_dot_product_attention_math_for_mps_v2: TORCH_VERSION_2_13_0
108+
aoti_torch_mps_grid_sampler_3d: TORCH_VERSION_2_13_0
109+
aoti_torch_mps_grid_sampler_3d_backward: TORCH_VERSION_2_13_0
110+
torch_delete_stable_ivalue: TORCH_VERSION_2_13_0
111+
torch_exception_get_what: TORCH_VERSION_2_13_0
112+
torch_exception_get_what_without_backtrace: TORCH_VERSION_2_13_0
113+
torch_library_set_python_module: TORCH_VERSION_2_13_0
114+
torch_new_stable_ivalue: TORCH_VERSION_2_13_0

nix-builder/lib/build.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ rec {
185185
backendPythonDeps
186186
;
187187

188+
inherit (kernelConfig) torchStableAbiVersion;
189+
188190
kernelName = kernelConfig.name;
189191
doAbiCheck = true;
190192
};

nix-builder/lib/extension/torch/arch.nix

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
# Wheter to strip rpath for non-nix use.
6666
stripRPath ? false,
6767

68+
# The Torch stable ABI version to check for.
69+
torchStableAbiVersion ? null,
70+
6871
# Revision to bake into the ops name.
6972
rev,
7073

@@ -274,6 +277,8 @@ stdenv.mkDerivation (prevAttrs: {
274277

275278
doInstallCheck = true;
276279

280+
inherit torchStableAbiVersion;
281+
277282
# We need access to the host system on Darwin for the Metal compiler.
278283
__noChroot = metalSupport;
279284

nix-builder/pkgs/kernel-abi-check/default.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ rustPlatform.buildRustPackage {
2424
|| file.name == "Cargo.lock"
2525
|| file.name == "manylinux-policy.json"
2626
|| file.hasExt "rs"
27+
|| file.name == "shim_function_versions.txt"
2728
|| file.name == "stable_abi.toml";
2829
in
2930
import ../crate-dirs.nix {

nix-builder/pkgs/kernel-abi-check/kernel-abi-check-hook.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
#!/bin/sh
22

3+
echo "Sourcing kernel-abi-check-hook.sh"
4+
35
_checkAbiHook() {
46
if [ -z "${doAbiCheck:-}" ]; then
57
echo "Skipping ABI check"
68
else
9+
10+
if [ -z "${torchStableAbiVersion:-}" ]; then
11+
_torchStableAbiFlag=""
12+
else
13+
_torchStableAbiFlag="--torch-stable-abi=${torchStableAbiVersion}"
14+
fi
15+
716
echo "Checking of ABI compatibility"
817
find "$out/" -name '*.so' -print0 | \
9-
xargs -0 -n1 kernel-abi-check
18+
xargs -0 -n1 kernel-abi-check ${_torchStableAbiFlag}
1019
fi
1120
}
1221

0 commit comments

Comments
 (0)