Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions kernel-abi-check/kernel-abi-check/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repository = "https://github.com/huggingface/kernel-builder"
[dependencies]
clap = { version = "4", features = ["derive"] }
color-eyre = "0.6"
cpp_demangle = "0.5"
eyre = "0.6"
itertools = "0.14.0"
object = "0.36.7"
Expand Down
3 changes: 3 additions & 0 deletions kernel-abi-check/kernel-abi-check/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@ pub use manylinux::{check_manylinux, ManylinuxViolation};
mod python_abi;
pub use python_abi::{check_python_abi, PythonAbiViolation};

mod torch_stable_abi;
pub use torch_stable_abi::{check_torch_stable_abi, TorchStableAbiViolation};

mod version;
pub use version::Version;
40 changes: 38 additions & 2 deletions kernel-abi-check/kernel-abi-check/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use eyre::{Context, Result};
use object::{File, Object};

use kernel_abi_check::{
check_macos, check_manylinux, check_python_abi, MacOSViolation, ManylinuxViolation,
PythonAbiViolation, Version,
check_macos, check_manylinux, check_python_abi, check_torch_stable_abi, MacOSViolation,
ManylinuxViolation, PythonAbiViolation, TorchStableAbiViolation, Version,
};

/// CLI tool to check library versions
Expand All @@ -28,6 +28,10 @@ struct Cli {
/// Python ABI version.
#[arg(short, long, value_name = "VERSION", default_value = "3.9")]
python_abi: Version,

/// Torch stable ABI version.
#[arg(long, value_name = "VERSION")]
torch_stable_abi: Option<Version>,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -84,9 +88,41 @@ fn main() -> Result<()> {
eprintln!("✅ No compatibility issues found");
}

if let Some(torch_stable_abi) = args.torch_stable_abi {
eprintln!("🔥 Checking for compatibility with Torch stable ABI version {torch_stable_abi}");
let torch_stable_abi_violations =
check_torch_stable_abi(&torch_stable_abi, file.format(), file.symbols())?;
print_torch_stable_abi_violations(&torch_stable_abi_violations, &torch_stable_abi);

if !torch_stable_abi_violations.is_empty() {
return Err(eyre::eyre!("Torch stable ABI compatibility issues found"));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return Err(eyre::eyre!("Torch stable ABI compatibility issues found"));
return Err(eyre::eyre!("Torch stable ABI compatibility issues found"));

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already printed with an stop sign below.

} else {
eprintln!("✅ No Torch stable ABI compatibility issues found");
}
}

Ok(())
}

fn print_torch_stable_abi_violations(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouf, this is superb!

violations: &BTreeSet<TorchStableAbiViolation>,
torch_abi: &Version,
) {
if !violations.is_empty() {
eprintln!("\n⛔ Non-stable Torch ABI symbols found (incompatible with Torch stable ABI {torch_abi}):\n");
for violation in violations {
match violation {
TorchStableAbiViolation::IncompatibleStableAbiSymbol { name, added } => {
eprintln!("{name}: {added}");
}
TorchStableAbiViolation::NonStableAbiSymbol { name } => {
eprintln!("{name}");
}
}
}
}
}

fn print_manylinux_violations(
violations: &BTreeSet<ManylinuxViolation>,
manylinux_version: &str,
Expand Down
104 changes: 104 additions & 0 deletions kernel-abi-check/kernel-abi-check/src/torch_stable_abi/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use std::collections::{BTreeSet, HashMap};
use std::str::FromStr;

use cpp_demangle::Symbol as CppSymbol;
use eyre::Result;
use object::{BinaryFormat, ObjectSymbol, Symbol};
use once_cell::sync::Lazy;

use crate::version::Version;

// https://raw.githubusercontent.com/pytorch/pytorch/refs/heads/main/torch/csrc/stable/c/shim_function_versions.txt
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): Could use the permalink.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't make it a permalink, because we typically want to use the URL to get the latest version and update our copy.

static SHIM_FUNCTION_VERSIONS_RAW: &str = include_str!("shim_function_versions.txt");

/// Maps shim function names to the minimum Torch version that introduced them.
/// Functions absent from this map were available before 2.10.0.
pub static TORCH_SHIM_VERSIONS: Lazy<HashMap<String, Version>> = Lazy::new(|| {
let mut map = HashMap::new();
for line in SHIM_FUNCTION_VERSIONS_RAW.lines() {
// Skip blank lines and comments.
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((name, version_token)) = line.split_once(':') {
let name = name.trim().to_owned();
// TORCH_VERSION_2_10_0 -> "2.10.0"
let version_str = version_token
.trim()
.strip_prefix("TORCH_VERSION_")
.expect("unexpected version token format")
.replace('_', ".");
let version = Version::from_str(&version_str)
.expect("invalid version in shim_function_versions.txt");
map.insert(name, version);
}
}
map
});

/// Torch stable ABI violation.
#[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)]
pub enum TorchStableAbiViolation {
/// Symbol is newer than the specified Torch Stable ABI version.
IncompatibleStableAbiSymbol { name: String, added: Version },

/// Symbol is not part of ABI3.
NonStableAbiSymbol { name: String },
}

/// Check for violations of the Python ABI policy.
pub fn check_torch_stable_abi<'a>(
torch_stable_abi: &Version,
binary_format: BinaryFormat,
symbols: impl IntoIterator<Item = Symbol<'a, 'a>>,
) -> Result<BTreeSet<TorchStableAbiViolation>> {
let mut violations = BTreeSet::new();

for symbol in symbols {
if !symbol.is_undefined() {
continue;
}

let mut symbol_name = symbol.name()?;
if matches!(binary_format, BinaryFormat::MachO) {
// Mach-O C symbol mangling adds an underscore.
symbol_name = symbol_name.strip_prefix("_").unwrap_or(symbol_name);
}

// If this is a C shim symbol, check if it is valid for this version.
if let Some(symbol_version) = TORCH_SHIM_VERSIONS.get(symbol_name) {
if symbol_version > torch_stable_abi {
violations.insert(TorchStableAbiViolation::IncompatibleStableAbiSymbol {
name: symbol_name.to_owned(),
added: symbol_version.clone(),
});
}
continue;
}

// Try to demangle the symbol as a C++ symbol. If that fails, it's probably an
// unrelated C symbol.
let cpp_symbol = match CppSymbol::new(symbol_name) {
Ok(cpp_symbol) => cpp_symbol,
Err(_) => {
continue;
}
};
let demangled = cpp_symbol.demangle()?;

// Check if Torch symbols are from the stable ABI.
if demangled.starts_with("torch::stable::") {
// This branch fulfills to purposes: (1) avoid that stable ABI
// C++ symbols get reported by the filter below. (2) Once a
// versioned list of symbols is available, check versions.
} else if demangled.starts_with("c10::")
|| demangled.starts_with("at::")
|| demangled.starts_with("torch::")
{
violations.insert(TorchStableAbiViolation::NonStableAbiSymbol { name: demangled });
}
}

Ok(violations)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Auto-generated file listing shim functions and their minimum required versions
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we ensure this is also the latest stable?

Copy link
Copy Markdown
Member Author

@danieldk danieldk May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for the manylinux and Python ABI3 files. We should probably have a CI job to check for this and maybe make a PR automatically one one of the files is outdated.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into it after the PR is merged: #593

# Format: function_name: TORCH_VERSION_MAJOR_MINOR_PATCH
#
# This file is automatically updated by the stable_shim_usage_linter.
# If a function is not in this file, it was available before 2.10.0.
# DO NOT EDIT MANUALLY.

ParallelFunc: TORCH_VERSION_2_10_0
StableListHandle: TORCH_VERSION_2_10_0
StableListOpaque: TORCH_VERSION_2_10_0
StringHandle: TORCH_VERSION_2_10_0
StringOpaque: TORCH_VERSION_2_10_0
aoti_torch_aten_full: TORCH_VERSION_2_10_0
aoti_torch_aten_subtract_Tensor: TORCH_VERSION_2_10_0
aoti_torch_cuda__scaled_grouped_mm: TORCH_VERSION_2_10_0
torch_c10_cuda_check_msg: TORCH_VERSION_2_10_0
torch_c10_cuda_free_error_msg: TORCH_VERSION_2_10_0
torch_call_dispatcher: TORCH_VERSION_2_10_0
torch_cuda_stream_synchronize: TORCH_VERSION_2_10_0
torch_delete_list: TORCH_VERSION_2_10_0
torch_delete_string: TORCH_VERSION_2_10_0
torch_get_const_data_ptr: TORCH_VERSION_2_10_0
torch_get_cuda_stream_from_pool: TORCH_VERSION_2_10_0
torch_get_current_cuda_blas_handle: TORCH_VERSION_2_10_0
torch_get_mutable_data_ptr: TORCH_VERSION_2_10_0
torch_get_num_threads: TORCH_VERSION_2_10_0
torch_get_thread_idx: TORCH_VERSION_2_10_0
torch_library_impl: TORCH_VERSION_2_10_0
torch_list_get_item: TORCH_VERSION_2_10_0
torch_list_push_back: TORCH_VERSION_2_10_0
torch_list_set_item: TORCH_VERSION_2_10_0
torch_list_size: TORCH_VERSION_2_10_0
torch_new_list_reserve_size: TORCH_VERSION_2_10_0
torch_new_string_handle: TORCH_VERSION_2_10_0
torch_parallel_for: TORCH_VERSION_2_10_0
torch_parse_device_string: TORCH_VERSION_2_10_0
torch_set_current_cuda_stream: TORCH_VERSION_2_10_0
torch_set_requires_grad: TORCH_VERSION_2_10_0
torch_string_c_str: TORCH_VERSION_2_10_0
torch_string_length: TORCH_VERSION_2_10_0
aoti_torch_cpu_nonzero_static: TORCH_VERSION_2_11_0
aoti_torch_cuda__flash_attention_forward_quantized: TORCH_VERSION_2_11_0
aoti_torch_cuda__scaled_dot_product_flash_attention_quantized: TORCH_VERSION_2_11_0
aoti_torch_cuda_mm_dtype_out: TORCH_VERSION_2_11_0
aoti_torch_cuda_nonzero_static: TORCH_VERSION_2_11_0
aoti_torch_mps_nonzero_static: TORCH_VERSION_2_11_0
aoti_torch_xpu_mm_dtype_out: TORCH_VERSION_2_11_0
torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_11_0
torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_11_0
torch_from_blob: TORCH_VERSION_2_11_0
aoti_torch_cpu__grouped_mm: TORCH_VERSION_2_12_0
aoti_torch_cpu_rand_like: TORCH_VERSION_2_12_0
aoti_torch_cpu_rand_like_generator: TORCH_VERSION_2_12_0
aoti_torch_cpu_randint_like: TORCH_VERSION_2_12_0
aoti_torch_cpu_randint_like_low_dtype: TORCH_VERSION_2_12_0
aoti_torch_cpu_randn_like: TORCH_VERSION_2_12_0
aoti_torch_cpu_randn_like_generator: TORCH_VERSION_2_12_0
aoti_torch_cuda__flash_attention_forward_no_dropout_inplace: TORCH_VERSION_2_12_0
aoti_torch_cuda__flash_attention_forward_no_dropout_inplace_v2: TORCH_VERSION_2_12_0
aoti_torch_cuda__flash_attention_forward_v2: TORCH_VERSION_2_12_0
aoti_torch_cuda__grouped_mm: TORCH_VERSION_2_12_0
aoti_torch_cuda_rand_like: TORCH_VERSION_2_12_0
aoti_torch_cuda_rand_like_generator: TORCH_VERSION_2_12_0
aoti_torch_cuda_randint_like: TORCH_VERSION_2_12_0
aoti_torch_cuda_randint_like_low_dtype: TORCH_VERSION_2_12_0
aoti_torch_cuda_randn_like: TORCH_VERSION_2_12_0
aoti_torch_cuda_randn_like_generator: TORCH_VERSION_2_12_0
aoti_torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_12_0
aoti_torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_12_0
aoti_torch_mps__grouped_mm: TORCH_VERSION_2_12_0
aoti_torch_mps_rand_like: TORCH_VERSION_2_12_0
aoti_torch_mps_rand_like_generator: TORCH_VERSION_2_12_0
aoti_torch_mps_randint_like: TORCH_VERSION_2_12_0
aoti_torch_mps_randint_like_low_dtype: TORCH_VERSION_2_12_0
aoti_torch_mps_randn_like: TORCH_VERSION_2_12_0
aoti_torch_mps_randn_like_generator: TORCH_VERSION_2_12_0
aoti_torch_xpu__grouped_mm: TORCH_VERSION_2_12_0
aoti_torch_xpu_rand_like: TORCH_VERSION_2_12_0
aoti_torch_xpu_rand_like_generator: TORCH_VERSION_2_12_0
aoti_torch_xpu_randint_like: TORCH_VERSION_2_12_0
aoti_torch_xpu_randint_like_low_dtype: TORCH_VERSION_2_12_0
aoti_torch_xpu_randn_like: TORCH_VERSION_2_12_0
aoti_torch_xpu_randn_like_generator: TORCH_VERSION_2_12_0
torch_library_def_with_tags: TORCH_VERSION_2_12_0
torch_tag_core: TORCH_VERSION_2_12_0
torch_tag_cudagraph_unsafe: TORCH_VERSION_2_12_0
torch_tag_data_dependent_output: TORCH_VERSION_2_12_0
torch_tag_dynamic_output_shape: TORCH_VERSION_2_12_0
torch_tag_flexible_layout: TORCH_VERSION_2_12_0
torch_tag_generated: TORCH_VERSION_2_12_0
torch_tag_inplace_view: TORCH_VERSION_2_12_0
torch_tag_maybe_aliasing_or_mutating: TORCH_VERSION_2_12_0
torch_tag_needs_contiguous_strides: TORCH_VERSION_2_12_0
torch_tag_needs_exact_strides: TORCH_VERSION_2_12_0
torch_tag_needs_fixed_stride_order: TORCH_VERSION_2_12_0
torch_tag_nondeterministic_bitwise: TORCH_VERSION_2_12_0
torch_tag_nondeterministic_seeded: TORCH_VERSION_2_12_0
torch_tag_out_variant: TORCH_VERSION_2_12_0
torch_tag_pointwise: TORCH_VERSION_2_12_0
torch_tag_pt2_compliant_tag: TORCH_VERSION_2_12_0
torch_tag_reduction: TORCH_VERSION_2_12_0
torch_tag_view_copy: TORCH_VERSION_2_12_0
aoti_torch_cpu_grid_sampler_3d: TORCH_VERSION_2_13_0
aoti_torch_cpu_grid_sampler_3d_backward: TORCH_VERSION_2_13_0
aoti_torch_cuda_grid_sampler_3d: TORCH_VERSION_2_13_0
aoti_torch_cuda_grid_sampler_3d_backward: TORCH_VERSION_2_13_0
aoti_torch_mps__scaled_dot_product_attention_math_for_mps_v2: TORCH_VERSION_2_13_0
aoti_torch_mps_grid_sampler_3d: TORCH_VERSION_2_13_0
aoti_torch_mps_grid_sampler_3d_backward: TORCH_VERSION_2_13_0
torch_delete_stable_ivalue: TORCH_VERSION_2_13_0
torch_exception_get_what: TORCH_VERSION_2_13_0
torch_exception_get_what_without_backtrace: TORCH_VERSION_2_13_0
torch_library_set_python_module: TORCH_VERSION_2_13_0
torch_new_stable_ivalue: TORCH_VERSION_2_13_0
2 changes: 2 additions & 0 deletions nix-builder/lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ rec {
backendPythonDeps
;

inherit (kernelConfig) torchStableAbiVersion;

kernelName = kernelConfig.name;
doAbiCheck = true;
};
Expand Down
5 changes: 5 additions & 0 deletions nix-builder/lib/extension/torch/arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
# Wheter to strip rpath for non-nix use.
stripRPath ? false,

# The Torch stable ABI version to check for.
torchStableAbiVersion ? null,

# Revision to bake into the ops name.
rev,

Expand Down Expand Up @@ -274,6 +277,8 @@ stdenv.mkDerivation (prevAttrs: {

doInstallCheck = true;

inherit torchStableAbiVersion;

# We need access to the host system on Darwin for the Metal compiler.
__noChroot = metalSupport;

Expand Down
1 change: 1 addition & 0 deletions nix-builder/pkgs/kernel-abi-check/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ rustPlatform.buildRustPackage {
|| file.name == "Cargo.lock"
|| file.name == "manylinux-policy.json"
|| file.hasExt "rs"
|| file.name == "shim_function_versions.txt"
|| file.name == "stable_abi.toml";
in
import ../crate-dirs.nix {
Expand Down
11 changes: 10 additions & 1 deletion nix-builder/pkgs/kernel-abi-check/kernel-abi-check-hook.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
#!/bin/sh

echo "Sourcing kernel-abi-check-hook.sh"

_checkAbiHook() {
if [ -z "${doAbiCheck:-}" ]; then
echo "Skipping ABI check"
else

if [ -z "${torchStableAbiVersion:-}" ]; then
_torchStableAbiFlag=""
else
_torchStableAbiFlag="--torch-stable-abi=${torchStableAbiVersion}"
fi

echo "Checking of ABI compatibility"
find "$out/" -name '*.so' -print0 | \
xargs -0 -n1 kernel-abi-check
xargs -0 -n1 kernel-abi-check ${_torchStableAbiFlag}
fi
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ buildPythonPackage {
|| file.hasExt "pyi"
|| file.name == "pyproject.toml"
|| file.hasExt "rs"
|| file.name == "shim_function_versions.txt"
|| file.name == "stable_abi.toml";
in
import ../../crate-dirs.nix {
Expand Down
Loading