diff --git a/.gitignore b/.gitignore index c47cb50116..c1151def5f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ out # Virtual Environment of Python .venv +uv.lock diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 94f94eab33..47bab0e311 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -25,11 +25,11 @@ functionality. | [AveragePool2d][12] | ✅ | ✅ | | [BatchNormalization][14] | ✅ | ✅ | | [Bernoulli][15] | ✅ | ✅ | -| [BitShift][16] | ❌ | ✅ | -| [BitwiseAnd][17] | ❌ | ✅ | -| [BitwiseNot][18] | ❌ | ✅ | -| [BitwiseOr][19] | ❌ | ✅ | -| [BitwiseXor][20] | ❌ | ✅ | +| [BitShift][16] | ✅ | ✅ | +| [BitwiseAnd][17] | ✅ | ✅ | +| [BitwiseNot][18] | ✅ | ✅ | +| [BitwiseOr][19] | ✅ | ✅ | +| [BitwiseXor][20] | ✅ | ✅ | | [BlackmanWindow][21] | ❌ | ❌ | | [Cast][22] | ✅ | ✅ | | [CastLike][23] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index d6621e9f01..af89438d1d 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -14,6 +14,27 @@ fn main() { .input("tests/avg_pool1d/avg_pool1d.onnx") .input("tests/avg_pool2d/avg_pool2d.onnx") .input("tests/batch_norm/batch_norm.onnx") + .input("tests/bitshift/bitshift_left.onnx") + .input("tests/bitshift/bitshift_left_scalar.onnx") + .input("tests/bitshift/scalar_bitshift_left.onnx") + .input("tests/bitshift/scalar_bitshift_left_scalar.onnx") + .input("tests/bitshift/bitshift_right.onnx") + .input("tests/bitshift/bitshift_right_scalar.onnx") + .input("tests/bitshift/scalar_bitshift_right.onnx") + .input("tests/bitshift/scalar_bitshift_right_scalar.onnx") + .input("tests/bitwise_and/bitwise_and.onnx") + .input("tests/bitwise_and/bitwise_and_scalar.onnx") + .input("tests/bitwise_and/scalar_bitwise_and.onnx") + .input("tests/bitwise_and/scalar_bitwise_and_scalar.onnx") + .input("tests/bitwise_not/bitwise_not.onnx") + .input("tests/bitwise_or/bitwise_or.onnx") + .input("tests/bitwise_or/bitwise_or_scalar.onnx") + .input("tests/bitwise_or/scalar_bitwise_or.onnx") + .input("tests/bitwise_or/scalar_bitwise_or_scalar.onnx") + .input("tests/bitwise_xor/bitwise_xor.onnx") + .input("tests/bitwise_xor/bitwise_xor_scalar.onnx") + .input("tests/bitwise_xor/scalar_bitwise_xor.onnx") + .input("tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx") .input("tests/bernoulli/bernoulli.onnx") .input("tests/cast/cast.onnx") .input("tests/ceil/ceil.onnx") diff --git a/crates/burn-import/onnx-tests/tests/bernoulli/mod.rs b/crates/burn-import/onnx-tests/tests/bernoulli/mod.rs index d415469aa2..f3cc9ab89c 100644 --- a/crates/burn-import/onnx-tests/tests/bernoulli/mod.rs +++ b/crates/burn-import/onnx-tests/tests/bernoulli/mod.rs @@ -5,7 +5,7 @@ include_models!(bernoulli); mod tests { use super::*; use burn::tensor::Shape; - use burn::tensor::{Tensor, TensorData, Tolerance, ops::FloatElem}; + use burn::tensor::Tensor; type Backend = burn_ndarray::NdArray; diff --git a/crates/burn-import/onnx-tests/tests/bitshift/bitshift.py b/crates/burn-import/onnx-tests/tests/bitshift/bitshift.py new file mode 100644 index 0000000000..bc8fc31b06 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitshift/bitshift.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# used to generate all bitshift ONNX models + +import onnx + +def build_model(name, input1_shape, input2_shape, output_shape, direction): + """ + Build a BitShift ONNX model with specified input/output shapes and direction. + + Args: + name: Name of the model (used for file naming) + input1_shape: Shape of first input ([] for scalar) + input2_shape: Shape of second input ([] for scalar) + output_shape: Shape of output ([] for scalar) + direction: "LEFT" or "RIGHT" + """ + op_type = "BitShift" + + nodes = [ + onnx.helper.make_node( + op_type, + inputs=["input1", "input2"], + outputs=["output"], + name=f"/{op_type}", + direction=direction + ), + ] + + inputs = [ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input1_shape + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input2_shape + ), + ), + ] + + outputs = [ + onnx.helper.make_value_info( + name="output", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=output_shape + ), + ) + ] + + model = onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 18)], + graph=onnx.helper.make_graph( + name="main_graph", + nodes=nodes, + inputs=inputs, + outputs=outputs, + initializer=[] + ), + ) + + onnx.checker.check_model(model) + onnx.save(model, f"{name}.onnx") + print(f"Finished exporting model to {name}.onnx") + +if __name__ == "__main__": + # Define all model configurations + configs = [ + # (name, input1_shape, input2_shape, output_shape, direction) + ("bitshift_left", [4], [4], [4], "LEFT"), + ("bitshift_right", [4], [4], [4], "RIGHT"), + ("bitshift_left_scalar", [4], [], [4], "LEFT"), + ("bitshift_right_scalar", [4], [], [4], "RIGHT"), + ("scalar_bitshift_left", [], [4], [4], "LEFT"), + ("scalar_bitshift_right", [], [4], [4], "RIGHT"), + ("scalar_bitshift_left_scalar", [], [], [], "LEFT"), + ("scalar_bitshift_right_scalar", [], [], [], "RIGHT"), + ] + + for config in configs: + build_model(*config) \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/bitshift/bitshift_left.onnx b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_left.onnx new file mode 100644 index 0000000000..39e53a1a49 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_left.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/bitshift_left_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_left_scalar.onnx new file mode 100644 index 0000000000..4b406cfd03 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_left_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/bitshift_right.onnx b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_right.onnx new file mode 100644 index 0000000000..b9c990d3e9 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_right.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/bitshift_right_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_right_scalar.onnx new file mode 100644 index 0000000000..aa05c05eb8 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/bitshift_right_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/mod.rs b/crates/burn-import/onnx-tests/tests/bitshift/mod.rs new file mode 100644 index 0000000000..b6dded20f0 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitshift/mod.rs @@ -0,0 +1,132 @@ +// Include the models for this node type +use crate::include_models; +include_models!( + bitshift_left, + bitshift_left_scalar, + scalar_bitshift_left, + scalar_bitshift_left_scalar, + bitshift_right, + bitshift_right_scalar, + scalar_bitshift_right, + scalar_bitshift_right_scalar +); + +#[cfg(test)] +mod tests { + use super::*; + use burn::tensor::{Int, Tensor, TensorData}; + + type Backend = burn_ndarray::NdArray; + + #[test] + fn bitshift_left_tensors() { + // Initialize the model with weights (loaded from the exported file) + let device = Default::default(); + let model: bitshift_left::Model = bitshift_left::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let input2 = Tensor::::from_ints([1, 1, 2, 2], &device); + let output = model.forward(input1, input2); + let expected = TensorData::from([2i64, 4, 12, 16]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn bitshift_left_scalar_tensor() { + // Initialize the model with weights (loaded from the exported file) + let device = Default::default(); + let model: bitshift_left_scalar::Model = bitshift_left_scalar::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let scalar = 2; + let output = model.forward(input1, scalar); + let expected = TensorData::from([4i64, 8, 12, 16]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn bitshift_right_tensors() { + let device = Default::default(); + let model: bitshift_right::Model = bitshift_right::Model::new(&device); + + // Run the model + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let input2 = Tensor::::from_ints([1, 1, 2, 2], &device); + let output = model.forward(input1, input2); + let expected = TensorData::from([0i64, 1, 0, 1]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn bitshift_right_scalar_tensor() { + // Initialize the model with weights (loaded from the exported file) + let device = Default::default(); + let model: bitshift_right_scalar::Model = + bitshift_right_scalar::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let scalar = 2; + let output = model.forward(input1, scalar); + let expected = TensorData::from([0i64, 0, 0, 1]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitshift_left_tensor() { + let device = Default::default(); + let model: scalar_bitshift_left::Model = scalar_bitshift_left::Model::new(&device); + // Run the model + let scalar = 4; + let shift_amounts = Tensor::::from_ints([1, 1, 2, 2], &device); + let output = model.forward(scalar, shift_amounts); + // 4 << 1 = 8, 4 << 1 = 8, 4 << 2 = 16, 4 << 2 = 16 + let expected = TensorData::from([8i64, 8, 16, 16]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitshift_right_tensor() { + let device = Default::default(); + let model: scalar_bitshift_right::Model = + scalar_bitshift_right::Model::new(&device); + // Run the model + let scalar = 8; + let shift_amounts = Tensor::::from_ints([1, 2, 3, 4], &device); + let output = model.forward(scalar, shift_amounts); + // 8 >> 1 = 4, 8 >> 2 = 2, 8 >> 3 = 1, 8 >> 4 = 0 + let expected = TensorData::from([4i64, 2, 1, 0]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitshift_left_scalar() { + let device = Default::default(); + let model: scalar_bitshift_left_scalar::Model = + scalar_bitshift_left_scalar::Model::new(&device); + // Run the model + let lhs = 4; + let rhs = 2; + let output = model.forward(lhs, rhs); + // 4 << 2 = 16 + assert_eq!(output, 16); + } + + #[test] + fn scalar_bitshift_right_scalar() { + let device = Default::default(); + let model: scalar_bitshift_right_scalar::Model = + scalar_bitshift_right_scalar::Model::new(&device); + // Run the model + let lhs = 16; + let rhs = 2; + let output = model.forward(lhs, rhs); + // 16 >> 2 = 4 + assert_eq!(output, 4); + } +} diff --git a/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_left.onnx b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_left.onnx new file mode 100644 index 0000000000..92872b9032 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_left.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_left_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_left_scalar.onnx new file mode 100644 index 0000000000..ee09388800 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_left_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_right.onnx b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_right.onnx new file mode 100644 index 0000000000..23d45460d3 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_right.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_right_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_right_scalar.onnx new file mode 100644 index 0000000000..5e5e39e743 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitshift/scalar_bitshift_right_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.onnx b/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.onnx new file mode 100644 index 0000000000..7de7b53296 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.py b/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.py new file mode 100644 index 0000000000..ca7a16cfa0 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# used to generate all bitwise_and ONNX models + +import onnx + +def build_model(name, input1_shape, input2_shape, output_shape): + """ + Build a BitwiseAnd ONNX model with specified input/output shapes. + + Args: + name: Name of the model (used for file naming) + input1_shape: Shape of first input ([] for scalar) + input2_shape: Shape of second input ([] for scalar) + output_shape: Shape of output ([] for scalar) + """ + op_type = "BitwiseAnd" + + nodes = [ + onnx.helper.make_node( + op_type, + inputs=["input1", "input2"], + outputs=["output"], + name=f"/{op_type}" + ), + ] + + inputs = [ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input1_shape + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input2_shape + ), + ), + ] + + outputs = [ + onnx.helper.make_value_info( + name="output", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=output_shape + ), + ) + ] + + model = onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 18)], + graph=onnx.helper.make_graph( + name="main_graph", + nodes=nodes, + inputs=inputs, + outputs=outputs, + initializer=[] + ), + ) + + onnx.checker.check_model(model) + onnx.save(model, f"{name}.onnx") + print(f"Finished exporting model to {name}.onnx") + +if __name__ == "__main__": + # Define all model configurations + configs = [ + # (name, input1_shape, input2_shape, output_shape) + ("bitwise_and", [4], [4], [4]), + ("bitwise_and_scalar", [4], [], [4]), + ("scalar_bitwise_and", [], [4], [4]), + ("scalar_bitwise_and_scalar", [], [], []), + ] + + for config in configs: + build_model(*config) \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and_scalar.onnx new file mode 100644 index 0000000000..fb960631ee Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_and/bitwise_and_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_and/mod.rs b/crates/burn-import/onnx-tests/tests/bitwise_and/mod.rs new file mode 100644 index 0000000000..18d6656e98 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_and/mod.rs @@ -0,0 +1,66 @@ +// Include the models for this node type +use crate::include_models; +include_models!( + bitwise_and, + bitwise_and_scalar, + scalar_bitwise_and, + scalar_bitwise_and_scalar +); + +#[cfg(test)] +mod tests { + use super::*; + use burn::tensor::{Int, Tensor, TensorData}; + + type Backend = burn_ndarray::NdArray; + + #[test] + fn bitwise_and_tensors() { + let device = Default::default(); + let model: bitwise_and::Model = bitwise_and::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let input2 = Tensor::::from_ints([1, 1, 2, 2], &device); + let output = model.forward(input1, input2); + let expected = TensorData::from([1i64, 0, 2, 0]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn bitwise_and_scalar_tensor() { + let device = Default::default(); + let model: bitwise_and_scalar::Model = bitwise_and_scalar::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([1, 2, 3, 4], &device); + let scalar = 2; + let output = model.forward(input1, scalar); + let expected = TensorData::from([0i64, 2, 2, 0]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitwise_and_tensor() { + let device = Default::default(); + let model: scalar_bitwise_and::Model = scalar_bitwise_and::Model::new(&device); + // Run the model + let scalar = 2; + let input2 = Tensor::::from_ints([1, 2, 3, 4], &device); + let output = model.forward(scalar, input2); + // Bitwise AND is commutative, so result should be same as tensor-scalar + let expected = TensorData::from([0i64, 2, 2, 0]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitwise_and_scalar() { + let device = Default::default(); + let model: scalar_bitwise_and_scalar::Model = + scalar_bitwise_and_scalar::Model::new(&device); + // Run the model + let lhs = 7; // 0b111 + let rhs = 3; // 0b011 + let output = model.forward(lhs, rhs); + // 7 & 3 = 3 + assert_eq!(output, 3); + } +} diff --git a/crates/burn-import/onnx-tests/tests/bitwise_and/scalar_bitwise_and.onnx b/crates/burn-import/onnx-tests/tests/bitwise_and/scalar_bitwise_and.onnx new file mode 100644 index 0000000000..f835767c09 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_and/scalar_bitwise_and.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_and/scalar_bitwise_and_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitwise_and/scalar_bitwise_and_scalar.onnx new file mode 100644 index 0000000000..3a9a2a06c7 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_and/scalar_bitwise_and_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.onnx b/crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.onnx new file mode 100644 index 0000000000..064610b28d Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.py b/crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.py new file mode 100644 index 0000000000..dbde9d485a --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_not/bitwise_not.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/bitwise_not/bitwise_not.onnx + +import onnx + +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 18)], + graph=onnx.helper.make_graph( + name="main_graph", + nodes=[ + onnx.helper.make_node( + "BitwiseNot", + inputs=["input"], + outputs=["output"], + name="/BitwiseNot" + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=[1, 4] + ), + ), + ], + outputs=[ + onnx.helper.make_value_info( + name="output", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=[1, 4] + ), + ) + ] + ), + ) + +def main(): + onnx_model = build_model() + file_name = "bitwise_not.onnx" + + onnx.checker.check_model(onnx_model) # Ensure valid ONNX + onnx.save(onnx_model, file_name) # Save the model + print(f"Finished exporting model to {file_name}") + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/bitwise_not/mod.rs b/crates/burn-import/onnx-tests/tests/bitwise_not/mod.rs new file mode 100644 index 0000000000..12aec646b5 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_not/mod.rs @@ -0,0 +1,22 @@ +// Include the models for this node type +use crate::include_models; +include_models!(bitwise_not); + +#[cfg(test)] +mod tests { + use super::*; + use burn::tensor::{Int, Tensor, TensorData}; + + type Backend = burn_ndarray::NdArray; + + #[test] + fn bitwise_not_tensors() { + let device = Default::default(); + let model: bitwise_not::Model = bitwise_not::Model::new(&device); + // Run the model + let input = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let output = model.forward(input); + let expected = TensorData::from([[-2i64, -3, -4, -5]]); + output.to_data().assert_eq(&expected, true); + } +} diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or.onnx b/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or.onnx new file mode 100644 index 0000000000..83ec97b3da Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or.py b/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or.py new file mode 100644 index 0000000000..c0f4d75f9b --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or.py @@ -0,0 +1,76 @@ +import torch +import onnx + + +def build_model(scalar=False, scalar_first=False): + if scalar_first: + input1_shape = [] + input2_shape = [1, 4] + else: + input1_shape = [1, 4] + input2_shape = [1, 4] if not scalar else [] + + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 18)], + graph=onnx.helper.make_graph( + name="main_graph", + nodes=[ + onnx.helper.make_node( + "BitwiseOr", + inputs=["input1", "input2"], + outputs=["output"], + name="/BitwiseOr" + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input1_shape + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input2_shape + ), + ), + ], + outputs=[ + onnx.helper.make_value_info( + name="output", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=[1, 4] + ), + ) + ] + ), + ) + + +def main(): + onnx_model = build_model() + file_name = "bitwise_or.onnx" + + onnx.checker.check_model(onnx_model) # Ensure valid ONNX + onnx.save(onnx_model, file_name) # Save the model + print(f"Finished exporting model to {file_name}") + + onnx_scalar_model = build_model(scalar=True) + scalar_file_name = "bitwise_or_scalar.onnx" + + onnx.checker.check_model(onnx_scalar_model) # Ensure valid ONNX + onnx.save(onnx_scalar_model, scalar_file_name) # Save the model + print(f"Finished exporting scalar model to {scalar_file_name}") + + # Scalar-Tensor version + onnx_scalar_first_model = build_model(scalar_first=True) + scalar_first_file_name = "scalar_bitwise_or.onnx" + + onnx.checker.check_model(onnx_scalar_first_model) # Ensure valid ONNX + onnx.save(onnx_scalar_first_model, scalar_first_file_name) # Save the model + print(f"Finished exporting scalar-first model to {scalar_first_file_name}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or_scalar.onnx new file mode 100644 index 0000000000..885f2a045c Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_or/bitwise_or_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/mod.rs b/crates/burn-import/onnx-tests/tests/bitwise_or/mod.rs new file mode 100644 index 0000000000..4cbfd644bb --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_or/mod.rs @@ -0,0 +1,69 @@ +// Include the models for this node type +use crate::include_models; +include_models!( + bitwise_or, + bitwise_or_scalar, + scalar_bitwise_or, + scalar_bitwise_or_scalar +); + +#[cfg(test)] +mod tests { + use super::*; + use burn::tensor::{Int, Tensor, TensorData}; + + type Backend = burn_ndarray::NdArray; + + #[test] + fn bitwise_or_tensors() { + // Initialize the model + let device = Default::default(); + let model: bitwise_or::Model = bitwise_or::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let input2 = Tensor::::from_ints([[1, 1, 2, 2]], &device); + let output = model.forward(input1, input2); + let expected = TensorData::from([[1i64, 3, 3, 6]]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn bitwise_or_scalar_tensor() { + // Initialize the model + let device = Default::default(); + let model: bitwise_or_scalar::Model = bitwise_or_scalar::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let scalar = 2; + let output = model.forward(input1, scalar); + let expected = TensorData::from([[3i64, 2, 3, 6]]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitwise_or_tensor() { + // Initialize the model + let device = Default::default(); + let model: scalar_bitwise_or::Model = scalar_bitwise_or::Model::new(&device); + // Run the model + let scalar = 2; + let input2 = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let output = model.forward(scalar, input2); + // Bitwise OR is commutative, so result should be same as tensor-scalar + let expected = TensorData::from([[3i64, 2, 3, 6]]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitwise_or_scalar() { + let device = Default::default(); + let model: scalar_bitwise_or_scalar::Model = + scalar_bitwise_or_scalar::Model::new(&device); + // Run the model + let lhs = 5; // 0b101 + let rhs = 3; // 0b011 + let output = model.forward(lhs, rhs); + // 5 | 3 = 7 (0b111) + assert_eq!(output, 7); + } +} diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or.onnx b/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or.onnx new file mode 100644 index 0000000000..4738b94e0c Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or_scalar.onnx new file mode 100644 index 0000000000..8a35005cff Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or_scalar.py b/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or_scalar.py new file mode 100644 index 0000000000..102750ea2e --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_or/scalar_bitwise_or_scalar.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +import onnx +import onnx.helper as helper +import onnx.checker as checker +import onnx.numpy_helper +import numpy as np + +def build_model(): + # Scalar inputs + input1 = helper.make_tensor_value_info("input1", onnx.TensorProto.INT32, []) + input2 = helper.make_tensor_value_info("input2", onnx.TensorProto.INT32, []) + output = helper.make_tensor_value_info("output", onnx.TensorProto.INT32, []) + + # Create bitwise OR node + or_node = helper.make_node( + "BitwiseOr", + inputs=["input1", "input2"], + outputs=["output"] + ) + + # Create the graph + graph_def = helper.make_graph( + [or_node], + "scalar_bitwise_or_scalar", + [input1, input2], + [output], + ) + + # Create the model + model_def = helper.make_model(graph_def, producer_name="scalar_bitwise_or_scalar") + checker.check_model(model_def) + + return model_def + +if __name__ == "__main__": + model = build_model() + onnx.save(model, "scalar_bitwise_or_scalar.onnx") \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor.onnx b/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor.onnx new file mode 100644 index 0000000000..151c641477 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor.py b/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor.py new file mode 100644 index 0000000000..253496feee --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# used to generate model: onnx-tests/tests/bitwise_xor/bitwise_xor.onnx + +import onnx + +def build_model(scalar=False, scalar_first=False): + if scalar_first: + input1_shape = [] + input2_shape = [1, 4] + else: + input1_shape = [1, 4] + input2_shape = [1, 4] if not scalar else [] + + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 18)], + graph=onnx.helper.make_graph( + name="main_graph", + nodes=[ + onnx.helper.make_node( + "BitwiseXor", + inputs=["input1", "input2"], + outputs=["output"], + name="/BitwiseXor" + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input1_shape + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=input2_shape + ), + ), + ], + outputs=[ + onnx.helper.make_value_info( + name="output", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT32, shape=[1, 4] + ), + ) + ] + ), + ) + + +def main(): + onnx_model = build_model() + file_name = "bitwise_xor.onnx" + + onnx.checker.check_model(onnx_model) # Ensure valid ONNX + onnx.save(onnx_model, file_name) # Save the model + print(f"Finished exporting model to {file_name}") + + onnx_scalar_model = build_model(scalar=True) + scalar_file_name = "bitwise_xor_scalar.onnx" + + onnx.checker.check_model(onnx_scalar_model) # Ensure valid ONNX + onnx.save(onnx_scalar_model, scalar_file_name) # Save the model + print(f"Finished exporting scalar model to {scalar_file_name}") + + # Scalar-Tensor version + onnx_scalar_first_model = build_model(scalar_first=True) + scalar_first_file_name = "scalar_bitwise_xor.onnx" + + onnx.checker.check_model(onnx_scalar_first_model) # Ensure valid ONNX + onnx.save(onnx_scalar_first_model, scalar_first_file_name) # Save the model + print(f"Finished exporting scalar-first model to {scalar_first_file_name}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor_scalar.onnx new file mode 100644 index 0000000000..6adba4564c Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_xor/bitwise_xor_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/mod.rs b/crates/burn-import/onnx-tests/tests/bitwise_xor/mod.rs new file mode 100644 index 0000000000..9dd8048823 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_xor/mod.rs @@ -0,0 +1,66 @@ +// Include the models for this node type +use crate::include_models; +include_models!( + bitwise_xor, + bitwise_xor_scalar, + scalar_bitwise_xor, + scalar_bitwise_xor_scalar +); + +#[cfg(test)] +mod tests { + use super::*; + use burn::tensor::{Int, Tensor, TensorData}; + + type Backend = burn_ndarray::NdArray; + + #[test] + fn bitwise_xor_tensors() { + let device = Default::default(); + let model: bitwise_xor::Model = bitwise_xor::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let input2 = Tensor::::from_ints([[1, 1, 2, 2]], &device); + let output = model.forward(input1, input2); + let expected = TensorData::from([[0i64, 3, 1, 6]]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn bitwise_xor_scalar_tensor() { + let device = Default::default(); + let model: bitwise_xor_scalar::Model = bitwise_xor_scalar::Model::new(&device); + // Run the model + let input1 = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let scalar = 2; + let output = model.forward(input1, scalar); + let expected = TensorData::from([[3i64, 0, 1, 6]]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitwise_xor_tensor() { + let device = Default::default(); + let model: scalar_bitwise_xor::Model = scalar_bitwise_xor::Model::new(&device); + // Run the model + let scalar = 2; + let input2 = Tensor::::from_ints([[1, 2, 3, 4]], &device); + let output = model.forward(scalar, input2); + // Bitwise XOR is commutative, so result should be same as tensor-scalar + let expected = TensorData::from([[3i64, 0, 1, 6]]); + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn scalar_bitwise_xor_scalar() { + let device = Default::default(); + let model: scalar_bitwise_xor_scalar::Model = + scalar_bitwise_xor_scalar::Model::new(&device); + // Run the model + let lhs = 5; // 0b101 + let rhs = 3; // 0b011 + let output = model.forward(lhs, rhs); + // 5 ^ 3 = 6 (0b110) + assert_eq!(output, 6); + } +} diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor.onnx b/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor.onnx new file mode 100644 index 0000000000..bb038bc8d4 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx b/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx new file mode 100644 index 0000000000..fdb0adb6bc Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor_scalar.py b/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor_scalar.py new file mode 100644 index 0000000000..4fd073c31f --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/bitwise_xor/scalar_bitwise_xor_scalar.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +import onnx +import onnx.helper as helper +import onnx.checker as checker +import onnx.numpy_helper +import numpy as np + +def build_model(): + # Scalar inputs + input1 = helper.make_tensor_value_info("input1", onnx.TensorProto.INT32, []) + input2 = helper.make_tensor_value_info("input2", onnx.TensorProto.INT32, []) + output = helper.make_tensor_value_info("output", onnx.TensorProto.INT32, []) + + # Create bitwise XOR node + xor_node = helper.make_node( + "BitwiseXor", + inputs=["input1", "input2"], + outputs=["output"] + ) + + # Create the graph + graph_def = helper.make_graph( + [xor_node], + "scalar_bitwise_xor_scalar", + [input1, input2], + [output], + ) + + # Create the model + model_def = helper.make_model(graph_def, producer_name="scalar_bitwise_xor_scalar") + checker.check_model(model_def) + + return model_def + +if __name__ == "__main__": + model = build_model() + onnx.save(model, "scalar_bitwise_xor_scalar.onnx") \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/test_mod.rs b/crates/burn-import/onnx-tests/tests/test_mod.rs index 341b300116..14f365ed56 100644 --- a/crates/burn-import/onnx-tests/tests/test_mod.rs +++ b/crates/burn-import/onnx-tests/tests/test_mod.rs @@ -12,6 +12,11 @@ pub mod argmin; pub mod avg_pool; pub mod batch_norm; pub mod bernoulli; +pub mod bitshift; +pub mod bitwise_and; +pub mod bitwise_not; +pub mod bitwise_or; +pub mod bitwise_xor; pub mod cast; pub mod ceil; pub mod clip; diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index ffa1fb3fe9..3f39b45213 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -2,23 +2,25 @@ use std::marker::PhantomData; use super::{ argmax::ArgMaxNode, argmin::ArgMinNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, - batch_norm::BatchNormNode, bernoulli::BernoulliNode, binary::BinaryNode, ceil::CeilNode, - clip::ClipNode, concat::ConcatNode, constant::ConstantNode, - constant_of_shape::ConstantOfShapeNode, conv_transpose_1d::ConvTranspose1dNode, - conv_transpose_2d::ConvTranspose2dNode, conv_transpose_3d::ConvTranspose3dNode, - conv1d::Conv1dNode, conv2d::Conv2dNode, conv3d::Conv3dNode, depth_to_space::DepthToSpaceNode, - dropout::DropoutNode, expand::ExpandNode, floor::FloorNode, gather::GatherNode, - gather_elements::GatherElementsNode, gemm::GemmNode, global_avg_pool::GlobalAvgPoolNode, - group_norm::GroupNormNode, instance_norm::InstanceNormNode, layer_norm::LayerNormNode, - linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, - max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode, pad::PadNode, prelu::PReluNode, - random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, - random_uniform::RandomUniformNode, random_uniform_like::RandomUniformLikeNode, - range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, round::RoundNode, slice::SliceNode, - space_to_depth::SpaceToDepthNode, split::SplitNode, squeeze::SqueezeNode, sum::SumNode, - tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + batch_norm::BatchNormNode, bernoulli::BernoulliNode, binary::BinaryNode, + bitshift::BitShiftNode, bitwiseand::BitwiseAndNode, bitwisenot::BitwiseNotNode, + bitwiseor::BitwiseOrNode, bitwisexor::BitwiseXorNode, ceil::CeilNode, clip::ClipNode, + concat::ConcatNode, constant::ConstantNode, constant_of_shape::ConstantOfShapeNode, + conv_transpose_1d::ConvTranspose1dNode, conv_transpose_2d::ConvTranspose2dNode, + conv_transpose_3d::ConvTranspose3dNode, conv1d::Conv1dNode, conv2d::Conv2dNode, + conv3d::Conv3dNode, depth_to_space::DepthToSpaceNode, dropout::DropoutNode, expand::ExpandNode, + floor::FloorNode, gather::GatherNode, gather_elements::GatherElementsNode, gemm::GemmNode, + global_avg_pool::GlobalAvgPoolNode, group_norm::GroupNormNode, instance_norm::InstanceNormNode, + layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode, + pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, + random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode, + random_uniform_like::RandomUniformLikeNode, range::RangeNode, reshape::ReshapeNode, + resize::ResizeNode, round::RoundNode, slice::SliceNode, split::SplitNode, squeeze::SqueezeNode, + sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode, + unsqueeze::UnsqueezeNode, }; -use crate::burn::{BurnImports, Scope, Type}; +use crate::burn::{BurnImports, Scope, Type, node::space_to_depth::SpaceToDepthNode}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use serde::Serialize; @@ -92,6 +94,11 @@ pub enum Node { BatchNorm(BatchNormNode), Bernoulli(BernoulliNode), Binary(BinaryNode), + BitShift(BitShiftNode), + BitwiseAnd(BitwiseAndNode), + BitwiseOr(BitwiseOrNode), + BitwiseNot(BitwiseNotNode), + BitwiseXor(BitwiseXorNode), Clip(ClipNode), Concat(ConcatNode), Constant(ConstantNode), @@ -157,6 +164,11 @@ macro_rules! match_all { Node::BatchNorm(node) => $func(node), Node::Bernoulli(node) => $func(node), Node::Binary(node) => $func(node), + Node::BitShift(node) => $func(node), + Node::BitwiseAnd(node) => $func(node), + Node::BitwiseOr(node) => $func(node), + Node::BitwiseNot(node) => $func(node), + Node::BitwiseXor(node) => $func(node), Node::Clip(node) => $func(node), Node::Concat(node) => $func(node), Node::Constant(node) => $func(node), @@ -230,6 +242,11 @@ impl Node { Node::BatchNorm(_) => "batch_norm", Node::Bernoulli(_) => "bernoulli", Node::Binary(binary) => binary.binary_type.as_str(), + Node::BitShift(_) => "bitshift", + Node::BitwiseAnd(_) => "bitwiseand", + Node::BitwiseOr(_) => "bitwiseor", + Node::BitwiseNot(_) => "bitwisenot", + Node::BitwiseXor(_) => "bitwisexor", Node::Concat(_) => "concat", Node::Clip(_) => "clip", Node::Constant(_) => "constant", diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index 3be9a8f7dc..0adb72772c 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -176,6 +176,7 @@ impl BinaryNode { Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function)) } + pub(crate) fn powf(lhs: Type, rhs: Type, output: Type) -> Self { let function = match (&lhs, &rhs) { (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.powf(#rhs) }, diff --git a/crates/burn-import/src/burn/node/bitshift.rs b/crates/burn-import/src/burn/node/bitshift.rs new file mode 100644 index 0000000000..92436c467d --- /dev/null +++ b/crates/burn-import/src/burn/node/bitshift.rs @@ -0,0 +1,205 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{BurnImports, Scope, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Direction { + Left, + Right, +} + +#[derive(Debug, Clone, new)] +pub struct BitShiftNode { + pub inputs: Vec, + pub output: Type, + pub direction: Direction, +} + +impl NodeCodegen for BitShiftNode { + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + self.inputs.clone() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let output = &self.output.name(); + + let operation = match (&self.inputs[0], &self.inputs[1]) { + (Type::Tensor(lhs_tensor), Type::Tensor(rhs_tensor)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + match self.direction { + Direction::Left => quote! { #lhs.bitwise_left_shift(#rhs) }, + Direction::Right => quote! { #lhs.bitwise_right_shift(#rhs) }, + } + } + (Type::Tensor(lhs_tensor), Type::Scalar(rhs_scalar)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = &rhs_scalar.name; + match self.direction { + Direction::Left => quote! { #lhs.bitwise_left_shift_scalar(#rhs.elem()) }, + Direction::Right => quote! { #lhs.bitwise_right_shift_scalar(#rhs.elem()) }, + } + } + (Type::Scalar(lhs_scalar), Type::Tensor(rhs_tensor)) => { + let lhs = &lhs_scalar.name; + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + // For scalar op tensor, we need to broadcast the scalar to a tensor first + let shift_op = match self.direction { + Direction::Left => quote! { _scalar_tensor.bitwise_left_shift(#rhs) }, + Direction::Right => quote! { _scalar_tensor.bitwise_right_shift(#rhs) }, + }; + quote! { + { + let _scalar_tensor = Tensor::full(#rhs.shape(), #lhs, &#rhs.device()); + #shift_op + } + } + } + (Type::Scalar(lhs_scalar), Type::Scalar(rhs_scalar)) => { + let lhs = &lhs_scalar.name; + let rhs = &rhs_scalar.name; + match self.direction { + Direction::Left => quote! { #lhs << #rhs }, + Direction::Right => quote! { #lhs >> #rhs }, + } + } + _ => panic!("BitShiftNode only supports tensor and scalar inputs"), + }; + + quote! { + let #output = #operation; + } + } + + fn into_node(self) -> Node { + Node::BitShift(self) + } + + fn register_imports(&self, imports: &mut BurnImports) { + // Register ElementConversion for scalar operations + for input in &self.inputs { + if matches!(input, Type::Scalar(_)) { + imports.register("burn::tensor::ElementConversion"); + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + TensorType, + graph::BurnGraph, + node::{bitshift::BitShiftNode, test::assert_tokens}, + }; + + #[test] + fn test_codegen_bitshift_left() { + let mut graph = BurnGraph::::default(); + + graph.register(BitShiftNode::new( + vec![ + Type::Tensor(TensorType::new_int("input1", 1)), + Type::Tensor(TensorType::new_int("input2", 1)), + ], + Type::Tensor(TensorType::new_int("output", 1)), + Direction::Left, + )); + + graph.register_input_output( + vec!["input1".to_string(), "input2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input1: Tensor, input2: Tensor) -> Tensor { + let output = input1.bitwise_left_shift(input2); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_bitshift_right() { + let mut graph = BurnGraph::::default(); + + graph.register(BitShiftNode::new( + vec![ + Type::Tensor(TensorType::new_int("input1", 1)), + Type::Tensor(TensorType::new_int("input2", 1)), + ], + Type::Tensor(TensorType::new_int("output", 1)), + Direction::Right, + )); + + graph.register_input_output( + vec!["input1".to_string(), "input2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input1: Tensor, input2: Tensor) -> Tensor { + let output = input1.bitwise_right_shift(input2); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/bitwiseand.rs b/crates/burn-import/src/burn/node/bitwiseand.rs new file mode 100644 index 0000000000..f3d258a788 --- /dev/null +++ b/crates/burn-import/src/burn/node/bitwiseand.rs @@ -0,0 +1,144 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{BurnImports, Scope, TensorKind, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct BitwiseAndNode { + pub inputs: Vec, + pub output: Type, +} + +impl NodeCodegen for BitwiseAndNode { + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + self.inputs.clone() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let output = &self.output.name(); + + let operation = match (&self.inputs[0], &self.inputs[1]) { + (Type::Tensor(lhs_tensor), Type::Tensor(rhs_tensor)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + quote! { #lhs.bitwise_and(#rhs) } + } + (Type::Tensor(lhs_tensor), Type::Scalar(rhs_scalar)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = &rhs_scalar.name; + quote! { #lhs.bitwise_and_scalar(#rhs.elem()) } + } + (Type::Scalar(lhs_scalar), Type::Tensor(rhs_tensor)) => { + let lhs = &lhs_scalar.name; + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + // Bitwise AND is commutative, so we can swap the order + quote! { #rhs.bitwise_and_scalar(#lhs.elem()) } + } + (Type::Scalar(lhs_scalar), Type::Scalar(rhs_scalar)) => { + let lhs = &lhs_scalar.name; + let rhs = &rhs_scalar.name; + quote! { #lhs & #rhs } + } + _ => panic!("BitwiseAndNode only supports tensor and scalar inputs"), + }; + + quote! { + let #output = #operation; + } + } + + fn into_node(self) -> Node { + match &self.output { + Type::Tensor(tensor) => { + if tensor.kind != TensorKind::Int { + panic!("BitwiseAndNode only supports Int tensor outputs"); + } + } + Type::Scalar(scalar) => { + if !matches!( + scalar.kind, + crate::burn::ScalarKind::Int32 | crate::burn::ScalarKind::Int64 + ) { + panic!("BitwiseAndNode only supports Int scalar outputs"); + } + } + _ => panic!("BitwiseAndNode only supports tensor and scalar outputs"), + } + Node::BitwiseAnd(self) + } + + fn register_imports(&self, imports: &mut BurnImports) { + // Register ElementConversion for scalar operations + for input in &self.inputs { + if matches!(input, Type::Scalar(_)) { + imports.register("burn::tensor::ElementConversion"); + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + TensorType, + graph::BurnGraph, + node::{bitwiseand::BitwiseAndNode, test::assert_tokens}, + }; + + #[test] + fn test_codegen_bitwise_and() { + let mut graph = BurnGraph::::default(); + + graph.register(BitwiseAndNode { + inputs: vec![ + Type::Tensor(TensorType::new_int("input1", 1)), + Type::Tensor(TensorType::new_int("input2", 1)), + ], + output: Type::Tensor(TensorType::new_int("output", 1)), + }); + graph.register_input_output( + vec!["input1".to_string(), "input2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input1: Tensor, input2: Tensor) -> Tensor { + let output = input1.bitwise_and(input2); + output + } + } + }; + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/bitwisenot.rs b/crates/burn-import/src/burn/node/bitwisenot.rs new file mode 100644 index 0000000000..16ef93913b --- /dev/null +++ b/crates/burn-import/src/burn/node/bitwisenot.rs @@ -0,0 +1,96 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorKind, TensorType, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct BitwiseNotNode { + pub input: TensorType, + pub output: TensorType, +} + +impl NodeCodegen for BitwiseNotNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![{ + if self.input.kind != TensorKind::Int { + panic!("BitwiseNotNode only supports Int TensorType inputs"); + } + Type::Tensor(self.input.clone()) + }] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + quote! { + let #output = #input.bitwise_not(); + } + } + + fn into_node(self) -> Node { + if self.output.kind != TensorKind::Int { + panic!("BitwiseNotNode only supports Int TensorType outputs"); + } + Node::BitwiseNot(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + TensorType, + graph::BurnGraph, + node::{bitwisenot::BitwiseNotNode, test::assert_tokens}, + }; + + #[test] + fn test_codegen_bitwise_not() { + let mut graph = BurnGraph::::default(); + + graph.register(BitwiseNotNode { + input: TensorType::new_int("input", 2), + output: TensorType::new_int("output", 2), + }); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor} + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.bitwise_not(); + output + } + } + }; + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/bitwiseor.rs b/crates/burn-import/src/burn/node/bitwiseor.rs new file mode 100644 index 0000000000..81ac6ae897 --- /dev/null +++ b/crates/burn-import/src/burn/node/bitwiseor.rs @@ -0,0 +1,144 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{BurnImports, Scope, TensorKind, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct BitwiseOrNode { + pub inputs: Vec, + pub output: Type, +} + +impl NodeCodegen for BitwiseOrNode { + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + self.inputs.clone() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let output = &self.output.name(); + + let operation = match (&self.inputs[0], &self.inputs[1]) { + (Type::Tensor(lhs_tensor), Type::Tensor(rhs_tensor)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + quote! { #lhs.bitwise_or(#rhs) } + } + (Type::Tensor(lhs_tensor), Type::Scalar(rhs_scalar)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = &rhs_scalar.name; + quote! { #lhs.bitwise_or_scalar(#rhs.elem()) } + } + (Type::Scalar(lhs_scalar), Type::Tensor(rhs_tensor)) => { + let lhs = &lhs_scalar.name; + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + // Bitwise OR is commutative, so we can swap the order + quote! { #rhs.bitwise_or_scalar(#lhs.elem()) } + } + (Type::Scalar(lhs_scalar), Type::Scalar(rhs_scalar)) => { + let lhs = &lhs_scalar.name; + let rhs = &rhs_scalar.name; + quote! { #lhs | #rhs } + } + _ => panic!("BitwiseOrNode only supports tensor and scalar inputs"), + }; + + quote! { + let #output = #operation; + } + } + + fn into_node(self) -> Node { + match &self.output { + Type::Tensor(tensor) => { + if tensor.kind != TensorKind::Int { + panic!("BitwiseOrNode only supports Int tensor outputs"); + } + } + Type::Scalar(scalar) => { + if !matches!( + scalar.kind, + crate::burn::ScalarKind::Int32 | crate::burn::ScalarKind::Int64 + ) { + panic!("BitwiseOrNode only supports Int scalar outputs"); + } + } + _ => panic!("BitwiseOrNode only supports tensor and scalar outputs"), + } + Node::BitwiseOr(self) + } + + fn register_imports(&self, imports: &mut BurnImports) { + // Register ElementConversion for scalar operations + for input in &self.inputs { + if matches!(input, Type::Scalar(_)) { + imports.register("burn::tensor::ElementConversion"); + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + TensorType, + graph::BurnGraph, + node::{bitwiseor::BitwiseOrNode, test::assert_tokens}, + }; + + #[test] + fn test_codegen_bitwise_or() { + let mut graph = BurnGraph::::default(); + + graph.register(BitwiseOrNode { + inputs: vec![ + Type::Tensor(TensorType::new_int("input1", 2)), + Type::Tensor(TensorType::new_int("input2", 2)), + ], + output: Type::Tensor(TensorType::new_int("output", 2)), + }); + graph.register_input_output( + vec!["input1".to_string(), "input2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input1: Tensor, input2: Tensor) -> Tensor { + let output = input1.bitwise_or(input2); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/bitwisexor.rs b/crates/burn-import/src/burn/node/bitwisexor.rs new file mode 100644 index 0000000000..079eaad912 --- /dev/null +++ b/crates/burn-import/src/burn/node/bitwisexor.rs @@ -0,0 +1,149 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{BurnImports, Scope, TensorKind, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct BitwiseXorNode { + pub inputs: Vec, + pub output: Type, +} + +impl NodeCodegen for BitwiseXorNode { + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn input_types(&self) -> Vec { + self.inputs.clone() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let output = &self.output.name(); + + let operation = match (&self.inputs[0], &self.inputs[1]) { + (Type::Tensor(lhs_tensor), Type::Tensor(rhs_tensor)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + quote! { #lhs.bitwise_xor(#rhs) } + } + (Type::Tensor(lhs_tensor), Type::Scalar(rhs_scalar)) => { + let lhs = scope.tensor_use_owned(lhs_tensor, node_position); + let rhs = &rhs_scalar.name; + quote! { #lhs.bitwise_xor_scalar(#rhs.elem()) } + } + (Type::Scalar(lhs_scalar), Type::Tensor(rhs_tensor)) => { + let lhs = &lhs_scalar.name; + let rhs = scope.tensor_use_owned(rhs_tensor, node_position); + // Bitwise XOR is commutative, so we can swap the order + quote! { #rhs.bitwise_xor_scalar(#lhs.elem()) } + } + (Type::Scalar(lhs_scalar), Type::Scalar(rhs_scalar)) => { + let lhs = &lhs_scalar.name; + let rhs = &rhs_scalar.name; + quote! { #lhs ^ #rhs } + } + _ => panic!("BitwiseXorNode only supports tensor and scalar inputs"), + }; + + quote! { + let #output = #operation; + } + } + + fn into_node(self) -> Node { + match &self.output { + Type::Tensor(tensor) => { + if tensor.kind != TensorKind::Int { + panic!("BitwiseXorNode only supports Int tensor outputs"); + } + } + Type::Scalar(scalar) => { + if !matches!( + scalar.kind, + crate::burn::ScalarKind::Int32 | crate::burn::ScalarKind::Int64 + ) { + panic!("BitwiseXorNode only supports Int scalar outputs"); + } + } + _ => panic!("BitwiseXorNode only supports tensor and scalar outputs"), + } + Node::BitwiseXor(self) + } + + fn register_imports(&self, imports: &mut BurnImports) { + // Register ElementConversion for scalar operations + for input in &self.inputs { + if matches!(input, Type::Scalar(_)) { + imports.register("burn::tensor::ElementConversion"); + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + TensorType, + graph::BurnGraph, + node::{bitwisexor::BitwiseXorNode, test::assert_tokens}, + }; + + #[test] + fn test_codegen_bitwise_xor() { + let mut graph = BurnGraph::::default(); + + graph.register(BitwiseXorNode { + inputs: vec![ + Type::Tensor(TensorType::new_int("input1", 2)), + Type::Tensor(TensorType::new_int("input2", 2)), + ], + output: Type::Tensor(TensorType::new_int("output", 2)), + }); + graph.register_input_output( + vec!["input1".to_string(), "input2".to_string()], + vec!["output".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + input1: Tensor, + input2: Tensor + ) -> Tensor { + let output = input1.bitwise_xor(input2); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 0bf7468619..72a6e880b6 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -7,6 +7,11 @@ pub(crate) mod avg_pool2d; pub(crate) mod batch_norm; pub(crate) mod bernoulli; pub(crate) mod binary; +pub(crate) mod bitshift; +pub(crate) mod bitwiseand; +pub(crate) mod bitwisenot; +pub(crate) mod bitwiseor; +pub(crate) mod bitwisexor; pub(crate) mod ceil; pub(crate) mod clip; pub(crate) mod concat; diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index f397c36a85..21f728e771 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -23,6 +23,11 @@ use crate::{ batch_norm::BatchNormNode, bernoulli::BernoulliNode, binary::BinaryNode, + bitshift::{BitShiftNode, Direction}, + bitwiseand::BitwiseAndNode, + bitwisenot::BitwiseNotNode, + bitwiseor::BitwiseOrNode, + bitwisexor::BitwiseXorNode, ceil::CeilNode, clip::ClipNode, concat::ConcatNode, @@ -106,6 +111,8 @@ use onnx_ir::{ util::shape_config, }; +use onnx_ir::node::bitshift::bitshift_config; + pub use crate::burn::graph::RecordType; use crate::burn::node::mean::MeanNode; @@ -292,6 +299,11 @@ impl ParsedOnnxGraph { match node.node_type { NodeType::Add => graph.register(Self::add_conversion(node)), NodeType::ArgMax => graph.register(Self::argmax_conversion(node)), + NodeType::BitShift => graph.register(Self::bitshift_conversion(node)), + NodeType::BitwiseAnd => graph.register(Self::bitwise_and_conversion(node)), + NodeType::BitwiseOr => graph.register(Self::bitwise_or_conversion(node)), + NodeType::BitwiseXor => graph.register(Self::bitwise_xor_conversion(node)), + NodeType::BitwiseNot => graph.register(Self::bitwise_not_conversion(node)), NodeType::ArgMin => graph.register(Self::argmin_conversion(node)), NodeType::Bernoulli => graph.register(Self::bernoulli_conversion(node)), NodeType::Sub => graph.register(Self::sub_conversion(node)), @@ -676,6 +688,48 @@ impl ParsedOnnxGraph { BinaryNode::equal(lhs, rhs, output) } + fn bitshift_conversion(node: Node) -> BitShiftNode { + let inputs = node.inputs.iter().map(Type::from).collect(); + let output = Type::from(node.outputs.first().unwrap()); + let onnx_direction = bitshift_config(&node); + + // Map ONNX direction to burn-import Direction + let direction = match onnx_direction { + onnx_ir::node::bitshift::Direction::Left => Direction::Left, + onnx_ir::node::bitshift::Direction::Right => Direction::Right, + }; + + BitShiftNode::new(inputs, output, direction) + } + + fn bitwise_and_conversion(node: Node) -> BitwiseAndNode { + let inputs = node.inputs.iter().map(Type::from).collect(); + let output = Type::from(node.outputs.first().unwrap()); + + BitwiseAndNode::new(inputs, output) + } + + fn bitwise_or_conversion(node: Node) -> BitwiseOrNode { + let inputs = node.inputs.iter().map(Type::from).collect(); + let output = Type::from(node.outputs.first().unwrap()); + + BitwiseOrNode::new(inputs, output) + } + + fn bitwise_xor_conversion(node: Node) -> BitwiseXorNode { + let inputs = node.inputs.iter().map(Type::from).collect(); + let output = Type::from(node.outputs.first().unwrap()); + + BitwiseXorNode::new(inputs, output) + } + + fn bitwise_not_conversion(node: Node) -> BitwiseNotNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + + BitwiseNotNode::new(input, output) + } + fn max_conversion(node: Node) -> BinaryNode { let lhs = Type::from(node.inputs.first().unwrap()); let rhs = Type::from(node.inputs.get(1).unwrap()); diff --git a/crates/onnx-ir/src/node/bitshift.rs b/crates/onnx-ir/src/node/bitshift.rs new file mode 100644 index 0000000000..f00af8c001 --- /dev/null +++ b/crates/onnx-ir/src/node/bitshift.rs @@ -0,0 +1,77 @@ +use crate::ir::Node; + +pub use self::Direction as BitShiftDirection; + +/// Direction for BitShift operation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Direction { + Left, + Right, +} + +impl Direction { + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "left" => Ok(Direction::Left), + "right" => Ok(Direction::Right), + _ => Err(format!("Invalid bit shift direction: {s}")), + } + } +} + +/// Configuration for BitShift operation +pub fn bitshift_config(node: &Node) -> Direction { + let direction_str = node + .attrs + .get("direction") + .map(|val| val.clone().into_string()) + .unwrap_or_else(|| "left".to_string()); + + Direction::from_str(&direction_str) + .unwrap_or_else(|e| panic!("Failed to parse bitshift direction: {e}")) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::NodeType; + use crate::node::test_utils::NodeBuilder; + + #[test] + fn test_bitshift_config_with_direction_left() { + let node = NodeBuilder::new(NodeType::BitShift, "test_bitshift") + .input_tensor_i32("X", 2, None) + .input_tensor_i32("Y", 2, None) + .output_tensor_i32("Z", 2, None) + .attr_string("direction", "left") + .build(); + + let config = bitshift_config(&node); + assert_eq!(config, Direction::Left); + } + + #[test] + fn test_bitshift_config_with_direction_right() { + let node = NodeBuilder::new(NodeType::BitShift, "test_bitshift") + .input_tensor_i32("X", 2, None) + .input_tensor_i32("Y", 2, None) + .output_tensor_i32("Z", 2, None) + .attr_string("direction", "right") + .build(); + + let config = bitshift_config(&node); + assert_eq!(config, Direction::Right); + } + + #[test] + fn test_bitshift_config_default_direction() { + let node = NodeBuilder::new(NodeType::BitShift, "test_bitshift") + .input_tensor_i32("X", 2, None) + .input_tensor_i32("Y", 2, None) + .output_tensor_i32("Z", 2, None) + .build(); + + let config = bitshift_config(&node); + assert_eq!(config, Direction::Left); + } +} diff --git a/crates/onnx-ir/src/node/mod.rs b/crates/onnx-ir/src/node/mod.rs index c1e502f629..dba9cd993e 100644 --- a/crates/onnx-ir/src/node/mod.rs +++ b/crates/onnx-ir/src/node/mod.rs @@ -16,6 +16,7 @@ pub mod avg_pool1d; pub mod avg_pool2d; pub mod batch_norm; pub mod bernoulli; +pub mod bitshift; pub mod cast; pub mod clip; pub mod comparison; diff --git a/crates/onnx-ir/src/rank_inference.rs b/crates/onnx-ir/src/rank_inference.rs index 142f5df4ee..0d52234eac 100644 --- a/crates/onnx-ir/src/rank_inference.rs +++ b/crates/onnx-ir/src/rank_inference.rs @@ -32,6 +32,11 @@ pub fn rank_inference(node: &mut Node) { NodeType::AveragePool1d => same_as_input(node), NodeType::AveragePool2d => same_as_input(node), NodeType::BatchNormalization => same_as_input(node), + NodeType::BitShift => same_as_input_broadcast(node), + NodeType::BitwiseAnd => same_as_input_broadcast(node), + NodeType::BitwiseNot => same_as_input(node), + NodeType::BitwiseOr => same_as_input_broadcast(node), + NodeType::BitwiseXor => same_as_input_broadcast(node), NodeType::Bernoulli => bernoulli_update_output(node), NodeType::Cast => cast_update_outputs(node), NodeType::Ceil => same_as_input(node),