Closed
Description
ones
in Tripy works by creating a constant and broadcasting it up to the correct shape.
There seems to be a bug with how tensorrt.broadcast
is optimized out in the case where both operands are broadcasted in this way:
x = tp.ones((2, 2))
y = tp.ones((2, 2))
print(x + y)
Input MLIR:
module @"outs_%t7_1" {
func.func @main() -> tensor<?x?xf32> {
%cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<f32>
%cst_i32 = tensorrt.constant dense<2> : tensor<2xi32>
%0 = tensorrt.broadcast %cst_f32 broadcast_dims<> shape(%cst_i32 : tensor<2xi32>) : tensor<f32> to tensor<?x?xf32>
%cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<f32>
%cst_i32_1 = tensorrt.constant dense<2> : tensor<2xi32>
%1 = tensorrt.broadcast %cst_f32_0 broadcast_dims<> shape(%cst_i32_1 : tensor<2xi32>) : tensor<f32> to tensor<?x?xf32>
%2 = tensorrt.element_wise <kSUM>(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
}
This turns into:
tensorrt.module @trt_engines {
func.func @tensorrt_cluster() -> tensor<?x?xf32> {
%cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x1xf32>
%0 = tensorrt.element_wise <kSUM>(%cst_f32, %cst_f32 : tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
}
This incorrectly results in an output with a volume of 1:
tensor([[2.0000]], dtype=float32, loc=gpu:0, shape=(1, 1))
In this case, the broadcast
instead needs to be rotated over the elementwise operation so that the output is still of the correct shape.