We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
For the below IR at runtime we get the following error:
summary = 'InvalidArgument: InvalidArgument: Input argument 0 validation failed against corresponding function signature arg 0. Reason: InvalidArgument: Runtime stride mismatch. Expected [-9223372036854775808, 1] but received [0, 0]'
module @ins_t_outs_t251_t252_t253_t254_t255_20 { func.func @main(%arg0: tensor<3x0xf32> {tensorrt.shape_profile = #tensorrt.shape_profile<min = [3, 0], opt = [3, 0], max = [3, 0]>}) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) { %c = stablehlo.constant dense<0> : tensor<1xi32> %c_0 = stablehlo.constant dense<3> : tensor<i32> %c_1 = stablehlo.constant dense<1> : tensor<1xi32> %c_2 = stablehlo.constant dense<3> : tensor<1xi32> %c_3 = stablehlo.constant dense<0> : tensor<i32> %c_4 = stablehlo.constant dense<0> : tensor<1xi32> %0 = stablehlo.concatenate %c_2, %c_4, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %c_5 = stablehlo.constant dense<2> : tensor<1xi32> %1 = stablehlo.real_dynamic_slice %0, %c_1, %c_5, %c_1 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32> %c_6 = stablehlo.constant dense<5> : tensor<1xi32> %2 = stablehlo.divide %1, %c_6 : (tensor<?xi32>, tensor<1xi32>) -> tensor<1xi32> %3 = stablehlo.multiply %2, %c : tensor<1xi32> %4 = stablehlo.concatenate %c, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %5 = stablehlo.real_dynamic_slice %0, %c, %c_1, %c_1 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32> %6 = stablehlo.multiply %2, %c_1 : tensor<1xi32> %7 = stablehlo.concatenate %5, %6, dim = 0 : (tensor<?xi32>, tensor<1xi32>) -> tensor<2xi32> %8 = stablehlo.concatenate %c_1, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %9 = stablehlo.real_dynamic_slice %arg0, %4, %7, %8 : (tensor<3x0xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> %10 = stablehlo.multiply %2, %c_1 : tensor<1xi32> %11 = stablehlo.concatenate %c, %10, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %12 = stablehlo.real_dynamic_slice %0, %c, %c_1, %c_1 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32> %13 = stablehlo.multiply %2, %c_5 : tensor<1xi32> %14 = stablehlo.concatenate %12, %13, dim = 0 : (tensor<?xi32>, tensor<1xi32>) -> tensor<2xi32> %15 = stablehlo.concatenate %c_1, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %16 = stablehlo.real_dynamic_slice %arg0, %11, %14, %15 : (tensor<3x0xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> %17 = stablehlo.multiply %2, %c_5 : tensor<1xi32> %18 = stablehlo.concatenate %c, %17, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %19 = stablehlo.real_dynamic_slice %0, %c, %c_1, %c_1 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32> %c_7 = stablehlo.constant dense<3> : tensor<1xi32> %20 = stablehlo.multiply %2, %c_7 : tensor<1xi32> %21 = stablehlo.concatenate %19, %20, dim = 0 : (tensor<?xi32>, tensor<1xi32>) -> tensor<2xi32> %22 = stablehlo.concatenate %c_1, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %23 = stablehlo.real_dynamic_slice %arg0, %18, %21, %22 : (tensor<3x0xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> %24 = stablehlo.multiply %2, %c_7 : tensor<1xi32> %25 = stablehlo.concatenate %c, %24, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %26 = stablehlo.real_dynamic_slice %0, %c, %c_1, %c_1 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32> %c_8 = stablehlo.constant dense<4> : tensor<1xi32> %27 = stablehlo.multiply %2, %c_8 : tensor<1xi32> %28 = stablehlo.concatenate %26, %27, dim = 0 : (tensor<?xi32>, tensor<1xi32>) -> tensor<2xi32> %29 = stablehlo.concatenate %c_1, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %30 = stablehlo.real_dynamic_slice %arg0, %25, %28, %29 : (tensor<3x0xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> %31 = stablehlo.multiply %2, %c_8 : tensor<1xi32> %32 = stablehlo.concatenate %c, %31, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %33 = stablehlo.real_dynamic_slice %0, %c, %c_1, %c_1 : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32> %34 = stablehlo.multiply %2, %c_6 : tensor<1xi32> %35 = stablehlo.concatenate %33, %34, dim = 0 : (tensor<?xi32>, tensor<1xi32>) -> tensor<2xi32> %36 = stablehlo.concatenate %c_1, %c_1, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %37 = stablehlo.real_dynamic_slice %arg0, %32, %35, %36 : (tensor<3x0xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> return %9, %16, %23, %30, %37 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32> } }
The text was updated successfully, but these errors were encountered:
Fixed last week
Sorry, something went wrong.
This seems to be failing again with the TRT dialect. Re-opening so we can investigate.
parthchadha
akhilg-nv
No branches or pull requests
For the below IR at runtime we get the following error:
The text was updated successfully, but these errors were encountered: