Skip to content

Adding bitwise ONNX ops #3120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e6b1e5f
Started integrating bitwise operators to burn
AshAnand34 Apr 29, 2025
a62cc9c
Started developing the bitwise operators for burn
AshAnand34 Apr 29, 2025
8f00bdc
Created python files to create test ONNX models for the bitwise opera…
AshAnand34 Apr 29, 2025
2dbe899
Fixing bitshift.rs to account for direction
AshAnand34 Apr 29, 2025
5378d38
Merge remote-tracking branch 'origin/main' into adding-bitwise-ops
AshAnand34 Apr 29, 2025
48c3aa9
Added test onnx files to build and added rank inference for the bitwi…
AshAnand34 Apr 29, 2025
4b91855
Made fixes to test onnx files due to wrong node inits and restricted …
AshAnand34 Apr 30, 2025
8d997f9
Fixed test_onnx.rs to include the right onnx file names for bitshift
AshAnand34 Apr 30, 2025
69d61c7
Created onnx model tests for bitwise operators
AshAnand34 Apr 30, 2025
3fce6eb
rank fixing
AshAnand34 Apr 30, 2025
e867f61
Fixed unit tests for bitwise operators
AshAnand34 Apr 30, 2025
5b57b09
More unit test fixes in bitwise operators
AshAnand34 Apr 30, 2025
6420833
created onnx files for scalar versions of the bitwise operators
AshAnand34 May 1, 2025
dccbc2a
Integrated scalar versions of onnx model to the onnx tests
AshAnand34 May 1, 2025
60599d0
Added scalar argtypes to to_burn
AshAnand34 May 1, 2025
92603ab
Merge branch 'tracel-ai:main' into adding-bitwise-ops
AshAnand34 May 1, 2025
fa5574f
Added bitshift config plus minor fixes
AshAnand34 May 2, 2025
871f295
Merge branch 'tracel-ai:main' into adding-bitwise-ops
AshAnand34 May 2, 2025
ede2f88
Merge branch 'tracel-ai:main' into adding-bitwise-ops
AshAnand34 May 7, 2025
41ee35f
Fixing formatting errors
AshAnand34 May 7, 2025
e53950c
resolve cubecl latest version
AshAnand34 May 7, 2025
135317c
Merge remote-tracking branch 'upstream/main' into pr/3120
antimora Jul 11, 2025
00c1eb1
Remove unused op_configuration module import
antimora Jul 11, 2025
1d89f45
Refactor bitwise node handling for scalar support
antimora Jul 11, 2025
cf6463d
Refactor BitShift ONNX test generation and enable tests
antimora Jul 11, 2025
c56f54e
Fix formatting and logging in node modules
antimora Jul 11, 2025
a96a632
Refactor bitwise operations to use Tensor methods
antimora Jul 11, 2025
32e6ab1
Revert powf and powi changes
antimora Jul 11, 2025
120df41
Fix formating
antimora Jul 11, 2025
fde17a8
Add missing newline at end of bitshift.rs
antimora Jul 11, 2025
6247a7d
Update ONNX ops support status in documentation
antimora Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ functionality.
| [AveragePool2d][12] | ✅ | ✅ |
| [BatchNormalization][14] | ✅ | ✅ |
| [Bernoulli][15] | ❌ | ✅ |
| [BitShift][16] | | ✅ |
| [BitwiseAnd][17] | | ✅ |
| [BitwiseNot][18] | | ✅ |
| [BitwiseOr][19] | | ✅ |
| [BitwiseXor][20] | | ✅ |
| [BitShift][16] | | ✅ |
| [BitwiseAnd][17] | | ✅ |
| [BitwiseNot][18] | | ✅ |
| [BitwiseOr][19] | | ✅ |
| [BitwiseXor][20] | | ✅ |
| [BlackmanWindow][21] | ❌ | ❌ |
| [Cast][22] | ✅ | ✅ |
| [CastLike][23] | ❌ | ❌ |
Expand Down Expand Up @@ -92,7 +92,7 @@ functionality.
| [If][77] | ❌ | ✅ |
| [Im][78] | ❌ | ❌ |
| [InstanceNormalization][79] | ✅ | ✅ |
| [IsInf][80] | ❌ | |
| [IsInf][80] | ❌ | |
| [IsNaN][81] | ❌ | ✅ |
| [LayerNormalization][82] | ✅ | ✅ |
| [LeakyRelu][83] | ✅ | ✅ |
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ fn main() {
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/bitshift/bitshift_left.onnx")
.input("tests/bitshift/bitshift_left_scalar.onnx")
.input("tests/bitshift/bitshift_right.onnx")
.input("tests/bitshift/bitshift_right_scalar.onnx")
.input("tests/bitwise_and/bitwise_and.onnx")
.input("tests/bitwise_and/bitwise_and_scalar.onnx")
.input("tests/bitwise_not/bitwise_not.onnx")
.input("tests/bitwise_or/bitwise_or.onnx")
.input("tests/bitwise_or/bitwise_or_scalar.onnx")
.input("tests/bitwise_xor/bitwise_xor.onnx")
.input("tests/bitwise_xor/bitwise_xor_scalar.onnx")
.input("tests/cast/cast.onnx")
.input("tests/ceil/ceil.onnx")
.input("tests/clip/clip.onnx")
Expand Down
75 changes: 75 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitshift/bitshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/bitshift/bitshift_left.onnx and bitshift_right.onnx

import onnx

def build_model(direction: str = "LEFT", scalar_shift: bool = False):
op_type = "BitShift"
direction_attr = "LEFT" if direction == "LEFT" else "RIGHT"

nodes = [
onnx.helper.make_node(
op_type,
inputs=["x", "shift"],
outputs=["output"],
name=f"/{op_type}",
direction=direction_attr
),
]

# Both tensor and scalar versions have the same input structure
# The scalar version will be handled by the burn runtime with scalar input
inputs = [
onnx.helper.make_value_info(
name="x",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT32, shape=[4]
),
),
onnx.helper.make_value_info(
name="shift",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT32, shape=[] if scalar_shift else [4]
),
),
]

return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 18)],
graph=onnx.helper.make_graph(
name="main_graph",
nodes=nodes,
inputs=inputs,
outputs=[
onnx.helper.make_value_info(
name="output",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT32, shape=[4]
),
)
],
initializer=[]
),
)

def export_bitshift(direction: str = "LEFT"):
# Regular tensor version
onnx_model = build_model(direction, scalar_shift=False)
file_name = f"bitshift_{direction.lower()}.onnx"

onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, file_name)
print(f"Finished exporting model to {file_name}")

# Scalar version
onnx_model_scalar = build_model(direction, scalar_shift=True)
file_name_scalar = f"bitshift_{direction.lower()}_scalar.onnx"

onnx.checker.check_model(onnx_model_scalar)
onnx.save(onnx_model_scalar, file_name_scalar)
print(f"Finished exporting model to {file_name_scalar}")

if __name__ == "__main__":
for direction in ["LEFT", "RIGHT"]:
export_bitshift(direction)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
73 changes: 73 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitshift/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Include the models for this node type
use crate::include_models;
include_models!(
bitshift_left,
bitshift_left_scalar,
bitshift_right,
bitshift_right_scalar
);

#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{Int, Tensor, TensorData};

type Backend = burn_ndarray::NdArray<f32>;

#[test]
fn bitshift_left_tensors() {
// Initialize the model with weights (loaded from the exported file)
let device = Default::default();
let model: bitshift_left::Model<Backend> = bitshift_left::Model::new(&device);
// Run the model
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 2], &device);
let output = model.forward(input1, input2);
let expected = TensorData::from([2i64, 4, 12, 16]);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn bitshift_left_scalar_tensor() {
// Initialize the model with weights (loaded from the exported file)
let device = Default::default();
let model: bitshift_left_scalar::Model<Backend> = bitshift_left_scalar::Model::new(&device);
// Run the model
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let scalar = 2;
let output = model.forward(input1, scalar);
let expected = TensorData::from([4i64, 8, 12, 16]);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn bitshift_right_tensors() {
let device = Default::default();
let model: bitshift_right::Model<Backend> = bitshift_right::Model::new(&device);

// Run the model
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 2], &device);
let output = model.forward(input1, input2);
let expected = TensorData::from([0i64, 1, 0, 1]);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn bitshift_right_scalar_tensor() {
// Initialize the model with weights (loaded from the exported file)
let device = Default::default();
let model: bitshift_right_scalar::Model<Backend> =
bitshift_right_scalar::Model::new(&device);
// Run the model
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let scalar = 2;
let output = model.forward(input1, scalar);
let expected = TensorData::from([0i64, 0, 0, 1]);

output.to_data().assert_eq(&expected, true);
}
}
18 changes: 18 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pytorch2.5.1:m
'
x
youtput /BitwiseAnd"
BitwiseAnd
main_graphZ
x


Z
y


b
output


B
39 changes: 39 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/bitwise_and/bitwise_and.onnx

import torch
import onnx

def export_bitwise_and():
class BitwiseAndModel(torch.nn.Module):
def forward(self, x, y):
if isinstance(y, int):
# If y is a scalar, convert it to a tensor
y = torch.tensor([y], dtype=x.dtype)
return torch.bitwise_and(x, y)

model = BitwiseAndModel()
x = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
y = torch.tensor([4, 3, 2, 1], dtype=torch.int32)
torch.onnx.export(
model,
(x, y),
"bitwise_and.onnx",
opset_version=18,
input_names=["x", "y"],
output_names=["output"],
)

# Scalar version
and_scalar = 2 # Scalar shift value
torch.onnx.export(
model,
(x, and_scalar),
f"bitwise_and_scalar.onnx",
opset_version=18,
input_names=["x", "y"],
output_names=["output"],
)

if __name__ == "__main__":
export_bitwise_and()
Binary file not shown.
35 changes: 35 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitwise_and/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Include the models for this node type
use crate::include_models;
include_models!(bitwise_and, bitwise_and_scalar);

#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{Int, Tensor, TensorData};

type Backend = burn_ndarray::NdArray<f32>;

#[test]
fn bitwise_and_tensors() {
let device = Default::default();
let model: bitwise_and::Model<Backend> = bitwise_and::Model::new(&device);
// Run the model
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 2], &device);
let output = model.forward(input1, input2);
let expected = TensorData::from([1i64, 0, 2, 0]);
output.to_data().assert_eq(&expected, true);
}

#[test]
fn bitwise_and_scalar_tensor() {
let device = Default::default();
let model: bitwise_and_scalar::Model<Backend> = bitwise_and_scalar::Model::new(&device);
// Run the model
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let scalar = 2;
let output = model.forward(input1, scalar);
let expected = TensorData::from([0i64, 2, 2, 0]);
output.to_data().assert_eq(&expected, true);
}
}
Binary file not shown.
49 changes: 49 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/bitwise_not/bitwise_not.onnx

import onnx

def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 18)],
graph=onnx.helper.make_graph(
name="main_graph",
nodes=[
onnx.helper.make_node(
"BitwiseNot",
inputs=["input"],
outputs=["output"],
name="/BitwiseNot"
),
],
inputs=[
onnx.helper.make_value_info(
name="input",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT32, shape=[1, 4]
),
),
],
outputs=[
onnx.helper.make_value_info(
name="output",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT32, shape=[1, 4]
),
)
]
),
)

def main():
onnx_model = build_model()
file_name = "bitwise_not.onnx"

onnx.checker.check_model(onnx_model) # Ensure valid ONNX
onnx.save(onnx_model, file_name) # Save the model
print(f"Finished exporting model to {file_name}")

if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions crates/burn-import/onnx-tests/tests/bitwise_not/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Include the models for this node type
use crate::include_models;
include_models!(bitwise_not);

#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{Int, Tensor, TensorData};

type Backend = burn_ndarray::NdArray<f32>;

#[test]
fn bitwise_not_tensors() {
let device = Default::default();
let model: bitwise_not::Model<Backend> = bitwise_not::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 2, Int>::from_ints([[1, 2, 3, 4]], &device);
let output = model.forward(input);
let expected = TensorData::from([[-2i64, -3, -4, -5]]);
output.to_data().assert_eq(&expected, true);
}
}
Binary file not shown.
Loading
Loading