Skip to content

Incorrect broadcast elimination when both operands to elementwise are broadcasted #521

Closed
@pranavm-nvidia

Description

@pranavm-nvidia

ones in Tripy works by creating a constant and broadcasting it up to the correct shape.

There seems to be a bug with how tensorrt.broadcast is optimized out in the case where both operands are broadcasted in this way:

x = tp.ones((2, 2))
y = tp.ones((2, 2))

print(x + y)

Input MLIR:

module @"outs_%t7_1" {
  func.func @main() -> tensor<?x?xf32> {
    %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<f32>
    %cst_i32 = tensorrt.constant dense<2> : tensor<2xi32>
    %0 = tensorrt.broadcast %cst_f32 broadcast_dims<> shape(%cst_i32 : tensor<2xi32>) : tensor<f32> to tensor<?x?xf32>
    %cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<f32>
    %cst_i32_1 = tensorrt.constant dense<2> : tensor<2xi32>
    %1 = tensorrt.broadcast %cst_f32_0 broadcast_dims<> shape(%cst_i32_1 : tensor<2xi32>) : tensor<f32> to tensor<?x?xf32>
    %2 = tensorrt.element_wise <kSUM>(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
    return %2 : tensor<?x?xf32>
  }
}

This turns into:

tensorrt.module @trt_engines {
  func.func @tensorrt_cluster() -> tensor<?x?xf32> {
    %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x1xf32>
    %0 = tensorrt.element_wise <kSUM>(%cst_f32, %cst_f32 : tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
    return %0 : tensor<?x?xf32>
  }
}

This incorrectly results in an output with a volume of 1:

tensor([[2.0000]], dtype=float32, loc=gpu:0, shape=(1, 1))

In this case, the broadcast instead needs to be rotated over the elementwise operation so that the output is still of the correct shape.

Metadata

Metadata

Assignees

No one assigned

    Labels

    mlir-tensorrtPull request for the mlir-tensorrt project

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions