diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 03ea773a25130..bc2d8117930bc 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -11,4 +11,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: gradle/wrapper-validation-action@v1 + - uses: gradle/wrapper-validation-action@v2 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index ce8fb3160954e..936ab0de899a2 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -7,7 +7,7 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: github/issue-labeler@v3.3 + - uses: github/issue-labeler@v3.4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" configuration-path: .github/labeler.yml diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index c03399f4693be..5bc21595bf882 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -37,7 +37,7 @@ jobs: wget https://github.com/dotnet/docfx/releases/download/v${DOCFXVERSION}/docfx-linux-x64-v${DOCFXVERSION}.zip -O build/docfx/docfx.zip unzip build/docfx/docfx.zip -d build/docfx - name: Install NuGet - uses: nuget/setup-nuget@v1 + uses: nuget/setup-nuget@v2 - name: Build Documentation run: | build/docfx/docfx metadata csharp/ApiDocs/docfx.json diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 708842e59f9f2..3e553049a186e 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -30,7 +30,7 @@ jobs: java-version: '11' distribution: 'adopt' - name: Build with Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 with: build-root-directory: java gradle-executable: java/gradlew diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c94e3fa5bcb8c..181f3fb17d332 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9.0.0 + - uses: actions/stale@v8 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 262f758b8b954..0b14371f2b8f3 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -117,8 +117,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) @@ -985,9 +984,12 @@ function(onnxruntime_set_compile_flags target_name) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") endforeach() - if ((NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") OR (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda")) + if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") endif() + if (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda") + target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") + endif() endif() if (onnxruntime_USE_ROCM) # flags are detected with CXX language mode, some flags are not supported with hipclang @@ -1588,7 +1590,7 @@ if (UNIX AND onnxruntime_USE_NCCL) else() set(onnxruntime_USE_NCCL OFF) set(onnxruntime_USE_MPI OFF) -message( WARNING "MPI and NCCL disabled on Win build." ) + message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." ) endif() if (onnxruntime_USE_MPI) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 2c7bf9f1c2f5c..a56864ebf4644 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# enable stream for all the non-minimal build -if (NOT onnxruntime_MINIMAL_BUILD) +# Enable stream for all the non-minimal build, except for DML. There's currently a bug +# in the allocation planner when reusing buffers and more than one streams are used that +# make it possible (although rarely) to reach a reference count of 0 for a buffer that is +# still being used. Since DML doesn't benefit from multiple streams, disabling it is the +# safest option for now. +# https://github.com/microsoft/onnxruntime/issues/19480 +if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML) add_compile_definitions(ORT_ENABLE_STREAM) endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index c6c9d8f4894c5..7e7819ac31a19 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -66,11 +66,7 @@ if(onnxruntime_USE_CUDA) set(PROVIDERS_CUDA onnxruntime_providers_cuda) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) - else() - set(PROVIDERS_COREML onnxruntime_providers_coreml) - endif() + set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_NNAPI_BUILTIN) set(PROVIDERS_NNAPI onnxruntime_providers_nnapi) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 2ca4a22aca7d2..c9f35e5337f9b 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -7,6 +7,27 @@ endif() add_compile_definitions(USE_COREML=1) +# Check if we can build the coremltools code for creating an mlpackage with an mlprogram. +# The coremltools source requires std::filesystem::path which is only available from iOS 13 on. +set(_enable_ML_PROGRAM ON) +if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0) + message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.") + set(_enable_ML_PROGRAM OFF) +elseif(LINUX) + # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing. + find_library(LibUUID_LIBRARY NAMES uuid) + find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h) + if (NOT LibUUID_INCLUDE_DIR) + message(STATUS "uuid/uuid.h was not found as is required for ML Program support. " + "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ") + set(_enable_ML_PROGRAM OFF) + endif() +endif() + +if (_enable_ML_PROGRAM) + add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) +endif() + # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") @@ -19,8 +40,8 @@ target_compile_definitions(coreml_proto PUBLIC $) set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") -set(_src_sub_dir "coreml_proto/") +set(_src_sub_dir "coreml_proto/") onnxruntime_protobuf_generate( APPEND_PATH GEN_SRC_SUB_DIR ${_src_sub_dir} @@ -55,6 +76,10 @@ file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" ) +file(GLOB onnxruntime_providers_coreml_public_headers CONFIGURE_DEPENDS + "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml/*.h" +) + file(GLOB onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" @@ -67,15 +92,38 @@ file(GLOB_RECURSE "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" ) -if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" + +if(_enable_ML_PROGRAM) + # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them + # build on Windows and Linux. + file(GLOB + onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Util/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/BlobDataType.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp" + ) + + # Add helpers to create mlpackage + file(GLOB + onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/Utils/JsonMap.?pp" ) + + set(coremltools_srcs + ${onnxruntime_providers_coreml_milblob_cc_srcs} + ${onnxruntime_providers_coreml_modelpackage_cc_srcs} + ) + + source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) endif() # Add CoreML objective c++ source code -if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +if (APPLE) file(GLOB onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" @@ -83,26 +131,79 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ) +else() + # add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies + # by using stub implementations on non-Apple platforms. + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils_stub.cc" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model_stub.cc" + ) endif() set(onnxruntime_providers_coreml_cc_srcs ${onnxruntime_providers_coreml_cc_srcs_top} ${onnxruntime_providers_coreml_cc_srcs_nested} ${onnxruntime_providers_shared_utils_cc_srcs} + ${onnxruntime_providers_coreml_objcc_srcs} ) -source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_INCLUDE_DIR} FILES ${onnxruntime_providers_coreml_public_headers}) + onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} + ${onnxruntime_providers_coreml_public_headers} + ${onnxruntime_providers_coreml_cc_srcs} + ${coremltools_srcs} ) + onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 + safeint_interface ) -if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml coreml_proto) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) +target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto) +add_dependencies(onnxruntime_providers_coreml coreml_proto) + +if (APPLE) + target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__) endif() + +if (_enable_ML_PROGRAM) + # Setup coremltools fp16 and json dependencies for creating an mlpackage. + # + # These are also used by external/xnnpack.cmake. fp16 depends on psimd + FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) + onnxruntime_fetchcontent_makeavailable(psimd) + set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) + FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) + set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") + set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + onnxruntime_fetchcontent_makeavailable(fp16) + + # need to tweak the include paths to match what the coreml source code expects + target_include_directories(onnxruntime_providers_coreml PRIVATE + ${fp16_SOURCE_DIR}/include + ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann + ${coremltools_SOURCE_DIR} + ${coremltools_SOURCE_DIR}/mlmodel/src/ + ${coremltools_SOURCE_DIR}/modelpackage/src/ + ) + + add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) + + if (LINUX) + target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) + endif() +endif() + +if (APPLE) + target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML") +endif() + add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 456344aa34d95..3f20787e87425 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -473,6 +473,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py" +) file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" ) @@ -543,6 +546,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/phi2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper @@ -646,6 +650,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_phi2_src} + $/onnxruntime/transformers/models/phi2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_stable_diffusion_src} $/onnxruntime/transformers/models/stable_diffusion/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d485abe6bb1a6..85a9bf50460d3 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -44,12 +44,7 @@ set(contrib_ops_excluded_files "bert/packed_multihead_attention.cc" "bert/packed_multihead_attention_impl.h" "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" - "diffusion/group_norm_impl.h" - "diffusion/group_norm_impl_kernel.cuh" - "diffusion/group_norm_common_base.h" - "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 5b4a007d6b974..3ed695327c183 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -111,7 +111,9 @@ function(AddTest) target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") - target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") + if (${HAS_NOERROR}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -565,11 +567,7 @@ if(onnxruntime_USE_ROCM) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) - else() - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_ACL) @@ -674,15 +672,9 @@ endif() if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) - else() - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_XNNPACK) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 268ee3960e75a..57cecd3e66adb 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -827,6 +827,7 @@ if (winml_is_inbox) get_target_property(compile_options ${target} COMPILE_OPTIONS) get_target_property(include_directories ${target} INCLUDE_DIRECTORIES) get_target_property(link_libraries ${target} LINK_LIBRARIES) + get_target_property(link_flags ${target} LINK_FLAGS) get_target_property(link_options ${target} LINK_OPTIONS) add_library(${new_target} SHARED ${sources}) @@ -835,6 +836,7 @@ if (winml_is_inbox) target_compile_options(${new_target} PRIVATE ${compile_options}) target_include_directories(${new_target} PRIVATE ${include_directories}) target_link_libraries(${new_target} PRIVATE ${link_libraries}) + set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}") target_link_options(${new_target} PRIVATE ${link_options}) endfunction() diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e7b537d6894c8..f523e97293427 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
temperature (optional) : T
@@ -5812,11 +5822,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2ea557b7d61fe..8ff2135c6b1f6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -765,7 +765,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -784,7 +784,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/docs/python/conf.py b/docs/python/conf.py index 7ab2d42aa15e1..438c21570eaac 100644 --- a/docs/python/conf.py +++ b/docs/python/conf.py @@ -2,12 +2,10 @@ # Licensed under the MIT License. # pylint: disable=C0103 -# -*- coding: utf-8 -*- -# -# Configuration file for the Sphinx documentation builder. +"""Configuration file for the Sphinx documentation builder.""" import os -import shutil # noqa: F401 +import shutil import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "_common")) @@ -127,7 +125,5 @@ def setup(app): urllib.request.urlretrieve(url, dest) loc = os.path.split(dest)[-1] if not os.path.exists(loc): - import shutil # noqa: F811 - shutil.copy(dest, loc) return app diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 22827d43b200f..b9b8a25286b7b 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -753,7 +753,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. - @remarks check_outer_scope of true is not supported in a minimal build */ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 3cdbb07099cab..1023d50310181 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -165,7 +165,8 @@ class GraphViewer { if a const initializer is part of the underlying Graph but not part of this GraphViewer, it will still be returned instead of nullptr */ - const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, + bool check_outer_scope = true) const; /** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */ const Node* ParentNode() const noexcept { return graph_->ParentNode(); } diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 03715eb5b78b2..55abb90b981f5 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -28,9 +28,12 @@ enum COREMLFlags { // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, + // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + COREML_FLAG_CREATE_MLPROGRAM = 0x010, + // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it - COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES, + COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, }; #ifdef __cplusplus diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 1370f5c4c5e10..108173474db46 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext { bool cudnn_conv1d_pad_to_nc1d = false; bool enable_skip_layer_norm_strict_mode = false; bool prefer_nhwc = false; + bool use_tf32 = true; void Init(const OrtKernelContext& kernel_ctx) { cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); @@ -52,6 +53,7 @@ struct CudaContext : public CustomOpContext { cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); enable_skip_layer_norm_strict_mode = FetchResource(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); + use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); } template diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 82bb8ba83be4a..6d53760ab60b5 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 { // The strict mode has better accuracy but lower performance. int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not + int use_tf32 = 1; // use TF32 }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index c0e6328f27122..1fef077860be3 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -18,4 +18,5 @@ enum CudaResource : int { cudnn_conv1d_pad_to_nc1d_t, enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, + use_tf32_t, }; \ No newline at end of file diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f3..cec3fadf446ca 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7fef2dc784b7b..9925197e4507c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -673,7 +673,7 @@ private void runProvider(OrtProvider provider) throws OrtException { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); } else { - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 1ed883ace36e5..0e3bc15ba9c70 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); } diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea1..b29cb8cbd6d35 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..56682ef98e117 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..d5da33640dc7d 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored diff --git a/js/common/package-lock.json b/js/common/package-lock.json index a5ada877b916a..3988ac80707e0 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -9,13 +9,13 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index 64ab2736adbe3..cd2612aab4984 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -9,7 +9,7 @@ }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b325..e9068ad837a81 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 4dca90d7415cf..bbb0c4f3d1e22 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da28..906c78a1b7ec4 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2557971eb4ded..4a8c92bb97bfd 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -41,6 +41,7 @@ Do not modify directly.* | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | +| FastGelu | com.microsoft(1+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | @@ -61,6 +62,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6e..c0517ce363644 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -92,6 +92,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index d737a28654220..ba874c8dd0f80 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -13,12 +13,14 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose' import {cumsum, parseCumSumAttributes} from './ops/cumsum'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; +import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; @@ -72,6 +74,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], ['Expand', [expand]], + ['FastGelu', [fastGelu]], ['Floor', [unaryOps.floor]], ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], @@ -90,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index a81a7a8f1df5c..089fecd758e30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl(`vec4<${dataType}>`, dataType)} + ${erfImpl(dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts new file mode 100644 index 0000000000000..f50a6a3f011fe --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common'; +import * as unary from './unary-op'; + +// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias. + +const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => { + const dataType = inputTensors[0].dataType; + const outputSize = ShapeUtil.size(inputTensors[0].dims); + const biasLength = ShapeUtil.size(inputTensors[1].dims); + // can only use vec4 when bias length is multiple of 4 + const useVec4 = biasLength % 4 === 0; + const getShaderSource = (shaderHelper: ShaderHelper): string => { + const x = inputVariable('x', dataType, [1], 4); + const bias = inputVariable('bias', dataType, [1], 4); + const y = outputVariable('y', dataType, [1], 4); + + const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}]; + + const singleElementBias = (i: 0|1|2|3) => ` + let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size; + let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`; + const biasGetExpression = useVec4 ? + ` + let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` : + `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} + let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`; + + return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)} + + ${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')} + + let x = ${x.getByOffset('global_idx')}; + ${biasGetExpression} + let x_in = x + bias; + ${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))} + }`; + }; + + return { + name: 'FastGeluWithBias', + shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']}, + getShaderSource, + getRunData: (inputs) => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + programUniforms: + [{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}], + dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)} + }) + }; +}; + +export const fastGelu = (context: ComputeContext): void => { + if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) { + unary.fastGelu(context); + } else { + context.compute(createFastGeluProgramInfo(context.inputs)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3f73d9cb7c5bc..d5f97213e49ce 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -85,28 +85,28 @@ const createLayerNormProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} let offset = global_idx * uniforms.norm_size_vectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let invStdDev = - inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${ + sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 0000000000000..ead7635cf3ac4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + const a = inputs[0]; + const b = inputs[1]; + const scales = inputs[2]; + const aRank = a.dims.length; + const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); + const outputSize = ShapeUtil.size(outputShape); + + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, + {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} + ]; + programUniforms.push(...createTensorShapeVariables(a.dims)); + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); + programUniforms.push(...createTensorShapeVariables(scales.dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); + const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} + ]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const wordPerBlob = blobSize / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ + var result = array<${dataType}, 8>(); + var offset: u32 = 0; + let count: u32 = 4; + for (var i: u32 = 0; i < 8u; i++) { + result[i] = ${dataType}(extractBits(value, offset, count)); + offset += count; + } + return result; + } + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var value: ${dataType} = 0.0; + let output_indices = ${output.offsetToIndices('global_idx')}; + var a_indices: ${a.type.indices} = output_indices; + var n = ${output.indicesGet('output_indices', aRank - 1)}; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : + ''} + var scale_idex = n * ${nBlocksPerCol}; + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'n')}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_idex')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point: ${dataType} = ${ + zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${wordPerBlob}; word++) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_value = ${b.getByIndices('b_indices')}; + let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var i: u32 = 0; i < 8; i++) { + ${a.indicesSet('a_indices', aRank - 1, 'offset')}; + let a_value = ${a.getByIndices('a_indices')}; + let b_quantized_value = b_quantized_values[i]; + let b_dequantized_value = (b_quantized_value - zero_point) * scale; + value += a_value * b_dequantized_value; + offset++; + } + word_offset += 8; + } + scale_idex++; + ${ + zeroPoints ? ` + if (zero_point_offset == 28) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + } else { + zero_point_offset += 4; + }` : + ''} + block_offset += uniforms.block_size; + } + ${output.setByOffset('global_idx', 'value')}; + } + `; + }; + return { + name: 'MatMulNBits', + shaderCache: + {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 14d6f37927590..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; @@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 1accfac18b876..5f105c745739e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -178,7 +178,7 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string, varType = 'f32') => ` +export const erfImpl = (varType = 'f32') => ` const r0: ${varType} = 0.3275911; const r1: ${varType} = 0.254829592; const r2: ${varType} = -0.284496736; @@ -186,7 +186,7 @@ const r3: ${varType} = 1.421413741; const r4: ${varType} = -1.453152027; const r5: ${varType} = 1.061405429; -fn erf_vf32(v: ${dataType}) -> ${dataType} { +fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); @@ -194,8 +194,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType))); }; export const exp = (context: ComputeContext): void => { @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => { export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl(`vec4<${dataType}>`, dataType))); + context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { @@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan')); }; +export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`; + export const tanh = (context: ComputeContext): void => { // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression)); +}; + +export const fastGeluImpl = (varType = 'f32') => ` +const fast_gelu_a: ${varType} = 0.5; +const fast_gelu_b: ${varType} = 0.7978845608028654; +const fast_gelu_c: ${varType} = 0.035677408136300125; + +fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> { + return ${tanhExpression('v')}; +} +`; + +export const fastGeluExpression = (x: string) => + `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; + +export const fastGelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); + context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined, + context.inputs[0].dataType)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..54eaf5e0c43cc 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': @@ -169,7 +176,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value diff --git a/js/web/package.json b/js/web/package.json index a502c2b6b032d..55c3a3238bafc 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -69,11 +69,14 @@ "exports": { ".": { "node": "./dist/ort.node.min.js", + "types": "./types.d.ts", "default": { "import": "./dist/esm/ort.min.js", "require": "./dist/cjs/ort.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.js", + "types": "./types.d.ts", "default": "./dist/ort.min.js" } } @@ -81,34 +84,41 @@ "./experimental": { "import": "./dist/esm/ort.all.min.js", "require": "./dist/cjs/ort.all.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.all.js", + "types": "./types.d.ts", "default": "./dist/ort.all.min.js" } }, "./wasm": { "import": "./dist/esm/ort.wasm.min.js", "require": "./dist/cjs/ort.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { "import": "./dist/esm/ort.wasm-core.min.js", "require": "./dist/cjs/ort.wasm-core.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { "import": "./dist/esm/ort.webgl.min.js", "require": "./dist/cjs/ort.webgl.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgl.min.js" }, "./webgpu": { "import": "./dist/esm/ort.webgpu.min.js", "require": "./dist/cjs/ort.webgpu.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" }, "./training": { "import": "./dist/esm/ort.training.wasm.min.js", "require": "./dist/cjs/ort.training.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.training.wasm.min.js" } }, diff --git a/js/web/test/data/ops/fast-gelu.jsonc b/js/web/test/data/ops/fast-gelu.jsonc new file mode 100644 index 0000000000000..2550173e95402 --- /dev/null +++ b/js/web/test/data/ops/fast-gelu.jsonc @@ -0,0 +1,211 @@ +[ + { + "name": "FastGelu test without bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.841192], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581, + 1.0617, 1.17393, 1.28671, 1.39957 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "FastGelu test with bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + }, + { + "data": [0.5], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.39957], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [3]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [2]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2, 3], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869, + 4.39999, 3.49938 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [7]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7], + "dims": [7], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989, + 4.09996, 3.59959 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[4x4], [8]", + "inputs": [ + { + "data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0], + "dims": [4, 4], + "type": "float32" + }, + { + "data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957, + 1.39957, 4.39999, 1.0617, -0.149419, 3.09737 + ], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 0000000000000..c57c431afb3ce --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,1527 @@ +[ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, + 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, + 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, + 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, + 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, + 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, + 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, + 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, + 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, + 360008, 226848, 451256, 292432, 550440 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 56db28b0a379c..b43b1ac37e37d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1352,7 +1352,9 @@ "equal.jsonc", "exp.jsonc", "expand.jsonc", + "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", @@ -1361,6 +1363,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index b01d474788f25..ecc7d4b4a09a5 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; */ const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; -function toInternalTensor(tensor: ort.Tensor): Tensor { - return new Tensor( - tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType); -} function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } @@ -330,6 +326,10 @@ export class TensorResultValidator { } checkTensorResult(actual: Tensor[], expected: Tensor[]): void { + this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor)); + } + + checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { // check output size expect(actual.length, 'size of output tensors').to.equal(expected.length); @@ -347,10 +347,6 @@ export class TensorResultValidator { } } - checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { - this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor)); - } - checkNamedTensorResult(actual: Record, expected: Test.NamedTensor[]): void { // check output size expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length); @@ -364,7 +360,7 @@ export class TensorResultValidator { } // This function check whether 2 tensors should be considered as 'match' or not - areEqual(actual: Tensor, expected: Tensor): boolean { + areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean { if (!actual || !expected) { return false; } @@ -392,13 +388,13 @@ export class TensorResultValidator { switch (actualType) { case 'string': - return this.strictEqual(actual.stringData, expected.stringData); + return this.strictEqual(actual.data, expected.data); case 'float32': case 'float64': return this.floatEqual( - actual.numberData as number[] | Float32Array | Float64Array, - expected.numberData as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array); case 'uint8': case 'int8': @@ -409,10 +405,8 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array, - expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); default: throw new Error('type not implemented or not supported'); diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 8c186b9b36451..014fc57f21558 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -893,7 +893,9 @@ describe('New Conv tests', () => { const expected = cpuConv( inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, testData.strides); - if (!validator.areEqual(actual, expected)) { + try { + validator.checkTensorResult([actual], [expected]); + } catch { console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`); console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`); throw new Error('Expected and Actual did not match'); diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index a015b6fd60c8f..6ff18176ebeb2 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -41,6 +41,17 @@ NS_ASSUME_NONNULL_BEGIN */ @property BOOL onlyEnableForDevicesWithANE; +/** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with + * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. + */ +@property BOOL onlyAllowStaticInputShapes; + +/** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + */ +@property BOOL createMLProgram; + @end @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP) diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 6340fdea1c3a7..58b47d68eea63 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -26,7 +26,10 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti const uint32_t flags = (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | - (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0); + (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | + (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | + (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML( [self CXXAPIOrtSessionOptions], flags)); return YES; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 8afeb874750b4..a34f41d2938c6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -64,6 +64,7 @@ struct AttentionParameters { bool pass_past_in_kv; float mask_filter_value; float scale; + bool use_tf32; AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -82,6 +83,7 @@ struct PackedAttentionParameters { int token_count; bool has_relative_position_bias; bool broadcast_res_pos_bias; + bool use_tf32; }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..af0904b7d6e4b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..93837e785b4a4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..b1dd55eb20f34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,7 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 03d4e89ac20fe..231eb17d1a947 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -268,7 +283,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -330,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index bf6431cf1afb2..7a807342ad685 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -84,6 +84,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + // Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr std::vector bias_dims{weights->Shape().GetDims()[1]}; const TensorShape bias_shape{bias_dims}; @@ -251,7 +253,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 54c9a5da1e9da..c20f42c4d06bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -461,7 +461,8 @@ Status UnfusedAttention( total_sequence_length, sequence_length, qk_head_size, &alpha, data.k, qk_head_size, present_size_per_batch_k, data.q, qk_head_size, sequence_length * qk_head_size, - &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, + device_prop, parameters.use_tf32)); DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); @@ -514,7 +515,7 @@ Status UnfusedAttention( v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 3f703ae3d05e6..ceee17c2a2d01 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -273,13 +273,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data()), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (h2, h1)*(h1, S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(q_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) // calcualte k, v @@ -298,13 +298,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, @@ -318,13 +318,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(key->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } } else { @@ -342,13 +342,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { kv_sequence_length = cache_sequence_length; @@ -372,6 +372,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { device_prop, #ifdef USE_ROCM GetTuningContext(), +#else + UseTF32(), #endif context->GetComputeStream(), cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index 1dc22a9c8ea98..e24d9da94c964 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -37,7 +37,8 @@ Status DecoderQkvToContext( T* workspace_buffer, T* output, T* new_key_cache, - T* new_value_cache) { + T* new_value_cache, + bool use_tf32) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -128,14 +129,14 @@ Status DecoderQkvToContext( kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } constexpr bool is_unidirectional = false; @@ -163,14 +164,14 @@ Status DecoderQkvToContext( head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } // scratch3 is BxNxSxH, transpose to output SxBxNxH @@ -180,6 +181,7 @@ Status DecoderQkvToContext( Status LaunchDecoderAttentionKernel( const cudaDeviceProp& device_prop, + bool use_tf32, Stream* stream, cublasHandle_t& cublas, const size_t element_size, @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } else { return DecoderQkvToContext( device_prop, @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h index 9db9ccb45e330..f9667a613e648 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -11,6 +11,7 @@ namespace cuda { Status LaunchDecoderAttentionKernel( const cudaDeviceProp& prop, // Device Properties + bool use_tf32, // Use TF32 Stream* stream, // ORT Stream cublasHandle_t& cublas, // Cublas handle const size_t element_size, // Element size of input tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 72ede2e22b557..07a6fbd60e171 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); // Update the q, k, and v buffers parameters.q = gemm_buffer.get(); diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index e556ae4a490e9..9c5d0e9834f6f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -136,7 +136,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, weights_data, n, input_data, k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); } else { // q const CudaT* q_weight = weights_data; @@ -145,7 +145,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, q_weight, n, input_data, k, - &zero, q_data, n, device_prop)); + &zero, q_data, n, device_prop, UseTF32())); // k const CudaT* k_weight = q_weight + static_cast(hidden_size) * hidden_size; CudaT* k_data = q_data + static_cast(batch_size) * sequence_length * hidden_size; @@ -153,7 +153,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, k_weight, n, input_data, k, - &zero, k_data, n, device_prop)); + &zero, k_data, n, device_prop, UseTF32())); // v const CudaT* v_weight = k_weight + static_cast(hidden_size) * hidden_size; @@ -162,7 +162,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, v_weight, n, input_data, k, - &zero, v_data, n, device_prop)); + &zero, v_data, n, device_prop, UseTF32())); } // Wait for async copy of batch_global_num @@ -195,7 +195,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(global_weights->Data()), n, input_data, k, - &zero, global_gemm_buffer, n, device_prop)); + &zero, global_gemm_buffer, n, device_prop, UseTF32())); } else { // global q const CudaT* global_q_weight = global_weights_data; @@ -205,7 +205,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_q_weight, n, input_data, k, - &zero, global_q, n, device_prop)); + &zero, global_q, n, device_prop, UseTF32())); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, @@ -226,7 +226,8 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { hidden_size, // ldc static_cast(max_num_global) * hidden_size, // strideC batch_size, // batch count - device_prop)); + device_prop, + UseTF32())); } // global k const CudaT* global_k_weight = global_weights_data + static_cast(hidden_size) * hidden_size; @@ -235,7 +236,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_k_weight, n, input_data, k, - &zero, global_k, n, device_prop)); + &zero, global_k, n, device_prop, UseTF32())); // global v const CudaT* global_v_weight = global_k_weight + static_cast(hidden_size) * hidden_size; @@ -244,7 +245,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_v_weight, n, input_data, k, - &zero, global_v, n, device_prop)); + &zero, global_v, n, device_prop, UseTF32())); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index f00239460071b..c9c66b73b3e9d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -1005,7 +1005,6 @@ Status LaunchLongformerAttentionKernel( bool disable_compact_memory, bool use_merged_qkv_weights, bool use_half4) { - CublasMathModeSetter helper(device_prop, cublas, CUBLAS_TENSOR_OP_MATH); size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index f978f50c6851f..2ef011cdd9a21 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -94,6 +94,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index ec8b1d051b3d9..55deed55dfd33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -268,6 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -308,12 +309,12 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = this->GetCublasHandle(context); // Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias - // The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel. + // The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel. CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 3b52320839403..ce7ac3796dbe1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 1b026e64778e3..b4a162989978c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 83af018a97ea6..49029da12a308 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -775,7 +775,7 @@ Status UnfusedAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); // Q, K and V are ready now DUMP_TENSOR_INIT(); @@ -808,7 +808,7 @@ Status UnfusedAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output TxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 92ba808dd85c2..05f55d9106d0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -200,7 +200,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c D, BNS, head_size, &one, reinterpret_cast(weight_tensor.template Data()), (int)D, reinterpret_cast(workspace.get()), (int)head_size, - &zero, gemm_output, ld_gemm_output, device_prop)); + &zero, gemm_output, ld_gemm_output, device_prop, UseTF32())); auto status = LaunchGatedRelativePositionBiasKernel( device_prop, stream, diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 8f368251f12c7..be8c0dc86c135 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -120,6 +120,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -318,6 +319,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 705f2d49fe2bf..001b6070d5e1a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -106,6 +106,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_tensor = context->Input(8); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bbcb7de99781f..0534ed6dc7fc0 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -117,7 +117,8 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 5b0e61e197014..015df70c8ec3c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -135,7 +135,8 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index bba30805ae1be..7adc2fe0a67ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/contrib_ops/js/fast_gelu.cc b/onnxruntime/contrib_ops/js/fast_gelu.cc new file mode 100644 index 0000000000000..62c538318160d --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fast_gelu.h b/onnxruntime/contrib_ops/js/fast_gelu.h new file mode 100644 index 0000000000000..68c7892741c66 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(FastGelu, FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 498a9f5679eb5..25e7567a2e9fc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,12 +8,14 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -24,13 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + SkipLayerNormalization)>}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..888db0fd161f2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..cca2c4757765b --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc deleted file mode 100644 index e82e15a304f4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define GROUP_NORM_TYPES float, MLFloat16 - -ONNX_OPERATOR_KERNEL_EX( - GroupNorm, kMSDomain, 1, kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); - -using namespace ONNX_NAMESPACE; - -namespace { -template -struct DispatchGroupNorm { - Status operator()(RocmTuningContext* tuning_ctx, - Stream* stream, - Tensor* output, - const Tensor* input, - const Tensor* gamma, - const Tensor* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_swish_activation) { - typedef typename ToHipType::MappedType HipT; - return LaunchGroupNormKernel( - tuning_ctx, - stream, - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - workspace, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_swish_activation); - } -}; - -} // namespace - -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { - epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(epsilon_ >= 0); - - int64_t num_groups; - ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); - ORT_ENFORCE(num_groups >= 0); - num_groups_ = static_cast(num_groups); - - int64_t activation; - ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish - use_swish_activation_ = (activation == 1); - - channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); -} - -Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - return Status::OK(); -} - -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* gamma = context->Input(1); - const Tensor* beta = context->Input(2); - Tensor* output = context->Output(0, input->Shape()); - - if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "only the channels_last layout is supported"); - } - - const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 4 dimensions, got ", input_dims.size()); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in gamma and input does not match"); - } - - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); - } - - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisible by num_groups"); - } - - if (context->GetUseDeterministicCompute()) { - static std::once_flag log_warning; - std::call_once(log_warning, []() { - LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic."; - }); - } - - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - - utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(GetTuningContext(), context->GetComputeStream(), - output, input, gamma, beta, workspace.get(), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index fb7091592c16e..d0a0d09fcbae3 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -26,13 +26,18 @@ namespace rocm { using onnxruntime::rocm::CKDataTypeAdaptor; -using Swish = ck::tensor_operation::element_wise::Swish; +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 5; constexpr int NumReduceDim = 3; -template +template auto GetCKGroupNormNHWCTypeStringAndOps() { using XDataType = typename CKDataTypeAdaptor::type; using YDataType = typename CKDataTypeAdaptor::type; @@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { using GammaDataType = float; using BetaDataType = float; - using Activation = std::conditional_t; + using Activation = std::conditional_t; - std::vector>>> ret; + std::vector>>> ret; for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { - if constexpr (WithSwish) { + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->withSwish, "Swish version only support groupnorm with swish"); + !params->use_silu, "Silu version only support groupnorm with silu"); } else { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->withSwish, "Pass version only support groupnorm without swish"); + params->use_silu, "Pass version only support groupnorm without silu"); } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; auto activation = Activation{}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 19b081881dcec..4cb371fdcf960 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -18,7 +18,7 @@ namespace internal { using F16 = ck::half_t; using F32 = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() { template <> std::vector>> + F16, F32, F32, F16, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Silu, 5, 3>(); template <> std::vector std::vector>> + F32, F32, F32, F32, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Swish, 5, 3>(); + F32, F32, F32, F32, F32, Silu, 5, 3>(); template <> std::vector -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f16_instances{}); + device_normalization_f16_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 9b0ccab17b4c1..ceb53ed442abc 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -11,12 +11,12 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f32_instances{}); + device_normalization_f32_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 008ae20b0561f..7cff640db2f34 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -8,110 +8,47 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::CeilDiv; - -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; - } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; - } - } - } - return maxDivisor; -} - template -struct GroupNormNHWCParams : OpParams { - GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma, - const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish) - : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) { - int32_t maxBlocksPerHW = 1024; - switch (c) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; - } - - hw = h * w; - const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW); - hwPerBlock = CeilDiv(hw, blocksPerHW); - cPerGroup = c / groups; - hwc = hw * c; - invHWC = 1.F / (float)(hw * cPerGroup); - groupsPerBlock = cPerBlock / cPerGroup; - } +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} std::string Signature() const override { - std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; return sig; } - - // The output buffer. Layout NHWC. - T* dst; - // The input buffer. Layout NHWC. - T const* src; - // The gamma scaling factor. - float const* gamma; - // The beta term to add in GN. - float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - float epsilon; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h; - int32_t w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index dbd5009e63676..142aaf14e8d2d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -15,9 +15,12 @@ namespace rocm { template Status LaunchGroupNormKernel( RocmTuningContext* tuning_ctx, - Stream* stream, + Stream* ort_stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -27,19 +30,26 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, - batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); if (tuning_ctx->IsTunableOpEnabled()) { static GroupNormNHWCTunableOp op; @@ -50,14 +60,17 @@ Status LaunchGroupNormKernel( } template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h deleted file mode 100644 index a0f7e0aca5def..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::RocmTuningContext; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; -} - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh deleted file mode 100644 index d6322a12a9363..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCm kernel is modified from TensorRT 8.5. -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sumSq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = static_cast(input_v.val[i]); - sum += val; - sumSq += val * val; - } -} - -template -__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw, - int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) { - // The object in charge of doing the sums for the different blocks. - typedef hipcub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[ThreadsPerBlock]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (ILP channels per thread). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - if (ci < c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * hwc + static_cast(hwi) * c + ci; - UpdateSum(src, offset, sum, sumSq); - } - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * ILP / cPerGroup; - int32_t cj = threadIdx.x * ILP - cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == cPerGroup - ILP) { // ILP channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= groupsPerBlock || gj >= groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x); - atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y); -} - -template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev, - const U* gamma_v, const U* beta_v, bool swish) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - VecT output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - U val = static_cast(input_v.val[i]); - val = (val - mean) * invStdDev; - val = gamma_v[i] * val + beta_v[i]; - - if (swish) { - val = val * sigmoid(val); - } - output_v.val[i] = static_cast(val); - } - *(reinterpret_cast(dst + offset)) = output_v; -} - -template -__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock, - int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) { - // The channel loaded by that thread (ILP channels per thread for F16x2). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - if (ci >= c) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < groups) { - sum = redBuffer[(2 * ni + 0) * groups + gi]; - sumSq = redBuffer[(2 * ni + 1) * groups + gi]; - } - - using VecF = onnxruntime::rocm::aligned_vector; - - const VecF gamma_v = *reinterpret_cast(gamma + ci); - const VecF beta_v = *reinterpret_cast(beta + ci); - - // Compute the mean. - float mean = sum * invHWC; - // Compute the variance. - float var = sumSq * invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * hwc + hwi * c + ci; - - // Fetch ILP channels per thread. - computeGroupNorm(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b7b9441ac997d..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -20,21 +20,21 @@ namespace rocm { namespace { -template +template std::string GetGroupNormTritonGroupName() { std::string ret = "GroupNormTriton_"; - std::string swish_suffix = WithSwish ? "Swish_" : "Pass_"; - ret += swish_suffix; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; ret += GetDataTypeName(); return ret; } } // namespace -template +template auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); auto* kernel_list = GetOrtTritonKernelByGroup(group_name); if (kernel_list == nullptr) { return ret; @@ -45,36 +45,50 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto* metadata = GetOrtTritonKernelMetadata(i); auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCParams* params) -> Status { + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSwish) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, - params->cPerGroup, - params->epsilon}; + params->channels_per_group, + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 56b3a030b289e..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,16 +12,22 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, - ACTIVATION_SWISH: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, ): row_x = tl.program_id(0) row_y = tl.program_id(1) @@ -36,14 +42,35 @@ def group_norm_kernel( offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,12 +84,16 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta - if ACTIVATION_SWISH: + if ACTIVATION_SILU: y *= tl.sigmoid(y) tl.store(y_ptr + offsets, y, mask=mask) @@ -71,27 +102,27 @@ def group_norm_kernel( # blocks = [16, 32, 64, 128, 256, 512] # hw_sizes = [8, 16, 32, 64, 128, 256, 512] # but this will result in too many functions and slow down the compilation. -with_swish = [True, False] +with_silu = [True, False] dtypes = ["fp32", "fp16"] blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" def get_function_table(): func_table = [] - for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks): - swish_suffix = "Swish" if swish else "Pass" - name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(swish_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)}, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, } func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} func_table.append(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 25d820f7ed326..e6831f764b418 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -20,115 +20,117 @@ namespace rocm { using onnxruntime::rocm::GPU_WARP_SIZE; template -void groupNormNHWCSum(const GroupNormNHWCParams* params) { - // Make sure the values are as we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - groupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->src, params->redBuffer, params->cPerBlock, \ - params->hwPerBlock, params->hw, params->hwc, params->c, \ - params->cPerGroup, params->groups, params->groupsPerBlock); \ +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SUM(256, 2) - case 480: - LAUNCH_GROUPNORM_SUM(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SUM(128, 2) + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SUM(64, 2) + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCSumOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCSumKernel + GroupNormNHWCSumKernel <<StreamHandle()>>>( - params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock, - params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock); + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); return HIP_CALL(hipGetLastError()); } template -void groupNormNHWCScale(const GroupNormNHWCParams* params) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - groupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->gamma, params->beta, \ - params->redBuffer, params->epsilon, params->c, params->cPerBlock, \ - params->cPerGroup, params->groups, params->hwc, params->invHWC, \ - params->hw, params->hwPerBlock, params->withSwish); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SCALE(256, 2) - case 480: - LAUNCH_GROUPNORM_SCALE(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SCALE(128, 2) + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SCALE(64, 2) + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCScaleOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCScaleKernel + GroupNormNHWCScaleKernel <<StreamHandle()>>>( - params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock, - params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish); + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); return HIP_CALL(hipGetLastError()); } template class GroupNormNHWCOp { public: - Status operator()(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -138,29 +140,30 @@ class GroupNormNHWCOp { return Status::OK(); } - Status IsSupported(const GroupNormNHWCParams* params) { + Status IsSupported(const GroupNormNHWCTunableParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), - "The value of attributes don't meet the requirements."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); return Status::OK(); } }; template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); - groupNormNHWCSum(params); +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); - groupNormNHWCScale(params); + GroupNormNHWCScale(params); HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); } @@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template -class GroupNormNHWCTunableOp : public TunableOp> { +class GroupNormNHWCTunableOp : public TunableOp> { public: GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 55cd6a1d112f5..382a3951f3a83 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index e16b90ded3381..5e4cd9f62f11b 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -482,7 +482,7 @@ class BFCArena : public IAllocator { Bin* BinForSize(size_t bytes) { return BinFromIndex(BinNumForSize(bytes)); } - char bins_space_[sizeof(Bin) * kNumBins]; + alignas(Bin) char bins_space_[sizeof(Bin) * kNumBins]; // The size of the current region allocation. SafeInt curr_region_allocation_bytes_; diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8c08152986cf6..32a5f749af084 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684faa..18d210ffd48f7 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 61147e4367876..dc45cad692b6e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -14,7 +13,9 @@ #include "core/common/logging/logging.h" #ifdef _WIN32 #include +#include #include "core/platform/tracing.h" +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -44,6 +45,49 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, @@ -52,14 +96,11 @@ class ExecutionProviders { TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 8583474a1e391..8bf013ed009d5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,6 +259,16 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); + } else { + ONNX_NAMESPACE::TensorShapeProto output_shape; + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + int64_t hidden_size = query_dims[2].dim_value(); + int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads); + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + output_shape.add_dim()->set_dim_value(head_size * num_heads); + updateOutputShape(ctx, 0, output_shape); } } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..e33ce20737f80 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1235,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1322,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1363,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 902839bee04ba..305122c56b865 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2843,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..119d420066a84 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); @@ -212,6 +218,8 @@ const std::string& GraphViewer::Description() const noexcept { bool GraphViewer::GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { + value = nullptr; + // if we are using filtered subgraph, the initializer has to be part of the subgraph if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend()) return false; diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc new file mode 100644 index 0000000000000..21266d356a020 --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/gather_slice_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, + int64_t& axis, int64_t& indices_n_dims) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + const NodeArg& input_arg = *(node.InputDefs()[1]); + + if (!optimizer_utils::IsScalar(input_arg)) return false; + + const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + + if (!indices_init) return false; + + if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + + // get the index value + Initializer init_const(*indices_init, graph.ModelPath()); + index = *(init_const.data()); + + // get attributes value + axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + indices_n_dims = indices_init->dims_size(); + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const { + // check the version of Slice ops + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + // get the opset version + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + // If Slice op of opset version 1 + if (onnx_opset_version == 1) { + if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || + !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || + starts.size() != ends.size()) { + return false; + } + + if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { + return false; + } + } + + // If Slice op of opset version >= 10 + if (onnx_opset_version >= 10) { + // node inputs include: starts - ends - axes - steps + + // return a pointer to the corresponding NodeArg if input of the node at the index exists + auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + // return a pointer to the initializer if it is constant; otherwise, a nullptr + auto get_initializer_if_constant = + [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { + const NodeArg* input = get_input_if_exists(input_index); + return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; + }; + + // return the initialization data if it is constant + auto get_initializer_data = + [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { + Initializer init(*slice_initializer, graph.ModelPath()); + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { + int32_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { + int64_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + return {}; + }; + + // starts and ends inputs have to exist, be constants and be of the same size. + const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); + const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); + const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); + const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); + + if (!starts_init || !ends_init || !axes_init || !steps_init) { + return false; + } + + starts = get_initializer_data(starts_init); + ends = get_initializer_data(ends_init); + axes = get_initializer_data(axes_init); + steps = get_initializer_data(steps_init); + + if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { + return false; + } + + if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { + return false; + } + + // if steps exists, it should be constant and all value should be 1 + if (steps.size() != starts.size()) { + return false; + } + + for (int64_t step : steps) { + if (step != 1) { + return false; + } + } + } + + return true; +} + +/* +GatherToSplitFusion is to fuse: + Node + |-> Gather(index=0, axis=axis) + |-> Gather(index=1, axis=axis) + |-> Slice(index=2, axis=axis) +To + Node + |-> Split(index=0) +So that we can use one kernel to finish the job. +*/ + +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + InlinedVector output_args; + + // Iterate the topological order and get Reshape ops + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + + if (p_node == nullptr) continue; + + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + // Currently only catch after Reshape ops, optimize in the future + if (node.OpType() != "Reshape") continue; + + size_t output_count = node.GetOutputEdgesCount(); + + // We only catch 1 scenario for Multi Query Attention for now. + // |---> Gather + // Reshape |---> Gather + // |---> Slice + // |... or (other ops) + + // Get the output into node args + if (output_count < 3) continue; + + output_args.push_back(node.OutputDefs()[0]); + } + + // iterate the children of Reshape node + for (const NodeArg* node_arg : output_args) { + auto shape = node_arg->Shape(); + if (!shape) continue; + + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + + // get the tensor rank + int64_t rank = static_cast(shape->dim_size()); + + bool can_fuse = true; + bool first_edge = true; + int64_t split_axis = 0; + int64_t indices_n_dims = -1; + + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs(3); + + InlinedVector> nodes_to_fuse; + size_t gather_node_count = 2, slice_node_count = 0; + + // find the nodes to be merged + for (auto consumer : consumers) { + int64_t index, axis, dims; + InlinedVector starts, ends, axes, steps; + + bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); + bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); + + if ((!consumer || consumer->InputDefs()[0] != node_arg) || + (!IsSupportedGatherOps && !IsSupportedSliceOps)) { + break; + } + + if (IsSupportedGatherOps) { + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. + can_fuse = false; + break; + } + + if (axis < 0) axis += rank; + + if (first_edge) { + auto dim = shape->dim(static_cast(axis)); + // dim.dim_value() = 73 + if (!utils::HasDimValue(dim)) { + can_fuse = false; + break; + } + split_axis = axis; + first_edge = false; + } else if (axis != split_axis) { + can_fuse = false; + break; + } + + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count)) { + can_fuse = false; + break; + } + + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.push_back(gather_node); + NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; + split_outputs[gather_node_count--] = gather_output_args; + } + + // check the Slice Ops + if (IsSupportedSliceOps) { + if (axes[0] != axis && !first_edge) { + can_fuse = false; + break; + } + + Node& slice_node = *graph.GetNode(consumer->Index()); + NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; + nodes_to_fuse.push_back(slice_node); + split_outputs[slice_node_count++] = slice_output_args; + } + } + + // condition check + if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; + + // generate the split node and merge the kernel + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node_arg->TypeAsProto()->tensor_type().elem_type()); + + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + + for (int64_t i = 0; i < rank; i++) { + if (i == split_axis) + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + else + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } + + InlinedVector split_output_types; + + for (size_t i = 0; i < consumer_count; ++i) { + split_output_types.push_back( + &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); + } + + // Generate the Split Node + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); + split_initializer_proto.add_dims(static_cast(3)); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); + // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 + int64_t slice_dim = static_cast(dim_value - 2); + InlinedVector split_value{{slice_dim, 1, 1}}; + split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); + NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", + {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); + + split_node.AddAttribute("axis", split_axis); + + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + } + + for (Node& node_to_fuse : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); + graph.RemoveNode(node_to_fuse.Index()); + } + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h new file mode 100644 index 0000000000000..1c5c307efed7f --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class GatherSliceToSplitFusion +Fuse (2 Gather nodes + 1 Slice) to 1 split node. +*/ + +class GatherSliceToSplitFusion : public GraphTransformer { + private: + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const; + + public: + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..4e939fe3c7b6b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -308,6 +309,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 159e3b23d1ab0..b6ad4fde6c1f7 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"}; +static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}; // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { + return false; + } } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index b1ab641a23256..4e3dff705bd41 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -76,6 +76,49 @@ bool IsQDQPairSupported( } } +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != InputIndex::TOTAL_COUNT || + q_input_defs.size() != InputIndex::TOTAL_COUNT || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + Initializer q_zp(*q_zp_tensor_proto, model_path); + Initializer q_scale(*q_scale_tensor_proto, model_path); + Initializer dq_zp(*dq_zp_tensor_proto, model_path); + Initializer dq_scale(*dq_scale_tensor_proto, model_path); + + return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); +} + bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { bool zero_point_exists = false; if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index bb0bf9438cfcb..8333168b0093f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -38,6 +38,18 @@ bool IsQDQPairSupported( const GetConstantInitializerFn& get_const_initializer, const Path& model_path); +// Check if a DQ -> Q sequence represents a conversion in quantization data type. +// Example of uint8 to uint16: +// Dequantize (uint8 to float) -> Quantize (float to uint16) +// Requires: +// 1. Q/DQ doesn't have optional input. +// 2. scale and zero-point are constant scalars. +// 3. Q and DQ have the same scale *type* and different zero-point *types*. +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path); + // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. // 2. scale and zero point is constant scalar diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1a0713db43db8..0eb34cbfbc9eb 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -32,6 +32,9 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" +#if defined(_M_X64) && !defined(_M_ARM64EC) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +#include "core/platform/windows/hardware_core_enumerator.h" +#endif #include #include @@ -248,12 +251,53 @@ void WindowsEnv::SleepForMicroseconds(int64_t micros) const { Sleep(static_cast(micros) / 1000); } +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" +#endif int WindowsEnv::DefaultNumCores() { return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); } int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) + // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has + // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work + // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number + // of threads to exclude the slowest cores out. + // The following code is based on assumptions that: + // 1. All Intel hybrid CPUs should have 3 levels of cache. + // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should + // not be used. + // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. + // However, no matter what the code should not cause any crash. The worst is it might return 1 that + // thread pools will not be created, which is just a perf issue and does not impact usability. + // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability + int regs[4]; + __cpuid(regs, 0); + bool bIsIntel = + (kVendorID_Intel[0] == regs[1]) && + (kVendorID_Intel[1] == regs[2]) && + (kVendorID_Intel[2] == regs[3]); + if (bIsIntel && regs[0] >= 7) { + // Query Structured Extended Feature Flags Enumeration Leaf + __cpuid(regs, 0x7); + // The bit 15 of EDX indicates if the processor is identified as a hybrid part. + bool ishybrid = regs[3] & (1 << 15); + if (ishybrid) { + // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. + // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. + // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. + return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); + } else { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } + } else +#endif + { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } } std::vector WindowsEnv::GetDefaultThreadAffinities() const { diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc new file mode 100644 index 0000000000000..121c59808ae59 --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "hardware_core_enumerator.h" +#include +#include +#include + +namespace onnxruntime { + +struct LogicalProcessorInformation { + std::unique_ptr Buffer; + size_t Length; +}; + +struct CoreCounter { + uint32_t PhysicalCores = 0; + uint32_t SocDieCores = 0; +}; + +static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { + DWORD length = 0; + DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length); + + assert(rc == FALSE); + + auto processorInformationBytes = std::make_unique(length); + + rc = GetLogicalProcessorInformationEx( + relationship, reinterpret_cast(processorInformationBytes.get()), &length); + + assert(rc == TRUE); + + return {std::move(processorInformationBytes), length}; +} + +uint32_t CountSetBits(DWORD input) { + uint32_t c; + for (c = 0; input; c++) { + input &= input - 1; + } + return c; +} + +static CoreCounter GetNumberOPhysicalAndEngineeringCores() { + auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); + + CoreCounter cores; + DWORD dwLevel2GroupMask = 0; + DWORD dwLevel3GroupMask = 0; + size_t read = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL; + + while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) { + currentProcessorInfo = + reinterpret_cast(logicalProcessorInformation.Buffer.get() + read); + if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) { + break; + } + + switch (currentProcessorInfo->Relationship) { + case RelationProcessorCore: + cores.PhysicalCores++; + break; + case RelationCache: + if (currentProcessorInfo->Cache.Level == 2) { + dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } else if (currentProcessorInfo->Cache.Level == 3) { + dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } + break; + } + + read += currentProcessorInfo->Size; + } + + cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + return cores; +} + +uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { + // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. + // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. + auto cores = GetNumberOPhysicalAndEngineeringCores(); + // We want to use the number of physical cores, but exclude soc cores + return cores.PhysicalCores - cores.SocDieCores; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.h b/onnxruntime/core/platform/windows/hardware_core_enumerator.h new file mode 100644 index 0000000000000..93b50f452afcd --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +struct HardwareCoreEnumerator { + HardwareCoreEnumerator() = delete; + static uint32_t DefaultIntraOpNumThreads(); +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index a9849873fd060..654281d526e4d 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" +#include "core/platform/ort_mutex.h" #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; +OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { std::lock_guard lock(mutex_); @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const { // return etw_status_; // } +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + callbacks_.push_back(callback); +} + void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - (void)SourceId; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock(callbacks_mutex_); + for (const auto& callback : callbacks_) { + callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } void WindowsTelemetry::EnableTelemetryEvents() const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index c3798943d491d..cdb186e9ed703 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,12 +2,14 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/telemetry.h" #include #include #include "core/platform/ort_mutex.h" #include "core/platform/windows/TraceLoggingConfig.h" -#include namespace onnxruntime { @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; + static std::vector callbacks_; + static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + static void NTAPI ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index e9cd4af94e5fd..c9adba9e579d0 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -3,12 +3,28 @@ #pragma once -// TODO come up with a more intuitive way of limiting this to Apple platform builds -// E.g., putting CoreML EP files that should be enabled iff `defined(__APPLE__)` in a separate directory. -#if !defined(__APPLE__) -#error "This file should only be included when building on Apple platforms." +#include "onnxruntime_config.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push + +// Disable warning from protobuf code. +// +// In file included from coreml_proto/Model.pb.h:30: +// In file included from _deps/protobuf-src/src/google/protobuf/extension_set.h:53: +// _deps/protobuf-src/src/google/protobuf/parse_context.h:328:47: +// error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif #endif +// Model.pb.h is generated in the build output directory from the CoreML protobuf files in +// onnxruntime/core/providers/coreml/coremltools/mlmodel/format #include "coreml_proto/Model.pb.h" +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index 897856256cc79..bc3ba4432e66d 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -22,22 +22,35 @@ namespace onnxruntime { namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags) { +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags) { return OpBuilderInputParams{graph_viewer, - (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0}; + coreml_version, + (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0, + (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0}; } -bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { +const IOpBuilder* GetOpBuilder(const Node& node) { const auto& op_builders = GetOpBuilders(); - if (Contains(op_builders, node.OpType())) { - const auto* op_builder = op_builders.at(node.OpType()); + const auto it = op_builders.find(node.OpType()); + if (it != op_builders.cend()) { + return it->second; + } + + return nullptr; +} + +bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { + const auto* op_builder = GetOpBuilder(node); + if (op_builder) { return op_builder->IsOpSupported(node, input_params, logger); } else { return false; } } -bool IsInputSupported(const NodeArg& input, const std::string& parent_name, +bool IsInputSupported(const Node& node, const NodeArg& input, const OpBuilderInputParams& input_params, const logging::Logger& logger) { if (!input.Exists()) { // optional input that is not provided @@ -48,8 +61,8 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, std::vector shape; // We do not support input with no shape if (!GetShape(input, shape, logger)) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name - << "] has no shape"; + LOGS(logger, VERBOSE) << MakeString("Input [", input_name, "] of Node [", node.Name(), "] type [", node.OpType(), + "] has no shape"); return false; } @@ -63,11 +76,19 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, // For some undocumented reason, Apple CoreML framework will fail loading the model if the model // input has dimension > 16384 // See this issue, https://github.com/apple/coremltools/issues/1003 + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf has maximum texture widths which may be the + // root cause. if (dim > 16384) { LOGS(logger, WARNING) << "CoreML does not support input dim > 16384. Input:" << input_name << ", shape: " << Shape2String(shape); return false; } + + if (dim == 0) { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; + } } // Limit input shape rank to 5. @@ -87,13 +108,6 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const logging::Logger& logger) { std::unordered_set supported_nodes{}; -#ifdef __APPLE__ - if (!util::HasRequiredBaseOS()) { - LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS"; - return supported_nodes; - } -#endif - for (const auto& node : graph_viewer.Nodes()) { const bool supported = IsNodeSupported(node, input_params, logger); LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() @@ -149,7 +163,9 @@ bool HasNeuralEngine(const logging::Logger& logger) { #else // In this case, we are running the EP on non-apple platform, which means we are running the model // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine - LOGS(logger, VERBOSE) << "HasNeuralEngine running on non-Apple hardware for model conversion only"; + LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. " + "Returning true to enable model conversion and local testing of CoreML EP implementation. " + "No CoreML model will be compiled or run."; has_neural_engine = true; #endif // #ifdef __APPLE__ diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index d8b27ac76ae73..300de2dedd122 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -23,10 +23,14 @@ class Logger; namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags); +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags); -bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, - const OpBuilderInputParams& input_params, const logging::Logger& logger); +const IOpBuilder* GetOpBuilder(const Node& node); + +bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params, + const logging::Logger& logger); bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc index 53f18b205880c..e9e520156576e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class LRNOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_lrn = layer->mutable_lrn(); @@ -56,9 +43,6 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 88d6616b4e097..dee87ce3632a8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -2,44 +2,32 @@ // Licensed under the MIT License. #include "core/common/narrow.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ActivationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& node) const override; }; -// Add operator related - -#ifdef __APPLE__ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -86,7 +74,7 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type(node.OpType()); if (op_type == "Sigmoid") { @@ -115,14 +103,10 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related namespace { // assumes that node.OpType() == "PRelu" -bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) { +bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); // X input rank must be 3 or 4 diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index 7a5d4a5af673b..e9a8176c8349b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -1,37 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ArgMaxOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); @@ -67,9 +56,6 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 25d5bad14ceb6..2570e6d88ae0d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,21 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { -// Shared functions - +namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, const logging::Logger& logger) { @@ -37,93 +34,78 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node return false; } +} // namespace -// Add operator related -#ifdef __APPLE__ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - ORT_RETURN_IF_NOT( - IsOpSupported(node, input_params, logger), - "Unsupported operator ", - node.OpType()); - - ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); - LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() - << "] type: [" << node.OpType() << "] was added"; - return Status::OK(); -} + Status status = AddToModelBuilderImpl(model_builder, node, logger); -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(ModelBuilder& model_builder, const Node& node) { - auto layer_name = node.Name(); - if (layer_name.empty()) { - // CoreML requires layer has a name, while the node name is optional in ONNX - // In this case, create a unique name for the layer - layer_name = model_builder.GetUniqueName(MakeString("Node_", node.Index(), "_type_", node.OpType())); + if (status.IsOK()) { + LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; } - return CreateNNLayer(layer_name); -} -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(const std::string& layer_name) { - std::unique_ptr layer = std::make_unique(); - layer->set_name(layer_name); - return layer; + return status; } -#endif - -// Operator support related bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, input_params, logger)) + if (input_params.create_mlprogram && !SupportsMLProgram()) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support MLProgram"; return false; + } - // We do not support external initializers for now - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (HasExternalInitializer(initializers, node, logger)) + if (!HasSupportedOpSet(node, logger)) { + return false; + } + + if (!HasSupportedInputs(node, input_params, logger)) { return false; + } - if (!HasSupportedOpSet(node, logger)) + // We do not support external initializers for now + const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); + if (HasExternalInitializer(initializers, node, logger)) { return false; + } return IsOpSupportedImpl(node, input_params, logger); } bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(*input, node_name, input_params, logger)) { + if (!IsInputSupported(node, *input, input_params, logger)) { return false; } } - return HasSupportedInputsImpl(node, logger); + return HasSupportedInputsImpl(node, input_params, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - // We only check the type of input 0 by default - // specific op builder can override this +/* static */ +bool BaseOpBuilder::IsInput0Supported(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) { const auto& input = *node.InputDefs()[0]; - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; + int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + // currently only float is supported + if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; } return true; } -bool BaseOpBuilder::HasSupportedOpSet(const Node& node, - const logging::Logger& logger) const { +bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // We only check the type of input 0 by default + // specific op builder can override this + return IsInput0Supported(node, input_params, logger); +} + +bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { auto since_version = node.SinceVersion(); if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset [" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b4132d3b770ec..06c4dd94ea30d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -3,11 +3,9 @@ #pragma once -#include "core/providers/coreml/builders/op_builder.h" - -#ifdef __APPLE__ +#include "core/common/span_utils.h" #include "core/providers/coreml/builders/coreml_spec.h" -#endif +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { @@ -18,45 +16,40 @@ class BaseOpBuilder : public IOpBuilder { public: virtual ~BaseOpBuilder() = default; - // Add operator related + // does the operator implementation support creating an ML Program + bool SupportsMLProgram() const override { return false; } + + bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override final; -#ifdef __APPLE__ - public: - virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {} Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const override final; - protected: - virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const = 0; - - static std::unique_ptr - CreateNNLayer(ModelBuilder& model_builder, const Node& node); - - static std::unique_ptr CreateNNLayer(const std::string& layer_name); -#endif - - // Operator support related - public: - bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) const override final; + void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: - virtual bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const { + // check if the first input's data type is supported. + static bool IsInput0Supported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger); + + private: + virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& /*logger*/) const { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; + virtual bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const; - virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } + virtual int GetMinSupportedOpSet(const Node& /*node*/) const { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 20; } - private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const; + + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 391b02eaec497..8da58f659acf1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -5,30 +5,20 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class BatchNormalizationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -36,9 +26,6 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } }; -// Add operator related - -#ifdef __APPLE__ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // skip everything except input0 for BatchNormalization const auto& input_defs = node.InputDefs(); @@ -48,10 +35,9 @@ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_buil model_builder.AddInitializerToSkip(input_defs[4]->Name()); // var } -Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -81,9 +67,6 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 10c9b32d03f37..6074fba1433d9 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -1,35 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - -#include "base_op_builder.h" namespace onnxruntime { namespace coreml { - class BinaryOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related + int GetMinSupportedOpSet(const Node& node) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; }; -#ifdef __APPLE__ -static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { +namespace { +bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); const auto* x_shape_proto = input_defs[0]->Shape(); @@ -57,15 +50,14 @@ static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& y_shape_proto->dim().begin(), y_shape_proto->dim().end(), dim_eq); } - -// Add operator related +} // namespace Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Add") { // original mutable_add() has limited broadcasting support @@ -99,31 +91,28 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now return 7; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - bool is_pow = node.OpType() == "Pow"; - if (!is_pow) { - return BaseOpBuilder::HasSupportedInputsImpl(node, logger); +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (node.OpType() != "Pow") { + return IsInput0Supported(node, input_params, logger); } const auto& input_1 = *node.InputDefs()[0]; const auto& input_2 = *node.InputDefs()[1]; + // Pow we only support both inputs as fp32 for now int32_t input_type_1; - if (!GetType(input_1, input_type_1, logger)) - return false; - int32_t input_type_2; - if (!GetType(input_2, input_type_2, logger)) + if (!GetType(input_1, input_type_1, logger) || + !GetType(input_2, input_type_2, logger)) { return false; + } if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type" diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index ef66e6b877a1f..710f596b2a562 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -1,17 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ - #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/common/narrow.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" -#include "coreml_proto/NeuralNetwork.pb.h" +using namespace COREML_SPEC; namespace onnxruntime { namespace coreml { @@ -133,7 +132,182 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span> shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->size()); + for (const auto& dim : *shape) { + if (dim >= 0) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type, + const ONNX_NAMESPACE::TensorShapeProto* shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->dim_size()); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +template +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + // need a 'false' that is dependent on the template types to make gcc happy and give a meaningful error message. + static_assert(false_for_T && false_for_T, "Unsupported data type"); // add specializations below as needed +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_strings()->mutable_values()->Add(data.begin(), data.end()); +} + +// copy int64_t (used by ONNX for strides/indexes/etc.) to int32_t (used by CoreML) +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + auto& int32_out = *tensor_value.mutable_ints()->mutable_values(); + int32_out.Reserve(narrow(data.size())); + for (const int64_t v : data) { + int32_out.AddAlreadyReserved(narrow(v)); + } +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_bools()->mutable_values()->Add(data.begin(), data.end()); +} + +} // namespace + +MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) { + switch (static_cast(onnx_type)) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return MILSpec::DataType::FLOAT32; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return MILSpec::DataType::FLOAT64; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return MILSpec::DataType::BFLOAT16; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return MILSpec::DataType::FLOAT16; + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + return MILSpec::DataType::INT8; + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + return MILSpec::DataType::INT16; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return MILSpec::DataType::INT32; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return MILSpec::DataType::INT64; + + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + return MILSpec::DataType::UINT8; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + return MILSpec::DataType::UINT16; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + return MILSpec::DataType::UINT32; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + return MILSpec::DataType::UINT64; + + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + return MILSpec::DataType::BOOL; + case ONNX_NAMESPACE::TensorProto_DataType_STRING: + return MILSpec::DataType::STRING; + default: + ORT_THROW("Unsupported data type: ", onnx_type); + } +} + +template +MILSpec::Value CreateTensorValue(const gsl::span data, + std::optional> shape) { + MILSpec::Value value; + MILSpec::TensorType& tensor_type = *value.mutable_type()->mutable_tensortype(); + + if (shape) { + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), *shape); + } else { + // infer as 1D shape + std::vector coreml_shape{narrow(data.size())}; + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), coreml_shape); + } + + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyDataToTensorValue(tensor_value, data); + + return value; +} + +template +MILSpec::Value CreateScalarTensorValue(const T& data) { + gsl::span data_span{&data, 1}; + std::vector shape = {}; // empty for scalar + return CreateTensorValue(data_span, shape); +} + +// explicit specializations for types we handle so the implementation can be in the .cc file +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); + +template MILSpec::Value CreateScalarTensorValue(const float& data); +template MILSpec::Value CreateScalarTensorValue(const int32_t& data); +template MILSpec::Value CreateScalarTensorValue(const std::string& data); +template MILSpec::Value CreateScalarTensorValue(const bool& data); + +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) { + MILSpec::NamedValueType nvt; + nvt.set_name(node_arg.Name()); + MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()), + node_arg.Shape()); + + return nvt; +} + +void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) { + MILSpec::Argument arg; + arg.mutable_arguments()->Add()->set_name(std::string(value_name)); + + (*op.mutable_inputs())[input_name] = std::move(arg); +} + +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) { + auto& outputs = *op.mutable_outputs(); + auto& output_arg = *outputs.Add(); + output_arg.set_name(output.Name()); + + MILSpec::ValueType& value = *output_arg.mutable_type(); + MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), + output.Shape()); +} + } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 23b11928f7dc2..8126f0c126914 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -5,22 +5,19 @@ #pragma once -#ifdef __APPLE__ +#include #include "core/common/gsl.h" #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" -namespace CoreML { -namespace Specification { -class WeightParams; -} -} // namespace CoreML +#include "core/providers/coreml/builders/coreml_spec.h" namespace onnxruntime { -namespace coreml { +class NodeArg; +namespace coreml { // Try to see if we can map explicit padding to auto padding for Conv/Pool // Since usually use auto padding is more efficient Status HandleAutoPad(const std::vector input_shape, @@ -32,6 +29,10 @@ Status HandleAutoPad(const std::vector input_shape, AutoPadType auto_pad_type, AutoPadType& auto_pad_type_out); +// +// NeuralNetwork utils +// + // Copy an onnx initializer data to a coreml weight Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONNX_NAMESPACE::TensorProto& tensor); @@ -44,7 +45,90 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +// +// MLProgram utils +// + +// helper for static_assert where the value needs to be dependent on a template parameter +template +constexpr bool false_for_T = false; + +template +COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() { + if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT64; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BFLOAT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT16; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BOOL; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::STRING; + } else { + static_assert(false_for_T, "Unsupported type."); + } +} + +// The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value. +// Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally +COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type); + +/// +/// Create a CoreML MILSpec::TensorValue for the given input data. +/// +/// Original C++ data type +/// CoreML C++ data type +/// ONNX data +/// ONNX data shape. Inferred to be a 1D shape of `{data.size()}` if not specified. +/// TensorValue containing data. +template +COREML_SPEC::MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape = std::nullopt); + +template +COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data); + +/// Create a NamedValueType from an ONNX tensor NodeArg. +/// Used to create inputs for the 'main' function in an ML Program. +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg); + +/// +/// Add an input argument to a MILSpec::Operation +/// +/// Operation to update. +/// The input name defined by the spec for the operation. +/// The name of the value that is providing the input. +/// "https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html" +void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, + std::string_view input_name, std::string_view value_name); + +/// +/// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg. +/// +/// Operation to update. +/// NodeArg with details of output to add. +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 15ee1f0fc7284..70053c2c606a0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -1,34 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" #include "core/providers/coreml/builders/helper.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class CastOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; -}; -// Add operator related + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; -#ifdef __APPLE__ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, const Node& /* node */, const logging::Logger& /* logger */) const { @@ -37,9 +28,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. return Status::OK(); } -#endif - -// Operator support related bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -84,7 +72,8 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 3a3f89d24c7d8..9aca172abec98 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -1,37 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class ClipOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // Both min and max values will be injected into the layer, no need to add to the model if (node.SinceVersion() >= 11) { @@ -50,7 +37,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_name = node.InputDefs()[0]->Name(); const auto& output_name = node.OutputDefs()[0]->Name(); float min, max; - ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, min, max, logger), "GetClipMinMax failed"); + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); @@ -58,7 +45,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (!has_min && !has_max) { // Clip without min/max is an identity node // In CoreML we don't have identity, use ActivationLinear instead - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); *layer->mutable_input()->Add() = input_name; *layer->mutable_output()->Add() = output_name; @@ -83,8 +70,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Handle clipping at min first if (has_min) { - const auto clip_min_layer_name = model_builder.GetUniqueName(MakeString(node_name, "_Clip_min")); - std::unique_ptr min_layer = CreateNNLayer(clip_min_layer_name); + std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); if (min == 0.0f) { // If min is 0. then this min will be handled by relu min_layer->mutable_activation()->mutable_relu(); } else { // otherwise, min will be handled by unary->threshold @@ -101,9 +87,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (has_max) { const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); { // Add threshold layer, which is actually max( -1 * min_output, -max) - const auto clip_max_threshold_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_threshold")); - auto threshold_layer = CreateNNLayer(clip_max_threshold_layer_name); + auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); threshold_layer->mutable_unary()->set_alpha(-max); threshold_layer->mutable_unary()->set_scale(-1.0f); threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); @@ -112,9 +96,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(threshold_layer)); } { // Add linear activation layer -1 * threshold_output - const auto clip_max_linear_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_linear")); - auto linear_layer = CreateNNLayer(clip_max_linear_layer_name); + auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); *linear_layer->mutable_input()->Add() = threshold_output_name; *linear_layer->mutable_output()->Add() = output_name; @@ -125,15 +107,11 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { float min, max; - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - return GetClipMinMax(initializers, node, min, max, logger); + return GetClipMinMax(input_params.graph_viewer, node, min, max, logger); } void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index b1e761024f5c9..34193318a0264 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -4,37 +4,26 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ConcatOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_concat()->set_sequenceconcat(false); @@ -48,9 +37,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index ff9dcbd9f8874..05e43dbbd16af 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -4,39 +4,35 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" -#include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" -#endif +#include "core/providers/shared/utils/utils.h" + +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { class ConvOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (model_builder.CreateMLProgram()) { + // we add the initializers as 'const' operations via ModelBuilder::RegisterInitializers + return; + } + const auto& input_defs = node.InputDefs(); // skip the weight and bias (if has it) for conv as we will directly set those as part of the NN layer @@ -49,136 +45,251 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); - const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); const auto& output_name = output_defs[0]->Name(); - const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - const bool is_1d_conv = (weight_shape.size() == 3); + // https://github.com/apple/coremltools/blob/7.1/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py - if (is_1d_conv) { - // weight_shape needs to be expanded from MXCXH->MXCXHx1 - weight_shape.push_back(1); - } + std::unique_ptr conv_op = model_builder.CreateOperation(node, "conv"); - NodeAttrHelper helper(node); - auto strides = helper.Get("strides", std::vector{1, 1}); - auto dilations = helper.Get("dilations", std::vector{1, 1}); - auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 - // to meet the required length 2 (for 2d conv it's normally 2) - // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. - if (is_1d_conv) { - if (strides.size() < 2) { - ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); - strides.push_back(1); + AddOperationInput(*conv_op, "x", input_name); + AddOperationInput(*conv_op, "weight", input_defs[1]->Name()); + + if (input_defs.size() > 2) { + AddOperationInput(*conv_op, "bias", input_defs[2]->Name()); } - if (dilations.size() < 2) { - ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); - dilations.push_back(1); + + // ONNX attributes. Add as inputs if specified/required + auto strides = helper.GetInt64s("strides"); + auto dilations = helper.GetInt64s("dilations"); + auto groups = helper.GetInt64("group"); + + // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. + const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; + const auto& op_type = conv_op->type(); + + if (strides) { + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", *strides)); + } else { + // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) + static const auto default_value = std::vector(num_spatial_dims, 1); + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", default_value)); } - if (onnx_pads.size() < 4) { - ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); - onnx_pads.insert(onnx_pads.begin() + 1, 0); - onnx_pads.push_back(0); + + if (dilations) { + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", *dilations)); + } else { + // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) + static const auto default_value = std::vector(num_spatial_dims, 1); + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", default_value)); } - } - const auto group = helper.Get("group", static_cast(1)); - - auto* coreml_conv = layer->mutable_convolution(); - - std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); - - if (is_1d_conv) { - const auto expand_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_expand")); - std::unique_ptr expand_layer = CreateNNLayer(expand_layer_name); - // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case - // we need to add an additional dimension here to the input to make it "2d Conv" like. - // NxCxH -> NxCxHx1 - expand_layer->mutable_expanddims()->add_axes(-1); - *expand_layer->mutable_input()->Add() = input_name; - *expand_layer->mutable_output()->Add() = expand_output_name; - model_builder.AddLayer(std::move(expand_layer)); - } - coreml_conv->set_outputchannels(weight_shape[0]); // M - coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group - coreml_conv->add_kernelsize(weight_shape[2]); // H - coreml_conv->add_kernelsize(weight_shape[3]); // W - coreml_conv->set_ngroups(group); - *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; - *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; - - coreml_conv->set_isdeconvolution(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - onnx_pads, strides, dilations, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_conv->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + + if (groups) { + AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); } - } else { - auto* padding_type = coreml_conv->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + + // pad type (string) + // valid - no pads (ONNX auto_pad VALID) + // custom - pads input (ONNX NOTSET) + // same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])` (assuming == ONNX SAME_UPPER) + // same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER) + // + // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value + // can be used. TBD if that provides any performance benefit with ML Program though as CoreML could + // potentially do that for us. + switch (auto_pad_type) { + case AutoPadType::NOTSET: { + // use `pads` attribute. + auto onnx_pads = helper.GetInt64s("pads"); // 'pads' must be provided if auto_pad is NOTSET + if (onnx_pads) { + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom"))); + + // need to re-order from x1_start, x2_start..., x1_end, x2_end... to + // x1_start, x1_end, x2_start, x2_end,... + size_t num_pads = onnx_pads->size(); + size_t num_dims = num_pads / 2; + std::vector reordered_pads(num_pads, 0); + for (size_t i = 0; i < num_pads; ++i) { + auto cur_dim = i % num_dims; + if (i < num_dims) { // start values + reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; + } else { // end values + reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; + } + } + + AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); + + break; + } + + // in theory the pads may not be provided and in that case the default is no padding. + // as that is the same as 'valid', fall through + [[fallthrough]]; + } + case AutoPadType::VALID: + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + + break; + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: { + const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); + + // despite what the spec says, a 'pad' input seems to be required. + // https://github.com/apple/coremltools/issues/2127 + // provide the default value. passing in an empty vector also works. TBD what's better. + std::vector ignored_pads(num_spatial_dims * 2, 0); + AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); + + break; + } } - } - // Add weight - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + // set output + AddOperationOutput(*conv_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(conv_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto strides = helper.Get("strides", std::vector{1, 1}); + auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto group = helper.Get("group", static_cast(1)); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + + const bool is_1d_conv = (weight_shape.size() == 3); + + // add dummy 'W' dim with value of 1 so we can use 2D conv. + if (is_1d_conv) { + input_shape.push_back(1); + weight_shape.push_back(1); + + // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 + // to meet the required length 2 (for 2d conv it's normally 2) + if (strides.size() < 2) { + ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); + strides.push_back(1); + } + + if (dilations.size() < 2) { + ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); + dilations.push_back(1); + } + + // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. + if (onnx_pads.size() < 4) { + ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); + onnx_pads.insert(onnx_pads.begin() + 1, 0); + onnx_pads.push_back(0); + } + } - // Add bias if present - if (input_defs.size() > 2) { - coreml_conv->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); - } + auto* coreml_conv = layer->mutable_convolution(); - if (is_1d_conv) { - std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); - *layer->mutable_input()->Add() = expand_output_name; - *layer->mutable_output()->Add() = conv_output_name; - model_builder.AddLayer(std::move(layer)); - - // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, - // we need to squeeze it back from NxCxHx1->NxCxH. - const auto squeeze_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_squeeze")); - std::unique_ptr squeeze_layer = CreateNNLayer(squeeze_layer_name); - squeeze_layer->mutable_squeeze()->add_axes(-1); - *squeeze_layer->mutable_input()->Add() = conv_output_name; - *squeeze_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(squeeze_layer)); - } else { - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(layer)); + std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); + + if (is_1d_conv) { + // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case + // we need to add an additional dimension here to the input to make it "2d Conv" like. + // NxCxH -> NxCxHx1 + auto expand_layer = model_builder.CreateNNLayer(node, "_Conv_expand"); + expand_layer->mutable_expanddims()->add_axes(-1); + *expand_layer->mutable_input()->Add() = input_name; + *expand_layer->mutable_output()->Add() = expand_output_name; + model_builder.AddLayer(std::move(expand_layer)); + } + + coreml_conv->set_outputchannels(weight_shape[0]); // M + coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group + coreml_conv->add_kernelsize(weight_shape[2]); // H + coreml_conv->add_kernelsize(weight_shape[3]); // W + coreml_conv->set_ngroups(group); + *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; + *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; + + coreml_conv->set_isdeconvolution(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + onnx_pads, strides, dilations, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_conv->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_conv->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } + } + + // Add weight + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + + // Add bias if present + if (input_defs.size() > 2) { + coreml_conv->set_hasbias(true); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); + } + + if (is_1d_conv) { + std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); + *layer->mutable_input()->Add() = expand_output_name; + *layer->mutable_output()->Add() = conv_output_name; + model_builder.AddLayer(std::move(layer)); + + // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, + // we need to squeeze it back from NxCxHx1->NxCxH. + auto squeeze_layer = model_builder.CreateNNLayer(node, "_Conv_squeeze"); + squeeze_layer->mutable_squeeze()->add_axes(-1); + *squeeze_layer->mutable_input()->Add() = conv_output_name; + *squeeze_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(squeeze_layer)); + } else { + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(layer)); + } } return Status::OK(); } -#endif - -// Operator support related bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -186,23 +297,73 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4 && tensor.dims().size() != 3) { - LOGS(logger, VERBOSE) << "Conv [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d and conv 1d are supported."; + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name, true); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + // ML Program supports non-const weight, 1D, 2D and 3D. + // keep to 1D and 2D for consistency with the NeuralNetwork implementation for now. + // add 3D support as/when needed. + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + if (!weight) { + LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be a constant initializer"; return false; } - } else { - LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be known"; + } + + // use the weight for the shape as it should always be known + const auto* weight_shape = input_defs[1]->Shape(); + int64_t num_dims = weight_shape ? weight_shape->dim_size() : -1; + + // ONNX spec requires N and C as first 2 dims + if (num_dims != 3 && num_dims != 4) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] is " << num_dims - 2 << "D. " + << "Only 1D and 2D Conv are supported currently."; return false; } - if (input_defs.size() > 2) { - const auto& bias_name = input_defs[2]->Name(); - if (!Contains(initializers, bias_name)) { - LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name(), true)) { + LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + return false; + } + + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + // spec says same_lower is supported in CoreML 5. it lies. CoreML 6 is required otherwise you get + // `Unexpected value for parameter pad_type[0] "same_lower" not in ("custom", "same", "valid").` + // We _could_ manually calculate the pads, but not implementing that until we have a real use case to justify + // the effort as it's not clear how common usage of same_lower is. + if (input_params.create_mlprogram && input_params.coreml_version < 6) { + if (StringToAutoPadType(helper.Get("auto_pad", "NOTSET")) == AutoPadType::SAME_LOWER) { + LOGS(logger, VERBOSE) << "Pad type of SAME_LOWER [" << name << "] is not supported until CoreML 6." + << "Available version is CoreML " << input_params.coreml_version; + return false; + } + } +#endif + + // there's no equivalent to allow a manual kernel shape in CoreML. + // it's OK if a specified kernel_shape matches kH and kW dims of the weight input. + auto kernel_shape = helper.GetInt64s("kernel_shape"); + if (kernel_shape) { + bool valid = true; + if (static_cast(kernel_shape->size()) == num_dims - 2) { + for (int i = 0; i < num_dims - 2; ++i) { + // check the specified kernel shape matches the weight shape. skip the initial N and C dims in the latter. + if ((*kernel_shape)[i] != weight_shape->dim()[i + 2].dim_value()) { + valid = false; + break; + } + } + } else { + valid = false; + } + + if (!valid) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] kernel_shape attribute does not match the weight shape"; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index a4ad1c31b5027..1eba312b2577b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -4,37 +4,26 @@ #include "core/common/safeint.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class DepthToSpaceOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); @@ -54,9 +43,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc index b303fe7884cb1..f0adb70587bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class FlattenOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); // Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams auto* coreml_flatten = layer->mutable_flattento2d(); @@ -51,9 +38,6 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 9c7ec306ca093..7d32675e3e510 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -2,34 +2,24 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class GatherOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) namespace { int64_t GetAxisAttribute(const Node& node) { NodeAttrHelper node_attr_helper{node}; @@ -38,8 +28,8 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_gather()->set_axis(GetAxisAttribute(node)); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices @@ -47,10 +37,9 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 71b08db6d44d8..48f77354d7c30 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -7,38 +7,25 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class GemmOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; }; -// Add operator related - -#ifdef __APPLE__ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op = node.OpType(); const auto& input_defs(node.InputDefs()); @@ -71,7 +58,7 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -120,9 +107,6 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc index ba12600e8bc40..99d6f01cb8c5b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc @@ -7,30 +7,20 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PadOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -64,9 +54,6 @@ static InlinedVector GetPaddingAxesData(const InitializedTensorSet& ini return axes_tensor_data; } -// Add operator related - -#ifdef __APPLE__ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // pads model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // constant_value @@ -78,7 +65,7 @@ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pad = layer->mutable_padding(); auto* constant_padding_type = coreml_pad->mutable_constant(); // CoreML::Specification::PaddingLayerParams_PaddingConstant @@ -122,9 +109,6 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index fd1c77c851e6f..01aced739b36d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -4,38 +4,27 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PoolOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pool = layer->mutable_pooling(); const auto& op_type = node.OpType(); @@ -108,9 +97,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index 6a2014e7952a2..32378b1f654d8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -1,36 +1,27 @@ // Copyright (c) Shukant Pal. // Licensed under the MIT License. +#include "core/optimizer/initializer.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/optimizer/initializer.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ReductionOpBuilder : public BaseOpBuilder { -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -#ifdef __APPLE__ namespace { template void AddReductionParams(T* params, const std::vector& axes, bool keepdims, bool noop_with_empty_axes) { @@ -76,7 +67,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "ReduceSum") { AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); @@ -93,7 +84,6 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -124,4 +114,4 @@ void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 67aee73630cdb..7ae1746be3122 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -6,31 +6,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ReshapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -38,9 +28,6 @@ class ReshapeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } }; -// Add operator related - -#ifdef __APPLE__ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } @@ -48,7 +35,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -69,9 +56,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 5f963dc30dd8f..35dcde41a6bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -8,31 +8,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ResizeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -41,7 +31,7 @@ class ResizeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } }; -// Helper functions +namespace { bool GetResizeScales(const InitializedTensorSet& initializers, const Node& node, std::vector& scales, const logging::Logger&) { @@ -73,10 +63,8 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, sizes = std::vector(sizes_data.begin(), sizes_data.end()); return true; } +} // namespace -// Add operator related - -#ifdef __APPLE__ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // We don't really use ROI here, so add it to skipped list if it's an initializer tensor model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI @@ -96,7 +84,7 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_upsample = layer->mutable_upsample(); NodeAttrHelper helper(node); @@ -131,9 +119,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index fd64153ffd283..a86e3d9538d87 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -2,44 +2,30 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" // for NodeAttrHelper -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class ShapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_getshape(); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { NodeAttrHelper node_attr_helper{node}; diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 2c250b3cc9f5a..b716af738e1b1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -1,39 +1,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/slice_helper.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class SliceOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: int GetMinSupportedOpSet(const Node& /* node */) const override { // Before Slice-10, some inputs were attributes instead. We don't support that for now. return 10; } - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params, const logging::Logger& logger) const override; }; @@ -107,9 +99,6 @@ bool ValidateSliceComputeMetadataForCoreML(const SliceOp::PrepareForComputeMetad } } // namespace -// Add operator related -#if defined(__APPLE__) - void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -132,7 +121,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR(PrepareSliceComputeMetadataFromConstantInitializers(node, model_builder.GetGraphViewer(), compute_metadata)); - auto layer = CreateNNLayer(model_builder, node); + auto layer = model_builder.CreateNNLayer(node); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); auto* slice_static = layer->mutable_slicestatic(); @@ -163,10 +152,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -#endif // defined(__APPLE__) - -// Operator support related -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index c454a2a779f6e..266396a0fe90e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -1,43 +1,29 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/coreml/shape_utils.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class SoftmaxOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_name = node.InputDefs()[0]->Name(); const auto& output_name = node.OutputDefs()[0]->Name(); @@ -68,9 +54,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); { // Add reshape layer - const auto softmax_reshape1_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1")); - auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; *reshape_layer->mutable_input()->Add() = input_name; *reshape_layer->mutable_output()->Add() = reshape1_output_name; @@ -86,9 +70,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } { // Add reshape back layer - const auto softmax_reshape2_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2")); - auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; *reshape_layer->mutable_input()->Add() = softmax_output_name; *reshape_layer->mutable_output()->Add() = output_name; @@ -99,10 +81,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 56c87c883156b..0497357c45c54 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -1,35 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class SplitOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -37,10 +26,6 @@ class SplitOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } }; -// Add operator related - -#ifdef __APPLE__ - void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -63,7 +48,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // attribute introduced since opset 18 uint64_t num_outputs; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_splitnd = layer->mutable_splitnd(); coreml_splitnd->set_axis(axis); @@ -82,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt("num_outputs").value()); + num_outputs = narrow(helper.GetInt64("num_outputs").value()); // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); @@ -111,10 +96,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -159,7 +140,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } } else { if (node.SinceVersion() >= 18) { - const auto num_outputs = helper.GetInt("num_outputs"); + const auto num_outputs = helper.GetInt64("num_outputs"); if (!num_outputs.has_value()) { LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; @@ -169,9 +150,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); return false; } - if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { - LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." - << "The value should be smaller or equal to the size of dimension being split. num_outputs: " + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || + num_outputs.value() > split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " << num_outputs.value(); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index 2e14c85ce69c1..e9cc1c2dbf638 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -1,48 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include + +#include "core/common/safeint.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/optimizer/initializer.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/optimizer/initializer.h" namespace onnxruntime { namespace coreml { class SqueezeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ -void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); - } -} - -/* static */ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { +namespace { +Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { // Squeeze opset 13 use input as axes if (node.SinceVersion() > 12) { // If axes is not provided, return an empty axes as default to squeeze all @@ -62,11 +44,18 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const return Status::OK(); } +} // namespace + +void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + } +} Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_squeeze = layer->mutable_squeeze(); std::vector axes; @@ -84,9 +73,6 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& /*logger*/) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc index 7d5018a19f74c..f6a61d55a3d63 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc @@ -3,33 +3,23 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class TransposeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -// Add operator related - -#ifdef __APPLE__ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); NodeAttrHelper helper(node); std::vector perm = helper.Get("perm", std::vector()); @@ -51,7 +41,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 660755b43c043..3403378d59114 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,32 +3,25 @@ #include "core/providers/common.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -#ifdef __APPLE__ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sqrt") { layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); @@ -45,9 +38,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); @@ -55,4 +45,4 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9c8b7bce507e4..daab36f7b933d 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -2,56 +2,555 @@ // Licensed under the MIT License. #include -#include - -#include "model_builder.h" -#include "helper.h" -#include "op_builder_factory.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/platform/env.h" #include "core/providers/common.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" -#include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" +#if defined(COREML_ENABLE_MLPROGRAM) +// includes from coremltools-src in _deps +#include "modelpackage/src/ModelPackage.hpp" +#include "mlmodel/src/MILBlob/Blob/StorageWriter.hpp" +using MILBlob::Blob::StorageWriter; +#endif + +using namespace CoreML::Specification; + namespace onnxruntime { namespace coreml { -ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags) - : graph_viewer_(graph_viewer), - logger_(logger), - coreml_flags_(coreml_flags) { +namespace { +#if defined(COREML_ENABLE_MLPROGRAM) +// Should the initializer be written to file or kept as an immediate value +bool ShouldWriteInitializerToWeightsFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/load.py#L51-L57 + + bool use_weight_file = false; + + switch (tensor_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + auto num_elements = TensorShape(utils::GetTensorShapeFromTensorProto(tensor_proto)).Size(); + use_weight_file = num_elements >= 10; + break; + } + default: + break; + } + + return use_weight_file; +} + +// copy from the ONNX TensorProto to a CoreML field. +// T1 is the source type. T2 is the target type. If the types differ, T1 must be smaller than T2. +// e.g. uint32_t data can be written to RepeatedField +template +void CopyRawDataToRepeatedField(const ONNX_NAMESPACE::TensorProto& tensor_proto, + google::protobuf::RepeatedField& repeated_field) { + const auto& raw_data = tensor_proto.raw_data(); + const T1* data = reinterpret_cast(raw_data.data()); + const T1* data_end = data + (raw_data.size() / sizeof(T1)); + if constexpr (sizeof(T1) == sizeof(T2)) { + repeated_field.Add(data, data_end); + } else { + static_assert(sizeof(T1) < sizeof(T2)); + // we need to iterate over the data and copy to the repeated field, converting to T2 as we go. + repeated_field.Resize(data_end - data, T2(0)); + for (int i = 0; data != data_end; ++data, ++i) { + repeated_field[i] = static_cast(*data); + } + } +} + +// copy T data from the TensorProto.int32_t field to TensorValue.bytes +template +void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.int32_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// copy T data from the TensorProto.uint64_data field to TensorValue.bytes +template +void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.uint64_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const uint64_t* in = tensor_proto.uint64_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// NOTE: This supports all the ONNX data types. Weights in CoreML may not need all these +void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILSpec::TensorValue& tensor_value) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + // handling based on + // ONNX TensorProto field usage + // https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/onnx/onnx.proto#L544-L572 + // CoreMLTools conversion implementation that maps data types to fields + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L98 + // along with some special cased types that are stored in bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L23 + // IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32) + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_floats()->mutable_values()); + } else { + tensor_value.mutable_floats()->mutable_values()->CopyFrom(tensor_proto.float_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + // from: double_data/raw, to: doubles + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_doubles()->mutable_values()); + } else { + tensor_value.mutable_doubles()->mutable_values()->CopyFrom(tensor_proto.double_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + // from: int32_data/raw, to: ints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + // from: int64_data/raw, to: longints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + + } else { + tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // iterate the int32_data, taking the 16-bits from each entry, and copying to the bytes. + // we use uint16_t as only the size of the data type matters + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy from int32_data to bytes. uint8_t for both as only the size of the data type matters when copying + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + // from: uint64_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy uint32_t values from TensorProto.uint64_data + CopyUInt64DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { + // from: uint64_data/raw, to: longints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + } else { + // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this + // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each + // individual value. + tensor_value.mutable_longints()->mutable_values()->CopyFrom( + reinterpret_cast&>(tensor_proto.uint64_data())); + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: { + // from: int32_data/raw, to: bools + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_bools()->mutable_values()); + } else { + const auto& int32s = tensor_proto.int32_data(); + auto& bools = *tensor_value.mutable_bools()->mutable_values(); + const int num_entries = int32s.size(); + bools.Reserve(num_entries); + const int32_t* in = int32s.data(); + for (int i = 0; i < num_entries; ++i) { + *bools.AddAlreadyReserved() = *in++; + } + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_STRING: { + // from: string_data (which is protobuf type bytes), to: strings (protobuf type string) + // due to the protobuf type mismatch we need to iterate and copy + auto& in = tensor_proto.string_data(); + auto& out = *tensor_value.mutable_strings()->mutable_values(); + out.Reserve(in.size()); + for (const auto& iter : in) { + *out.Add() = iter; + } + + break; + } + /* Not clear if there's an actual use-case for 16-bit int data currently, so leaving commented out + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + // from: int32_data/raw, to: ints + // WARNING: This may change to write to mutable_bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L113-L115 + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } */ + default: + ORT_THROW("AddTensorProtoDataToMILSpecTensorValue: Unsupported data type: ", data_type); + } +} + +template +uint64_t WriteRawDataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + MILBlob::Util::Span data(reinterpret_cast(tensor_proto.raw_data().data()), + tensor_proto.raw_data().size() / sizeof(T)); + return writer.WriteData(data); +} + +// Write T1 data from the TensorProto.int32_data field using StorageWriter. +// Currently int32_data can have any of these data types: +// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, +// FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ +// T1 provides the size of the ONNX data type. T2 is the CoreML type. +// The sizes and layout of T1 and T2 must match as we simply cast the bytes to T2. +template +uint64_t WriteFromInt32DataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + static_assert(sizeof(T1) == sizeof(T2), "Data sizes must match"); + + // need to copy to temporary data as we have to extract a subset of bytes from each int32_t entry. + // works better to extract the ONNX type first with static_cast, and reinterpret_cast to the CoreML type at the end. + std::vector values; + const int num_values = tensor_proto.int32_data_size(); + values.resize(num_values); // resize so we're not updating the length inside the copy loop + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_values; ++i) { + values[i] = static_cast(in[i]); + } + + MILBlob::Util::Span data(reinterpret_cast(values.data()), + num_values); + return writer.WriteData(data); +} + +// write the initializer to weight.bin and return the offset +// StorageWriter is currently limited to fp32, fp16, bfloat16, uint8/int8, uint16/int16. +// AFAIK we don't use bfloat16/int16/uint16 for weights in ONNX, so limit handling to fp32, fp16, uint8/int8 +uint64_t CopyOnnxTensorToCoreMLWeightsFile(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + uint64_t offset = 0; + + // See AddTensorProtoDataToMILSpecTensorValue for links to sources for info on where the different typed data is + // stored for ONNX and CoreML + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + MILBlob::Util::Span data(tensor_proto.float_data().data(), tensor_proto.float_data().size()); + offset = writer.WriteData(data); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + + break; + } + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + default: + ORT_THROW("AddWeightToFile: Unsupported data type: ", data_type); + } + + return offset; +} + +MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& weights_file_writer) { + MILSpec::Value value; + + // populate ValueType with tensor data type, dims and rank + MILSpec::ValueType& value_type = *value.mutable_type(); + MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype(); + tensor_type.set_datatype(OnnxDataTypeToMILSpec(tensor_proto.data_type())); + + tensor_type.set_rank(tensor_proto.dims().size()); + for (const auto& dim : tensor_proto.dims()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(dim); + } + + // add data to either weights.bin or as an immediate value + if (ShouldWriteInitializerToWeightsFile(tensor_proto)) { + uint64_t offset = CopyOnnxTensorToCoreMLWeightsFile(tensor_proto, weights_file_writer); + + auto* file_value = value.mutable_blobfilevalue(); + // Filename copied from + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L329 + file_value->set_filename("@model_path/weights/weight.bin"); + file_value->set_offset(offset); + } else { + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyOnnxTensorToCoreMLTensor(tensor_proto, tensor_value); + } + + return value; +} + +void CreateEmptyFile(const std::string& filename) { + std::ofstream file(filename, std::ofstream::out | std::ofstream::binary); + ORT_ENFORCE(file.is_open(), "Failed to open file ", filename); } -Status ModelBuilder::Initialize() { - coreml_model_ = std::make_unique(); - { // initialize CoreML model +#endif // defined(COREML_ENABLE_MLPROGRAM) + +std::string GetModelOutputPath(bool create_ml_program) { + // path is used to create the ML Package directory for ML Program, and for the model directly otherwise. + auto path = util::GetTemporaryFilePath(); + if (!create_ml_program) { + path += ".model.mlmodel"; + } + + return path; +} +} // namespace + +ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags) + : graph_viewer_(graph_viewer), + logger_(logger), + coreml_version_(coreml_version), + coreml_flags_(coreml_flags), + create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), + model_output_path_(GetModelOutputPath(create_ml_program_)), + coreml_model_(std::make_unique()) { + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + coreml_model_->set_specificationversion(CoreMLSpecVersion()); + MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); + MILSpec::Function& main = (*mlprogram.mutable_functions())["main"]; + + const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion()); + *main.mutable_opset() = coreml_opset; + mlprogram_main_ = &(*main.mutable_block_specializations())[coreml_opset]; + + // create the ModelPackage. this creates the output directory. + mlpackage_ = std::make_unique(model_output_path_, /* create */ true); + + // ModelPackage::addItem does a copy of the file. Due to this we 'add' an empty file first, + // and do the actual writes to the file created in the package. + // We can't use ModelPackage::createFile as we have to add a directory for the weights. + std::string tmp_dir = model_output_path_ + "/tmp"; + ORT_THROW_IF_ERROR(Env::Default().CreateFolder(ToPathString(tmp_dir))); + CreateEmptyFile(tmp_dir + "/weight.bin"); + + std::string weights_id = mlpackage_->addItem(tmp_dir, "weights", "com.microsoft.OnnxRuntime", + "CoreML Model Weights"); + auto weights_info = mlpackage_->findItem(weights_id); + weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); +#else + // should never happen due to handling in coreml_execution_provider.cc + ORT_THROW("ML Program is not enabled in this build"); +#endif + } else { // We support CorelML Specification Version 4 (Core ML 3) coreml_model_->set_specificationversion(4); auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); + neural_network->set_arrayinputshapemapping( + CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); } +} - PreprocessInitializers(); - ORT_RETURN_IF_ERROR(RegisterInitializers()); - ORT_RETURN_IF_ERROR(RegisterModelInputs()); - ORT_RETURN_IF_ERROR(AddOperations()); - ORT_RETURN_IF_ERROR(RegisterModelOutputs()); +ModelBuilder::~ModelBuilder() = default; - return Status::OK(); +/* + * NeuralNetwork related helpers + */ +std::unique_ptr ModelBuilder::CreateNNLayer(const Node& node, std::string_view suffix) { + auto layer_name = GetUniqueName(node, suffix); + + std::unique_ptr layer = std::make_unique(); + layer->set_name(layer_name); + return layer; +} + +void ModelBuilder::AddLayer(std::unique_ptr layer) { + auto* neural_network = coreml_model_->mutable_neuralnetwork(); + neural_network->mutable_layers()->AddAllocated(layer.release()); } -/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { - const auto& op_builders = GetOpBuilders(); - const auto it = op_builders.find(node.OpType()); - if (it != op_builders.cend()) - return it->second; +#if defined(COREML_ENABLE_MLPROGRAM) + +/* + * ML Program related helpers + */ +std::unique_ptr ModelBuilder::CreateOperation(const Node& node, + std::string_view op_type, + std::string_view suffix) { + std::string operation_name = GetUniqueName(node, suffix); + + std::unique_ptr op = std::make_unique(); + op->set_type(std::string(op_type)); + (*op->mutable_attributes())["name"] = CreateScalarTensorValue(operation_name); + + return op; +} + +void ModelBuilder::AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer) { + MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(initializer, *weights_file_writer_); + AddConstantOperation(name, std::move(coreml_tensor)); +} + +void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { + // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic + MILSpec::Operation& const_op = *mlprogram_main_->mutable_operations()->Add(); + const_op.set_type("const"); + + MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add(); + output.set_name(std::string(name)); + *output.mutable_type() = coreml_tensor.type(); + + auto& attr_map = *const_op.mutable_attributes(); + attr_map["name"] = CreateScalarTensorValue(std::string(name)); + attr_map["val"] = std::move(coreml_tensor); +} + +// Add operation to the Block for the main function in the ML Program +void ModelBuilder::AddOperation(std::unique_ptr operation) { + mlprogram_main_->mutable_operations()->AddAllocated(operation.release()); +} + +std::string ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + MILSpec::Value&& input_value) { + auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type)); + AddConstantOperation(unique_value_name, std::move(input_value)); + return unique_value_name; +} + +template +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape) { + // add specialization below + static_assert(false_for_T, "Missing specialization for value type"); + return ""; // unreachable +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); // CoreML uses int32 + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} - return nullptr; +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } +#endif // defined(COREML_ENABLE_MLPROGRAM) + +/* + * General implementation + */ void ModelBuilder::PreprocessInitializers() { - // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places + // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places. + // non-constant initializers need to be passed in as model inputs in case they're overridden at runtime. const auto& initializers = graph_viewer_.GetAllInitializedTensors(); const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); @@ -64,6 +563,7 @@ void ModelBuilder::PreprocessInitializers() { initializer_usage_[input->Name()]++; } } + if (const auto* op_builder = GetOpBuilder(node)) { op_builder->AddInitializersToSkip(*this, node); } @@ -77,27 +577,34 @@ Status ModelBuilder::RegisterInitializers() { // skip initializer if there is no remaining usage auto usage_count = initializer_usage_[name]; - if (usage_count == 0) + if (usage_count == 0) { continue; + } - std::unique_ptr layer = std::make_unique(); - layer->set_name(GetUniqueName("initializer_" + name)); - - // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer - auto* constant_tensor = layer->mutable_loadconstantnd(); - const auto& shape = tensor.dims(); - if (shape.empty()) { - // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor - constant_tensor->mutable_shape()->Add(1); + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + AddConstant(name, tensor); +#endif } else { - std::transform(shape.cbegin(), shape.cend(), - google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), - [](int64_t dim) -> uint64_t { return SafeInt(dim); }); - } + std::unique_ptr layer = std::make_unique(); + layer->set_name(GetUniqueName("initializer_" + name)); + + // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer + auto* constant_tensor = layer->mutable_loadconstantnd(); + const auto& shape = tensor.dims(); + if (shape.empty()) { + // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor + constant_tensor->mutable_shape()->Add(1); + } else { + std::transform(shape.cbegin(), shape.cend(), + google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), + [](int64_t dim) -> uint64_t { return SafeInt(dim); }); + } - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); - *layer->mutable_output()->Add() = name; - AddLayer(std::move(layer)); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); + *layer->mutable_output()->Add() = name; + AddLayer(std::move(layer)); + } } return Status::OK(); @@ -179,15 +686,15 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i data_type = type_proto->tensor_type().elem_type(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::FLOAT32); + multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: // If we have an int64 input/output type, since COREML_SPEC:ArrayFeatureType does not support INT64 // we assign it to be INT32 here - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); if (!is_input) { // Record the output names and we need to change them back to Int64 when CoreML EP returns these values to ORT AddInt64Output(name); @@ -204,6 +711,19 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape}); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + MILSpec::Function& main = (*coreml_model_->mutable_mlprogram()->mutable_functions())["main"]; + if (is_input) { + // the model inputs need to be wired up as args to the 'main' function + main.mutable_inputs()->Add(CreateNamedTensorValueType(node_arg)); + } else { + // the model outputs need to be set as outputs of the Block for the 'main' function + *mlprogram_main_->mutable_outputs()->Add() = node_arg.Name(); + } + } +#endif // defined(COREML_ENABLE_MLPROGRAM) + return Status::OK(); } @@ -215,16 +735,16 @@ Status ModelBuilder::RegisterModelInputs() { return Status::OK(); } -Status ModelBuilder::AddOperations() { - const auto builder_params = MakeOpBuilderParams(graph_viewer_, coreml_flags_); - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - if (const auto* op_builder = GetOpBuilder(*node)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, builder_params, logger_)); +Status ModelBuilder::ProcessNodes() { + for (const auto node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { + const auto& node = *graph_viewer_.GetNode(node_idx); + if (const auto* op_builder = GetOpBuilder(node)) { + ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node, logger_)); } else { + // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing + // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + "Node [", node.Name(), "], type [", node.OpType(), "] is not supported"); } } @@ -239,29 +759,72 @@ Status ModelBuilder::RegisterModelOutputs() { return Status::OK(); } -Status ModelBuilder::Compile(std::unique_ptr& model, const std::string& path) { - ORT_RETURN_IF_ERROR(SaveCoreMLModel(path)); - model.reset(new Model(path, logger_, coreml_flags_)); - model->SetScalarOutputs(std::move(scalar_outputs_)); - model->SetInt64Outputs(std::move(int64_outputs_)); - model->SetInputOutputInfo(std::move(input_output_info_)); - return model->LoadModel(); +Status ModelBuilder::CreateModel() { + PreprocessInitializers(); + + ORT_RETURN_IF_ERROR(RegisterInitializers()); + ORT_RETURN_IF_ERROR(RegisterModelInputs()); + ORT_RETURN_IF_ERROR(ProcessNodes()); + ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + + return Status::OK(); } -Status ModelBuilder::SaveCoreMLModel(const std::string& path) { - ORT_RETURN_IF_ERROR(Initialize()); - std::ofstream stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Save the CoreML model failed"); +Status ModelBuilder::SaveModel() { + std::string output_path = model_output_path_; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel"; + CreateEmptyFile(tmp_model_path); + + std::string model_id = mlpackage_->setRootModel(tmp_model_path, "model.mlmodel", "com.microsoft.OnnxRuntime", + "CoreML Model Specification"); + auto model_info = mlpackage_->findItem(model_id); + output_path = model_info->path(); + } +#endif - // TODO, Delete, debug only - if (const char* path = std::getenv("ORT_COREML_EP_CONVERTED_MODEL_PATH")) { - std::ofstream temp_stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&temp_stream), "Save the CoreML model failed"); + // scope this so the stream is closed and flushed by the ofstream dtor + { + LOGS(logger_, INFO) << "Writing CoreML Model to " << output_path; + std::ofstream stream(output_path, std::ofstream::out | std::ofstream::binary); + ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Saving the CoreML model failed. Path=", output_path); } +#if defined(COREML_ENABLE_MLPROGRAM) + // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program + // related types as well. + mlprogram_main_ = nullptr; + mlpackage_.reset(); + weights_file_writer_.reset(); +#endif + return Status::OK(); } +Status ModelBuilder::LoadModel(std::unique_ptr& model) { + model = std::make_unique(model_output_path_, + std::move(input_output_info_), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + + return model->LoadModel(); // load using CoreML API, including compilation +} + +// static +Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::unique_ptr& model) { + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags); + + ORT_RETURN_IF_ERROR(builder.CreateModel()); + ORT_RETURN_IF_ERROR(builder.SaveModel()); + + return builder.LoadModel(model); +} + void ModelBuilder::AddScalarOutput(const std::string& output_name) { scalar_outputs_.insert(output_name); } @@ -270,11 +833,6 @@ void ModelBuilder::AddInt64Output(const std::string& output_name) { int64_outputs_.insert(output_name); } -void ModelBuilder::AddLayer(std::unique_ptr layer) { - auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->mutable_layers()->AddAllocated(layer.release()); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { // decrement usage count if this is a known initializer. // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names @@ -289,7 +847,7 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) { skipped_inputs_.insert(input_name); } -std::string ModelBuilder::GetUniqueName(const std::string& base_name) { +std::string ModelBuilder::GetUniqueName(std::string_view base_name) { std::string unique_name; do { std::ostringstream os; @@ -300,5 +858,12 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) { return unique_name; } +std::string ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { + if (node.Name().empty()) { + return GetUniqueName(MakeString("Node_", node.Index(), "_", node.OpType(), suffix)); + } else { + return GetUniqueName(node.Name() + std::string(suffix)); + } +} } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index af2d5437be8d1..961ba647257b5 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -3,57 +3,171 @@ #pragma once +#include "core/common/span_utils.h" #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/coreml/model/model.h" + +#if defined(COREML_ENABLE_MLPROGRAM) +// coremltools classes +namespace MPL { +class ModelPackage; +} + +namespace MILBlob { +namespace Blob { +class StorageWriter; +} +} // namespace MILBlob +#endif namespace onnxruntime { namespace coreml { class IOpBuilder; class Model; -struct OnnxTensorInfo; class ModelBuilder { + private: + ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags); + public: - ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags); - ~ModelBuilder() = default; + // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` + static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::unique_ptr& model); - Status Compile(std::unique_ptr& model, const std::string& path); - Status SaveCoreMLModel(const std::string& path); + ~ModelBuilder(); - // Accessors for members const GraphViewer& GetGraphViewer() const { return graph_viewer_; } const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } - + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name) const { + return graph_viewer_.GetConstantInitializer(name, true); + } + + // Since CoreML 2 the spec version is +1 as CoreML 1.1 was spec version 2. + // We only support CoreML 3 and later so the spec version is always version + 1. + int32_t CoreMLVersion() const { return coreml_version_; } + int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; } + + // Returns true if we are creating an ML Program + bool CreateMLProgram() const { +#if defined(COREML_ENABLE_MLPROGRAM) + return create_ml_program_; +#else + return false; +#endif + } + + /* + * NeuralNetworkLayer helpers + */ + + // Create a NeuralNetwork layer using the node name and optional suffix for the name. + // If Node has no name a unique name will be generated from the node index and operator. + std::unique_ptr CreateNNLayer(const Node& node, std::string_view suffix = ""); + + // Add layer to the Core ML NeuralNetwork model void AddLayer(std::unique_ptr layer); - // The initializer will be processed separately, skip it as an initializer +#if defined(COREML_ENABLE_MLPROGRAM) + /* + * MLProgram helpers + */ + + // Create Operation, set type and the unique name attribute. + std::unique_ptr CreateOperation(const Node& node, std::string_view op_type, + std::string_view suffix = ""); + + // + // Helpers for adding attributes from ONNX nodes as inputs to an ML Program Operation + // + + /// + /// Add a value as a 'const' operation, generating a unique name for the value from op_type and value_type. + /// Use for values that were not initializers in the original ONNX model. e.g. attributes from ONNX nodes. + /// Add existing initializers using AddConstant with the TensorProto. + /// + /// e.g. adding the bias input of Gemm would have op_type='gemm' and value_type='bias'. + /// + /// Value type. + /// Typically MILSpec::Operation.type(). + /// Typically the input name of the operation that will consume the value. + /// Value to add. + /// Optional shape for the value. + /// If T is a primitive type `shape` is ignored and the value is treated as a scalar. + /// For a container type, if `shape` is not provided the shape is inferred to be 1-D of {value.size()}. + /// + /// Unique name generated for value. + template + std::string AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt) { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + // add specialization in AddConstantImpl for new types if needed + "AddConstant currently supports float, int64_t, std::string and bool."); + return AddConstantImpl(op_type, value_type, value, shape); + } + + template + std::string AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, + std::optional> shape = std::nullopt) { + return AddConstant(op_type, value_type, AsSpan(value), shape); + } + + /// + /// Add a scalar value as a 'const' operation. See AddConstant for details. + /// + template + std::string AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { + return AddConstant(op_type, value_type, AsSpan({value}), AsSpan({})); + } + + /// + /// Add an existing a constant ONNX initializer to the ML Program as a 'const' operation + /// + /// Initializer name + /// Initializer data + void AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer); + + // add the operation to the main function + void AddOperation(std::unique_ptr operation); +#endif + + /* + * General helpers + */ + + // The initializer is processed separately (e.g. layout is transformed) by the operator builder, + // so we don't do a copy of the original initializer into the model. void AddInitializerToSkip(const std::string& tensor_name); // There are some input which will not be used, add it to a list which will not // be added to CoreML model, since CoreML does not like input unused void AddInputToSkip(const std::string& input_name); - std::string GetUniqueName(const std::string& base_name); + std::string GetUniqueName(std::string_view base_name); + std::string GetUniqueName(const Node& node, std::string_view suffix); private: - const GraphViewer& graph_viewer_; - const logging::Logger& logger_; - uint32_t coreml_flags_; - - std::unique_ptr coreml_model_; - std::unordered_set scalar_outputs_; - std::unordered_set int64_outputs_; - std::unordered_map input_output_info_; - - std::unordered_map initializer_usage_; - std::unordered_set skipped_inputs_; - - uint32_t name_token_{0}; - std::unordered_set unique_names_; - - // Convert the onnx model to CoreML::Specification::Model - Status Initialize(); +#if defined(COREML_ENABLE_MLPROGRAM) + template + std::string AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt); + + void AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); + std::string AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + COREML_SPEC::MILSpec::Value&& input_value); +#endif + + // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk. + // We then load it using CoreML in order compile it. + Status CreateModel(); + Status SaveModel(); + Status LoadModel(std::unique_ptr& model); // If a CoreML operation will use initializers directly, we will add the initializers to the skip list void PreprocessInitializers(); @@ -61,7 +175,7 @@ class ModelBuilder { // Copy and process all the initializers to CoreML model Status RegisterInitializers(); - Status AddOperations(); + Status ProcessNodes(); Status RegisterModelInputs(); Status RegisterModelOutputs(); Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input); @@ -72,7 +186,32 @@ class ModelBuilder { // Record the onnx int64 type output names void AddInt64Output(const std::string& output_name); - static const IOpBuilder* GetOpBuilder(const Node& node); + const GraphViewer& graph_viewer_; + const logging::Logger& logger_; + const int32_t coreml_version_; + const uint32_t coreml_flags_; + const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) + const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel + + std::unique_ptr coreml_model_; + std::unordered_set scalar_outputs_; + std::unordered_set int64_outputs_; + std::unordered_map input_output_info_; + + std::unordered_map initializer_usage_; + std::unordered_set skipped_inputs_; + + uint32_t name_token_{0}; + std::unordered_set unique_names_; + +#if defined(COREML_ENABLE_MLPROGRAM) + // mlprogram_main_ is the main block of the CoreML ML Program. + // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] + // entry we create. + COREML_SPEC::MILSpec::Block* mlprogram_main_{nullptr}; + std::unique_ptr mlpackage_; + std::unique_ptr weights_file_writer_; +#endif }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder.h b/onnxruntime/core/providers/coreml/builders/op_builder.h index 79de6438c9700..0bb7f280c33e6 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder.h @@ -11,36 +11,39 @@ namespace coreml { class ModelBuilder; struct OpBuilderInputParams { - OpBuilderInputParams(const GraphViewer& graph_viewer, bool only_allow_static_input_shapes) + OpBuilderInputParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + bool only_allow_static_input_shapes, + bool create_mlprogram) : graph_viewer(graph_viewer), - only_allow_static_input_shapes(only_allow_static_input_shapes) {} + coreml_version(coreml_version), + only_allow_static_input_shapes(only_allow_static_input_shapes), + create_mlprogram(create_mlprogram) {} const GraphViewer& graph_viewer; + const int32_t coreml_version; // required to determine which version of an operation can be used. const bool only_allow_static_input_shapes; + const bool create_mlprogram; // whether to create ML Program (Core ML 5+) or NeuralNetwork (Core ML 3+) }; class IOpBuilder { public: virtual ~IOpBuilder() = default; - // Add operator related -#ifdef __APPLE__ - public: // Check if the initializers of this operator need preprocess // which will not be copied virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0; // Add the operator to CoreML model virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; -#endif - // Operator support related - public: // Check if an operator is supported virtual bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; + + // Does the builder implementation support creating an ML Program? + virtual bool SupportsMLProgram() const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index d72420bcfff88..6469b4cefa5ea 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -3,7 +3,7 @@ #pragma once -#include "op_builder.h" +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index c133f7b82aba4..8e718da07703c 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory.h" // defines flags #include +#include "core/common/logging/logging.h" #include "core/framework/compute_capability.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" @@ -12,12 +14,10 @@ #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_cxx_api.h" -#ifdef __APPLE__ #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" -#endif namespace onnxruntime { @@ -25,7 +25,24 @@ constexpr const char* COREML = "CoreML"; CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, - coreml_flags_(coreml_flags) { + coreml_flags_(coreml_flags), + coreml_version_(coreml::util::CoreMLVersion()) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { + LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION && + (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#else + if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#endif } CoreMLExecutionProvider::~CoreMLExecutionProvider() {} @@ -35,28 +52,34 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; - // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes - // TODO investigate whether we want to support subgraph using CoreML EP - if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { return result; } const auto& logger = *GetLogger(); + // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes + // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the + // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate. + if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + return result; + } + const bool has_neural_engine = coreml::HasNeuralEngine(logger); if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { - LOGS(logger, VERBOSE) << "The current system does not have Apple Neural Engine"; + LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used."; return result; } - const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_flags_); + const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); - const auto gen_metadef_name = [&]() { - HashValue model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(COREML, "_", model_hash, "_", metadef_id); - }; + const auto gen_metadef_name = + [&]() { + HashValue model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(COREML, "_", model_hash, "_", metadef_id); + }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, gen_metadef_name, COREML, kCoreMLExecutionProvider); @@ -86,17 +109,16 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return result; } -#ifdef __APPLE__ +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - coreml::ModelBuilder builder(graph_viewer, *GetLogger(), coreml_flags_); std::unique_ptr coreml_model; - const std::string coreml_model_file_path = coreml::util::GetTemporaryFilePath(); - ORT_RETURN_IF_ERROR(builder.Compile(coreml_model, coreml_model_file_path)); + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + coreml_model)); { const auto& input_defs = fused_node.InputDefs(); @@ -241,22 +263,6 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { - ORT_UNUSED_PARAMETER(fused_node_and_graph); - NodeComputeInfo compute_info; - compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; }; - compute_info.release_state_func = [](FunctionState /*state*/) {}; - compute_info.compute_func = [](FunctionState /* state */, const OrtApi* /* api */, - OrtKernelContext* /* context */) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build."); - }; - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} -#endif //__APPLE__ +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 0201739547dd1..24a001280eef5 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -3,9 +3,9 @@ #pragma once +#include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" #include "core/framework/model_metadef_id_generator.h" -#include "core/providers/coreml/coreml_provider_factory.h" namespace onnxruntime { namespace coreml { @@ -26,15 +26,14 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; #endif + private: // The bit flags which define bool options for COREML EP, bits are defined as // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h - const uint32_t coreml_flags_; - - private: -// > -#ifdef __APPLE__ - std::unordered_map> coreml_models_; -#endif + uint32_t coreml_flags_; + const int32_t coreml_version_; ModelMetadefIdGenerator metadef_id_generator_; + + // map of fused_node_name to compiled_coreml_model + InlinedHashMap> coreml_models_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index f7f45bce087bc..4f9a014c4d885 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -8,10 +8,50 @@ #include -#define API_AVAILABLE_OS_VERSIONS API_AVAILABLE(macos(10.15), ios(13)) +#if defined(__APPLE__) +// See https://apple.github.io/coremltools/mlmodel/Format/Model.html for the info on each CoreML specification version. +// See https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html for the list of ops +// in each CoreML specification version. -// Base requireed OS to run CoreML Specification Version 4 (Core ML 3) -#define HAS_VALID_BASE_OS_VERSION @available(macOS 10.15, iOS 13, *) +// Specification Versions : OS Availability(Core ML Version) +// +// 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) +// - initial version of CoreML EP +// 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) +// - additional layers in NeuralNetwork but currently none are implemented by the CoreML EP +// 6 : iOS 15, macOS 12, tvOS 15, watchOS 8 (Core ML 5) +// - adds MLProgram (MILSpec.Program) +// - iOS 15 ops +// 7 : iOS 16, macOS 13, tvOS 16, watchOS 9 (Core ML 6) +// - iOS 16 ops +// 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) +// - iOS 17 ops +// +// **NOTE** We use the Core ML version not the spec version. +// +// e.g. iOS 13 has Core ML 3 (which is Core ML Specification version 4), and the related macros are +// API_AVAILABLE_COREML3, HAS_COREML3_OR_LATER and onnxruntime::coreml::util::CoreMLVersion() will return 3. + +// https://developer.apple.com/documentation/swift/marking-api-availability-in-objective-c +// API_AVAILABLE is used to decorate Objective-C APIs +#define API_AVAILABLE_COREML3 API_AVAILABLE(macos(10.15), ios(13)) +#define API_AVAILABLE_COREML4 API_AVAILABLE(macos(11), ios(14)) +#define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) +#define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) +#define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) + +// @available is used in implementation code +// Base required OS to run CoreML Specification Version 4 (Core ML 3) +#define HAS_COREML3_OR_LATER @available(macOS 10.15, iOS 13, *) +#define HAS_COREML4_OR_LATER @available(macOS 11, iOS 14, *) +#define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) +#define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) +#define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) + +#endif + +#define MINIMUM_COREML_VERSION 3 // first version we support +#define MINIMUM_COREML_MLPROGRAM_VERSION 5 // first version where ML Program was available namespace onnxruntime { namespace coreml { @@ -21,6 +61,9 @@ namespace util { // This corresponds to [CoreML Specification Version 4 (Core ML 3)] bool HasRequiredBaseOS(); +// Return the CoreML version if 3 or higher. Otherwise returns -1. +int CoreMLVersion(); + // Get a temporary macOS/iOS temp file path std::string GetTemporaryFilePath(); diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 4c394386cd37a..0ae0cf8f0d207 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -10,19 +10,33 @@ namespace util { bool HasRequiredBaseOS() { - // This may look strange, but it is required "@available(macOS ....)" to safe-guard some code - // otherwise the compiler will spit -Wunsupported-availability-guard - if (HAS_VALID_BASE_OS_VERSION) - return true; - else - return false; + return CoreMLVersion() >= 3; +} + +int32_t CoreMLVersion() { + if (HAS_COREML7_OR_LATER) + return 7; + if (HAS_COREML6_OR_LATER) + return 6; + if (HAS_COREML5_OR_LATER) + return 5; + if (HAS_COREML4_OR_LATER) + return 4; + if (HAS_COREML3_OR_LATER) + return 3; + + return -1; } std::string GetTemporaryFilePath() { - // Get temporary directory. + // Get temporary directory for user. NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; // Generate a Unique file name to use. NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString]; + + // make it easy to see who generated it + temporary_filename = [@"onnxruntime-" stringByAppendingString:temporary_filename]; + // Create URL to that file. NSURL* temporary_file_url = [temporary_directory_url URLByAppendingPathComponent:temporary_filename]; diff --git a/onnxruntime/core/providers/coreml/model/host_utils_stub.cc b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc new file mode 100644 index 0000000000000..5c383b0274e8c --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/platform/env.h" +#include "core/providers/coreml/model/host_utils.h" + +namespace onnxruntime { +namespace coreml { +namespace util { + +bool HasRequiredBaseOS() { + return true; +} + +int CoreMLVersion() { + return 7; // CoreML 7 is the latest we support. +} + +std::string GetTemporaryFilePath() { + static std::atomic counter = 0; + + // we want to avoid creating endless directories/names whilst avoiding clashes if tests run in parallel so cycle + // through 20 potential output names. + auto dir_name = "coreml_ep_test_run." + std::to_string(counter++ % 20); + + // to replicate the iOS/macOS host_utils.mm behavior where the output is / + // we want to return the name of something that does not exist. this is required for ML Package creation. + auto& env = Env::Default(); + if (env.FolderExists(dir_name)) { + ORT_THROW_IF_ERROR(env.DeleteFolder(ToPathString(dir_name))); + } + + return dir_name; +} + +} // namespace util +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 105b6a0333b15..b940c4b768aec 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -33,19 +33,29 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; class Model { - friend class ModelBuilder; - public: + Model(const std::string& path, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, uint32_t coreml_flags); + ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); + Status LoadModel(); + Status Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); - bool IsScalarOutput(const std::string& output_name) const; + bool IsScalarOutput(const std::string& output_name) const { + return Contains(scalar_outputs_, output_name); + } - bool IsInt64Output(const std::string& output_name) const; + bool IsInt64Output(const std::string& output_name) const { + return Contains(int64_outputs_, output_name); + } // Mutex for exclusive lock to this model object OrtMutex& GetMutex() { return mutex_; } @@ -57,35 +67,27 @@ class Model { const std::vector& GetOnnxOutputs() const { return onnx_outputs_; } void SetOnnxOutputs(std::vector&& outputs) { onnx_outputs_ = std::move(outputs); } - const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const; - const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const; + const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const { + const auto info_it = input_output_info_.find(name); + return info_it != input_output_info_.end() ? &info_it->second : nullptr; + } + + const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const { + const auto* info = TryGetInputOutputInfo(name); + ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); + return *info; + } private: std::unique_ptr execution_; + std::unordered_map input_output_info_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; std::vector onnx_inputs_; std::vector onnx_outputs_; - std::unordered_map input_output_info_; - OrtMutex mutex_; - - Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - Status LoadModel(); - - void SetInputOutputInfo(std::unordered_map&& input_output_info) { - input_output_info_ = std::move(input_output_info); - } - - void SetScalarOutputs(std::unordered_set&& scalar_outputs) { - scalar_outputs_ = std::move(scalar_outputs); - } - - void SetInt64Outputs(std::unordered_set&& int64_outputs) { - int64_outputs_ = std::move(int64_outputs); - } }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 155201ad4c39c..d5cd70bff9479 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -252,14 +252,14 @@ - (instancetype)initWithPath:(const std::string&)path coreml_flags:(uint32_t)coreml_flags; - (void)cleanup; - (void)dealloc; -- (Status)loadModel API_AVAILABLE_OS_VERSIONS; +- (Status)loadModel API_AVAILABLE_COREML3; - (Status)predict:(const std::unordered_map&)inputs outputs:(const std::unordered_map&)outputs getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&) get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_OS_VERSIONS; + API_AVAILABLE_COREML3; -@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS; +@property(nullable) MLModel* model API_AVAILABLE_COREML3; @end @@ -308,6 +308,10 @@ - (Status)loadModel { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); } + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. NSError* error = nil; NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; @@ -454,7 +458,7 @@ Status Predict(const std::unordered_map& inputs, return Status::OK(); } - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { Status status{}; @autoreleasepool { status = [execution_ loadModel]; @@ -471,7 +475,7 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { @autoreleasepool { return [execution_ predict:inputs outputs:outputs @@ -482,8 +486,16 @@ Status Predict(const std::unordered_map& inputs, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+"); } -Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)) { +Model::Model(const std::string& path, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, + uint32_t coreml_flags) + : execution_(std::make_unique(path, logger, coreml_flags)), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { } Model::~Model() {} @@ -497,25 +509,5 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { return execution_->Predict(inputs, outputs, get_output_tensor_mutable_raw_data_fn); } - -bool Model::IsScalarOutput(const std::string& output_name) const { - return Contains(scalar_outputs_, output_name); -} - -bool Model::IsInt64Output(const std::string& output_name) const { - return Contains(int64_outputs_, output_name); -} - -const OnnxTensorInfo* Model::TryGetInputOutputInfo(const std::string& name) const { - const auto info_it = input_output_info_.find(name); - return info_it != input_output_info_.end() ? &info_it->second : nullptr; -} - -const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { - const auto* info = TryGetInputOutputInfo(name); - ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); - return *info; -} - } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc new file mode 100644 index 0000000000000..087c9f8c05d5f --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/model/model.h" + +namespace onnxruntime { +namespace coreml { + +class Execution {}; + +Model::Model(const std::string& /*path*/, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& /*logger*/, + uint32_t /*coreml_flags*/) + : execution_(std::make_unique()), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { +} + +Model::~Model() { +} + +Status Model::LoadModel() { + // return OK so we hit more CoreML EP code. + return Status::OK(); +} + +Status Model::Predict(const std::unordered_map& /*inputs*/, + const std::unordered_map& /*outputs*/, + const GetOutputTensorMutableRawDataFn& /*get_output_tensor_mutable_raw_data_fn*/) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Executing a CoreML model is not supported on this platform."); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index cbdf79caf3afd..813fdc54ecd0d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/cpu_execution_provider.h" +#include #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/mlas/inc/mlas.h" @@ -29,7 +30,7 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) std::vector CPUExecutionProvider::CreatePreferredAllocators() { bool create_arena = info_.create_arena; -#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) +#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) || defined(ABSL_HAVE_ADDRESS_SANITIZER) // JEMalloc/mimalloc already have memory pool, so just use device allocator. create_arena = false; #elif !(defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index 59b512def619d..e1dcaf500a325 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -700,7 +700,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, BiCubicParamsAntiAlias::type> p; p.cubic_coeff_a = cubic_coeff_a; SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, - alloc, get_original_coordinate, exclude_outside, false); + alloc, get_original_coordinate, exclude_outside, true); return UpsampleBaseAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, use_extrapolation, extrapolation_value, diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index e9941ce743bc3..41c999bacee13 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -141,8 +141,7 @@ class HalfGemmOptions { } #else cublasMath_t GetMathMode() const { - // CublasMathModeSetter will check whether device has tensor cores later. - return CUBLAS_TENSOR_OP_MATH; + return CUBLAS_DEFAULT_MATH; } cudaDataType GetComputeType() const { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 77e682e05a2a4..48a952e6dd98f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Log); @@ -1882,6 +1883,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index d0bb2321edf0a..55f0b5570e0ee 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -78,6 +78,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } bool IsNHWCPreferred() const { return info_.prefer_nhwc; } + bool UseTF32() const { return info_.use_tf32; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 7b507296d5982..c96381e3e68b1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -31,8 +31,10 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; -constexpr const char* kPreferNCHWMode = "prefer_nhwc"; -constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kPreferNHWCMode = "prefer_nhwc"; +constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kUseTF32 = "use_tf32"; + } // namespace provider_option_names } // namespace cuda @@ -112,8 +114,9 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) - .AddAssignmentToReference(cuda::provider_option_names::kPreferNCHWMode, info.prefer_nhwc) - .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) + .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -164,8 +167,9 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, - {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, - {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; @@ -185,8 +189,9 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, - {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, - {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index b286f5a9161b0..1cac3d1513698 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -76,6 +76,9 @@ struct CUDAExecutionProviderInfo { bool use_ep_level_unified_stream{false}; + // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. + bool use_tf32{true}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -83,12 +86,37 @@ struct CUDAExecutionProviderInfo { } // namespace onnxruntime template <> -struct std::hash<::onnxruntime::cuda::TunableOpInfo> { - size_t operator()(const ::onnxruntime::cuda::TunableOpInfo& info) const { - size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enable, seed_and_value); - onnxruntime::HashCombine(info.tuning_enable, seed_and_value); - onnxruntime::HashCombine(info.max_tuning_duration_ms, seed_and_value); - return seed_and_value; +struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { + size_t operator()(const ::onnxruntime::CUDAExecutionProviderInfo& info) const { + size_t value{0xbc9f1d34}; // seed + + // Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each) + size_t data = static_cast(info.device_id) ^ + (static_cast(info.arena_extend_strategy) << 16) ^ + (static_cast(info.cudnn_conv_algo_search) << 18) ^ + (static_cast(info.do_copy_in_default_stream) << 20) ^ + (static_cast(info.has_user_compute_stream) << 21) ^ + (static_cast(info.cudnn_conv_use_max_workspace) << 22) ^ + (static_cast(info.enable_cuda_graph) << 23) ^ + (static_cast(info.tunable_op.enable) << 24) ^ + (static_cast(info.tunable_op.tuning_enable) << 25) ^ + (static_cast(info.cudnn_conv1d_pad_to_nc1d) << 26) ^ + (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ + (static_cast(info.prefer_nhwc) << 28) ^ + (static_cast(info.use_ep_level_unified_stream) << 29) ^ + (static_cast(info.use_tf32) << 30); + onnxruntime::HashCombine(data, value); + + onnxruntime::HashCombine(info.gpu_mem_limit, value); + onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + + // Memory pointers + onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + + // The default memory arena cfg is not used in hashing right now. + return value; } }; diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index e3106e41e77c8..288da23f35ec8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -90,6 +90,10 @@ class CudaKernel : public OpKernel { return stream->cublas_handle_; } + bool UseTF32() const { + return provider_->UseTF32(); + } + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 892e8d5329eba..103c79c93b2ca 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -225,6 +225,7 @@ struct CUDA_Provider : Provider { info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; + info.use_tf32 = params->use_tf32 != 0; return std::make_shared(info); } @@ -258,6 +259,7 @@ struct CUDA_Provider : Provider { cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; + cuda_options.use_tf32 = internal_options.use_tf32; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 0a256394b7d99..3c0bf183362dd 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -212,6 +212,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::prefer_nhwc_t: return reinterpret_cast(ep_info_.prefer_nhwc); break; + case CudaResource::use_tf32_t: + return reinterpret_cast(ep_info_.use_tf32); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index 3e50116eafd17..ee0334e552022 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -51,25 +51,27 @@ Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, CudaT one = cuda::ToCudaType::FromFloat(1.0f); CudaT zero = cuda::ToCudaType::FromFloat(0.0f); - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(static_cast(einsum_cuda_assets)->cublas_handle_, - CUBLAS_OP_N, - CUBLAS_OP_N, - static_cast(N), - static_cast(M), - static_cast(K), - &one, - reinterpret_cast(input_2_data), - static_cast(N), - static_cast(right_stride), - reinterpret_cast(input_1_data), - static_cast(K), - static_cast(left_stride), - &zero, - reinterpret_cast(output_data), - static_cast(N), - static_cast(output_stride), - static_cast(num_batches), - static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp())); + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + static_cast(einsum_cuda_assets)->cublas_handle_, + CUBLAS_OP_N, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &one, + reinterpret_cast(input_2_data), + static_cast(N), + static_cast(right_stride), + reinterpret_cast(input_1_data), + static_cast(K), + static_cast(left_stride), + &zero, + reinterpret_cast(output_data), + static_cast(N), + static_cast(output_stride), + static_cast(num_batches), + static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp(), + static_cast(einsum_cuda_assets)->cuda_ep_->UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 8fe23c9a036cc..4e61e0c8c69c6 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -118,7 +118,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const b_data, N, GetConstOnes(M, Stream(ctx)), 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) { // B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( @@ -130,7 +130,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const GetConstOnes(N, Stream(ctx)), N, b_data, 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else { // B is (M, N), no broadcast needed. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, static_cast(M) * N * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); @@ -153,7 +153,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const // ideally we need to set the output buffer contents to 0 if bias is missing, // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer B != nullptr ? &beta : &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index e4c37c52a1780..6e126fbeadce8 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -173,7 +173,8 @@ Status FuncMatMul( &cuda_zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(A->Shape(), B->Shape(), trans_A, trans_B, trans_batch_B, trans_batch_B, stride_A, stride_B, stride_C, batch_count)) { @@ -195,7 +196,8 @@ Status FuncMatMul( ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } @@ -213,12 +215,12 @@ Status FuncMatMul( ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, cuda_kernel->GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + cuda_kernel->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -238,7 +240,8 @@ Status FuncMatMul( Y_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } @@ -321,7 +324,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help &zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(left_X->Shape(), right_X->Shape(), transa, transb, trans_batch_a_, trans_batch_b_, stride_A, stride_B, stride_C, batch_count)) { @@ -343,7 +347,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } @@ -361,12 +366,12 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + this->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -386,7 +391,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help output_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 655877f425054..fd8b69d7bd2f5 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -160,7 +160,7 @@ UNARY_OP_CSILHFD(Neg, 13) UNARY_OP_HFD(Floor, 13) UNARY_OP_HFD(Ceil, 13) UNARY_OP_HFD(Reciprocal, 13) -UNARY_OP_HFD(Sqrt, 13) +UNARY_OP_HFDX(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 5c3db4a499972..73c5ac80756be 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -83,7 +83,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(Neg) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 82f3503919237..a417be5a86c32 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -326,7 +326,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -351,8 +352,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; @@ -399,6 +405,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); if (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; } else { perf.mathType = CUDNN_DEFAULT_MATH; } @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set( const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type) { + cudnnDataType_t data_type, + bool use_tf32) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set( CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index bcaa4d855b81e..181fbc99fd8e9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final { const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type); + cudnnDataType_t data_type, + bool use_tf32); operator cudnnConvolutionDescriptor_t() const { return desc_; } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 55dceaa2698e8..939b9959af818 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -167,7 +167,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (has_bias) { const auto& b_shape = p.B->Shape(); @@ -187,8 +188,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 510cc5cfbb7dd..053c66ddcb34a 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -29,13 +29,15 @@ cublasGemmHelper(cublasHandle_t handle, const float* B, int ldb, const float* beta, float* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - // TF32 uses 10 bit mantissa which has sufficient margin of precision for most use cases. It gets 8x throughput than FP32 in A100. - // It can be overrided by setting environment variable NVIDIA_TF32_OVERRIDE = 0 to disable TF32 - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); + // To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 or set provider option use_tf32 = 0 + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemm(handle, @@ -58,7 +60,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const double* B, int ldb, const double* beta, double* C, int ldc, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemm(handle, transa, transb, @@ -79,7 +82,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const half* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -121,7 +125,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -155,10 +160,11 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, - const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, + const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -169,7 +175,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*, - BFloat16*, int, const cudaDeviceProp&) { + BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -185,7 +191,17 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const float* beta, float* Carray[], int ldc, int batch_count, - const cudaDeviceProp&) { + const cudaDeviceProp& prop, + bool use_tf32) { +// The caller shall check memory alignments of the matrices when use_tf32 is true. +#if defined(USE_CUDA) + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); +#else + ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); +#endif + return cublasSgemmBatched(handle, transa, transb, @@ -208,7 +224,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const double* beta, double* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmBatched(handle, transa, transb, @@ -231,7 +248,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const half* beta, half* Carray[], int ldc, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -266,11 +284,12 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], - int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, - BFloat16* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], + int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, + BFloat16* Carray[], int ldc, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -282,7 +301,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOpera #else inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*[], int, const BFloat16*[], int, - const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&) { + const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&, + bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -301,15 +321,14 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, float* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { -#ifdef ENABLE_TRAINING_OPS + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); -#else - ORT_UNUSED_PARAMETER(prop); -#endif + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemmStridedBatched(handle, @@ -337,7 +356,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, double* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmStridedBatched(handle, transa, transb, @@ -363,7 +383,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -411,7 +432,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -447,49 +469,66 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const BFloat16* alpha, const BFloat16* A, int lda, - long long int strideA, const BFloat16* B, int ldb, - long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, - long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const BFloat16* alpha, const BFloat16* A, int lda, + long long int strideA, const BFloat16* B, int ldb, + long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, + long long int strideC, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); // accumulating in FP32 - return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, - ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); + return cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, + ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT); } #else -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, - int, const BFloat16*, const BFloat16*, int, long long int, - const BFloat16*, int, long long int, const BFloat16*, BFloat16*, - int, long long int, int, const cudaDeviceProp&) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + int, const BFloat16*, const BFloat16*, int, long long int, + const BFloat16*, int, long long int, const BFloat16*, BFloat16*, + int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif // transpose using geam -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, + float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, + double* C, int ldc) { return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } bool CanUse_cublasTransposeHelper_MLFloat16(int m, int n); -cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); + +cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, + int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); // copy -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { return cublasScopy(handle, n, x, incx, y, incy); } -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { return cublasDcopy(handle, n, x, incx, y, incy); } -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); + +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c6a15e76f4736..2456b396de3f6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. uint32_t c_maxConstNodeDataSize = 8; - ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; - if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { - // The tensor description's size should be no larger than the constant input unless it was rounded to + constantInput = constantCpuGraphInputGetter(arg->Name()); + } + + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to // the required alignment. assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index dbd06abf82f72..d524780de71b8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter m_requiredConstantCpuInputs.begin(), m_requiredConstantCpuInputs.end(), inputIndex) != m_requiredConstantCpuInputs.end(); - + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } @@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl) { // The tensor may be stored as raw data or in typed fields. - if (impl->has_raw_data()) + if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor)); + m_dataPtr = reinterpret_cast(m_unpackedExternalTensor.data()); + m_tensorByteSize = m_unpackedExternalTensor.size(); + } + else if (impl->has_raw_data()) { m_dataPtr = reinterpret_cast(impl->mutable_raw_data()->data()); m_tensorByteSize = impl->raw_data().size(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 6530d89d895e7..59e253e88457a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; + std::vector m_unpackedExternalTensor; std::byte* m_dataPtr = nullptr; // Lifetime is managed by the caller and guaranteed to outlive this class diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 30c339b845b36..44004b5d77f70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + // When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise, + // it has the shape [batchSize, sequenceLength, hiddenSize] + const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4; + // When positionIds is a scalar, it represents the start offset for each sequence const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); - const uint32_t batchSize = inputDataSizes[1]; + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputDataSizes[3] / headSize; + const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; @@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + + if (inputIs4D) + { + const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; + stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + } + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; @@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Swap the 2 halves and join them together DML_JOIN_OPERATOR_DESC joinInputDesc{}; joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; joinInputDesc.Axis = splitInputDesc.Axis; joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; @@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; - mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.ATensor = &joinedDataDmlTensorDesc; mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; - mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; // Multiply the non-rotated data with the cos and the rotated data with the sin DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; - mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc; mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; - mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; addDesc.ATensor = &inputOutputDmlTensorDesc; addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 3209ad734fa20..0b32508a5bb38 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -184,9 +184,8 @@ bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) { return true; } -common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, - float& scale, int32_t& zero_point) { +common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, + const Path& model_path, float& scale, int32_t& zero_point) { scale = 0.0f; zero_point = 0; @@ -198,14 +197,24 @@ common::Status GetQuantizationScaleAndZeroPoint( const auto& quant_param = *io_def.quant_param; { // get the scale const auto& name = quant_param.scale.Name(); - Initializer unpacked_tensor(*initializers.at(name), model_path); + const auto* s = graph_viewer.GetConstantInitializer(name); + if (!s) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); + }; + + Initializer unpacked_tensor(*s, model_path); // The scale should be one or more floats scale = unpacked_tensor.DataAsSpan()[0]; } if (quant_param.zero_point) { // get the zero point if it's there const auto& name = quant_param.zero_point->Name(); - Initializer unpacked_tensor(*initializers.at(name), model_path); + const auto* zp = graph_viewer.GetConstantInitializer(name); + if (!zp) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); + }; + + Initializer unpacked_tensor(*zp, model_path); // Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI zero_point = static_cast(unpacked_tensor.DataAsByteSpan()[0]); } @@ -213,13 +222,13 @@ common::Status GetQuantizationScaleAndZeroPoint( return Status::OK(); } -common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::string& name, - float& scale, int32_t& zero_point, ArgType arg_type) { +common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::string& name, float& scale, int32_t& zero_point, + ArgType arg_type) { const auto& io_defs = arg_type == ArgType::kInput ? node_unit.Inputs() : node_unit.Outputs(); for (const auto& io_def : io_defs) { if (io_def.node_arg.Name() == name) - return GetQuantizationScaleAndZeroPoint(initializers, io_def, node_unit.ModelPath(), + return GetQuantizationScaleAndZeroPoint(graph_viewer, io_def, node_unit.ModelPath(), scale, zero_point); } @@ -348,7 +357,7 @@ bool IsNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph_viewer, } const auto* op_builder = op_builder_it->second; - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node_unit, params); + return op_builder->IsOpSupported(graph_viewer, node_unit, params); } bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer, @@ -381,11 +390,11 @@ uint32_t ShapeSize(const Shape& shape, size_t begin_idx, size_t end_idx) { SafeInt{1}, std::multiplies>{}); } -bool CheckIsInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const std::string& input_name, const char* input_description) { - if (!Contains(initializers, input_name)) { +bool CheckIsConstantInitializer(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::string& input_name, const char* input_description) { + if (!graph_viewer.GetConstantInitializer(input_name)) { LOGS_DEFAULT(VERBOSE) << input_description << " of " << node_unit.Name() << "of type [" - << node_unit.OpType() << "] must be an initializer tensor"; + << node_unit.OpType() << "] must be a constant initializer"; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index 766034b3decea..a606b8aceb63d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -132,11 +132,11 @@ bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type); bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit); common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const Path& model_path, float& scale, int32_t& zero_point); common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::string& name, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const std::string& name, float& scale, int32_t& zero_point, ArgType arg_type = ArgType::kInput); // Get Shape/Type of a NodeArg @@ -167,11 +167,11 @@ inline uint32_t ShapeSize(const Shape& shape) { return ShapeSize(shape, 0, shape.size()); } -// Check the given input is an initializer tensor +// Check the given input is a constant initializer // input_name is the name of the initializer // input_description is the string describing the input in the output message (if any) -bool CheckIsInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const std::string& input_name, const char* input_description); +bool CheckIsConstantInitializer(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::string& input_name, const char* input_description); // Convert ONNX int64 input to NNAPI int32 type input and optionally handle negative axis if needed // Mostly used in handling `axes` input for now diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc index 00bca4001326c..91cad034d8854 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc @@ -29,7 +29,7 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -91,7 +91,7 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No // Operator support related -bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc index 7797e0a47caaf..adc79576272ab 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/graph/graph_viewer.h" #include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h" namespace onnxruntime { @@ -11,10 +12,11 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node const auto is_ext_initializer = [&](const NodeArg& node_arg) { const auto& input_name(node_arg.Name()); - if (!Contains(initializers, input_name)) + const auto initializer = initializers.find(input_name); + if (initializer == initializers.end()) return false; - const auto& tensor = *initializers.at(input_name); + const auto& tensor = *initializer->second; if (tensor.has_data_location() && tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { LOGS_DEFAULT(VERBOSE) << "Initializer [" << input_name @@ -51,8 +53,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeU model_builder.GetEffectiveFeatureLevel(), model_builder.UseNCHW(), }; - ORT_RETURN_IF_NOT(IsOpSupported(model_builder.GetInitializerTensors(), node_unit, params), - "Unsupported operator ", node_unit.OpType()); + + // We checked supported in IExecutionProvider::GetCapability. + // Checking again in AddToModelBuilder which is called in IExecutionProvider::Compile is redundant. + // ORT_RETURN_IF_NOT(IsOpSupported(model_builder.GetGraphViewer(), node_unit, params), + // "Unsupported operator ", node_unit.OpType()); + #ifndef NDEBUG model_builder.SetDebugCurrentOnnxNodeIndex(node_unit.Index()); #endif @@ -64,7 +70,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeU // Operator support related -bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool BaseOpBuilder::IsOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { int32_t required_feature_level = GetMinSupportedNNAPIFeatureLevel(node_unit, params); if (required_feature_level > params.android_feature_level) { @@ -77,20 +83,20 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!IsNodeUnitTypeSupported(node_unit)) return false; - if (!HasSupportedInputOutputs(initializers, node_unit, params)) + if (!HasSupportedInputOutputs(graph_viewer, node_unit, params)) return false; // We do not support external initializers for now - if (HasExternalInitializer(initializers, node_unit)) + if (HasExternalInitializer(graph_viewer.GetAllInitializedTensors(), node_unit)) return false; if (!HasSupportedOpSet(node_unit)) return false; - return IsOpSupportedImpl(initializers, node_unit, params); + return IsOpSupportedImpl(graph_viewer, node_unit, params); } -bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool BaseOpBuilder::HasSupportedInputOutputs(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { // We do not support unknown(null) input shape auto has_supported_shape = [](const NodeArg& node_arg, const std::string& name, const std::string& op_type) { @@ -128,12 +134,12 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial return false; } } - return HasSupportedInputOutputsImpl(initializers, node_unit, params); + + return HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } -bool BaseOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const { +bool BaseOpBuilder::HasSupportedInputOutputsImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, + const OpSupportCheckParams& /* params */) const { // We only check the type of input 0 by default // specific op builder can override this const auto& input = node_unit.Inputs()[0].node_arg; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h index 339ccd67f33e3..6a54bf7bdb938 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h @@ -52,11 +52,11 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related public: - bool IsOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; protected: - virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& /* node_unit */, + virtual bool IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& /* node_unit */, const OpSupportCheckParams& /* params */) const { return true; } @@ -68,9 +68,8 @@ class BaseOpBuilder : public IOpBuilder { return ANEURALNETWORKS_FEATURE_LEVEL_1; } - virtual bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const; + virtual bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const; virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 19; } @@ -82,7 +81,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const NodeUnit& node_unit) const; - bool HasSupportedInputOutputs(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool HasSupportedInputOutputs(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const; }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc index 3add0ac26c0d4..75a66d3a14643 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc @@ -33,7 +33,7 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // BatchNormalization opset 6- has unsupported attributes @@ -127,7 +127,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu // Operator support related -bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { if (node_unit.Outputs().size() != 1) { LOGS_DEFAULT(VERBOSE) << "Your onnx model may be in training mode, please export " @@ -158,20 +158,20 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& const auto& b_name = inputs[2].node_arg.Name(); const auto& mean_name = inputs[3].node_arg.Name(); const auto& var_name = inputs[4].node_arg.Name(); - if (!Contains(initializers, scale_name)) { - LOGS_DEFAULT(VERBOSE) << "Scale of BN must be known"; + if (!graph_viewer.GetConstantInitializer(scale_name)) { + LOGS_DEFAULT(VERBOSE) << "Scale of BN must be a constant initializer"; return false; } - if (!Contains(initializers, b_name)) { - LOGS_DEFAULT(VERBOSE) << "B of BN must be known"; + if (!graph_viewer.GetConstantInitializer(b_name)) { + LOGS_DEFAULT(VERBOSE) << "B of BN must be a constant initializer"; return false; } - if (!Contains(initializers, mean_name)) { - LOGS_DEFAULT(VERBOSE) << "Mean of BN must be known"; + if (!graph_viewer.GetConstantInitializer(mean_name)) { + LOGS_DEFAULT(VERBOSE) << "Mean of BN must be a constant initializer"; return false; } - if (!Contains(initializers, var_name)) { - LOGS_DEFAULT(VERBOSE) << "Var of BN must be known"; + if (!graph_viewer.GetConstantInitializer(var_name)) { + LOGS_DEFAULT(VERBOSE) << "Var of BN must be a constant initializer"; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc index dce1a7c8659bf..5599fbdc69bdd 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc @@ -34,10 +34,10 @@ class BinaryOpBuilder : public BaseOpBuilder { private: int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; @@ -95,7 +95,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const if (is_quant_op) { ORT_RETURN_IF_ERROR(GetBinaryOpQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit, + model_builder.GetGraphViewer(), node_unit, a_scale, b_scale, y_scale, a_zero_point, b_zero_point, y_zero_point)); } @@ -163,22 +163,22 @@ int BinaryOpBuilder::GetMinSupportedOpSet(const NodeUnit& node_unit) const { } bool BinaryOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { bool is_quantized_op = IsQuantizedOp(node_unit); bool is_pow = node_unit.OpType() == "Pow"; if (!is_quantized_op && !is_pow) - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); if (is_quantized_op) { // QLinearAdd/QDQAdd/QLinearMul/QDQMul if (!HasValidBinaryOpQuantizedInputTypes(node_unit)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0, 1}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } @@ -203,7 +203,7 @@ bool BinaryOpBuilder::HasSupportedInputOutputsImpl( return true; } -bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool BinaryOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& op_type(node_unit.OpType()); const auto& inputs = node_unit.Inputs(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc index b31ee484dc5a2..9059de817e210 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc @@ -29,7 +29,7 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -70,7 +70,7 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N return Status::OK(); } -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool CastOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { NodeAttrHelper helper(node_unit); const auto to = helper.Get("to", 0); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc index b3e294d2f0845..9821d9267c71f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc @@ -32,7 +32,7 @@ class ClipOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -64,7 +64,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } float min, max; - GetClipMinMax(model_builder.GetInitializerTensors(), node_unit.GetNode(), min, max, + GetClipMinMax(model_builder.GetGraphViewer(), node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger()); int32_t op_code; @@ -85,10 +85,10 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // Operator support related -bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ClipOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { float min, max; - if (!GetClipMinMax(initializers, node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger())) + if (!GetClipMinMax(graph_viewer, node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger())) return false; // We only supoort relu6 or relu1 diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc index 2bf8f07e26fd4..a8394faec51be 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc @@ -32,11 +32,11 @@ class ConcatOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } @@ -113,7 +113,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const float scale = 0.0f; int32_t zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Inputs()[i], node_unit.ModelPath(), + model_builder.GetGraphViewer(), node_unit.Inputs()[i], node_unit.ModelPath(), scale, zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point)); @@ -128,7 +128,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int32_t y_zero_point = operand_types.at(input0).operandType.zeroPoint; if (is_quant_op) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(), + model_builder.GetGraphViewer(), node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); } @@ -151,7 +151,7 @@ bool ConcatOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQConcat; } -bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool ConcatOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -168,7 +168,7 @@ bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ } bool ConcatOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { const auto& op_type = node_unit.OpType(); const auto& op_name = node_unit.Name(); @@ -188,11 +188,11 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( if (IsQuantizedOp(node_unit)) { std::vector input_indices(input_size); std::iota(input_indices.begin(), input_indices.end(), 0); - if (!IsQuantizedIOSupported(initializers, node_unit, input_indices, params, ArgType::kInput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, input_indices, params, ArgType::kInput)) { return false; } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) { return false; } @@ -203,7 +203,7 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( size_t input_idx = 0; auto status = GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[input_idx], node_unit.ModelPath(), + graph_viewer, node_unit.Inputs()[input_idx], node_unit.ModelPath(), input_scales[input_idx], input_zps[input_idx]); if (!status.IsOK()) { @@ -214,7 +214,7 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( } for (++input_idx; input_idx < input_size; ++input_idx) { - if (!HasRequiredScaleAndZeroPoint(initializers, + if (!HasRequiredScaleAndZeroPoint(graph_viewer, MakeString("Op [", op_type, "] name [", op_name, "] input ", input_idx), node_unit.Inputs()[input_idx], node_unit.ModelPath(), @@ -225,7 +225,7 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( } // NNAPI (28-) requires the output scale and zp be the same as the input 0 - if (!HasRequiredScaleAndZeroPoint(initializers, + if (!HasRequiredScaleAndZeroPoint(graph_viewer, MakeString("Op [", op_type, "] name [", op_name, "]'s output 0"), node_unit.Outputs()[0], node_unit.ModelPath(), input_scales[0] /* required_scale */, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc index 5b8bbd338a13d..5477cd16f9c01 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc @@ -33,7 +33,7 @@ class ConvOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -41,9 +41,8 @@ class ConvOpBuilder : public BaseOpBuilder { return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -279,19 +278,19 @@ bool ConvOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { } bool ConvOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); // QLinearConv only supports input of uint8 for now if (!HasValidBinaryOpQuantizedInputTypes(node_unit)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0, 1}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; return true; @@ -299,7 +298,7 @@ bool ConvOpBuilder::HasSupportedInputOutputsImpl( // Operator support related -bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ConvOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { const auto& op_type = node_unit.OpType(); bool is_quant_conv = IsQuantizedOp(node_unit); @@ -314,8 +313,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, NodeAttrHelper helper(node_unit); const auto group = helper.Get("group", 1); const auto weight_name = inputs[1].node_arg.Name(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); + const auto* weight = graph_viewer.GetConstantInitializer(weight_name); + if (weight) { + const auto& tensor = *weight; if (tensor.dims().size() != 4) { LOGS_DEFAULT(VERBOSE) << "Only conv 2d is supported."; return false; @@ -335,13 +335,13 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } } else { - LOGS_DEFAULT(VERBOSE) << "The weight of convolution must be known"; + LOGS_DEFAULT(VERBOSE) << "The weight of convolution must be a constant initializer"; return false; } if (is_quant_conv) { - if (inputs.size() > 2 && !Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Bias of QLinearConv must be known"; + if (inputs.size() > 2 && !graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "Bias of QLinearConv must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc index 649f1e1cff2b7..ef8709641e2d0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc @@ -29,7 +29,7 @@ class DepthToSpaceOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -66,7 +66,7 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related -bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { NodeAttrHelper helper(node_unit); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc index b2d89ffecdca4..7d0e04fbd7b0e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc @@ -38,9 +38,9 @@ class DequantizeLinearOpBuilder : public BaseOpBuilder { } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override { - return IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput); + return IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput); } }; @@ -61,7 +61,7 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil float scale = 0.0; int32_t zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Inputs()[0], node_unit.ModelPath(), scale, zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), scale, zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point)); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc index 065b9638bdf64..b5e9c011990ce 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc @@ -44,7 +44,7 @@ class FlattenOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -70,7 +70,7 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons // Operator support related -bool FlattenOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool FlattenOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc index ac8970f19df06..d6da9181b5a3d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc @@ -36,7 +36,7 @@ class GatherOpBuilder : public BaseOpBuilder { return ANEURALNETWORKS_FEATURE_LEVEL_3; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -133,7 +133,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); Shape input_shape; @@ -166,8 +166,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; if (indices_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) { - if (!Contains(initializers, indices_name)) { - LOGS_DEFAULT(VERBOSE) << "Indices of Gather must be known."; + if (!graph_viewer.GetConstantInitializer(indices_name)) { + LOGS_DEFAULT(VERBOSE) << "Indices of Gather must be a constant initializer."; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc index 9b3003d472b02..8488f7cc74a6e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc @@ -69,11 +69,10 @@ class GemmOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } @@ -261,21 +260,20 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // Operator support related -bool GemmOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const { +bool GemmOpBuilder::HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) { - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } // QLinearMatMul/QDQGemm/QDQMatMul if (!HasValidBinaryOpQuantizedInputTypes(node_unit)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0, 1}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; return true; @@ -295,7 +293,7 @@ bool GemmOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedGemm(GetQuantizedOpType(node_unit)); } -bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { // check batch matmul first, then fall back to checking single gemm/matmul { @@ -355,8 +353,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } - if (transB == 0 && !Contains(initializers, inputs[1].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "B of Gemm must be known if transB != 1"; + if (transB == 0 && !graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "B of Gemm must be a constant initializer if transB != 1"; return false; } @@ -380,8 +378,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } else if (op_type == "MatMul" || is_qlinear_matmul) { // Only support A*B B is an initializer - if (!Contains(initializers, inputs[1].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known"; + if (!graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "B of MatMul must be a constant initializer"; return false; } } else { @@ -389,8 +387,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } if (is_quant_gemm) { - if (inputs.size() > 2 && !Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Bias of QDQ Gemm must be known"; + if (inputs.size() > 2 && !graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "Bias of QDQ Gemm must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc index 3db63a756ab1a..6a633c443c9e5 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc @@ -27,7 +27,7 @@ class LeakyReluOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // LeakyRelu opset 6- has unsupported attributes @@ -111,7 +111,7 @@ Status LeakyReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related -bool LeakyReluOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /*initializers*/, const NodeUnit& node_unit, +bool LeakyReluOpBuilder::IsOpSupportedImpl(const GraphViewer& /*graph_viewer*/, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc index 522f389ae62a0..aeadbd17053cf 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc @@ -37,7 +37,7 @@ class MinMaxOpBuilder : public BaseOpBuilder { // Min/Max opset 5- uses consumed_inputs attribute which is not supported for now int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 6; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -53,7 +53,7 @@ Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool MinMaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool MinMaxOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { // TODO: support 2+ inputs for Min/Max op if (node_unit.Inputs().size() != 2) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc index 11d37f9036b11..b0404ebec0583 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc @@ -45,7 +45,7 @@ class PadOpBuilder : public BaseOpBuilder { return 11; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -115,7 +115,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}); } -bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool PadOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); @@ -152,14 +152,13 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c // only support if `pads` input is known and does not contain negative values { - const auto pads_initializer_it = initializers.find(inputs[1].node_arg.Name()); - if (pads_initializer_it == initializers.end()) { - LOGS_DEFAULT(VERBOSE) << "pads must be known"; + const auto* pads_initializer = graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name()); + if (!pads_initializer) { + LOGS_DEFAULT(VERBOSE) << "pads must be a constant initializer"; return false; } - const ONNX_NAMESPACE::TensorProto& pads_initializer = *pads_initializer_it->second; - Initializer unpacked_tensor(pads_initializer); + Initializer unpacked_tensor(*pads_initializer); auto tensor_data = unpacked_tensor.DataAsSpan(); for (size_t i = 0; i < unpacked_tensor.size(); i++) { if (tensor_data[i] < 0) { @@ -173,8 +172,8 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c // only support if `constant_value` input is known // Note: Could add support for non-constant initializer later. Then we need to ensure it is a scalar (with shape []). if (inputs.size() > 2) { - if (!Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "constant_value must be known"; + if (!graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "constant_value must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc index c14568aaccfa3..a2a4786b72ec7 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc @@ -32,7 +32,7 @@ class PoolOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -40,10 +40,9 @@ class PoolOpBuilder : public BaseOpBuilder { return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const override; - bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; + bool IsNodeUnitTypeSupported(const NodeUnit& node_unit) const override; bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -116,16 +115,16 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N float y_scale = input_operand_type.operandType.scale; int32_t y_zero_point = input_operand_type.operandType.zeroPoint; if (is_quant_pool) { - const auto& initializers = model_builder.GetInitializerTensors(); + const auto& graph_viewer = model_builder.GetGraphViewer(); float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); // Verify if the scale and zero point values from onnx input and nnapi input match ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); + graph_viewer, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); } InlinedVector input_indices; @@ -171,7 +170,7 @@ bool PoolOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedPool(GetQuantizedOpType(node_unit)); } -bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& op_name = node_unit.Name(); const auto& op_type = node_unit.OpType(); @@ -236,7 +235,7 @@ bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, float input_scale = 0.0f; int32_t input_zp = 0; auto status = GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), input_scale, input_zp); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), input_scale, input_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name << "] GetQuantizationScaleAndZeroPoint for input_scale/zp failed, message: " @@ -247,7 +246,7 @@ bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, float output_scale = 0.0f; int32_t output_zp = 0; status = GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Outputs()[0], node_unit.ModelPath(), output_scale, output_zp); + graph_viewer, node_unit.Outputs()[0], node_unit.ModelPath(), output_scale, output_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name << "] GetQuantizationScaleAndZeroPoint for output_scale/zp failed, message: " @@ -274,7 +273,7 @@ bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } bool PoolOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { const auto& op_type = node_unit.OpType(); bool is_quant_pool = IsQuantizedOp(node_unit); @@ -282,13 +281,13 @@ bool PoolOpBuilder::HasSupportedInputOutputsImpl( bool is_average_pool = op_type == "AveragePool" || op_type == "QLinearAveragePool"; bool is_quant_average_pool = is_quant_pool && is_average_pool; if (!is_max_pool && !is_quant_average_pool) - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); if (is_quant_average_pool) { - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc index 49ff01d27219a..d13b81c2a14b8 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc @@ -38,9 +38,9 @@ class QuantizeLinearOpBuilder : public BaseOpBuilder { } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override { - return IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput); + return IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput); } }; @@ -60,7 +60,7 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde float scale = 0.0f; int32_t zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(), scale, zero_point)); + model_builder.GetGraphViewer(), node_unit.Outputs()[0], node_unit.ModelPath(), scale, zero_point)); Type output_type = Type::TENSOR_QUANT8_ASYMM; const OperandType output_operand_type(output_type, shaper[output], scale, zero_point); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc index 8d0347673ba56..a6da290753b74 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc @@ -35,7 +35,7 @@ class ReductionOpBuilder : public BaseOpBuilder { private: int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -169,7 +169,7 @@ int32_t ReductionOpBuilder::GetMinSupportedNNAPIFeatureLevel( return ANEURALNETWORKS_FEATURE_LEVEL_3; } -bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); const auto& op(node_unit.OpType()); @@ -190,7 +190,7 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; if (inputs.size() > 1 && inputs[1].node_arg.Exists()) { const auto& axes_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, axes_name)) { + if (!graph_viewer.GetConstantInitializer(axes_name)) { LOGS_DEFAULT(VERBOSE) << "Axes of ReduceMean must be a constant initializer."; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc index 869883b98b22e..f2f9165d2f3cc 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc @@ -35,14 +35,13 @@ class ReshapeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // Reshape opset 4- uses attributes for new shape which we do not support for now int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 5; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -59,10 +58,10 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto& shaper(model_builder.GetShaper()); - const auto& initializers(model_builder.GetInitializerTensors()); + const auto& graph_viewer(model_builder.GetGraphViewer()); auto input = node_unit.Inputs()[0].node_arg.Name(); - const auto& shape_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name()); + const auto& shape_tensor = *graph_viewer.GetConstantInitializer(node_unit.Inputs()[1].node_arg.Name()); Initializer unpacked_tensor(shape_tensor); auto raw_shape = unpacked_tensor.DataAsSpan(); const auto size = SafeInt(shape_tensor.dims()[0]); @@ -80,7 +79,7 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons int32_t x_zero_point = 0; if (IsQuantizedOp(node_unit)) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); } @@ -93,12 +92,13 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape; } -bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); const auto& perm_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, perm_name)) { - LOGS_DEFAULT(VERBOSE) << "New shape of reshape must be known"; + const auto* perm = graph_viewer.GetConstantInitializer(perm_name); + if (!perm) { + LOGS_DEFAULT(VERBOSE) << "New shape of reshape must be a constant initializer"; return false; } @@ -112,7 +112,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer return false; } - const auto& perm_tensor = *initializers.at(perm_name); + const auto& perm_tensor = *perm; Initializer unpacked_tensor(perm_tensor); auto raw_perm = unpacked_tensor.DataAsSpan(); const auto perm_size = SafeInt(perm_tensor.dims()[0]); @@ -138,17 +138,17 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer } bool ReshapeOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) { - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) { return false; } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) { return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index cdaa1c8fac76c..d75b9cc72ff4b 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -33,19 +33,18 @@ class ResizeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, - const OpSupportCheckParams& /* params */) const override; + int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing // We only support Resize opset 11+ here int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -74,7 +73,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& initializers(model_builder.GetInitializerTensors()); NodeAttrHelper helper(node_unit); const auto& inputs = node_unit.Inputs(); const auto android_feature_level = model_builder.GetEffectiveFeatureLevel(); @@ -92,7 +90,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); } @@ -147,7 +145,7 @@ bool ResizeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQResize; } -bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -228,32 +226,29 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } } - { // scales and sizes (if present) must be initializers + // scales or sizes must be constant initializers + { + // scales is input 3, sizes input 4, one must exist. only one is used. const auto inputs = node_unit.Inputs(); - if (inputs.size() < 3) { + bool using_scales = inputs.size() > 2 && inputs[2].node_arg.Exists(); + bool using_sizes = !using_scales && inputs.size() > 3 && inputs[3].node_arg.Exists(); + if (!using_scales && !using_sizes) { LOGS_DEFAULT(VERBOSE) << "Input scales or sizes of Resize must be known"; return false; } - // scales - bool using_scales = (inputs.size() > 2 && inputs[2].node_arg.Exists()); - if (using_scales && !Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be known"; - return false; - } - - // sizes - bool using_sizes = inputs.size() > 3 && inputs[3].node_arg.Exists(); - if (using_sizes && !Contains(initializers, inputs[3].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be known"; - return false; - } - bool input_is_nchw = false; // haven't a good solution to check layout when scale is 1.0F // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (using_scales) { // we are using scales - const auto& scales_tensor = *initializers.at(inputs[2].node_arg.Name()); - Initializer const unpacked_tensor(scales_tensor); + bool input_is_nchw = false; + + if (using_scales) { + const auto* scales = graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name()); + if (!scales) { + LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be a constant initializer"; + return false; + } + + const Initializer unpacked_tensor(*scales); auto scales_data = unpacked_tensor.DataAsSpan(); input_is_nchw = scales_data[1] == 1.0F; float const scale_n = scales_data[0]; @@ -265,10 +260,13 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } } else { - // we are using sizes - const auto& sizes_name = inputs[3].node_arg.Name(); - const auto& sizes_tensor = *initializers.at(sizes_name); - Initializer unpacked_tensor(sizes_tensor); + const auto* sizes = graph_viewer.GetConstantInitializer(inputs[3].node_arg.Name()); + if (!sizes) { + LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be a constant initializer"; + return false; + } + + Initializer unpacked_tensor(*sizes); auto sizes_data = unpacked_tensor.DataAsSpan(); input_is_nchw = sizes_data[1] == input_shape[1]; @@ -308,7 +306,7 @@ int32_t ResizeOpBuilder::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_u } bool ResizeOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) @@ -323,10 +321,10 @@ bool ResizeOpBuilder::HasSupportedInputOutputsImpl( } if (IsQuantizedOp(node_unit)) { - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc index 903469d34e67c..facdc7132dc00 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc @@ -40,7 +40,7 @@ class SliceOpBuilder : public BaseOpBuilder { // We only support slice from opset 10 int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 10; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -201,7 +201,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -219,19 +219,19 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[1].node_arg.Name(), "starts")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[1].node_arg.Name(), "starts")) { return false; } - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[2].node_arg.Name(), "ends")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[2].node_arg.Name(), "ends")) { return false; } const auto& inputs = node_unit.Inputs(); if (inputs.size() > 3) { - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[3].node_arg.Name(), "axes")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[3].node_arg.Name(), "axes")) { return false; } if (inputs.size() > 4) { - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[4].node_arg.Name(), "steps")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[4].node_arg.Name(), "steps")) { return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc index 1e420fec80827..a2a8b4512b028 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc @@ -33,7 +33,7 @@ class SoftMaxOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -41,7 +41,7 @@ class SoftMaxOpBuilder : public BaseOpBuilder { return ANEURALNETWORKS_FEATURE_LEVEL_2; } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } @@ -77,8 +77,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons int32_t y_zero_point = 0; if (IsQuantizedOp(node_unit)) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Inputs()[0], node_unit.ModelPath(), - x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); @@ -156,7 +155,7 @@ bool SoftMaxOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQSoftmax; } -bool SoftMaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool SoftMaxOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -197,24 +196,23 @@ bool SoftMaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftMaxOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const { +bool SoftMaxOpBuilder::HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) { - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) { return false; } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) { return false; } // NNAPI requires the scale be 1.f/256 and zero point to be 0 if (!HasRequiredScaleAndZeroPoint( - initializers, + graph_viewer, MakeString("Op [", node_unit.OpType(), "] name [", node_unit.Name(), "]'s output 0 "), node_unit.Outputs()[0], node_unit.ModelPath(), 1.f / 256 /* required_scale */, 0 /* required_zp */)) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc index 68b63badb8f7e..edee298ad1ccf 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc @@ -35,7 +35,7 @@ class SplitOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // Split opset 13- uses "split" as attribute. Currently it's not supported. @@ -67,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int32_t num_outputs; if (node_unit.SinceVersion() >= 18) { - num_outputs = SafeInt(*helper.GetInt("num_outputs")); + num_outputs = SafeInt(*helper.GetInt64("num_outputs")); } else { num_outputs = SafeInt(node_unit.Outputs().size()); } @@ -85,7 +85,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -98,13 +98,13 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const auto split_dims_at_axis = input_shape[SafeInt(HandleNegativeAxis(axis, input_shape.size()))]; if (input_defs.size() > 1 && input_defs[1].node_arg.Exists()) { // if optional input `split` is provided - auto split_initializer_it = initializers.find(input_defs[1].node_arg.Name()); - if (split_initializer_it == initializers.end()) { - LOGS_DEFAULT(VERBOSE) << "Optional input 'split' must be initializer if provided."; + const auto* splits = graph_viewer.GetConstantInitializer(input_defs[1].node_arg.Name()); + if (!splits) { + LOGS_DEFAULT(VERBOSE) << "Optional input 'split' must be a constant initializer if provided."; return false; } - const auto& splits_tensor = *split_initializer_it->second; - Initializer unpacked_tensor(splits_tensor); + + Initializer unpacked_tensor(*splits); auto splits_span = unpacked_tensor.DataAsSpan(); uint32_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), SafeInt(0)); if (sum_of_splits != split_dims_at_axis) { @@ -119,6 +119,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, auto it = std::adjacent_find(splits_span.begin(), splits_span.end(), [](const auto& a, const auto& b) { return a != b; }); + if (it != splits_span.end()) { LOGS_DEFAULT(VERBOSE) << "NNAPI only supports the case that number of splits evenly divides split axis size"; return false; @@ -126,7 +127,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } else { uint32_t num_outputs; if (node_unit.SinceVersion() >= 18) { - auto num_outputs_attr = helper.GetInt("num_outputs"); + auto num_outputs_attr = helper.GetInt64("num_outputs"); if (!num_outputs_attr.has_value()) { LOGS_DEFAULT(VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc index a0fe744eaacc8..fb3ca5e6175fa 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc @@ -32,7 +32,7 @@ class SqueezeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -59,7 +59,7 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons // Operator support related -bool SqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool SqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); Shape input_shape; @@ -76,8 +76,8 @@ bool SqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer // Squeeze opset 13 use input 1 as axes, if we have input 1 then it need to be an initializer if (node_unit.SinceVersion() > 12 && inputs.size() > 1) { const auto& axes_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, axes_name)) { - LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be known"; + if (!graph_viewer.GetConstantInitializer(axes_name)) { + LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc index 4d243c730bf05..6fe5ca32fe044 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc @@ -32,7 +32,7 @@ class TransposeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -41,7 +41,7 @@ class TransposeOpBuilder : public BaseOpBuilder { } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; @@ -59,7 +59,6 @@ void TransposeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto& shaper(model_builder.GetShaper()); - const auto& initializers(model_builder.GetInitializerTensors()); const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); @@ -78,7 +77,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); } @@ -95,7 +94,7 @@ bool TransposeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQTranspose; } -bool TransposeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool TransposeOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -112,7 +111,7 @@ bool TransposeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia } bool TransposeOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) @@ -127,10 +126,10 @@ bool TransposeOpBuilder::HasSupportedInputOutputsImpl( } if (IsQuantizedOp(node_unit)) { - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc index 796fd207fe428..dbd960ee5536c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc @@ -32,19 +32,18 @@ class UnaryOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, + int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; - static bool IsQuantizedOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + static bool IsQuantizedOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params); }; @@ -117,11 +116,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const float y_scale = 0.0f; int32_t y_zero_point = 0; if (is_qlinear_sigmoid) { - const auto& initializers = model_builder.GetInitializerTensors(); float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); // Verify if the scale and zero point values from onnx input and nnapi input match ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); @@ -141,10 +139,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool UnaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool UnaryOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { if (node_unit.OpType() == "QLinearSigmoid") { - return IsQuantizedOpSupported(initializers, node_unit, params); + return IsQuantizedOpSupported(graph_viewer, node_unit, params); } else if (node_unit.OpType() == "Sigmoid") { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -178,16 +176,16 @@ int32_t UnaryOpBuilder::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_un } bool UnaryOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { // We only need to override input check for QLinearSigmoid if (node_unit.OpType() != "QLinearSigmoid") - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; return true; @@ -204,13 +202,13 @@ int UnaryOpBuilder::GetMinSupportedOpSet(const NodeUnit& node_unit) const { } /* static */ bool UnaryOpBuilder::IsQuantizedOpSupported( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) { + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) { const auto& op_type = node_unit.OpType(); ORT_ENFORCE(op_type == "QLinearSigmoid"); // NNAPI requires the scale be 1.f/256 and zero point to be 0 // See https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/android10-c2f2-release/nn/common/operations/Activation.cpp#180 - if (!HasRequiredScaleAndZeroPoint(initializers, + if (!HasRequiredScaleAndZeroPoint(graph_viewer, MakeString("Op [", op_type, "] name [", node_unit.Name(), "]'s output 0 "), node_unit.Outputs()[0], node_unit.ModelPath(), 1.f / 256 /* required_scale */, 0 /* required_zp */)) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc index a9bece7d42364..95cd813800c9a 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc @@ -32,7 +32,7 @@ class UnsqueezeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -74,7 +74,7 @@ Status UnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co // Operator support related -bool UnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool UnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); Shape input_shape; @@ -93,8 +93,8 @@ bool UnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ // Unsqueeze opset 13 uses input 1 as axes, if we have input 1 then it needs to be an initializer if (node_unit.SinceVersion() > 12 && inputs.size() > 1) { const auto& axes_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, axes_name)) { - LOGS_DEFAULT(VERBOSE) << "Input axes of Unsqueeze must be known"; + if (!graph_viewer.GetConstantInitializer(axes_name)) { + LOGS_DEFAULT(VERBOSE) << "Input axes of Unsqueeze must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index b75e78cbfe7cc..6962a7be94bb6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -100,7 +100,7 @@ void ModelBuilder::PreprocessActivations() { activation_node_units_.emplace(node_unit.get(), ANEURALNETWORKS_FUSED_RELU); } else if (op_type == "Clip") { // Relu1 or Relu6 float min, max; - if (!GetClipMinMax(GetInitializerTensors(), node, min, max, logging::LoggingManager::DefaultLogger())) + if (!GetClipMinMax(graph_viewer_, node, min, max, logging::LoggingManager::DefaultLogger())) continue; if (min == -1.0f && max == 1.0f) { @@ -151,7 +151,7 @@ void ModelBuilder::GetAllQuantizedOpInputs() { } static Status GetInputDataType( - const InitializedTensorSet& initializers, + const GraphViewer& graph_viewer, const std::unordered_map>& all_quantized_op_inputs, const std::string& name, int32_t data_type, const Shape& shape, OperandType& operand_type) { @@ -177,7 +177,7 @@ static Status GetInputDataType( // TODO, verify the scale and zero point match if there are multiple op using same input const auto* node_unit = all_quantized_op_inputs.at(name)[0]; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, *node_unit, name, scale, zero_point, ArgType::kInput)); + graph_viewer, *node_unit, name, scale, zero_point, ArgType::kInput)); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: @@ -226,9 +226,8 @@ Status ModelBuilder::RegisterInitializers() { } OperandType operand_type(Type::TENSOR_FLOAT32, shape); - ORT_RETURN_IF_ERROR( - GetInputDataType(GetInitializerTensors(), all_quantized_op_inputs_, - name, tensor.data_type(), shape, operand_type)); + ORT_RETURN_IF_ERROR(GetInputDataType(graph_viewer_, all_quantized_op_inputs_, name, tensor.data_type(), shape, + operand_type)); shaper_.AddShape(name, operand_type.dimensions); uint32_t index = 0; @@ -304,7 +303,7 @@ Status ModelBuilder::RegisterModelInputs() { "The input of graph doesn't have elem_type: ", input_name); } else { ORT_RETURN_IF_ERROR( - GetInputDataType(GetInitializerTensors(), all_quantized_op_inputs_, + GetInputDataType(graph_viewer_, all_quantized_op_inputs_, input_name, type_proto->tensor_type().elem_type(), shape, operand_type)); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h index c565af491ff90..f6db4022fb8f4 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h @@ -56,7 +56,7 @@ class IOpBuilder { // Operator support check related // Check if an operator is supported - virtual bool IsOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + virtual bool IsOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const = 0; }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index 26db7c8e7afea..a066c64dac67d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -679,16 +679,15 @@ Status HandleAutoPad(const Shape& input_shape, return Status::OK(); } -Status GetBinaryOpQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - float& a_scale, float& b_scale, float& y_scale, - int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) { +Status GetBinaryOpQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + float& a_scale, float& b_scale, float& y_scale, + int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), a_scale, a_zero_point)); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), a_scale, a_zero_point)); ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[1], node_unit.ModelPath(), b_scale, b_zero_point)); + graph_viewer, node_unit.Inputs()[1], node_unit.ModelPath(), b_scale, b_zero_point)); ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); + graph_viewer, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); return Status::OK(); } @@ -699,16 +698,18 @@ Status GetConvMatMulOpQuantizationScaleAndZeroPoint( int32_t& a_zero_point, int32_t& w_zero_point, int32_t& y_zero_point, std::optional>& w_scales, bool& is_per_tensor_u8s8) { is_per_tensor_u8s8 = false; - const auto& initializers(model_builder.GetInitializerTensors()); + const auto& graph_viewer(model_builder.GetGraphViewer()); + // Get scale and zero points // We will handle per-channel weight scale and zero point later ORT_RETURN_IF_ERROR( - GetBinaryOpQuantizationScaleAndZeroPoint(initializers, node_unit, + GetBinaryOpQuantizationScaleAndZeroPoint(graph_viewer, node_unit, a_scale, w_scale, y_scale, a_zero_point, w_zero_point, y_zero_point)); const auto& inputs = node_unit.Inputs(); - const auto& weight_tensor = *initializers.at(inputs[1].node_arg.Name()); + // all these were checked to be constant in GemmOpBuilder::IsOpSupportedImpl + const auto& weight_tensor = *graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name()); // We are done here if this is u8u8 QLinearConv if (weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) @@ -719,7 +720,7 @@ Status GetConvMatMulOpQuantizationScaleAndZeroPoint( // For this case we will need to convert the int8 weight tensor to uint8 // And have same scale and 128 as zero point // The conversion of the weight tensor itself will be done in the OpBuilder - const auto& scale_tensor = *initializers.at(inputs[1].quant_param->scale.Name()); + const auto& scale_tensor = *graph_viewer.GetConstantInitializer(inputs[1].quant_param->scale.Name()); int64_t scale_dim = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; if (scale_dim == 1) { w_zero_point = 128; @@ -1072,20 +1073,20 @@ Status AddReshapeOperator(ModelBuilder& model_builder, return Status::OK(); } -bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, +bool IsQuantizationScaleSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const OpSupportCheckParams& params, const std::string& op_type, bool is_quant_matmul, bool is_conv_matmul_u8s8_weight) { const auto scale_name = io_def.quant_param->scale.Name(); - auto it = initializers.find(scale_name); - if (it == initializers.cend()) { - LOGS_DEFAULT(VERBOSE) << "The scale of " << op_type << " must be an initializer tensor"; + const auto* scale = graph_viewer.GetConstantInitializer(scale_name); + if (!scale) { + LOGS_DEFAULT(VERBOSE) << "The scale of " << op_type << " must be a constant initializer"; return false; } - const auto& scale_tensor = *it->second; + const auto& scale_tensor = *scale; int64_t scales_dim = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; if (!is_conv_matmul_u8s8_weight) { if (scales_dim != 1) { @@ -1123,7 +1124,7 @@ bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, return true; } -bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, +bool IsQuantizationZeroPointSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::string& op_type, const Path& model_path, @@ -1134,12 +1135,13 @@ bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, return true; const auto& zero_point_name = io_def.quant_param->zero_point->Name(); - if (!Contains(initializers, zero_point_name)) { - LOGS_DEFAULT(VERBOSE) << "The zero point of " << op_type << " must be an initializer tensor"; + const auto* zero_point = graph_viewer.GetConstantInitializer(zero_point_name); + if (!zero_point) { + LOGS_DEFAULT(VERBOSE) << "The zero point of " << op_type << " must be a constant initializer"; return false; } - const auto& zero_tensor = *initializers.at(zero_point_name); + const auto& zero_tensor = *zero_point; int64_t zero_dim = zero_tensor.dims().empty() ? 1 : zero_tensor.dims()[0]; if (!is_conv_matmul_u8s8_weight) { @@ -1194,8 +1196,9 @@ bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, return true; } -bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const std::vector& indices, const OpSupportCheckParams& params, ArgType arg_type) { +bool IsQuantizedIOSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::vector& indices, const OpSupportCheckParams& params, + ArgType arg_type) { const auto& op_type = node_unit.OpType(); auto quant_op_type = GetQuantizedOpType(node_unit); @@ -1247,12 +1250,12 @@ bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const Node } // Check scale and zero point - if (!IsQuantizationScaleSupported(initializers, io_def, params, op_type, + if (!IsQuantizationScaleSupported(graph_viewer, io_def, params, op_type, is_quant_matmul, is_conv_matmul_u8s8_weight)) { return false; } - if (!IsQuantizationZeroPointSupported(initializers, io_def, op_type, node_unit.ModelPath(), + if (!IsQuantizationZeroPointSupported(graph_viewer, io_def, op_type, node_unit.ModelPath(), is_quant_matmul, is_conv_matmul_u8s8_weight)) { return false; } @@ -1261,33 +1264,27 @@ bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const Node return true; } -bool HasRequiredScaleAndZeroPoint(const InitializedTensorSet& initializers, +bool HasRequiredScaleAndZeroPoint(const GraphViewer& graph_viewer, const std::string& op_desc, const NodeUnitIODef& io_def, const Path& path, float required_scale, int32_t required_zp) { float scale = 0.0f; int32_t zp = 0; - auto status = GetQuantizationScaleAndZeroPoint(initializers, io_def, path, - scale, zp); + auto status = GetQuantizationScaleAndZeroPoint(graph_viewer, io_def, path, scale, zp); if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << op_desc - << " GetQuantizationScaleAndZeroPoint failed, message: " - << status.ErrorMessage(); + LOGS_DEFAULT(ERROR) << op_desc << " GetQuantizationScaleAndZeroPoint failed, message: " << status.ErrorMessage(); return false; } if (scale != required_scale) { - LOGS_DEFAULT(VERBOSE) << op_desc - << " scale can only be [" << required_scale - << "], actual scale: " << scale; + LOGS_DEFAULT(VERBOSE) << op_desc << " scale can only be [" << required_scale << "], actual scale: " << scale; return false; } if (zp != required_zp) { - LOGS_DEFAULT(VERBOSE) << op_desc - << "] zero point can only be [" << required_zp - << "], actual zero point: " << scale; + LOGS_DEFAULT(VERBOSE) << op_desc << "] zero point can only be [" << required_zp << "], actual zero point: " + << zp; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 0cc442890ab6e..7ccf4c1ef7555 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -118,7 +118,7 @@ Status HandleAutoPad(const Shape& input_shape, // Get scales and zero points for the qlinear binary ops (which has 2 input and 1 output) // QLinearConv, QLinearMatmul, QLinearAdd, QLinearMul // a, b are inputs, and y is output -Status GetBinaryOpQuantizationScaleAndZeroPoint(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +Status GetBinaryOpQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnit& node_unit, float& a_scale, float& b_scale, float& y_scale, int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point); @@ -193,14 +193,14 @@ inline bool IsNodeLayoutNHWC(const NodeUnit& node_unit) { return node_unit.Domain() == kMSInternalNHWCDomain; } -bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, +bool IsQuantizationScaleSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const OpSupportCheckParams& params, const std::string& op_type, bool is_quant_matmul, bool is_conv_matmul_u8s8_weight); -bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, +bool IsQuantizationZeroPointSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::string& op_type, const Path& model_path, @@ -208,13 +208,13 @@ bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, bool is_conv_matmul_u8s8_weight); // Check if the given quantized input(s) or output(s) is supported -bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool IsQuantizedIOSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const std::vector& indices, const OpSupportCheckParams& params, ArgType arg_type); // Some Quantized NNAPI operations have required output scale and zero point // e.g. Softmax (uint8) requires output scale be 1.f/256 and zp be 0 // This helper function checks if the given io_def has required scale and zp -bool HasRequiredScaleAndZeroPoint(const InitializedTensorSet& initializers, +bool HasRequiredScaleAndZeroPoint(const GraphViewer& graph_viewer, const std::string& op_desc, const NodeUnitIODef& io_def, const Path& path, diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 676ea5b038d7e..bbb157e17d07b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -70,7 +70,9 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else - if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { + if (!subgraph_context_.has_dynamic_input_shape && + global_context_.onnx_model_path_name != "" && + dev_prec != "CPU_FP16") { exe_network_ = global_context_.ie_core.LoadNetwork(global_context_.onnx_model_path_name, hw_target, device_config, diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index d95e2baa9457f..4a9106f0c06af 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +struct HandleConvertResult { + Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine + // whether a DQ -> Q sequence was successfully merged into a Convert. + const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence. + // Set to nullptr if this node unit could not be merged into a Convert. +}; + +/** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param qnn_model_wrapper The QNN model that is being built. + * \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer + * to the Q node unit that was successfully merged with the provided DQ node unit. + */ +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc new file mode 100644 index 0000000000000..977a9e0b3d9d0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" + +#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values). + +namespace onnxruntime { +namespace qnn { + +class ConvertOpBuilder : public BaseOpBuilder { + public: + ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder); + + Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; +}; + +Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector input_names; + + // Process the input from the DQ node + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names)); + + // Process the output from the Q node. Override the QNN operator type to "Convert". + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {}, + logger, do_op_validation, QNN_OP_CONVERT)); + return Status::OK(); +} + +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Looking for a standalone DQ to start the sequence. + if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + const Node& dq_node = maybe_dq_node_unit.GetNode(); + + // DQ must have a single Q child. DQ must not produce a graph output. + auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); + if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return {}; + } + + const Node& q_node = *children[0]; + const auto q_node_unit_it = node_unit_map.find(&q_node); + + if (q_node_unit_it == node_unit_map.end()) { + return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr}; + } + + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + return {}; + } + + ConvertOpBuilder op_builder; + + LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; + + auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger, + do_op_validation); + return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr}; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 314cab4a36ca9..dc91b9dfa199e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } + std::unordered_set handled_node_units; + // Op builer const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { @@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // Check whether it's part of NodeUnit const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node) + // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) const std::string& op_type = node_unit.OpType(); + + if (node != &node_unit.GetNode()) { + continue; + } + + if (handled_node_units.count(&node_unit) != 0) { + continue; // Already handled. + } + + // Try to convert particular DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + node_unit, + node_unit_map, + logger_, + false /*do_op_validation*/); + ORT_RETURN_IF_ERROR(convert_result.status); + + if (convert_result.q_node_unit) { + // Successfully merged DQ -> Q sequence into a QNN Convert op. + // Mark both of these node units as handled. + handled_node_units.insert(&node_unit); + handled_node_units.insert(convert_result.q_node_unit); + continue; + } + LOGS(logger_, VERBOSE) << " node name: [" << node->Name() << "] node optype: [" << op_type << "] as part of the NodeUnit type: [" << node_unit.OpType() << "] name: [" << node_unit.Name() << "]"; - if (node != &node_unit.GetNode()) { - continue; - } - if (const auto* op_builder = GetOpBuilder(op_type)) { ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); } + + handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b58f6e10df94c..f5a166d36b15a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const { - // If we have visited one of the nodes in the node_unit, use the result directly - const auto it = node_unit_supported_result.find(&node_unit); - if (it != node_unit_supported_result.cend()) { - return it->second; + const std::string& op_type = node_unit.OpType(); + bool supported = false; + const auto* op_builder = qnn::GetOpBuilder(op_type); + if (op_builder == nullptr) { + LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." + << node_unit.OpType() << " node `" << node_unit.Name() + << "` will not be assigned to QNN EP."; } else { - const std::string& op_type = node_unit.OpType(); - - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - supported = (Status::OK() == status); + auto status = op_builder->IsOpSupported(qnn_model_wrapper, + node_unit, logger); + if (Status::OK() != status) { + LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); } - node_unit_supported_result[&node_unit] = supported; - return supported; + supported = (Status::OK() == status); } + return supported; } std::unordered_set @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, if (node != &node_unit->GetNode()) { continue; } - const bool supported = IsNodeSupported(qnn_model_wrapper, - *node_unit, - node_unit_supported_result, - logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; + + if (node_unit_supported_result.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + // Try to convert certain standalone DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + *node_unit, + node_unit_map, + logger, + true /*do_op_validation*/); + if (!convert_result.status.IsOK()) { + LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. " + << "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", " + << "Message: " << convert_result.status.ErrorMessage(); + } + + bool supported = false; + + if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op + supported = true; + + // Mark the Q node unit as handled and supported here so that we don't try to process it again. + node_unit_supported_result.insert({convert_result.q_node_unit, true}); + supported_nodes.insert(&convert_result.q_node_unit->GetNode()); + } else { + supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); + LOGS(logger, VERBOSE) << "Node supported: [" << supported + << "] index: [" << node->Index() + << "] name: [" << node->Name() + << "] Operator type: [" << node->OpType() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + } + if (supported) { // If the node_unit is supported, add all of its nodes to the supported list. for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { supported_nodes.insert(node_in_group); } } + + node_unit_supported_result.insert({node_unit, supported}); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 09bcb24db4dc2..0bcaa39b22f6d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const; std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index 2f549cc1ac143..c245b18057ca7 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -74,12 +74,32 @@ struct ROCMExecutionProviderInfo { } // namespace onnxruntime template <> -struct std::hash<::onnxruntime::rocm::TunableOpInfo> { - size_t operator()(const ::onnxruntime::rocm::TunableOpInfo& info) const { - size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enable, seed_and_value); - onnxruntime::HashCombine(info.tuning_enable, seed_and_value); - onnxruntime::HashCombine(info.max_tuning_duration_ms, seed_and_value); - return seed_and_value; +struct std::hash<::onnxruntime::ROCMExecutionProviderInfo> { + size_t operator()(const ::onnxruntime::ROCMExecutionProviderInfo& info) const { + size_t value{0xbc9f1d34}; // seed + + // Bits: device_id (16), arena_extend_strategy/miopen_conv_exhaustive_search (reserved 2), boolean options (1 each) + size_t data = static_cast(info.device_id) ^ + (static_cast(info.arena_extend_strategy) << 16) ^ + (static_cast(info.miopen_conv_exhaustive_search) << 18) ^ + (static_cast(info.do_copy_in_default_stream) << 20) ^ + (static_cast(info.has_user_compute_stream) << 21) ^ + (static_cast(info.miopen_conv_use_max_workspace) << 22) ^ + (static_cast(info.enable_hip_graph) << 23) ^ + (static_cast(info.tunable_op.enable) << 24) ^ + (static_cast(info.tunable_op.tuning_enable) << 25); + onnxruntime::HashCombine(data, value); + + onnxruntime::HashCombine(info.gpu_mem_limit, value); + onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + + // Memory pointers + onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + + // The default memory arena cfg is not used in hashing right now. + return value; } }; diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index c0b7d4722d3e4..70bf08d65401a 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -101,6 +101,10 @@ class RocmKernel : public OpKernel { return static_cast(provider_->GetTuningContext()); } + bool UseTF32() const { + return false; + } + // To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory // and it can only be released after the copy has finished template diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 7cbc37cb64c5a..d93f70785c093 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -115,7 +115,8 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmHelper(handle, transa, transb, @@ -154,7 +155,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } -// Compatible for function call with the extra hipDeviceProp_t argument +// Compatible for function call with extra arguments (see cublasGemmHelper) template rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_operation transa, @@ -165,7 +166,8 @@ rocblas_status rocblasGemmHelper(rocblas_handle handle, const Scalar* B, int ldb, const Scalar* beta, Scalar* C, int ldc, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmHelper(handle, transa, transb, @@ -404,7 +406,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } -// Compatible for function call with the extra hipDeviceProp_t argument +// Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) template rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_operation transa, @@ -419,7 +421,8 @@ rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, Scalar* C, int ldc, intmax_t strideC, int batchCount, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmStridedBatchedHelper(handle, transa, transb, @@ -445,7 +448,8 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, __half* C, int ldc, intmax_t strideC, int batchCount, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmStridedBatchedHelper(handle, transa, transb, diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 39ea4dd8412bb..c07a0929353b1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -25,12 +25,14 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg return true; } -bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, - float& min, float& max, const logging::Logger& logger) { +namespace { +bool GetClipMinMaxImpl(std::function get_const_initializer, + const Node& node, float& min, float& max, const logging::Logger& logger) { const auto& node_name = node.Name(); int32_t input_type; - if (!GetType(*node.InputDefs()[0], input_type, logger)) + if (!GetType(*node.InputDefs()[0], input_type, logger)) { return false; + } min = std::numeric_limits::lowest(); max = std::numeric_limits::max(); @@ -41,49 +43,73 @@ bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, min = helper.Get("min", std::numeric_limits::lowest()); max = helper.Get("max", std::numeric_limits::max()); } else { - if (node.InputDefs().size() > 1) { - // we have input min - const auto& min_name = node.InputDefs()[1]->Name(); - if (!Contains(initializers, min_name)) { - LOGS(logger, VERBOSE) << "Input min of Clip must be known"; + auto get_value = + [&](const ONNX_NAMESPACE::TensorProto* initializer, std::string_view type, float& value) -> bool { + if (!initializer) { + LOGS(logger, VERBOSE) << type << " input of Clip must be a constant initializer"; return false; } - Initializer unpacked_tensor_min(*initializers.at(min_name)); + + Initializer unpacked_tensor_min(*initializer); switch (input_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - min = unpacked_tensor_min.DataAsSpan()[0]; + value = unpacked_tensor_min.DataAsSpan()[0]; break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - min = (unpacked_tensor_min.DataAsSpan()[0]).ToFloat(); + value = unpacked_tensor_min.DataAsSpan()[0].ToFloat(); break; default: - LOGS(logger, VERBOSE) << "GetClipMinMax() only support Clip node with float inputs for now. " - << "The node [" << node_name << "] has input 0 type: " << input_type; + LOGS(logger, VERBOSE) << "GetClipMinMax() only supports float and float16 as min and max inputs for now." + << " The node [" << node_name << "] has input type: " << input_type; return false; } - if (node.InputDefs().size() > 2) { - // we have input max - const auto& max_name = node.InputDefs()[2]->Name(); - if (!Contains(initializers, max_name)) { - LOGS(logger, VERBOSE) << "Input max of Clip must be known"; - return false; - } - Initializer unpacked_tensor_max(*initializers.at(max_name)); - switch (input_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - max = unpacked_tensor_max.DataAsSpan()[0]; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - max = (unpacked_tensor_max.DataAsSpan()[0]).ToFloat(); - break; - } + return true; + }; + + // min and max are both optional. could have neither, one or both. + if (node.InputDefs().size() > 1 && node.InputDefs()[1]->Exists()) { + // we have input min + const auto& min_name = node.InputDefs()[1]->Name(); + const auto* min_value = get_const_initializer(min_name); + if (!get_value(min_value, "Min", min)) { + return false; + } + } + + if (node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { + // we have input max + const auto& max_name = node.InputDefs()[2]->Name(); + const auto* max_value = get_const_initializer(max_name); + if (!get_value(max_value, "Max", max)) { + return false; } } } return true; } +} // namespace + +bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min, float& max, + const logging::Logger& logger) { + return GetClipMinMaxImpl( + [&graph_viewer](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* { + return graph_viewer.GetConstantInitializer(name); + }, + node, min, max, logger); +} + +// deprecated version that is not able to check if the initializer is constant +bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, float& min, float& max, + const logging::Logger& logger) { + return GetClipMinMaxImpl( + [&initializers](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* { + auto entry = initializers.find(name); + return entry == initializers.end() ? nullptr : entry->second; + }, + node, min, max, logger); +} NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node) : node_attributes_(node.GetAttributes()) {} @@ -92,84 +118,134 @@ NodeAttrHelper::NodeAttrHelper(const NodeUnit& node_unit) : node_attributes_(node_unit.GetNode().GetAttributes()) {} float NodeAttrHelper::Get(const std::string& key, float def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.f(); + } - return node_attributes_.at(key).f(); + return def_val; } int32_t NodeAttrHelper::Get(const std::string& key, int32_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(entry->second.i()); + } - return SafeInt(node_attributes_.at(key).i()); + return def_val; } uint32_t NodeAttrHelper::Get(const std::string& key, uint32_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(entry->second.i()); + } - return SafeInt(node_attributes_.at(key).i()); + return def_val; } int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.i(); + } - return node_attributes_.at(key).i(); + return def_val; } const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.s(); + } - return node_attributes_.at(key).s(); + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& attr = entry->second; + std::vector v; + v.reserve(static_cast(attr.ints_size())); + std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), + [](int64_t val) -> int32_t { return narrow(val); }); + return v; + } - const auto& attr(node_attributes_.at(key)); - std::vector v; - v.reserve(static_cast(attr.ints_size())); - std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), - [](int64_t val) -> int32_t { return SafeInt(val); }); - return v; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& attr = entry->second; + std::vector v; + v.reserve(static_cast(attr.ints_size())); + std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), + [](int64_t val) -> uint32_t { return narrow(val); }); + return v; + } - const auto& attr(node_attributes_.at(key)); - std::vector v; - v.reserve(static_cast(attr.ints_size())); - std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), - [](int64_t val) -> uint32_t { return SafeInt(val); }); - return v; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.ints(); + return std::vector{values.cbegin(), values.cend()}; + } - const auto& source(node_attributes_.at(key).ints()); - return std::vector{source.cbegin(), source.cend()}; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.floats(); + return std::vector{values.cbegin(), values.cend()}; + } + + return def_val; +} + +std::optional NodeAttrHelper::GetFloat(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.f(); + } + + return result; +} + +std::optional NodeAttrHelper::GetInt64(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.i(); + } - const auto& source(node_attributes_.at(key).floats()); - return std::vector{source.cbegin(), source.cend()}; + return result; } -std::optional NodeAttrHelper::GetInt(const std::string& key) const { - if (!HasAttr(key)) - return std::nullopt; - return node_attributes_.at(key).i(); +std::optional> NodeAttrHelper::GetFloats(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.floats(); + result = std::vector(values.begin(), values.end()); + } + + return result; +} + +std::optional> NodeAttrHelper::GetInt64s(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.ints(); + result = std::vector(values.begin(), values.end()); + } + + return result; +} + +std::optional NodeAttrHelper::GetString(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.s(); + } + + return result; } bool NodeAttrHelper::HasAttr(const std::string& key) const { diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 1e93f040711df..5813dcc48d72b 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -16,14 +16,20 @@ namespace logging { class Logger; } +class GraphViewer; class Node; class NodeArg; class NodeUnit; -// Get the min/max of a Clip operator. -// If min/max are not known initializer tensors, will return false -// For now we only support getting float min/max, -// since in most cases, Clip(0,6)[Relu6] will be fused by quantization tool +// Get the min/max of a Clip operator. Reads values from attributes for opset < 11 and inputs after that. +// For opset 11+, if min/max are not constant initializers, will return false. +// For now we only support getting float min/max. +bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, + float& min, float& max, const logging::Logger& logger); + +/// GraphViewer GetConstantInitializer/IsConstantInitializer should be used to ensure the initializer is +/// constant. Low risk for Clip min/max but in general the infrastructure to check if an operator is supported needs +/// to be updated to not use InitializedTensorSet which may contain non-constant initializers. bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, float& min, float& max, const logging::Logger& logger); @@ -41,15 +47,17 @@ class NodeAttrHelper { // Get the attributes from the target node of the node_unit explicit NodeAttrHelper(const NodeUnit& node_unit); + /* + * Get with default + */ float Get(const std::string& key, float def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; int64_t Get(const std::string& key, int64_t def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; - std::vector Get(const std::string& key, const std::vector& def_val) const; - std::vector Get(const std::string& key, const std::vector& def_val) const; - // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; @@ -58,7 +66,16 @@ class NodeAttrHelper { uint32_t Get(const std::string& key, uint32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; - std::optional GetInt(const std::string& key) const; + /* + * Get without default. + */ + std::optional GetFloat(const std::string& key) const; + std::optional> GetFloats(const std::string& key) const; + + std::optional GetInt64(const std::string& key) const; + std::optional> GetInt64s(const std::string& key) const; + + std::optional GetString(const std::string& key) const; bool HasAttr(const std::string& key) const; diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index 9de5b889808fc..0d6001bcba89f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -47,7 +47,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& output_name = node.OutputDefs()[0]->Name(); emscripten::val options = emscripten::val::object(); float minValue, maxValue; - ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, minValue, maxValue, logger), + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, minValue, maxValue, logger), "GetClipMinMax failed"); options.set("minValue", minValue); options.set("maxValue", maxValue); @@ -70,6 +70,9 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { + // TODO: Update IsOpSupportedImpl to pass GraphViewer instead of InitializedTensorSet so the implementations + // can ensure initializers are constant. See #19401 for details of how this update was made to the NNAPI EP. + // GetClipMinMax(graph_viewer, node, minValue, maxValue, logger) float min, max; return GetClipMinMax(initializers, node, min, max, logger); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cae714954f72f..b045f30a59797 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -46,10 +46,11 @@ #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/platform/Barrier.h" -#include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" #ifdef _WIN32 #include "core/platform/tracing.h" +#include +#include "core/platform/windows/telemetry.h" #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling( } // namespace std::atomic InferenceSession::global_session_id_{1}; +std::map InferenceSession::active_sessions_; +#ifdef _WIN32 +OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +#endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, const ONNX_NAMESPACE::ModelProto& model_proto, @@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options, void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); - // a monotonically increasing session id for use in telemetry - session_id_ = global_session_id_.fetch_add(1); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // a monotonically increasing session id for use in telemetry + session_id_ = global_session_id_.fetch_add(1); + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[global_session_id_++] = this; + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + LogAllSessions(); + } + }); +#endif + SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options); + TraceSessionOptions(session_options, false); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { + (void)captureState; // Otherwise Linux build error + LOGS(*session_logger_, INFO) << session_options; #ifdef _WIN32 @@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.graph_optimization_level), "graph_optimization_level"), TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"), TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"), - TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute")); + TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"), + TraceLoggingBoolean(captureState, "isCaptureState")); TraceLoggingWrite( telemetry_provider_handle, @@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"), TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"), TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"), - TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero")); + TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"), + TraceLoggingBoolean(captureState, "isCaptureState")); for (const auto& config_pair : session_options.config_options.configurations) { TraceLoggingWrite( @@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBoolean(captureState, "isCaptureState")); } #endif } @@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() { } } + // Unregister the session +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); +#endif + active_sessions_.erase(global_session_id_); + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() { return binding_.get(); } +#ifdef _WIN32 +void InferenceSession::LogAllSessions() { + std::lock_guard lock(active_sessions_mutex_); + for (const auto& session_pair : active_sessions_) { + InferenceSession* session = session_pair.second; + TraceSessionOptions(session->session_options_, true); + } +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 96db49aabdaf6..f8211bfd2dd4e 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -21,11 +22,12 @@ #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" #include "core/framework/framework_provider_common.h" +#include "core/framework/session_options.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" -#include "core/framework/session_options.h" +#include "core/platform/ort_mutex.h" #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif @@ -119,6 +121,10 @@ class InferenceSession { }; using InputOutputDefMetaMap = InlinedHashMap; + static std::map active_sessions_; +#ifdef _WIN32 + static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ +#endif public: #if !defined(ORT_MINIMAL_BUILD) @@ -642,7 +648,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options); + void TraceSessionOptions(const SessionOptions& session_options, bool captureState); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -679,6 +685,10 @@ class InferenceSession { */ void ShrinkMemoryArenas(gsl::span arenas_to_shrink); +#ifdef _WIN32 + void LogAllSessions(); +#endif + #if !defined(ORT_MINIMAL_BUILD) virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 32ae15e71acc6..3bec9aa146f76 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1555,6 +1555,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; cuda_options_converted.use_ep_level_unified_stream = 0; + cuda_options_converted.use_tf32 = 1; return cuda_options_converted; } @@ -1681,7 +1682,11 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; + if (legacy_ov_options->enable_npu_fast_compile) { + ov_options_converted_map["enable_npu_fast_compile"] = "false"; + } else { + ov_options_converted_map["enable_npu_fast_compile"] = "true"; + } if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; @@ -1700,14 +1705,12 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling; - if (legacy_ov_options->enable_dynamic_shapes != '\0') { - std::string enable_dynamic_shapes = reinterpret_cast(legacy_ov_options->enable_dynamic_shapes); - if (enable_dynamic_shapes == "true" || enable_dynamic_shapes == "True") { - ov_options_converted_map["disable_dynamic_shapes"] = "false"; - } else if (enable_dynamic_shapes == "false" || enable_dynamic_shapes == "False") { - ov_options_converted_map["disable_dynamic_shapes"] = "true"; - } + if (legacy_ov_options->enable_dynamic_shapes) { + ov_options_converted_map["disable_dynamic_shapes"] = "false"; + } else { + ov_options_converted_map["disable_dynamic_shapes"] = "true"; } + // Add new provider option below ov_options_converted_map["num_streams"] = "1"; return ov_options_converted_map; diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ade1d96d617fb..17a955ba8ce1a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; + for (const auto& config_pair : provider_options) { + ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str())); + } + if (strcmp(provider_name, "DML") == 0) { #if defined(USE_DML) options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index a5a165e150cf1..2a6c14ff1b058 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -93,22 +93,31 @@ static std::unique_ptr CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { ThreadOptions to; if (options.thread_pool_size <= 0) { // default - auto default_affinities = Env::Default().GetDefaultThreadAffinities(); - if (default_affinities.size() <= 1) { - return nullptr; - } - options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { #ifdef _WIN32 // Only set thread affinity on Server with auto affinity. // On client best to let OS scheduler handle. // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage if (IsWindowsServer()) { + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } #else + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); #endif + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } } if (options.thread_pool_size <= 1) { diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 09f768f53ea65..2fbd118a43ed1 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -358,7 +358,7 @@ class InferenceSession(Session): def __init__( self, path_or_bytes: str | bytes | os.PathLike, - sess_options: Sequence[onnxruntime.SessionOptions] | None = None, + sess_options: onnxruntime.SessionOptions | None = None, providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs, @@ -413,7 +413,7 @@ def __init__( self._read_config_from_model = os.environ.get("ORT_LOAD_CONFIG_FROM_MODEL") == "1" # internal parameters that we don't expect to be used in general so aren't documented - disabled_optimizers = kwargs["disabled_optimizers"] if "disabled_optimizers" in kwargs else None + disabled_optimizers = kwargs.get("disabled_optimizers") try: self._create_inference_session(providers, provider_options, disabled_optimizers) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu index fd9e9c4fd1612..8b05b96ec38a9 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu @@ -56,6 +56,9 @@ class GemmBenchmark : public IKernelExplorer { typedef typename ToCudaType::MappedType CudaT; CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); + + // TF32 is enable by default. To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 + constexpr bool use_tf32 = true; CUBLAS_CALL_THROW(cublasGemmHelper( params_.cublas_handle, CUBLAS_OP_N, @@ -69,7 +72,8 @@ class GemmBenchmark : public IKernelExplorer { &zero, params_.output_, params_.n_, - device_prop_)); + device_prop_, + use_tf32)); } private: @@ -79,11 +83,11 @@ class GemmBenchmark : public IKernelExplorer { cudaDeviceProp device_prop_; }; -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ .def("Run", &name::Run); KE_REGISTER(m) { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index e32cb032798fc..400a9d8a7a187 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -35,7 +35,11 @@ def sigmoid_function(x): return 1.0 / (1.0 + np.exp(-x)) -def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): +def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip): + add_output = None + if has_skip: + input_x = input_x + skip_x + bias_x + add_output = input_x n, h, w, c = input_x.shape input_x = input_x.transpose([0, 3, 1, 2]) assert c % num_groups == 0 @@ -45,46 +49,82 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): x = x.transpose([0, 2, 3, 1]) x = x * gamma + beta - if with_swish: + if with_silu: x = x * sigmoid_function(x) - return x + return x, add_output -def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func): +def run_group_norm( + batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func +): np.random.seed(0) width = height input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) # the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18 - workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32) + workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32) epsilon = 1e-05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish - host_x = input_x.astype(dtype) - input_d = ke.DeviceArray(host_x) + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(np.float32) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + if has_skip: + skip_x_shape = skip_x.shape + b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels + b4 = ( + len(skip_x_shape) == 4 + and skip_x_shape[0] == batch_size + and skip_x_shape[1] == 1 + and skip_x_shape[2] == 1 + and skip_x_shape[3] == num_channels + ) + if b2 or b4: + broadcast_skip = True + channels_per_block = 0 # Compute in params initialization + + input_d = ke.DeviceArray(input_x.astype(dtype)) + skip_d = ke.DeviceArray(skip_x.astype(dtype)) + bias_d = ke.DeviceArray(bias_x.astype(dtype)) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) - y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype) + y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip) + y_ref = y_ref.astype(dtype) for impl in my_op.ListOps(): if not my_op.SelectOp(impl): @@ -95,6 +135,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: y_d.UpdateHostNumpyArray() np.testing.assert_allclose(y_ref, output_y, atol=1e-02) + if has_skip: + y_add_d_ref = y_add_d_ref.astype(dtype) + y_add_d.UpdateHostNumpyArray() + np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02) dtypes = ["float32", "float16"] @@ -102,19 +146,21 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm(sd_sizes, dtype, swish): +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [True, False]) +def test_group_norm(sd_sizes, dtype, silu, has_skip): for func in dtype_to_funcs(dtype): - run_group_norm(*sd_sizes, dtype, swish, func) + run_group_norm(*sd_sizes, dtype, silu, has_skip, func) @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm_ck(sd_sizes, dtype, swish): - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - run_group_norm(*sd_sizes, dtype, swish, ck_f_name) +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [False]) +def test_group_norm_ck(sd_sizes, dtype, silu, has_skip): + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name) @dataclass @@ -136,37 +182,67 @@ def report(self): def profile_group_norm_func( - batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func + batch_size: int, + height: int, + width: int, + num_channels: int, + num_groups: int, + dtype: str, + silu: bool, + has_skip: bool, + func, ): np.random.seed(0) input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) - workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32) + workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32) epsilon = 0.05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish + + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x) + skip_d = ke.DeviceArray(skip_x) + bias_d = ke.DeviceArray(bias_x) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) for impl in my_op.ListOps(): duration_ms = -1 @@ -181,14 +257,14 @@ def profile_group_norm_func( ) -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True): +def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func) # ck function - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name) + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name) sd_profile_sizes = [ @@ -227,7 +303,8 @@ def profile(): group.add_argument("num_channels", type=int) group.add_argument("num_groups", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--swish", action="store_true") + group.add_argument("--silu", action="store_true") + group.add_argument("--has_skip", action="store_true") group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: @@ -241,6 +318,7 @@ def profile(): args.num_channels, args.num_groups, args.dtype, - args.swish, + args.silu, + args.has_skip, args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu index 0bd47b2c0387e..6af163ab94b10 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu @@ -12,17 +12,21 @@ #include "python/tools/kernel_explorer/kernel_explorer_interface.h" namespace py = pybind11; - +using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes; namespace onnxruntime { template class GroupNormNHWC : public IKernelExplorer { public: - GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, + DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, + bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); } @@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCOp op_{}; std::string type_string_{}; @@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer { template class GroupNormNHWCStaticSelection : public IKernelExplorer { public: - GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWCStaticSelection"; } @@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; std::string type_string_{}; }; @@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { template class GroupNormNHWCTunable : public IKernelExplorer { public: - GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { params_.TuningContext()->EnableTunableOpAndTuning(); } @@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCTunableOp op_{}; }; #ifdef USE_COMPOSABLE_KERNEL -template +template class CKGroupNormNHWC : public IKernelExplorer { public: - CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { + CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer { #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL -template +template class GroupNormNHWCTriton : public IKernelExplorer { public: - GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { + GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { name_strings_.emplace_back(name); ops_.emplace_back(std::move(op)); } @@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP(name, type, threads_per_block, vec_size) \ py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &name::SetRepeats) \ .def("Profile", &name::Profile) \ .def("Run", &name::Run) \ @@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_COMMON(name, type, ...) \ py::class_>(m, name) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ .def("Profile", &type<__VA_ARGS__>::Profile) \ .def("Run", &type<__VA_ARGS__>::Run) \ @@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP_TYPED(name, type) \ REGISTER_COMMON(#name "_" #type, name, type) -#define REGISTER_CK(type, with_swish, swish_suffix) \ - REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish) +#define REGISTER_CK(type, with_silu, silu_suffix) \ + REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu) -#define REGISTER_TRITON(type, with_swish, swish_suffix) \ - REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish) +#define REGISTER_TRITON(type, with_silu, silu_suffix) \ + REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu) KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half); @@ -248,16 +270,16 @@ KE_REGISTER(m) { #ifdef USE_COMPOSABLE_KERNEL REGISTER_CK(half, false, "Pass"); - REGISTER_CK(half, true, "Swish"); + REGISTER_CK(half, true, "Silu"); REGISTER_CK(float, false, "Pass"); - REGISTER_CK(float, true, "Swish"); + REGISTER_CK(float, true, "Silu"); #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL REGISTER_TRITON(half, false, "Pass"); - REGISTER_TRITON(half, true, "Swish"); + REGISTER_TRITON(half, true, "Silu"); REGISTER_TRITON(float, false, "Pass"); - REGISTER_TRITON(float, true, "Swish"); + REGISTER_TRITON(float, true, "Silu"); #endif } diff --git a/onnxruntime/python/tools/qnn/add_trans_cast.py b/onnxruntime/python/tools/qnn/add_trans_cast.py index bd6b8701f8fb8..ced3e3519ad42 100644 --- a/onnxruntime/python/tools/qnn/add_trans_cast.py +++ b/onnxruntime/python/tools/qnn/add_trans_cast.py @@ -270,19 +270,15 @@ def main(): raise AssertionError("Error: Onnx model output: " + graph_output.name + " not exist from QNN model output.") for node in model.graph.node: - node_input_index = 0 - for node_input in node.input: + for node_input_index, node_input in enumerate(node.input): # update consumer node for graph inputs to connect to inserted node if node_input in graph_input_output_name_dic: node.input[node_input_index] = graph_input_output_name_dic[node_input] - node_input_index += 1 - node_output_index = 0 - for node_output in node.output: + for node_output_index, node_output in enumerate(node.output): # update producer node for graph outputs to connect to inserted node if node_output in graph_input_output_name_dic: node.output[node_output_index] = graph_input_output_name_dic[node_output] - node_output_index += 1 model.graph.node.extend(nodes_to_add) graph_topological_sort(model.graph) diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 77b3dce9fb004..624049b244580 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -1100,12 +1100,10 @@ def create_calibrator( calibrator = None if calibrate_method == CalibrationMethod.MinMax: # default settings for min-max algorithm - symmetric = False if "symmetric" not in extra_options else extra_options["symmetric"] - moving_average = False if "moving_average" not in extra_options else extra_options["moving_average"] - averaging_constant = 0.01 if "averaging_constant" not in extra_options else extra_options["averaging_constant"] - max_intermediate_outputs = ( - None if "max_intermediate_outputs" not in extra_options else extra_options["max_intermediate_outputs"] - ) + symmetric = extra_options.get("symmetric", False) + moving_average = extra_options.get("moving_average", False) + averaging_constant = extra_options.get("averaging_constant", 0.01) + max_intermediate_outputs = extra_options.get("max_intermediate_outputs", None) calibrator = MinMaxCalibrater( model, op_types_to_calibrate, @@ -1118,9 +1116,9 @@ def create_calibrator( ) elif calibrate_method == CalibrationMethod.Entropy: # default settings for entropy algorithm - num_bins = 128 if "num_bins" not in extra_options else extra_options["num_bins"] - num_quantized_bins = 128 if "num_quantized_bins" not in extra_options else extra_options["num_quantized_bins"] - symmetric = False if "symmetric" not in extra_options else extra_options["symmetric"] + num_bins = extra_options.get("num_bins", 128) + num_quantized_bins = extra_options.get("num_quantized_bins", 128) + symmetric = extra_options.get("symmetric", False) calibrator = EntropyCalibrater( model, op_types_to_calibrate, @@ -1132,9 +1130,9 @@ def create_calibrator( ) elif calibrate_method == CalibrationMethod.Percentile: # default settings for percentile algorithm - num_bins = 2048 if "num_bins" not in extra_options else extra_options["num_bins"] - percentile = 99.999 if "percentile" not in extra_options else extra_options["percentile"] - symmetric = True if "symmetric" not in extra_options else extra_options["symmetric"] + num_bins = extra_options.get("num_bins", 2048) + percentile = extra_options.get("percentile", 99.999) + symmetric = extra_options.get("symmetric", True) calibrator = PercentileCalibrater( model, op_types_to_calibrate, @@ -1147,8 +1145,8 @@ def create_calibrator( elif calibrate_method == CalibrationMethod.Distribution: # default settings for percentile algorithm - num_bins = 2048 if "num_bins" not in extra_options else extra_options["num_bins"] - scenario = "same" if "scenario" not in extra_options else extra_options["scenario"] + num_bins = extra_options.get("num_bins", 2048) + scenario = extra_options.get("scenario", "same") calibrator = DistributionCalibrater( model, diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py index 456a75eec2f8c..b54b421226f1a 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -86,11 +86,9 @@ def get_node_attribute(node: onnx.NodeProto, attribute_name: str): @staticmethod def input_index(node_output: str, child_node: onnx.NodeProto) -> int: - index = 0 - for input_name in child_node.input: + for index, input_name in enumerate(child_node.input): if input_name == node_output: return index - index += 1 return -1 @staticmethod diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3e9f9a6544a71..eb7bbec997d59 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -349,6 +349,10 @@ def process(self): self.int4_quant_algo() +def ort_convert_str_to_bool(value): + return value.lower() in ("true", "1") + + def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise int4 quantization for MatMul 2D weight matrices. @@ -366,7 +370,10 @@ def parse_args(): "--symmetric", required=False, default=True, - type=bool, + const=True, + nargs="?", + type=ort_convert_str_to_bool, + choices=[True, False], help="Indicate whether to quantize the model symmetrically", ) parser.add_argument( diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 898a5f70ac45e..9450426f12444 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -113,14 +113,10 @@ def __init__( "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"] ) self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] - self.is_weight_symmetric = ( - weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) - if "WeightSymmetric" not in self.extra_options - else self.extra_options["WeightSymmetric"] - ) - self.is_activation_symmetric = ( - False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] + self.is_weight_symmetric = self.extra_options.get( + "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) ) + self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False) self.min_real_range = self.extra_options.get("MinimumRealRange") self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) @@ -389,7 +385,7 @@ def add_new_nodes(self, nodes): def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( - "Please check if the model is already quantized." + "Please check if the model is already quantized. " "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." ) @@ -446,6 +442,23 @@ def is_valid_quantize_weight(self, weight_name): return False return self.parent.is_valid_quantize_weight(weight_name) + def _get_default_tensor_type(self, tensor_name): + if "DefaultTensorType" in self.extra_options: + logging.info( + "get_tensor_type returns DefaultTensorType for tensor name %r, use %d", + tensor_name, + self.extra_options["DefaultTensorType"], + ) + return self.extra_options["DefaultTensorType"] + raise RuntimeError( + f"Unable to find data type for weight_name={tensor_name!r}. " + f"shape_inference failed to return a type probably this node is " + f"from a different domain or using an input produced by such an operator. " + f"This may happen if you quantize a model already quantized. " + f"You may use extra_options `DefaultTensorType` to indicate " + f"the default weight type, usually `onnx.TensorProto.FLOAT`." + ) + def get_tensor_type(self, tensor_name, mandatory=False): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -454,11 +467,11 @@ def get_tensor_type(self, tensor_name, mandatory=False): vi = self.value_infos[tensor_name] if vi.type.HasField("tensor_type"): if mandatory and vi.type.tensor_type.elem_type == 0: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return vi.type.tensor_type.elem_type if (not self.enable_subgraph_quantization) or (self.parent is None): if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None otype = self.parent.is_valid_quantize_weight(tensor_name) if otype is not None: @@ -468,7 +481,7 @@ def get_tensor_type(self, tensor_name, mandatory=False): if res is not None: return res if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None def is_float_tensor(self, tensor_name): @@ -1336,9 +1349,15 @@ def _dequantize_value(self, value_name): if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names): quantized_value = self.quantized_value_map[value_name] # Add DequantizeLinear Node for this input + scale_init = find_by_name(quantized_value.scale_name, self.model.initializer()) - # axis is not specified so scale_init must be a scalar. - assert onnx.numpy_helper.to_array(scale_init).size == 1 + + # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done. + if self.model.model.producer_name != "onnx-quantizer" or ( + self.model.model.producer_name == "onnx-quantizer" and scale_init is not None + ): + # axis is not specified so scale_init must be a scalar. + assert onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 123cfe913d6e2..775a3e8b8b588 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -87,40 +87,28 @@ def __init__( # because those ops may be followed by nodes that require high resolution inputs. # Adding QDQ for those ops' output may end up with worse accuracy. # So, we don't recommend to add QDQ to node's output under such condition. - self.op_types_to_exclude_output_quantization = ( - [] - if "OpTypesToExcludeOutputQuantization" not in extra_options - else extra_options["OpTypesToExcludeOutputQuantization"] - ) + self.op_types_to_exclude_output_quantization = extra_options.get("OpTypesToExcludeOutputQuantization", []) # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization. # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair. # Therefore, we need to disable this optimization and add qdq pair to weight. - self.add_qdq_pair_to_weight = ( - False if "AddQDQPairToWeight" not in extra_options else extra_options["AddQDQPairToWeight"] - ) + self.add_qdq_pair_to_weight = extra_options.get("AddQDQPairToWeight", False) # Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training, # quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in # floating point format. To that end, we can use the FakeQuant operator for weights and activations that # can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use # FakeQuant because it only ever appears before a DQ (since it is quantized as int32). - self.quantize_bias = True if "QuantizeBias" not in extra_options else extra_options["QuantizeBias"] + self.quantize_bias = extra_options.get("QuantizeBias", True) # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. - self.dedicated_qdq_pair = ( - False if "DedicatedQDQPair" not in extra_options else extra_options["DedicatedQDQPair"] - ) + self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. - self.qdq_op_type_per_channel_support_to_axis = ( - {} - if "QDQOpTypePerChannelSupportToAxis" not in extra_options - else extra_options["QDQOpTypePerChannelSupportToAxis"] - ) + self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9823e8264e17b..4b56bc1e8d828 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -205,6 +205,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFastGelu": self._infer_GemmFastGelu, "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "GroupQueryAttention": self._infer_GroupQueryAttention, "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, @@ -212,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "PagedAttention": self._infer_PagedAttention, "PythonOp": self._infer_PythonOp, "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, @@ -240,6 +242,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} @@ -347,7 +350,7 @@ def _merge_symbols(self, dims): return None if all([d == dims[0] for d in dims]): return dims[0] - merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims] + merged = [self.suggested_merge_.get(d, d) for d in dims] if all([d == merged[0] for d in merged]): assert merged[0] in self.symbolic_dims_ return merged[0] @@ -468,9 +471,11 @@ def _onnx_infer_single_node(self, node): "SkipLayerNormalization", "SkipSimplifiedLayerNormalization", "PackedAttention", + "PagedAttention", "PythonOp", "MultiHeadAttention", "GroupNorm", + "GroupQueryAttention", "SkipGroupNorm", "BiasSplitGelu", "BiasAdd", @@ -819,17 +824,21 @@ def _infer_ArrayFeatureExtractor(self, node): # noqa: N802 def _infer_symbolic_compute_ops(self, node): funcs = { "Add": lambda l: l[0] + l[1], # noqa: E741 - "Div": lambda l: int(l[0] // l[1]) # noqa: E741 - if isinstance(l[0] // l[1], float) - else l[0] // l[1], # integer div in sympy + "Div": lambda l: ( # noqa: E741 + int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] + ), # integer div in sympy "Equal": lambda l: l[0] == l[1], # noqa: E741 "Floor": lambda l: sympy.floor(l[0]), # noqa: E741 - "Max": lambda l: l[1] # noqa: E741 - if is_literal(l[0]) and int(l[0]) < -self.int_max_ - else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), - "Min": lambda l: l[1] # noqa: E741 - if is_literal(l[0]) and int(l[0]) > self.int_max_ - else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), + "Max": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) < -self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) + ), + "Min": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) > self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) + ), "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], # noqa: E741 "Sub": lambda l: l[0] - l[1], # noqa: E741 "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741 @@ -1471,9 +1480,11 @@ def _infer_aten_group_norm(self, node): output_dtype, [ N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)), - as_scalar(group) - if group is not None - else str(self._new_symbolic_dim_from_output(node, i, 1)), + ( + as_scalar(group) + if group is not None + else str(self._new_symbolic_dim_from_output(node, i, 1)) + ), ], ) ) @@ -2409,6 +2420,35 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_PagedAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + + def _infer_GroupQueryAttention(self, node): # noqa: N802 + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + past_shape = self._try_get_shape(node, 3) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + if node.input[1] != "" and node.input[2] != "": + self._propagate_shape_and_type(node, 0, 0) + else: + # combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + assert node.input[1] == "" and node.input[2] == "" + num_heads = get_attribute(node, "num_heads") + kv_num_heads = get_attribute(node, "kv_num_heads") + query_shape = self._get_shape(node, 0) + if query_shape is not None: + hidden_size = query_shape[2] + if isinstance(hidden_size, int): + head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) + query_shape[2] = num_heads * head_size + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) + def _infer_SkipGroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node, 0, 0) if len(node.output) > 1: diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index b33491b356e86..20bb8a71dc35f 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -1575,15 +1575,13 @@ def output_metrics(model_to_metrics, csv_filename): for value in results: row = [ value["model_name"], - value["ratio_of_ops_in_cuda_not_fallback_cpu"] - if "ratio_of_ops_in_cuda_not_fallback_cpu" in value - else " ", - value["total_ops_in_trt"] if "total_ops_in_trt" in value else " ", - value["total_ops"] if "total_ops" in value else " ", - value["ratio_of_ops_in_trt"] if "ratio_of_ops_in_trt" in value else " ", - value["total_trt_execution_time"] if "total_trt_execution_time" in value else " ", - value["total_execution_time"] if "total_execution_time" in value else " ", - value["ratio_of_execution_time_in_trt"] if "ratio_of_execution_time_in_trt" in value else " ", + value.get("ratio_of_ops_in_cuda_not_fallback_cpu", " "), + value.get("total_ops_in_trt", " "), + value.get("total_ops", " "), + value.get("ratio_of_ops_in_trt", " "), + value.get("total_trt_execution_time", " "), + value.get("total_execution_time", " "), + value.get("ratio_of_execution_time_in_trt", " "), ] csv_writer.writerow(row) diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index b98aafc27579a..2ae64a72d08fe 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -45,7 +45,7 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: :return: A list of common 'docker build' arguments. """ - return [ + command = [ "--no-cache", "-t", f"{args.image_name}", @@ -54,6 +54,14 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: "--build-arg", f"ONNXRUNTIME_BRANCH={args.branch}", ] + if args.use_tensorrt_oss_parser: + command.extend( + [ + "--build-arg", + "PARSER_CONFIG=--use_tensorrt_oss_parser", + ] + ) + return command def is_valid_ver_str(version: str, min_comps: int = 0, max_comps: int = 0) -> bool: @@ -187,7 +195,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("-r", "--repo_path", required=True, help="Path to the onnxruntime repository") parser.add_argument("-i", "--image_name", required=True, help="The resulting Docker image name") parser.add_argument("-b", "--branch", default="main", help="Name of the onnxruntime git branch to checkout") - parser.add_argument("-t", "--trt_version", default="8.4.1.5", help="TensorRT version (e.g., 8.4.1.5)") + parser.add_argument("-t", "--trt_version", default="8.6.1.6", help="TensorRT version (e.g., 8.6.1.6)") parser.add_argument("-a", "--cuda_arch", default="75", help="CUDA architecture (e.g., 75)") # Command-line options for installing TensorRT from binaries. @@ -208,6 +216,12 @@ def parse_arguments() -> argparse.Namespace: help="CUDA version (e.g., 8.6) used to find TensorRT EA binary tar.gz package", ) parser.add_argument("--trt_bins_dir", default="", help="Directory containing TensorRT tar.gz package") + parser.add_argument( + "--use_tensorrt_oss_parser", + action="store_true", + default=False, + help="Use TensorRT OSS Parser", + ) return parser.parse_args() diff --git a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py index 6e20071683d90..c7d4a7836132a 100755 --- a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py @@ -13,6 +13,12 @@ def parse_arguments(): parser.add_argument("-b", "--branch", required=False, default="master", help="Github branch to test perf off of") parser.add_argument("-s", "--save", required=False, help="Directory to archive wheel file") parser.add_argument("-a", "--use_archived", required=False, help="Archived wheel file") + parser.add_argument( + "--use_tensorrt_oss_parser", + action="store_true", + default=False, + help="Use TensorRT OSS Parser", + ) args = parser.parse_args() return args @@ -35,14 +41,14 @@ def install_new_ort_wheel(ort_master_path): def main(): args = parse_arguments() - cmake_tar = "cmake-3.18.4-Linux-x86_64.tar.gz" + cmake_tar = "cmake-3.28.3-linux-x86_64.tar.gz" if not os.path.exists(cmake_tar): - subprocess.run(["wget", "-c", "https://cmake.org/files/v3.18/" + cmake_tar], check=True) + subprocess.run(["wget", "-c", "https://cmake.org/files/v3.28/" + cmake_tar], check=True) tar = tarfile.open(cmake_tar) tar.extractall() tar.close() - os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.18.4-Linux-x86_64"), "bin") + ":" + os.environ["PATH"] + os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.28.3-linux-x86_64"), "bin") + ":" + os.environ["PATH"] os.environ["CUDACXX"] = os.path.join(args.cuda_home, "bin", "nvcc") ort_master_path = args.ort_master_path @@ -57,24 +63,24 @@ def main(): subprocess.run(["git", "fetch"], check=True) subprocess.run(["git", "checkout", args.branch], check=True) subprocess.run(["git", "pull", "origin", args.branch], check=True) - subprocess.run( - [ - "./build.sh", - "--config", - "Release", - "--use_tensorrt", - "--tensorrt_home", - args.tensorrt_home, - "--cuda_home", - args.cuda_home, - "--cudnn", - "/usr/lib/x86_64-linux-gnu", - "--build_wheel", - "--skip_tests", - "--parallel", - ], - check=True, - ) + command = [ + "./build.sh", + "--config", + "Release", + "--use_tensorrt", + "--tensorrt_home", + args.tensorrt_home, + "--cuda_home", + args.cuda_home, + "--cudnn", + "/usr/lib/x86_64-linux-gnu", + "--build_wheel", + "--skip_tests", + "--parallel", + ] + if args.use_tensorrt_oss_parser: + command.append("--use_tensorrt_oss_parser") + subprocess.run(command, check=True) ort_wheel_file = install_new_ort_wheel(ort_master_path) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index f506516442b1e..89f9947688583 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -36,6 +36,8 @@ python benchmark.py -e torchscript onnxruntime -p "int8" -o Run OnnxRuntime with the ROCM provider and graph optimization script: python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm + Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support: + python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm It is recommended to use run_benchmark.sh to launch benchmark. """ @@ -106,6 +108,7 @@ def run_onnxruntime( use_raw_attention_mask, model_fusion_statistics, model_source, + enable_arm64_bfloat16_fastmath_mlas_gemm, args, ): import onnxruntime @@ -209,6 +212,7 @@ def run_onnxruntime( enable_all_optimization=True, num_threads=num_threads, verbose=verbose, + enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm, ) if ort_session is None: continue @@ -344,9 +348,7 @@ def run_pytorch( else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) logger.debug(f"Model {model}") logger.debug(f"Number of parameters {model.num_parameters()}") @@ -498,9 +500,7 @@ def run_tensorflow( tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) for batch_size in batch_sizes: if batch_size <= 0: @@ -764,6 +764,14 @@ def parse_arguments(): help="Manually set the model's layer number", ) + parser.add_argument( + "--enable_arm64_bfloat16_fastmath_mlas_gemm", + required=False, + action="store_true", + help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ", + ) + parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False) + FusionOptions.add_arguments(parser) args = parser.parse_args() @@ -909,6 +917,7 @@ def main(): use_raw_attention_mask, model_fusion_statistics, args.model_source, + args.enable_arm64_bfloat16_fastmath_mlas_gemm, args, ) except Exception: diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index b6f7a44450c62..c7d93470a729e 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -85,6 +85,7 @@ def create_onnxruntime_session( num_threads=-1, enable_profiling=False, verbose=False, + enable_mlas_gemm_fastmath_arm64_bfloat16=False, provider_options={}, # map execution provider name to its option # noqa: B006 ): session = None @@ -136,6 +137,9 @@ def create_onnxruntime_session( if provider_options: providers = [(name, provider_options[name]) if name in provider_options else name for name in providers] + if enable_mlas_gemm_fastmath_arm64_bfloat16: + sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1") + session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers) except Exception: logger.error("Exception", exc_info=True) @@ -341,11 +345,7 @@ def inference_ort_with_io_binding( # Bind inputs to device for name in ort_inputs: np_input = torch.from_numpy(ort_inputs[name]).to(device) - input_type = ( - IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] - if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP - else data_type - ) + input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type) io_binding.bind_input( name, np_input.device.type, diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index 84ecae1907cd3..aa82e047df328 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -174,12 +174,10 @@ def output_test_data(directory: str, inputs: Dict[str, np.ndarray]): else: print("Warning: directory %s existed. Files will be overwritten." % directory) - index = 0 - for name, data in inputs.items(): + for index, (name, data) in enumerate(inputs.items()): tensor = numpy_helper.from_array(data, name) with open(os.path.join(directory, f"input_{index}.pb"), "wb") as file: file.write(tensor.SerializeToString()) - index += 1 def fake_test_data( diff --git a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py new file mode 100644 index 0000000000000..9a66afe3ad4f9 --- /dev/null +++ b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +import onnx + + +class DynamoOnnxHelper: + """ + Helper class for processing ONNX models exported by torch Dynamo. + """ + + def __init__(self, model: onnx.ModelProto): + self.model = model + + def update_edges(self, edge_mapping: dict) -> None: + """ + Updates the edges in the model according to the given mapping. + """ + for node in self.model.graph.node: + for i in range(len(node.input)): + if node.input[i] in edge_mapping: + node.input[i] = edge_mapping[node.input[i]] + for i in range(len(node.output)): + if node.output[i] in edge_mapping: + node.output[i] = edge_mapping[node.output[i]] + + for graph_input in self.model.graph.input: + if graph_input.name in edge_mapping: + graph_input.name = edge_mapping[graph_input.name] + for graph_output in self.model.graph.output: + if graph_output.name in edge_mapping: + graph_output.name = edge_mapping[graph_output.name] + + def unroll_function(self, func_name: str) -> None: + """ + Unrolls the function with the given name in the model. + """ + logging.info(f"Unrolling function {func_name}...") + nodes_to_remove = [] + nodes_to_add = [] + edges_to_remove = [] + edges_to_add = [] + for node in self.model.graph.node: + if node.op_type == func_name: + nodes_to_remove.append(node) + edges_to_remove.extend(list(node.input) + list(node.output)) + + func_to_remove = None + for f in self.model.functions: + if f.name == func_name: + nodes_to_add.extend(list(f.node)) + edges_to_add.extend(list(f.input) + list(f.output)) + func_to_remove = f + + assert len(edges_to_remove) == len(edges_to_add) + + for node in nodes_to_remove: + self.model.graph.node.remove(node) + for node in nodes_to_add: + self.model.graph.node.append(node) + if func_to_remove is not None: + self.model.functions.remove(func_to_remove) + + edge_mapping = {} + for i in range(len(edges_to_remove)): + k = edges_to_remove[i] + v = edges_to_add[i] + if k != v: + edge_mapping[k] = v + + return self.update_edges(edge_mapping) + + def remove_function(self, func_name: str, input_id: int, output_id: int) -> None: + """ + Removes the function in the model. + """ + edge_mapping = {} + nodes_to_remove = [] + for node in self.model.graph.node: + if node.op_type.find(func_name) != -1: + edge_mapping[node.input[input_id]] = node.output[output_id] + nodes_to_remove.append(node) + for node in nodes_to_remove: + self.model.graph.node.remove(node) + + self.update_edges(edge_mapping) + + def remove_dropout_layer(self) -> None: + """ + Removes the dropout layer in the model. + """ + logging.info("Removing dropout layer...") + self.remove_function("Dropout", 0, 0) + + def remove_lm_head_layer(self) -> None: + """ + Removes the LM head layer in the model. + """ + logging.info("Removing LM head layer...") + # bugbug: need to copy the right vi over + self.remove_function("Linear_lm_head", 2, 0) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index f680a15fc2c1b..48c79b1d5fa0f 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -174,6 +174,7 @@ def convert_float_to_float16( node_block_list=None, force_fp16_initializers=False, force_fp16_inputs=None, + use_bfloat16_as_blocked_nodes_dtype=False, ): """Convert tensor float type in the input ONNX model to tensor float16. @@ -436,6 +437,7 @@ def convert_float_to_float16( node.input[i] = output_name break + accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT # process the nodes in block list that doesn't support tensor(float16) for node in node_list: # if input's name is in the value_info_list meaning input is tensor(float16) type, @@ -450,10 +452,10 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) output_name = node.name + "_input_cast_" + str(i) new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = accuracy_type # add Cast node (from tensor(float16) to tensor(float) before current node node_name = node.name + "_input_cast" + str(i) - new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] + new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)] model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name @@ -469,7 +471,7 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) input_name = node.name + "_output_cast_" + str(i) new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = accuracy_type # add Cast node (from tensor(float) to tensor(float16) after current node node_name = node.name + "_output_cast" + str(i) new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)] diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 71801401e9d06..ebecc1db24792 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -74,13 +74,74 @@ def check_runtime_shape_path( return True + def check_runtime_shape_path_openai( + self, + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ): + reshape_qkv_2_path = self.model.match_parent_path( + reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0] + ) + if reshape_qkv_2_path is None: + return False + else: + if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]: + return False + + matmul_qk_path_1 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] + ) + matmul_qk_path_2 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] + ) + if matmul_qk_path_1 is None or matmul_qk_path_2 is None: + return False + + mul_1 = matmul_qk_path_1[0] + mul_2 = matmul_qk_path_2[0] + if mul_1.input[1] != mul_2.input[1]: + return False + if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: + return False + + # For decoder attentions only + if add_qk is not None: + add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) + if add_qk_path is None: + return False + slice_q_path_1 = self.model.match_parent_path( + add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] + ) + slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if slice_q_path_1 is None and slice_q_path_2 is None: + return False + _, unsqueeze_1, _, _ = slice_q_path_1 + unsqueeze_2, _, _ = slice_q_path_2 + if unsqueeze_1.input[0] != unsqueeze_2.input[0]: + return False + if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: + return False + + return True + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # Track if fusion is occurring for OpenAI implementation of Whisper + model_impl_openai = False + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0, 0, 0], ) + qkv_nodes_openai = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) if qkv_nodes is not None: ( add_out, @@ -90,6 +151,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): reshape_qkv_1, matmul_qkv, ) = qkv_nodes + elif qkv_nodes_openai is not None: + qkv_nodes = qkv_nodes_openai + ( + add_out, + matmul_out, + reshape_qkv_2, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + # Set model implementation to openai + model_impl_openai = True else: return @@ -137,6 +209,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None], ) + v_nodes_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, None], + ) v_nodes_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past value concatenated before MatMul matmul_qkv, @@ -149,12 +226,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape"], [1], ) + v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0], + ) past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) present_v = transpose_v.output[0] + elif v_nodes_openai is not None: + v_nodes = v_nodes_openai + (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + + # Find the child path to access the correct present_v values + # Openai impl provides present/past v values in 3D format + # whereas ort MultiHeadAttention expects v values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Add -> Reshape -> Transpose -> Present_V + reshape_path = self.model.match_child_path( + add_v, + ["Reshape", "Transpose"], + exclude=[reshape_v_1], + ) + # For decoder attention types + # add_v_node Reshape <- Transpose <-Past_V + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_V + concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_add_v) = reshape_path + if transpose_add_v.output[0] in graph_output_names: + present_v = transpose_add_v.output[0] + if concat_path is not None: + (concat_v, _, transpose_concat_v) = concat_path + if transpose_concat_v.output[0] in graph_output_names: + present_v = transpose_concat_v.output[0] + concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_v_in = concat_nodes + past_v = transpose_concat_v_in.input[0] elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn v_nodes = v_nodes_with_past_self_attn @@ -171,6 +288,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) ) present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" + elif ( + v_nodes_with_past_cross_attn_openai is not None + and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + v_nodes = v_nodes_with_past_cross_attn_openai + past_v = v_nodes[-1].input[0] + present_v = v_nodes[-1].output[0] + if present_v not in graph_output_names: + identity_node_v = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) + ) + present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return @@ -181,12 +310,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] ) + qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + add_qk = None if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: _, _, add_qk, _, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 + elif qk_nodes_2_openai is not None: + _, add_qk, matmul_qk = qk_nodes_2_openai + qk_nodes = qk_nodes_2_openai else: return @@ -195,8 +329,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, 0, 1], ) + q_nodes_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + reshape_q_2 = None if q_nodes is not None: reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes + elif q_nodes_openai is not None: + q_nodes = q_nodes_openai + mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return @@ -205,6 +348,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, 1], ) + k_nodes_with_bias_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], + ) k_nodes_no_bias = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], @@ -222,11 +370,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape"], [1, 0], ) + k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( + # Decoder attention with past key directly used in MatMul + matmul_qk, + ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias + elif k_nodes_with_bias_openai is not None: + mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai + k_nodes = k_nodes_with_bias_openai + present_k = matmul_k.output[0] + + # Find the child path to access the correct present_k values + # Openai impl provides present/past k values in 3D format + # whereas ort MultiHeadAttention expects k values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Matmul -> Reshape -> Transpose -> Present_K + reshape_path = self.model.match_child_path( + matmul_k, + ["Reshape", "Transpose"], + exclude=[reshape_k_1], + ) + # For decoder attention types + # matmul_k_node Reshape <- Transpose <- Past_K + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_K + concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_matmul_k) = reshape_path + if transpose_matmul_k.output[0] in graph_output_names: + present_k = transpose_matmul_k.output[0] + if concat_path is not None: + (concat_k, _, transpose_concat_k) = concat_path + if transpose_concat_k.output[0] in graph_output_names: + present_k = transpose_concat_k.output[0] + concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_k_in = concat_nodes + past_k = transpose_concat_k_in.input[0] elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias @@ -249,12 +438,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) ) present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" + elif ( + k_nodes_no_bias_with_past_cross_attn_openai is not None + and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + k_nodes = k_nodes_no_bias_with_past_cross_attn_openai + past_k = k_nodes[-1].input[0] + present_k = k_nodes[-1].output[0] + if present_k not in graph_output_names: + identity_node_k = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) + ) + present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): # Create empty Add node for attention graph bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] empty_bias_name = "empty_bias" @@ -270,13 +471,29 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) - if not past_k and not self.check_runtime_shape_path( - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, + if ( + model_impl_openai + and not past_k + and not self.check_runtime_shape_path_openai( + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ) + ): + return + elif ( + not model_impl_openai + and not past_k + and not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ) ): return @@ -301,8 +518,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 - decoder_attention = one_root_input and qk_nodes == qk_nodes_2 - decoder_attention_with_past = encoder_attention and past_k and past_v + decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) + decoder_attention_with_past = ( + (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v + ) decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index b9b92d2fe8a00..edac1989e4e9e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from argparse import ArgumentParser +from enum import Enum class AttentionMaskFormat: @@ -19,6 +20,23 @@ class AttentionMaskFormat: NoMask = 3 +class AttentionOpType(Enum): + Attention = "Attention" + MultiHeadAttention = "MultiHeadAttention" + GroupQueryAttention = "GroupQueryAttention" + PagedAttention = "PagedAttention" + + def __str__(self): + return self.value + + # Override __eq__ to return string comparison + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + + class FusionOptions: """Options of fusion in graph optimization""" @@ -57,6 +75,8 @@ def __init__(self, model_type): elif model_type == "vit": self.attention_mask_format = AttentionMaskFormat.NoMask + self.attention_op_type = None + # options for stable diffusion if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True @@ -76,6 +96,9 @@ def use_raw_attention_mask(self, use_raw_mask=True): def disable_attention_mask(self): self.attention_mask_format = AttentionMaskFormat.NoMask + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attention_op_type = attn_op_type + @staticmethod def parse(args): options = FusionOptions(args.model_type) diff --git a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py index df80acbd97807..676052f747967 100644 --- a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py @@ -147,7 +147,7 @@ def match_bias_path(self, node, input_name_to_nodes, output_name_to_node): def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node): """Match whether an output is from a Transpose(perm=[0,3,1,2]) node.""" - parent = output_name_to_node[output_name] if output_name in output_name_to_node else None + parent = output_name_to_node.get(output_name, None) if parent is not None and parent.op_type == "Transpose": permutation = OnnxModel.get_node_attribute(parent, "perm") if permutation == [0, 3, 1, 2]: diff --git a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py index 6089c960e47ee..8797fd9c2cfaf 100644 --- a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py +++ b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py @@ -193,7 +193,7 @@ def output_summary(results: List[Dict[str, Any]], csv_filename: str, metric_name if row: for key in key_names: - row[key] = values[key] if key in values else "" + row[key] = values.get(key, "") csv_writer.writerow(row) csv_file.flush() diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py index a1e6d3125e7fb..4823f0d5874dd 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py @@ -171,12 +171,10 @@ def print_wins(wins, rows, test_name): rank = 0 previous_value = -1 - count = 0 - for key, value in sorted_wins.items(): + for count, (key, value) in enumerate(sorted_wins.items()): if value != previous_value: rank = count previous_value = value - count += 1 for row in rows: if row["run_id"] == key: diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py index 12700f00ad0c2..f4705bef6a988 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py @@ -387,8 +387,8 @@ def test_generation( if i % 10 == 0: print(f"{i}") input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] if "position_ids" in inputs else None - attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None + position_ids = inputs.get("position_ids", None) + attention_mask = inputs.get("attention_mask", None) onnx_runner = Gpt2Tester( input_ids, diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index c9a679c4eac8a..51a967cf22608 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -289,9 +289,7 @@ def inference(): def load_torch_model(model_name, device): - torch_model_name_or_dir = ( - PRETRAINED_LONGFORMER_MODELS[model_name] if model_name in PRETRAINED_LONGFORMER_MODELS else model_name - ) + torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name) model = LongformerModel.from_pretrained(torch_model_name_or_dir) model.to(device) return model diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md new file mode 100644 index 0000000000000..da62bba0f02fb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -0,0 +1,120 @@ +# Phi2 Optimizations +## Prerequisites +A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\ +Install onnx, onnxscript and transformers by running +```bash +pip install -r requirements.txt +``` +To export ONNX, PyTorch version 2.2.0 or higher is required. The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs. +\ +\ +**There are two options to run the conversion script:**\ +_From source:_ +```bash +# Default onnxruntime package is built with CUDA 11.8. For CUDA 12.x, refer to https://onnxruntime.ai/docs/install/#python-installs +pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu +git clone git@github.com:microsoft/onnxruntime.git +cd onnxruntime/onnxruntime/python/tools/transformers +python -m models.phi2.convert_to_onnx -h +``` +_From wheel:_ \ +Install [ORT nightly package](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages) +```bash +python -m onnxruntime.transformers.models.phi2.convert_to_onnx -h +``` + +## Export optimized phi2 onnx model for different scenarios +**Export FP32 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_gpu +``` +\ +**Export FP16 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu +``` +\ +**Export INT4 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu +``` +\ +**Export FP16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x +``` +\ +**Export INT4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu_sm8x +``` +\ +**Export FP32 ONNX model for CPU** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_cpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu +``` +\ +**Export INT4 ONNX model for CPU** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_cpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_cpu +``` +\ +**Export all at once** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x +``` +## Run example with ORT +**(e.g) Export FP16 and INT4 ONNX models for Nvidia GPUs with CUDA architecture SM=80~89 and run examples.** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example +``` +The inference example currently supports all models running on CUDA. + +## Limitations +- TorchDynamo-based ONNX Exporter only supports Linux. +- The program may not run as expected if the machine has limited memory. e.g Dynamo export may use ~11.6GB; Optimization may use ~4.5GB for each. diff --git a/onnxruntime/python/tools/transformers/models/phi2/__init__.py b/onnxruntime/python/tools/transformers/models/phi2/__init__.py new file mode 100644 index 0000000000000..e80f36a391fe1 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + +transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) +if transformers_dir not in sys.path: + sys.path.append(transformers_dir) diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py new file mode 100644 index 0000000000000..796d6ec55ef80 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -0,0 +1,499 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import argparse +import logging +import os +from pathlib import Path + +import onnx +import torch +from benchmark_helper import Precision +from fusion_options import AttentionOpType +from transformers import AutoConfig, AutoModelForCausalLM + +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer + + +class ConvertPhi2ToONNX: + def __init__( + self, + device: torch.device, + model_class: str = "microsoft/phi-2", + cache_dir: str = "./cache", + ): + self.model_class = model_class + self.device = device + self.cache_dir = cache_dir + self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir) + self.phi_model = None + self.batch_size = 2 + self.sequence_length = 8 + self.attn_op_type = None + self.precision = None + self.block_size = 16 + self.accuracy_level = None + + def set_quantization_params(self, block_size: int, accuracy_level: int | None): + self.block_size = block_size + self.accuracy_level = accuracy_level + + def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision): + self.attn_op_type = attn_op_type + self.precision = precision + + def erase_onnx_model(self, onnx_path: str) -> None: + assert onnx_path.endswith(".onnx") + if not os.path.exists(onnx_path): + return + + model = onnx.load_model(onnx_path, load_external_data=False) + onnx_data_path = None + for initializer in model.graph.initializer: + if initializer.data_location == 1 and initializer.external_data[0].key == "location": + onnx_data_path = "./" + initializer.external_data[0].value + break + logging.info(f"Erasing {onnx_path}...") + os.remove(onnx_path) + if onnx_data_path is not None: + onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path) + logging.info(f"Erasing {onnx_data_path}...") + os.remove(onnx_data_path) + + def get_phi2_torch_model(self): + logging.info("Loading phi2 torch model...") + if self.phi_model is not None: + return + self.phi_model = AutoModelForCausalLM.from_pretrained( + self.model_class, trust_remote_code=True, cache_dir=self.cache_dir + ) + self.phi_model.eval() + self.phi_model.to(self.device) + + def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int): + input_ids = torch.randint( + low=0, + high=self.phi_config.vocab_size, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=self.device, + ) + self.get_phi2_torch_model() + torch_inputs = self.phi_model.prepare_inputs_for_generation( + input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"] + ) + return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"] + + def dynamo_export(self, onnx_path: str): + input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length) + self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values) + + from torch._dynamo import config + + config.capture_scalar_outputs = True + + logging.info("Exporting Phi2 torch model to ONNX...") + torch.onnx.dynamo_export( + self.phi_model, + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ).save(onnx_path) + onnx.checker.check_model(onnx_path) + onnx.shape_inference.infer_shapes_path(onnx_path) + + def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): + from fusion_options import FusionOptions + from optimizer import optimize_model + + optimization_options = FusionOptions("phi") + optimization_options.set_attention_op_type(self.attn_op_type) + optimizer = optimize_model( + onnx_path, + model_type="phi", + num_heads=self.phi_config.num_attention_heads, + hidden_size=self.phi_config.hidden_size, + opt_level=0, + optimization_options=optimization_options, + only_onnxruntime=False, + ) + + fused_op_count = optimizer.get_fused_operator_statistics() + if optimizer.is_fully_optimized(fused_op_count): + logging.info("Model is fully optimized.") + else: + logging.info("Model is not fully optimized.") + + if self.precision == Precision.FLOAT32: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + + if ( + self.precision == Precision.FLOAT16 or self.precision == Precision.INT4 + ) and self.attn_op_type != AttentionOpType.MultiHeadAttention: + # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. + node_block_list = ( + [ + "Attention_29", + "Attention_30", + "Attention_31", + ] + if self.attn_op_type != AttentionOpType.PagedAttention + else [] + ) # TODO: temp setting for paged attention + logging.info("Converting onnx model to float16/bfloat16...") + optimizer.convert_float_to_float16( + keep_io_types=False, + node_block_list=node_block_list, + use_symbolic_shape_infer=True, + use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention, + ) + logging.info("Converting onnx model to float16/bfloat16 done.") + + if self.precision == Precision.FLOAT16: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + else: + assert self.precision == Precision.INT4 + quant = MatMul4BitsQuantizer( + model=optimizer.model, + block_size=self.block_size, + is_symmetric=True, + accuracy_level=self.accuracy_level, + ) + quant.process() + quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--fp32_cpu", + required=False, + action="store_true", + help="Generate fp32 ONNX model for CPU", + ) + + parser.add_argument( + "--int4_cpu", + required=False, + action="store_true", + help="Generate int4 ONNX model for CPU", + ) + + parser.add_argument( + "--fp32_gpu", + required=False, + action="store_true", + help="Generate fp32 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu", + required=False, + action="store_true", + help="Generate fp16 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--int4_gpu", + required=False, + action="store_true", + help="Generate int4 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu_sm8x", + required=False, + action="store_true", + help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", + ) + + parser.add_argument( + "--int4_gpu_sm8x", + required=False, + action="store_true", + help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", + ) + + parser.add_argument( + "--fp16_vllm", + required=False, + action="store_true", + help="Generate fp16 ONNX model for ORT VLLM", + ) + + parser.add_argument( + "--int4_vllm", + required=False, + action="store_true", + help="Generate int4 ONNX model for ORT VLLM", + ) + + parser.add_argument( + "--overwrite", + required=False, + action="store_true", + help="Overwrite existing ONNX models", + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./cache", + help="The cache directory for the pytorch model", + ) + + parser.add_argument( + "--device_id", + required=False, + type=int, + default=0, + help="The device id for the pytorch model", + ) + + parser.add_argument( + "--run_example", + required=False, + action="store_true", + help="Run ORT inference example", + ) + + parser.add_argument( + "--skip_export", + required=False, + action="store_true", + help="Skip exporting ONNX model", + ) + + parser.add_argument( + "--output_dir", + type=str, + help="The output directory for the ONNX models", + default="phi2_onnx_models", + ) + + parser.add_argument( + "--block_size", + required=False, + default=16, + type=int, + help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", + ) + + parser.add_argument( + "--int4_accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + + device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu") + + converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir) + converter.set_quantization_params(args.block_size, args.int4_accuracy_level) + + output_dir = args.output_dir + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + original_onnx_path = os.path.join(output_dir, "phi2_original.onnx") + + if not args.skip_export: + if not os.path.exists(original_onnx_path) or args.overwrite: + converter.dynamo_export(original_onnx_path) + + model_type_to_args = { + "fp32_cpu": ( + AttentionOpType.MultiHeadAttention, + Precision.FLOAT32, + os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"), + ), + "int4_cpu": ( + AttentionOpType.MultiHeadAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"), + ), + "fp32_gpu": ( + AttentionOpType.Attention, + Precision.FLOAT32, + os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"), + ), + "fp16_gpu": ( + AttentionOpType.Attention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"), + ), + "int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")), + "fp16_gpu_sm8x": ( + AttentionOpType.GroupQueryAttention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"), + ), + "int4_gpu_sm8x": ( + AttentionOpType.GroupQueryAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"), + ), + "fp16_vllm": ( + AttentionOpType.PagedAttention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"), + ), + "int4_vllm": ( + AttentionOpType.PagedAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"), + ), + } + + if not args.skip_export: + from multiprocessing import Process + + def run_optimize_phi2_onnx( + converter: ConvertPhi2ToONNX, + original_onnx_path: str, + attention_type: AttentionOpType, + precision: Precision, + optimized_onnx_path: str, + ): + converter.init_attn_type_and_precision(attention_type, precision) + converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path) + + processes = [] + if args.fp32_cpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"]) + ) + ) + + if args.int4_cpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"]) + ) + ) + + if args.fp32_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"]) + ) + ) + + if args.fp16_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"]) + ) + ) + + if args.int4_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"]) + ) + ) + + if args.fp16_gpu_sm8x: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]), + ) + ) + + if args.int4_gpu_sm8x: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]), + ) + ) + + if args.fp16_vllm: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]), + ) + ) + + if args.int4_vllm: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]), + ) + ) + + [p.start() for p in processes] + [p.join() for p in processes] + + if args.run_example: + from inference_example import run_phi2 + + if args.fp16_gpu_sm8x: + logging.info("Running fp16_gpu_sm8x example...") + run_phi2( + onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2], + use_buffer_share=True, + device_id=args.device_id, + use_step=True, + ) + if args.int4_gpu_sm8x: + logging.info("Running int4_gpu_sm8x example...") + run_phi2( + onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2], + use_buffer_share=True, + device_id=args.device_id, + use_step=True, + ) + if args.fp32_gpu: + logging.info("Running fp32_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["fp32_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + use_fp16=False, + ) + if args.fp16_gpu: + logging.info("Running fp16_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["fp16_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + ) + if args.int4_gpu: + logging.info("Running int4_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["int4_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + ) + if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm: + raise NotImplementedError("CPU/vllm inference example is not implemented yet.") + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py new file mode 100644 index 0000000000000..28828ffb853cb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py @@ -0,0 +1,215 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import numpy as np +import torch +from transformers import AutoTokenizer + +import onnxruntime as ort + +pt_to_np = { + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.float32": np.float32, + "torch.float16": np.float16, +} + + +class ORTGenerator: + def __init__(self, decoder_path): + self.onnx_decoder_path = decoder_path + self.num_heads = 32 + self.head_size = 80 + self.num_layers = 32 + self.max_sequence_length = 2048 + + def get_initial_inputs_and_outputs(self, encodings_dict): + self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 + + input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) + attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) + step = torch.tensor([0], device=self.device, dtype=torch.int64) + + inputs = { + "input_ids": input_ids.contiguous(), + "attention_mask": attention_mask.contiguous(), + } + + if self.use_step: + inputs["step"] = step.contiguous() + + batch_size, sequence_length = input_ids.shape + + past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 + past_shape = ( + (2, batch_size, self.num_heads, past_seq_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, past_seq_length, self.head_size) + ) + for i in range(self.num_layers): + past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) + inputs.update( + {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} + ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + + logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) + outputs = {"logits": logits.contiguous()} + + if not self.use_buffer_share: + present_shape = ( + (2, batch_size, self.num_heads, sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + return inputs, outputs + + def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict): + io_binding = model.io_binding() + device = None + + for k, v in inputs.items(): + io_binding.bind_input( + name=k, + device_type=v.device.type, + device_id=0 if v.device.type == "cpu" else v.device.index, + element_type=pt_to_np[repr(v.dtype)], + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + device = v.device + + for output in model.get_outputs(): + name = output.name + if self.use_buffer_share and "present" in name: + v = inputs[name.replace("present", "past")] + io_binding.bind_output( + name=name, + device_type=v.device.type, + device_id=v.device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + else: + v = outputs[name] + io_binding.bind_output( + name=name, + device_type=device.type, + device_id=0 if device.type == "cpu" else device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + + return io_binding + + def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False): + sess_options = ort.SessionOptions() + ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider" + self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) + + self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu") + self.use_fp16 = use_fp16 + self.use_buffer_share = use_buffer_share + self.packed_kv = packed_kv + self.use_step = use_step + + self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + self.tokenizer.pad_token = "[PAD]" + + def generate(self, prompt, max_length): + encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) + + inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) + + all_token_ids = inputs["input_ids"].clone() + batch_size, sequence_length = all_token_ids.shape + + current_length = sequence_length + has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + + while current_length < max_length: + io_binding = self.apply_io_binding(self.sess, inputs, outputs) + + io_binding.synchronize_inputs() + self.sess.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # Sample with argmax (greedy search) + next_token_logits = outputs["logits"][:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + + # Check if we previously reached EOS token id or if generated token id is EOS token id + has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id + + # Determine which new tokens to add to list of all token ids + # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't) + tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1]) + all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) + + # Return early if all batch entries have reached EOS token id + if torch.all(has_eos): + break + + # Update inputs for next inference run + current_length += 1 + inputs["input_ids"] = tokens_to_add.to(torch.int32) + if self.use_step: + inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) + inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to( + torch.int32 + ) + + # Set logits to zeros for next inference run and re-use memory buffer + if outputs["logits"].shape[1] != 1: + outputs["logits"] = outputs["logits"][:, :1, :].contiguous() + outputs["logits"].zero_() + + if not self.use_buffer_share: + for i in range(self.num_layers): + if not self.packed_kv: + inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"] + inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"] + else: + inputs[f"past_{i}"] = outputs[f"present_{i}"] + + new_sequence_length = inputs["attention_mask"].shape[1] + present_shape = ( + (2, batch_size, self.num_heads, new_sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, new_sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) + return texts + + +def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False): + prompt = [ + '''```python + def print_prime(n): + """ + Print all primes between 1 and n + """''' + ] + + generator = ORTGenerator(onnx_model_path) + generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step) + texts = generator.generate(prompt, max_length=200) + + for i in range(len(texts)): + print("Prompt: ", prompt[i]) + print("Texts: ", texts[i]) diff --git a/onnxruntime/python/tools/transformers/models/phi2/requirements.txt b/onnxruntime/python/tools/transformers/models/phi2/requirements.txt new file mode 100644 index 0000000000000..af6f441c149d0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/requirements.txt @@ -0,0 +1,3 @@ +onnx>=1.15.0 +transformers>=4.36.2 +onnxscript>=0.1.0.dev20240126 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index c03c6f0b21cd3..26b9a2792e9e1 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -92,7 +92,7 @@ def get_diffusers_module_name(self, model_name): "unetxl": "unet", "vae": "vae_decoder", } - return name_mapping[model_name] if model_name in name_mapping else model_name + return name_mapping.get(model_name, model_name) def get_cached_model_name(self, model_name): model_name = self.get_diffusers_module_name(model_name) diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 02100266200f8..7a678f2734ade 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,5 +1,22 @@ # Whisper +## Prerequisites + +Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running Whisper on CPU +- `requirements-cuda.txt` + - For running Whisper on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements.txt` + - Package versions needed in each of the above files + +In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers. + +- Linux: `sudo apt-get install ffmpeg` +- MacOS: `sudo brew install ffmpeg` +- Windows: Download from website + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search (using Whisper tiny as an example). @@ -10,10 +27,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as # From source $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format # From wheel -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format ``` ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) @@ -39,40 +56,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx") Here are some additional examples for exporting Whisper with beam search. +To see all available options +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx --help + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +``` + Export with Forced Decoder Input Ids ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids ``` Export + Optimize for FP32 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 ``` Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer ``` ## Benchmark Whisper diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 759ae6d14f184..e57385aa6db8f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import ast import datetime @@ -54,6 +60,8 @@ def load_via_numpy(): inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) if args.has_logits_processor: inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + if args.has_temperature: + inputs["temperature"] = np.array([args.temperature], dtype=np.float32) # Measure time taken to load audio file logger.info(f"Load audio: {args.audio_path}") @@ -163,6 +171,7 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): warmup_inputs = inputs[0] if type(inputs) is tuple else inputs benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + torch_device = torch.device(args.target_device) # Warm up warmup_range = ( @@ -180,7 +189,7 @@ def time_fn(args, fn, inputs): # Benchmark if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) start_time = time.time() bench_range = ( @@ -192,7 +201,7 @@ def time_fn(args, fn, inputs): fn(benchmark_inputs) if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line @@ -500,7 +509,13 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Whether to use timestamps logits processor or not (0 for false, 1 for true).", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature value for generation.", ) # Args for accessing detailed info @@ -581,6 +596,7 @@ def main(): args.has_audio_stream = "audio_stream" in ort_model_inputs setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010 if args.decoder_input_ids == []: args.decoder_input_ids = [config.decoder_start_token_id] diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d205a2d340721..814b0dd1ef6ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import datetime import json diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e15a12c07bed7..35211aab272e4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -28,17 +28,34 @@ def parse_arguments(argv=None): parser = argparse.ArgumentParser() - pretrained_models = PRETRAINED_WHISPER_MODELS - parser.add_argument( + conversion_args = parser.add_argument_group("Conversion Process Args") + optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)") + optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)") + quant_args = parser.add_argument_group("INT8 Quantization Args") + + ################################# + # Conversion options for Whisper + ################################# + + conversion_args.add_argument( "-m", "--model_name_or_path", required=False, default=PRETRAINED_WHISPER_MODELS[0], type=str, - help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS), + ) + + conversion_args.add_argument( + "--model_impl", + required=False, + default="hf", + choices=["hf", "openai"], + type=str, + help="Select implementation for export of encoder and decoder subgraphs", ) - parser.add_argument( + conversion_args.add_argument( "--cache_dir", required=False, type=str, @@ -46,7 +63,7 @@ def parse_arguments(argv=None): help="Directory to cache pre-trained models", ) - parser.add_argument( + conversion_args.add_argument( "--output", required=False, type=str, @@ -54,19 +71,24 @@ def parse_arguments(argv=None): help="Output directory", ) - parser.add_argument( + conversion_args.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model", ) - parser.set_defaults(optimize_onnx=False) + conversion_args.set_defaults(optimize_onnx=False) - parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") - parser.set_defaults(use_gpu=False) + conversion_args.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for model inference", + ) + conversion_args.set_defaults(use_gpu=False) - parser.add_argument( + conversion_args.add_argument( "-p", "--precision", required=False, @@ -76,221 +98,226 @@ def parse_arguments(argv=None): help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", ) - parser.add_argument("--verbose", required=False, action="store_true") - parser.set_defaults(verbose=False) - - parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") - parser.set_defaults(use_external_data_format=False) - - parser.add_argument( - "-s", - "--use_decoder_start_token", + conversion_args.add_argument( + "--use_int64_inputs", required=False, action="store_true", - help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \ - the encoder-decoder-init subgraph for decoder_input_ids.", + help="Use int64 instead of int32 for input_ids and attention_mask.", ) - parser.set_defaults(use_decoder_start_token=False) + conversion_args.set_defaults(use_int64_inputs=False) - parser.add_argument( - "-f", - "--use_forced_decoder_ids", + conversion_args.add_argument( + "--disable_auto_mixed_precision", required=False, action="store_true", - help="Use decoder_input_ids as an extra graph input to the beam search op", + help="Use pure fp16 instead of mixed precision", ) - parser.set_defaults(use_forced_decoder_ids=False) + conversion_args.set_defaults(disable_auto_mixed_precision=False) - parser.add_argument( - "-l", - "--use_logits_processor", + conversion_args.add_argument( + "-r", + "--provider", required=False, - action="store_true", - help="Use logits_processor as an extra graph input to enable specific logits processing", + type=str, + default="cpu", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CPUExecutionProvider.", ) - parser.set_defaults(use_specific_logits_processor=False) - parser.add_argument( - "-v", - "--use_vocab_mask", + conversion_args.add_argument( + "--verbose", required=False, action="store_true", - help="Use vocab_mask as an extra graph input to enable specific logits processing", + help="Enable verbose logging", ) - parser.set_defaults(use_vocab_mask=False) + conversion_args.set_defaults(verbose=False) - parser.add_argument( - "-u", - "--use_prefix_vocab_mask", + conversion_args.add_argument( + "-e", + "--use_external_data_format", required=False, action="store_true", - help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", + help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.", ) - parser.set_defaults(use_prefix_vocab_mask=False) + conversion_args.set_defaults(use_external_data_format=False) - parser.add_argument( + conversion_args.add_argument( "-w", "--overwrite", required=False, action="store_true", - help="overwrite existing ONNX model", + help="Overwrite existing ONNX model", ) - parser.set_defaults(overwrite=False) + conversion_args.set_defaults(overwrite=False) - parser.add_argument( - "--disable_auto_mixed_precision", + conversion_args.add_argument( + "--separate_encoder_and_decoder_init", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.", ) - parser.set_defaults(disable_auto_mixed_precision=False) + conversion_args.set_defaults(separate_encoder_and_decoder_init=False) - parser.add_argument( - "--separate_encoder_and_decoder_init", + conversion_args.add_argument( + "--no_beam_search_op", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + conversion_args.set_defaults(no_beam_search_op=False) - parser.add_argument( - "--use_int64_inputs", + conversion_args.add_argument( + "--state_dict_path", + type=str, + default="", + help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + + ############################################################# + # Optional inputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_inputs.add_argument( + "-v", + "--use_vocab_mask", required=False, action="store_true", - help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + help="Use vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(use_int64_inputs=False) + optional_inputs.set_defaults(use_vocab_mask=False) - parser.add_argument( - "--chain_model", + optional_inputs.add_argument( + "-u", + "--use_prefix_vocab_mask", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(chain_model=True) + optional_inputs.set_defaults(use_prefix_vocab_mask=False) - parser.add_argument( - "--use_whisper_beamsearch", + optional_inputs.add_argument( + "-f", + "--use_forced_decoder_ids", required=False, action="store_true", - help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \ - It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.", + help="Use decoder_input_ids as an extra graph input to the beam search op", ) - parser.set_defaults(use_whisper_beamsearch=False) + optional_inputs.set_defaults(use_forced_decoder_ids=False) - parser.add_argument( - "--extra_decoding_ids", + optional_inputs.add_argument( + "-l", + "--use_logits_processor", required=False, action="store_true", - help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + help="Use logits_processor as an extra graph input to enable specific logits processing", ) - parser.set_defaults(extra_decoding_ids=False) + optional_inputs.set_defaults(use_specific_logits_processor=False) - parser.add_argument( + optional_inputs.add_argument( "--collect_cross_qk", required=False, action="store_true", help="Beam search model collect stacked cross QK.", ) - parser.set_defaults(collect_cross_qk=False) + optional_inputs.set_defaults(collect_cross_qk=False) - parser.add_argument( - "--output_cross_qk", + optional_inputs.add_argument( + "--extra_decoding_ids", required=False, action="store_true", - help="Beam search model output collected qk as output. Also hint collect_cross_qk", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + ) + optional_inputs.set_defaults(extra_decoding_ids=False) + + optional_inputs.add_argument( + "-t", + "--use_temperature", + required=False, + action="store_true", + help="Use temperature as an extra graph input for the WhisperBeamSearch op", ) - parser.set_defaults(output_cross_qk=False) + optional_inputs.set_defaults(use_temperature=False) - parser.add_argument( - "--no_speech_token_id", - default=50362, + optional_inputs.add_argument( + "--no_repeat_ngram_size", type=int, - help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \ - Note that default value maybe different between the multilingual and English-only models.", + default=0, + help="default to 0", ) - parser.add_argument( - "--output_no_speech_probs", + ############################################################# + # Optional outputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_outputs.add_argument( + "--output_sequence_scores", required=False, action="store_true", - help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + help="Beam search model output scores for each generated sequence.", ) - parser.set_defaults(output_no_speech_probs=False) + optional_outputs.set_defaults(output_sequence_scores=False) - parser.add_argument( + optional_outputs.add_argument( "--output_scores", required=False, action="store_true", help="Beam search model output scores over vocab per generated token.", ) - parser.set_defaults(output_scores=False) + optional_outputs.set_defaults(output_scores=False) - parser.add_argument( - "--output_sequence_scores", + optional_outputs.add_argument( + "--output_cross_qk", required=False, action="store_true", - help="Beam search model output scores for each generated sequence.", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", ) - parser.set_defaults(output_sequence_scores=False) + optional_outputs.set_defaults(output_cross_qk=False) - parser.add_argument( + optional_outputs.add_argument( "--cross_qk_onnx_model", required=False, type=str, default=None, - help="the model which consume cross_qk.", + help="The model which consumes cross_qk outputs.", ) - parser.add_argument( - "--beam_output_model", - type=str, - default="whisper_beamsearch.onnx", - help="default name is whisper_beamsearch.onnx.", + optional_outputs.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) + optional_outputs.set_defaults(output_no_speech_probs=False) + + ################################### + # Quantization options for Whisper + ################################### - parser.add_argument( + quant_args.add_argument( "--quantize_embedding_layer", required=False, action="store_true", help="Quantize MatMul, GEMM, and Gather.", ) - parser.set_defaults(quantize_embedding_layer=False) + quant_args.set_defaults(quantize_embedding_layer=False) - parser.add_argument( + quant_args.add_argument( "--quantize_per_channel", required=False, action="store_true", help="Quantize weights per each channel.", ) - parser.set_defaults(quantize_per_channel=False) + quant_args.set_defaults(quantize_per_channel=False) - parser.add_argument( + quant_args.add_argument( "--quantize_reduce_range", required=False, action="store_true", help="Quantize weights with 7 bits.", ) - parser.set_defaults(quantize_reduce_range=False) - - parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") - - parser.add_argument( - "--state_dict_path", - type=str, - default="", - help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", - ) - - parser.add_argument( - "-r", - "--provider", - required=False, - type=str, - default="cpu", - choices=list(PROVIDERS.keys()), - help="Provider to benchmark. Default is CPUExecutionProvider.", - ) + quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk @@ -300,6 +327,7 @@ def parse_arguments(argv=None): def export_onnx_models( model_name_or_path, + model_impl, cache_dir, output_dir, use_gpu, @@ -307,7 +335,7 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = False, + use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, @@ -321,7 +349,7 @@ def export_onnx_models( device = torch.device("cuda:0" if use_gpu else "cpu") models = WhisperHelper.load_model( - model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path + model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path ) config = models["decoder"].config @@ -352,7 +380,6 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token, use_int32_inputs=use_int32_inputs, ) else: @@ -396,7 +423,7 @@ def export_onnx_models( extra_options={"MatMulConstBOnly": True}, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") else: output_path = onnx_path @@ -431,6 +458,7 @@ def main(argv=None): output_paths = export_onnx_models( args.model_name_or_path, + args.model_impl, cache_dir, output_dir, args.use_gpu, @@ -438,7 +466,7 @@ def main(argv=None): args.optimize_onnx, args.precision, args.verbose, - args.use_decoder_start_token, + args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, @@ -451,7 +479,7 @@ def main(argv=None): ) max_diff = 0 - if args.chain_model: + if not args.no_beam_search_op: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt new file mode 100644 index 0000000000000..db2cd95324328 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.17.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt new file mode 100644 index 0000000000000..9bd215de9bc09 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.17.1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt new file mode 100644 index 0000000000000..c307a3665f8a0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -0,0 +1,11 @@ +torch>=1.13.0 +transformers>=4.24.0 +openai-whisper +ffmpeg-python +datasets +soundfile +librosa +optimum +onnxruntime-extensions>=0.9.0 +protobuf==3.20.2 +numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a74666b7af297..14691da4ad643 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import logging import os @@ -9,7 +15,7 @@ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) @@ -23,11 +29,22 @@ def verify_inputs(beam_inputs, graph_inputs): assert graph_input.name in beam_input +def clean_list(arr, remove_all_strings=True): + if remove_all_strings: + # Remove all empty strings in list + return list(filter(lambda elm: elm != "", arr)) + + # Remove empty strings at end of list + while len(arr) > 0: + if arr[-1] == "": + arr.pop() + else: + break + return arr + + def chain_model(args): - # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op - args.use_whisper_beamsearch = ( - args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids - ) + # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" @@ -35,7 +52,10 @@ def chain_model(args): decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + # Create inputs/outputs for WhisperBeamSearch op + temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", @@ -44,38 +64,27 @@ def chain_model(args): "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", - "vocab_mask" if args.use_prefix_vocab_mask else "", + "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", + temperature_name if args.use_temperature else "", ] - beam_outputs = ["sequences"] - if args.output_sequence_scores: - beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") - if args.output_scores: - beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") - - if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 - beam_inputs.extend( - [ - "cross_qk_layer_head" if args.collect_cross_qk else "", - "extra_decoding_ids" if args.extra_decoding_ids else "", - ] - ) - if args.collect_cross_qk: - while len(beam_outputs) < 3: - beam_outputs.extend([""]) - beam_outputs.extend(["cross_qk"]) - if args.output_no_speech_probs: - while len(beam_outputs) < 4: - beam_outputs.extend([""]) - beam_outputs.extend(["no_speech_probs_beam"]) - - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None - output_scores_cast_node = output_sequence_scores_cast_node = None + sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" + scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + beam_outputs = [ + "sequences", + sequence_scores_name if args.output_sequence_scores else "", + scores_name if args.output_scores else "", + "cross_qk" if args.collect_cross_qk else "", + "no_speech_probs_beam" if args.output_no_speech_probs else "", + ] + + graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -98,6 +107,18 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + + if args.use_temperature: + temp_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="temperature_to_fp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temp_cast_node) + if args.output_sequence_scores: output_sequence_scores_cast_node = helper.make_node( "Cast", @@ -106,6 +127,8 @@ def chain_model(args): name="CastOutputSequenceScoresToFp32", to=TensorProto.FLOAT, ) + graph_nodes.append(output_sequence_scores_cast_node) + if args.output_scores: output_scores_cast_node = helper.make_node( "Cast", @@ -114,26 +137,38 @@ def chain_model(args): name="CastScoresToFp32", to=TensorProto.FLOAT, ) - - operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" - node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", config.eos_token_id), - helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + graph_nodes.append(output_scores_cast_node) + + # Create WhisperBeamSearch op + beam_search_attrs = [ + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id + ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) + if args.output_no_speech_probs + else "", + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", + ] + node = helper.make_node( + "WhisperBeamSearch", + inputs=clean_list(beam_inputs, remove_all_strings=False), + outputs=clean_list(beam_outputs, remove_all_strings=False), + name="BeamSearch", + domain="com.microsoft", ) - if args.use_whisper_beamsearch: - if args.collect_cross_qk: - node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) - if args.no_speech_token_id >= 0: - node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) + node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -143,73 +178,63 @@ def chain_model(args): num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] + ) + decoder_input_ids = helper.make_tensor_value_info( + "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - graph_inputs = [ - input_features, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - ] - if args.use_vocab_mask: - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) - graph_inputs.append(vocab_mask) - - if args.use_prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.use_forced_decoder_ids: - decoder_input_ids = helper.make_tensor_value_info( - "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] - ) - graph_inputs.append(decoder_input_ids) - - if args.use_logits_processor: - logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - graph_inputs.append(logits_processor) - - if args.collect_cross_qk: - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) - graph_inputs.append(cross_qk_layer_head) - - if args.extra_decoding_ids: - extra_decoding_ids = helper.make_tensor_value_info( - "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] - ) - graph_inputs.append(extra_decoding_ids) + graph_inputs = clean_list( + [ + input_features, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ] + ) - # graph outputs + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) - graph_outputs = [sequences] - if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): - cross_qk = helper.make_tensor_value_info( - "cross_qk", - TensorProto.FLOAT, - ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], - ) - graph_outputs.extend([cross_qk]) - - if args.output_no_speech_probs: - no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([no_speech_probs]) - - if args.output_sequence_scores: - sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([sequence_scores]) + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - if args.output_scores: - scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([scores]) + graph_outputs = clean_list( + [ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ] + ) + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") @@ -230,19 +255,7 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] - graph_nodes = ( - [ - input_features_cast_node, - len_pen_cast_node, - rep_pen_cast_node, - node, - output_sequence_scores_cast_node, - output_scores_cast_node, - ] - if args.precision == Precision.FLOAT16 - else [node] - ) - graph_nodes = [node for node in graph_nodes if node is not None] + graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", @@ -251,9 +264,16 @@ def chain_model(args): name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) - graph_nodes.extend([prob_cast_node]) - - beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + graph_nodes.append(prob_cast_node) + + # Make graph with WhisperBeamSearch op + beam_graph = helper.make_graph( + graph_nodes, + name="WhisperBeamSearch Graph", + inputs=graph_inputs, + outputs=graph_outputs, + initializer=initializers, + ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] @@ -287,10 +307,12 @@ def chain_model(args): ir_version=decoder_model.ir_version, ) + # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") os.remove(args.beam_model_output_dir) os.remove(args.beam_model_output_dir + ".data") + onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index eca5ce3de15d3..93fd64c9eb7d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -18,6 +18,7 @@ from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export from transformers import WhisperConfig, file_utils +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -67,10 +68,13 @@ def forward( class WhisperDecoder(torch.nn.Module): """A Whisper decoder with past key values""" - def __init__(self, decoder, config): + def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None): super().__init__() self.decoder = decoder self.config = config + self.model_impl = model_impl + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) def forward(self, decoder_input_ids, *past): encoder_outputs = file_utils.ModelOutput() @@ -78,6 +82,14 @@ def forward(self, decoder_input_ids, *past): encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states encoder_outputs["hidden_states"] = dummy_encoder_hidden_states encoder_outputs["attentions"] = None + + if self.model_impl == "openai": + dummy_encoder_hidden_states.unsqueeze(0) + dec_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, dummy_encoder_hidden_states, past=past + ) + return dec_out, present + if len(past) == 0: past_key_values = None else: @@ -213,7 +225,7 @@ def export_onnx( decoder.config, batch_size=2, encode_sequence_length=3000, - past_decode_sequence_length=5 if isinstance(decoder, WhisperDecoder) else 0, + past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 826d6e42c0775..93281848a5c9c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -25,12 +25,15 @@ class WhisperEncoder(torch.nn.Module): """Whisper encoder outputs only the last hidden state""" - def __init__(self, encoder, config: WhisperConfig): + def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"): super().__init__() self.encoder = encoder self.config = config + self.model_impl = model_impl def forward(self, input_features): + if self.model_impl == "openai": + return self.encoder(input_features) return self.encoder.model.encoder(input_features)[0] @@ -40,7 +43,11 @@ def __init__(self, input_features): @staticmethod def create_dummy( - batch_size: int, sequence_length: int, feature_size: int, device: torch.device, use_int32_inputs: bool + batch_size: int, + sequence_length: int, + feature_size: int, + device: torch.device, + use_int32_inputs: bool = False, ): """Create dummy inputs for Whisper encoder. @@ -61,9 +68,9 @@ def create_dummy( return WhisperEncoderInputs(input_features) def to_list(self) -> List: - if self.input_features is None: + if self.input_ids is None: return [] - return [self.input_features] + return [self.input_ids] class WhisperEncoderHelper: @@ -74,6 +81,7 @@ def export_onnx( onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export encoder to ONNX @@ -90,6 +98,7 @@ def export_onnx( sequence_length=3000, feature_size=config.num_mel_bins, device=device, + use_int32_inputs=use_int32_inputs, ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index a145178dbf37e..832f692e9980d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- +import copy import logging import os import tempfile @@ -19,6 +20,7 @@ from transformers import WhisperConfig from whisper_decoder import WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderInputs +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -34,11 +36,16 @@ def __init__( decoder: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: Optional[int] = None, + model_impl: str = "hf", + model: torch.nn.Module = None, ): super().__init__() self.config = config - self.whisper_encoder = WhisperEncoder(encoder, config) + self.whisper_encoder = WhisperEncoder(encoder, config, model_impl=model_impl) self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id) + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) + self.model_impl = model_impl def forward( self, @@ -47,9 +54,14 @@ def forward( ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) - decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) - present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) - present = present_self + present_cross + if self.model_impl == "openai": + encoder_hidden_states.unsqueeze(0) + decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + return decinit_out, encoder_hidden_states, present + else: + decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) + present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) + present = present_self + present_cross return decinit_out[0], encoder_hidden_states, present @@ -63,7 +75,7 @@ def create_dummy( config: WhisperConfig, batch_size: int, encode_sequence_length: int, - use_decoder_input_ids: int, + use_decoder_input_ids: bool, device: torch.device, use_int32_inputs: bool = False, ): # -> WhisperEncoderDecoderInitInputs: @@ -72,7 +84,6 @@ def create_dummy( sequence_length=3000, feature_size=config.num_mel_bins, device=device, - use_int32_inputs=use_int32_inputs, ) decoder_input_ids = None if use_decoder_input_ids: @@ -114,13 +125,15 @@ def export_onnx( model.config, batch_size=2, encode_sequence_length=3000, - use_decoder_input_ids=use_decoder_input_ids, + use_decoder_input_ids=True, device=device, use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() - out = model(inputs.encoder_input_ids, inputs.decoder_input_ids) + # TODO : Investigate whether copy of model if needed + cloned_model = copy.deepcopy(model).to(device) + out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) @@ -146,7 +159,7 @@ def export_onnx( hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.encoder_attention_heads) dynamic_axes = { - "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_input_ids": {0: "batch_size", 1: "feature_size"}, "encoder_hidden_states": { 0: "batch_size", 1: "encode_sequence_length", diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a4bef1f06b4fe..1b47b9426d983 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,12 +6,14 @@ import logging import os -import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -21,24 +23,20 @@ from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff # noqa: E402 -from onnx_model import OnnxModel # noqa: E402 -from optimizer import optimize_model # noqa: E402 - logger = logging.getLogger(__name__) PRETRAINED_WHISPER_MODELS = [ "whisper-tiny", "whisper-tiny.en", + "whisper-base", + "whisper-base.en", "whisper-small", "whisper-small.en", "whisper-medium", "whisper-medium.en", - "whisper-base", - "whisper-base.en", "whisper-large", "whisper-large-v2", + "whisper-large-v3", ] @@ -72,9 +70,49 @@ def get_onnx_path( directory = os.path.join(output_dir, model_name) if new_folder else output_dir return os.path.join(directory, model_name + ".onnx") + @staticmethod + def load_model_openai( + model_name_or_path: str, + cache_dir: str, + device: torch.device, + ) -> torch.nn.Module: + """Load model given a pretrained name or path, then build models for ONNX conversion. + + Args: + model_name_or_path (str): pretrained model name or path + cache_dir (str): cache directory + device (torch.device): device to run the model + merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True. + Returns: + Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. + """ + from whisper import _ALIGNMENT_HEADS, _MODELS, _download + from whisper.model import ModelDimensions, Whisper + + in_memory = False + + model_name = model_name_or_path.split("/")[-1][8:] + checkpoint_file, alignment_heads = None, None + if model_name in _MODELS: + checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory) + alignment_heads = _ALIGNMENT_HEADS[model_name] + + with open(checkpoint_file, "rb") as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + return model.to(device) + @staticmethod def load_model( model_name_or_path: str, + model_impl: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True, @@ -94,18 +132,29 @@ def load_model( if version.parse(transformers_version) >= version.parse("4.36.0"): extra_kwargs["attn_implementation"] = "eager" model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) + + if model_impl == "openai": + openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device) + model_encoder, model_decoder = openai_model.encoder, openai_model.decoder + passed_model = openai_model + else: + model_encoder, model_decoder = model, model + passed_model = None + if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) - decoder = WhisperDecoder(model, model.config) + decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model) decoder.eval().to(device) if merge_encoder_and_decoder_init: encoder_decoder_init = WhisperEncoderDecoderInit( - model, - model, + model_encoder, + model_decoder, model.config, decoder_start_token_id=None, + model_impl=model_impl, + model=passed_model, ) return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder} else: @@ -290,12 +339,17 @@ def verify_onnx( logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") os.system(install_cmd) - from datasets import load_dataset # noqa: F811 + from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -332,43 +386,51 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] - ) + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + parity = ( + pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options + ) + max_diff = 0 - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if max_diff != 0: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py new file mode 100644 index 0000000000000..941f61cf7cc29 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class WhisperDecoderInitOpenai(torch.nn.Module): + """WhisperDecoderInit for Openai.""" + + def __init__( + self, + model: torch.nn.Module, + decoder: torch.nn.Module, + ): + super().__init__() + self.whisper_model = model + self.whisper_decoder = decoder + self.kv_cache = {} + + @torch.no_grad() + def forward( + self, + tokens, + audio_features, + past=None, + ): + # Create a kv_cache for past_values + past_kv_cache = dict() + if past is not None: + # Convert past values from 4D to 3D + past = [torch.transpose(val, 1, 2) for val in past] + past = [val.reshape(val.shape[:2] + (-1,)) for val in past] + half_idx = len(past) // 2 + for idx, block in enumerate(self.whisper_decoder.blocks): + past_kv_cache[block.attn.key] = past[2 * idx] + past_kv_cache[block.attn.value] = past[2 * idx + 1] + past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] + past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + + if not self.kv_cache: + self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + + logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) + + # Add concat node for past values + if past is not None: + for block in self.whisper_decoder.blocks: + self.kv_cache[block.attn.key] = torch.cat( + [past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 + ).detach() + self.kv_cache[block.attn.value] = torch.cat( + [past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 + ).detach() + + present_self, present_cross = [], [] + # Group self and cross values + for block in self.whisper_decoder.blocks: + present_self.append(self.kv_cache[block.attn.key]) + present_self.append(self.kv_cache[block.attn.value]) + if past is None: + present_cross.append(self.kv_cache[block.cross_attn.key]) + present_cross.append(self.kv_cache[block.cross_attn.value]) + + present_self = present_self + present_cross + # Add reshape and transpose ops to convert from 3D to 4D + present_self = [ + present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self + ] + return logits, present_self diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 4e064fa53bfc6..3967a7875f3a7 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -492,10 +492,7 @@ def export_onnx_model_from_pt( example_inputs = image_processor(data, return_tensors="pt") else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) - + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") example_inputs = filter_inputs(example_inputs, input_names) @@ -599,9 +596,7 @@ def export_onnx_model_from_tf( # Fix "Using pad_token, but it is not set yet" error. if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier) model.resize_token_embeddings(len(tokenizer)) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 9d1066b6e372b..a8fc6e661933e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -82,6 +82,10 @@ def output_name_to_node(self): output_name_to_node[output_name] = node return output_name_to_node + def functions(self): + all_functions = [list(self.model.functions)] + return all_functions + def nodes(self): all_nodes = [] for graph in self.graphs(): @@ -426,6 +430,54 @@ def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, r return None + def match_child_path( + self, + node, + child_op_types, + child_output_index=None, + return_indice=None, + exclude=[], # noqa: B006 + ): + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + child_op_types (str): constraint of child node op_type of each input edge. + child_output_index (list): constraint of input index of each input edge. None means no constraint. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + children: a list of matched children node. + """ + if child_output_index is not None: + assert len(child_output_index) == len(child_op_types) + + current_node = node + matched_children = [] + for i, op_type in enumerate(child_op_types): + matched_child = None + node_children = self.get_children(current_node) + for child_i, child in enumerate(node_children): + if child.op_type == op_type and child not in exclude: + if child_output_index is not None and child_output_index[i] != child_i: + logger.debug( + f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", + stack_info=True, + ) + return None + matched_child = child + if matched_child is None: + logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + return None + + matched_children.append(matched_child) + current_node = matched_child + return matched_children + def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): if output_name_to_node is None: output_name_to_node = self.output_name_to_node() @@ -733,6 +785,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): "node_block_list", "force_fp16_initializers", "force_fp16_inputs", + "use_bfloat16_as_blocked_nodes_dtype", ] if key in kwargs } @@ -833,11 +886,9 @@ def get_graph_inputs(self, current_node, recursive=False): @staticmethod def input_index(node_output, child_node): - index = 0 - for input in child_node.input: + for index, input in enumerate(child_node.input): if input == node_output: return index - index += 1 return -1 def remove_unused_constant(self): @@ -903,7 +954,7 @@ def get_first_output(node): num_nodes_removed = 0 for node in self.model.graph.node: first_output = get_first_output(node) - kept_node = output_to_node[first_output] if first_output in output_to_node else None + kept_node = output_to_node.get(first_output) # Need double check the node since fused node might reuse output name of some nodes to be removed. # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 2a48722d17a19..61a786d7af60b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -121,7 +121,7 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): class BartOnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): + def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py new file mode 100644 index 0000000000000..0fdce29ae0fa0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -0,0 +1,928 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import List, Optional + +import numpy as np +from dynamo_onnx_helper import DynamoOnnxHelper +from fusion_base import Fusion +from fusion_options import AttentionOpType, FusionOptions +from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization +from fusion_utils import NumpyHelper +from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class ProcessGemmWFunc: + def __call__(self, x): + return np.transpose(x, (1, 0)) + + +class ProcessMatMulQFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[0], (1, 0)) + + +class ProcessMatMulKFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[1], (1, 0)) + + +class ProcessMatMulVFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[2], (1, 0)) + + +class ProcessBiasQFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[0] + return x + + +class ProcessBiasKFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[1] + return x + + +class ProcessBiasVFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[2] + return x + + +class ProcessRotCacheFunc: + def __call__(self, x): + # half rotary embedding + assert len(x.shape) == 2 + if x.shape[1] == 32: + return x[:, 0:16] + return x + + +# TODO: move to a seperate file +class Fission(Fusion): + def __init__( + self, + model: OnnxModel, + nodes_to_find: List[str], + ): + super().__init__(model, "DONOTUSE", nodes_to_find) + + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attn_op_type = attn_op_type + + def get_uname(self, layer_id, name): + return name + "_" + str(layer_id) + + def get_edge_by_name(self, edges, name): + for edge in edges: + if edge == name or edge.endswith(name) or edge.startswith(name): + return edge + raise ValueError(f"Edge {name} not found") + + def get_input_by_name(self, node, name): + return self.get_edge_by_name(node.input, name) + + def get_output_by_name(self, node, name): + return self.get_edge_by_name(node.output, name) + + def process_initializer(self, initializer_name, functor, custom_name=None): + i = self.model.get_initializer(initializer_name) + i_np_array = NumpyHelper.to_array(i) + processed_i_np_array = functor(i_np_array) + new_tensor = helper.make_tensor( + initializer_name + "_processed" if custom_name is None else custom_name, + data_type=TensorProto.FLOAT, + dims=processed_i_np_array.shape, + vals=processed_i_np_array.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(new_tensor, self.this_graph_name) + return new_tensor.name + + def add_fp32_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + + def add_int64_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.INT64 + + def replace_fp32_value_info(self, name, shape): + for value_info in self.model.graph().value_info: + if value_info.name == name: + self.model.graph().value_info.remove(value_info) + break + new_value_info = helper.make_tensor_value_info( + name, + elem_type=TensorProto.FLOAT, + shape=shape, + ) + self.model.graph().value_info.extend([new_value_info]) + + def set_unique_name_and_add_nodes( + self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str] + ): + for new_node in subgraph_nodes: + for i, name in enumerate(new_node.input): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.input[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.input[i]) + for i, name in enumerate(new_node.output): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.output[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.output[i]) + new_node.name = self.get_uname(layer_id, new_node.name) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + node = helper.make_node( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + name=prefix + "_LayerNormalization", + epsilon=9.999999747378752e-06, + ) + return [node] + + def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + matmul = helper.make_node( + "MatMul", + inputs=[inputs[0], inputs[1]], + outputs=[prefix + "matmul_out"], + name=prefix + "MatMul", + ) + add = helper.make_node( + "Add", + inputs=[prefix + "matmul_out", inputs[2]], + outputs=outputs, + name=prefix + "Bias", + ) + return [matmul, add] + + def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32): + assert len(inputs) == 4 + assert len(outputs) == 1 + node = helper.make_node( + "RotaryEmbedding", + inputs=inputs, + outputs=outputs, + name=prefix + "RotaryEmbedding", + domain="com.microsoft", + rotary_embedding_dim=rot_dim, + num_heads=num_heads, + ) + return [node] + + def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 1 + assert len(outputs) == 1 + node = helper.make_node( + "FastGelu", + inputs=inputs, + outputs=outputs, + name=prefix + "FastGelu", + domain="com.microsoft", + ) + return [node] + + def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 2 + assert len(outputs) == 1 + node = helper.make_node( + "Add", + inputs=inputs, + outputs=outputs, + name=prefix + "Add", + ) + return [node] + + def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 8 + assert len(outputs) == 3 + node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "MultiHeadAttention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + ) + return [node] + + def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 7 + assert len(outputs) == 3 + node = helper.make_node( + "GroupQueryAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "GroupQueryAttention", + domain="com.microsoft", + num_heads=num_heads, + kv_num_heads=num_heads, + ) + return [node] + + def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 5 + assert len(outputs) == 2 + node = helper.make_node( + "Attention", + inputs=inputs, + outputs=outputs, + name=prefix + "Attention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + do_rotary=1, + rotary_embedding_dim=32, + ) + return [node] + + def paged_attn( + self, + inputs: List[str], + outputs: List[str], + prefix: str = "", + num_heads=32, + head_size=80, + scale=0.11180339753627777, + ): + assert len(inputs) == 6 + assert len(outputs) == 1 + node = helper.make_node( + "PagedAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "PagedAttention", + domain="vllm.ort.ext", + num_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_size, + scale=scale, + ) + return [node] + + +class Phi2PreProcessor(DynamoOnnxHelper): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.num_hidden_layers = 32 + self.num_attention_heads = num_heads + self.hidden_size = hidden_size + + self.func_name = "modeling_phi_PhiModel_model_1" + + def get_phi2_edge_dict(self) -> dict: + edge_dict = {} + edge_dict["lm_head_1"] = "logits" + edge_dict["l_input_ids_"] = "input_ids" + edge_dict["key_states"] = "past_key_0" + edge_dict["value_states"] = "past_value_0" + for i in range(1, self.num_hidden_layers, 1): + edge_dict[f"key_states_{i}"] = f"past_key_{i}" + edge_dict[f"value_states_{i}"] = f"past_value_{i}" + edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" + edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + + outputs = [o.name for o in self.model.graph.output] + if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs: + edge_dict["model_layers_0_1_1"] = "present_key_0" + edge_dict["model_layers_0_1_2"] = "present_value_0" + else: + assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs + edge_dict["model_layers_0_1"] = "present_key_0" + edge_dict["model_layers_0_1_1"] = "present_value_0" + return edge_dict + + def simplify_phi2_op_type(self): + phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers" + for node in self.model.graph.node: + index = node.op_type.find(phi2_transformer_layer_name) + if index != -1: + node.op_type = node.op_type[index:] + + def process_graph_io(self, attn_op_type: AttentionOpType): + self.use_attn = attn_op_type == AttentionOpType.Attention + self.use_vllm = attn_op_type == AttentionOpType.PagedAttention + graph = self.model.graph + new_inputs = [] + for vi in graph.input: + if "input_ids" in vi.name: + vi_iid = helper.make_tensor_value_info( + vi.name, + elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64, + shape=["batch_size", "seq_len"], + ) + vi_step = helper.make_tensor_value_info( + "step", + elem_type=TensorProto.INT64, + shape=[1], + ) + vi_pid = helper.make_tensor_value_info( + "position_ids", + elem_type=TensorProto.INT64, + shape=["batch_size", "seq_len"], + ) + vi_mask = helper.make_tensor_value_info( + "attention_mask", + elem_type=TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) + vi_meta = helper.make_tensor_value_info( + "input_metadata", + elem_type=TensorProto.INT64, + shape=[1], + ) + new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend( + [vi_iid, vi_pid, vi_meta] + ) + if self.use_attn: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("past_key", "past"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + elif self.use_vllm: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"], + ) + new_inputs.extend([vi_cache]) + if "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "num_blocks", + "num_heads", + "head_size", + "block_size", + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + + graph.ClearField("input") + graph.input.extend(new_inputs) + + new_outputs = [] + for i, vi in enumerate(graph.output): + if i == 0: + new_outputs.extend([vi]) + else: + if self.use_attn: + if "present_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("present_key", "present"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + elif self.use_vllm: + pass + else: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + + graph.ClearField("output") + graph.output.extend(new_outputs) + + def preprocess_onnx(self, attn_op_type: AttentionOpType): + function_name = None + for func in self.model.functions: + if func.name.endswith(self.func_name): + function_name = func.name + break + assert function_name is not None + self.unroll_function(function_name) + self.update_edges(self.get_phi2_edge_dict()) + self.simplify_phi2_op_type() + self.remove_dropout_layer() + if attn_op_type == AttentionOpType.PagedAttention: + self.remove_lm_head_layer() + self.process_graph_io(attn_op_type) + + +class FissionTransformerEmbeddingPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 2 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + embedding = self.get_input_by_name(node, "embed_tokens.weight") + + layer_known_edges_names = [input, output, embedding] + + subgraph_nodes = [ + helper.make_node( + "Gather", + inputs=[embedding, input], + outputs=[output], + name="Embedding_Gather", + ), + ] + + self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names) + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerLayerNormPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 3 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + ln_weight = self.get_input_by_name(node, "final_layernorm.weight") + ln_bias = self.get_input_by_name(node, "final_layernorm.bias") + + layer_known_edges_names = [input, output, ln_weight, ln_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerCausalLMHeadPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 5 + assert len(node.output) == 1 + + input = node.input[2] + output = node.output[0] + + fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_input_by_name(node, "lm_head.bias") + + layer_known_edges_names = [input, output, fc_weight, fc_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerBlockPhi(Fission): + def __init__( + self, + model: OnnxModel, + num_heads: int, + ): + self.num_heads = num_heads + max_num_layers = 32 + self.func_to_layer_id = {} + nodes_to_find = [] + for layer in range(max_num_layers): + func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1" + nodes_to_find.append(func_name) + self.func_to_layer_id[func_name] = layer + + super().__init__(model, nodes_to_find) + + def get_layer_id(self, node): + return self.func_to_layer_id[node.op_type] + + def get_gqa_aux_nodes(self): + gqa_aux_nodes = [ + helper.make_node( + "Cast", + inputs=["attention_mask"], + outputs=["mask_int64"], + name="Cast_gqa_aux_0", + to=TensorProto.INT64, + ), + helper.make_node( + "ReduceSum", + inputs=["mask_int64", "one"], + outputs=["mask_row_sums"], + name="ReduceSum_gqa_aux", + ), + helper.make_node( + "Sub", + inputs=["mask_row_sums", "one"], + outputs=["seqlens_k_int64"], + name="Sub_gqa_aux", + ), + helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name="Cast_gqa_aux_1", + to=TensorProto.INT32, + ), + helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"), + helper.make_node( + "Gather", + inputs=["mask_shape", "one"], + outputs=["total_seq_len_int64"], + name="Gather_gqa_aux_0", + axis=0, + ), + helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_sequence_length"], + name="Cast_gqa_aux_2", + to=TensorProto.INT32, + ), + ] + return gqa_aux_nodes + + def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name): + q_weight = self.model.get_initializer(q_w) + k_weight = self.model.get_initializer(k_w) + v_weight = self.model.get_initializer(v_w) + qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0)) + kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0)) + vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0)) + qkv_weight = np.stack((qw, kw, vw), axis=1) + + q_bias = self.model.get_initializer(q_b) + k_bias = self.model.get_initializer(k_b) + v_bias = self.model.get_initializer(v_b) + qb = NumpyHelper.to_array(q_bias) + kb = NumpyHelper.to_array(k_bias) + vb = NumpyHelper.to_array(v_bias) + qkv_bias = np.stack((qb, kb, vb), axis=0) + + hidden_size = qkv_weight.shape[0] + + weight = helper.make_tensor( + weight_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size, hidden_size * 3], + vals=qkv_weight.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(weight, self.this_graph_name) + + bias = helper.make_tensor( + bias_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size * 3], + vals=qkv_bias.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(bias, self.this_graph_name) + + self.add_fp32_value_info(weight.name) + self.add_fp32_value_info(bias.name) + + return weight_name, bias_name + + def fuse( + self, + node, + input_name_to_nodes, + output_name_to_node, + ): + logger.info("Optimizing %s...", node.name) + + logger.info(f"AttentionOpType: {self.attn_op_type}") + + layer_id = self.get_layer_id(node) + + i_hidden_states = node.input[0] + i_key_cache = self.get_input_by_name(node, "past_key") + i_value_cache = self.get_input_by_name(node, "past_value") + + o_hidden_states = node.output[-1] + o_key_cache = self.get_output_by_name(node, "present_key") + o_value_cache = self.get_output_by_name(node, "present_value") + + ln_weight = self.get_input_by_name(node, "input_layernorm.weight") + ln_bias = self.get_input_by_name(node, "input_layernorm.bias") + + attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( + None, + None, + None, + None, + None, + None, + ) + attn_qkv_weight, attn_qkv_bias = None, None + cos_cache, sin_cache = None, None + + if self.attn_op_type != AttentionOpType.Attention: + attn_q_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + ) + attn_k_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + ) + attn_v_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + ) + attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias") + + cos_cache = self.process_initializer( + self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + ) + sin_cache = self.process_initializer( + self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + ) + else: + attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( + self.get_input_by_name(node, "self_attn.q_proj.weight"), + self.get_input_by_name(node, "self_attn.k_proj.weight"), + self.get_input_by_name(node, "self_attn.v_proj.weight"), + self.get_input_by_name(node, "self_attn.q_proj.bias"), + self.get_input_by_name(node, "self_attn.k_proj.bias"), + self.get_input_by_name(node, "self_attn.v_proj.bias"), + self.get_uname(layer_id, "attn_qkv_weight"), + self.get_uname(layer_id, "attn_qkv_bias"), + ) + + attn_out_weight = self.process_initializer( + self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + ) + attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias") + + mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias") + + layer_known_edges_names = [] + layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) + layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache]) + layer_known_edges_names.extend([ln_weight, ln_bias]) + if self.attn_op_type != AttentionOpType.Attention: + layer_known_edges_names.extend( + [ + attn_q_weight, + attn_q_bias, + attn_k_weight, + attn_k_bias, + attn_v_weight, + attn_v_bias, + cos_cache, + sin_cache, + ] + ) + else: + layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias]) + layer_known_edges_names.extend( + [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] + ) + layer_known_edges_names.extend( + ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"] + ) + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) + subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_")) + subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_")) + subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"])) + subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_")) + subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1")) + subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2")) + if self.attn_op_type != AttentionOpType.Attention: + subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + # vllm engine requires full position ids as the input + pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" + subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_")) + if self.attn_op_type == AttentionOpType.MultiHeadAttention: + subgraph_nodes.extend( + self.mha( + ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + elif self.attn_op_type == AttentionOpType.GroupQueryAttention: + subgraph_nodes.extend( + self.gqa( + [ + "query_rot", + "key_rot", + "value", + i_key_cache, + i_value_cache, + "seqlens_k", + "total_sequence_length", + ], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + if layer_id == 0: + gqa_aux_nodes = self.get_gqa_aux_nodes() + for new_node in gqa_aux_nodes: + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.model.add_initializer( + numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name + ) + elif self.attn_op_type == AttentionOpType.PagedAttention: + subgraph_nodes.extend( + self.paged_attn( + ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"], + ["attn_out"], + ) + ) + else: + past_name = f"past_{layer_id}" + present_name = f"present_{layer_id}" + layer_known_edges_names.extend([past_name, present_name]) + subgraph_nodes.extend( + self.attention( + ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name] + ) + ) + + self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names) + + self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class PhiOnnxModel(OnnxModel): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size) + self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads) + self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self) + self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self) + self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + assert options is not None + attn_op_type = options.attention_op_type + + self.fission_transformer_block.set_attention_op_type(attn_op_type) + + self.phi2_preprocessor.preprocess_onnx(attn_op_type) + + self.fission_transformer_block.apply() + self.fission_transformer_layernorm.apply() + self.fission_causal_lm_head.apply() + self.fission_transformer_embedding.apply() + + super().prune_graph() + + # SLN ctor is placed here intentionally to delay the symbolic shape inference + self.fuse_sln = FusionSkipLayerNormalization(self) + self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self) + self.fuse_sln.apply() + self.fuse_bias_sln.apply() + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "MultiHeadAttention", + "GroupQueryAttention", + "PagedAttention", + "Gelu", + "BiasGelu", + "FastGelu", + "LayerNormalization", + "SkipLayerNormalization", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators: {op_count}") + return op_count + + def is_fully_optimized(self, fused_op_count=None): + """ + Returns True when the model is fully optimized. + """ + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + attention = ( + op_count("Attention") + + op_count("MultiHeadAttention") + + op_count("GroupQueryAttention") + + op_count("PagedAttention") + ) + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + + is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention) + + if layer_norm == 0: + logger.debug("Layer Normalization not fused") + + if gelu == 0: + logger.debug("Gelu (or FastGelu) not fused") + + if attention == 0: + logger.warning("Attention (or MultiHeadAttention) not fused") + + return is_perfect diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ba61f4f6e43ba..ce0be6b3449ed 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -34,6 +34,7 @@ from onnx_model_clip import ClipOnnxModel from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_phi import PhiOnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel @@ -58,6 +59,7 @@ "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), "conformer": (ConformerOnnxModel, "pytorch", 1), + "phi": (PhiOnnxModel, "pytorch", 0), } diff --git a/onnxruntime/python/tools/transformers/profiler.py b/onnxruntime/python/tools/transformers/profiler.py index 8e45b149eaf03..2306b579f92fe 100644 --- a/onnxruntime/python/tools/transformers/profiler.py +++ b/onnxruntime/python/tools/transformers/profiler.py @@ -329,7 +329,7 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): calls = node_freq[node_name] avg_time = duration / float(calls) percentage = (duration / total) * 100.0 - provider = node_provider[node_name] if node_name in node_provider else "" + provider = node_provider.get(node_name, "") before_percentage += percentage lines.append( f"{duration:10d}\t{percentage:5.2f}\t{before_percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}" @@ -347,7 +347,7 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): calls = node_freq[node_name] avg_time = duration / float(calls) percentage = (duration / total) * 100.0 - provider = node_provider[node_name] if node_name in node_provider else "" + provider = node_provider.get(node_name, "") lines.append(f"{duration:10d}\t{percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}") return lines @@ -393,7 +393,7 @@ def group_node_results(sess_time, kernel_time_only, use_gpu): total_fence_time += item["dur"] continue - provider = item["args"]["provider"] if "provider" in item["args"] else "" + provider = item["args"].get("provider", "") if provider in provider_counter: provider_counter[provider] += 1 else: @@ -425,7 +425,7 @@ def group_node_results(sess_time, kernel_time_only, use_gpu): lines.append("-" * 64) lines.append("Total(μs)\tTime%\tKernel(μs)\tKernel%\tCalls\tAvgKernel(μs)\tFence(μs)\tOperator") for op_name, kernel_time in sorted(op_kernel_time.items(), key=lambda x: x[1], reverse=True): - fence_time = op_fence_time[op_name] if op_name in op_fence_time else 0 + fence_time = op_fence_time.get(op_name, 0) kernel_time_ratio = kernel_time / total_kernel_time total_time = kernel_time + fence_time time_ratio = total_time / (total_kernel_time + total_fence_time) diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index a449e881ad361..6a25196dbc24c 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -7,7 +7,7 @@ import logging import os -import onnx # noqa: F401 +import onnx import torch from transformers.modeling_utils import Conv1D @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data onnx_model_path, quantized_model_path, use_external_data_format=use_external_data_format, + extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) logger.info(f"quantized model saved to:{quantized_model_path}") # TODO: inlcude external data in total model size. diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh old mode 100644 new mode 100755 index f0422839c11eb..64d6ecde618f6 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -34,6 +34,9 @@ run_gpu_fp16=true run_cpu_fp32=false run_cpu_int8=false +# Set this to true to enable bfloat16 fastmath gemm kernels on aarch64 platforms with bfloat16 support +arm64_bfloat16_fastmath_mode=false + average_over=1000 # CPU takes longer time to run, only run 100 inferences to get average latency. if [ "$run_cpu_fp32" = true ] || [ "$run_cpu_int8" = true ]; then @@ -63,7 +66,7 @@ models_to_test="bert-base-cased roberta-base distilbert-base-uncased" # export CUDA_VISIBLE_DEVICES=1 # This script will generate a logs file with a list of commands used in tests. -echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" >> benchmark.log +echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" arm64_bfloat16_fastmath_mode=$arm64_bfloat16_fastmath_mode >> benchmark.log # Set it to false to skip testing. You can use it to dry run this script with the log file. run_tests=true @@ -127,6 +130,10 @@ if [ "$force_layer_number" = true ] ; then benchmark_options="$benchmark_options --force_num_layers $layer_number" fi +if [ "$arm64_bfloat16_fastmath_mode" = true ] ; then + benchmark_options="$benchmark_options --enable_arm64_bfloat16_fastmath_mlas_gemm" +fi + # ------------------------------------------- run_one_test() { if [ "$run_ort" = true ] ; then diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index f3e67930adbff..66f24c47f6cdb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import torch +from torch._C._onnx import OperatorExportTypes TrainingMode = torch.onnx.TrainingMode from packaging.version import Version # noqa: E402 @@ -18,7 +19,7 @@ def torch_onnx_export( training=TrainingMode.EVAL, input_names=None, output_names=None, - operator_export_type=None, + operator_export_type=OperatorExportTypes.ONNX, opset_version=None, _retain_param_name=None, do_constant_folding=True, diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 6afb61bd1f0a1..8ea37ad054ed0 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -640,122 +640,139 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); - - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); - - tester.AddInput("past", past_dims, reordered_kv_cache); - - // Rel - tester.AddOptionalInputEdge(); - - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); - - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); - - auto k_merged = pair.first; - auto k_transpose = pair.second; - - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); - - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); - - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 384}, + {2, 30, 768}, + {3, 31, 1536}, + {4, 512, 384}, + {1, 1024, 768}, + {1, 2046, 1536}, + {2, 2047, 384}, + {3, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); + + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(float); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); + + tester.AddInput("past", past_dims, reordered_kv_cache); + + // Rel + tester.AddOptionalInputEdge(); + + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); + + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); + + auto k_merged = pair.first; + auto k_transpose = pair.second; + + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); + + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); + + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); - - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -766,122 +783,138 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 768}, + {3, 30, 384}, + {8, 31, 1536}, + {4, 256, 384}, + {3, 1024, 768}, + {2, 2046, 1536}, + {1, 2047, 384}, + {2, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(MLFloat16); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); - // Rel - tester.AddOptionalInputEdge(); + // Rel + tester.AddOptionalInputEdge(); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); - auto k_merged = pair.first; - auto k_transpose = pair.second; + auto k_merged = pair.first; + auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -889,4 +922,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 84bbee35eed5a..98fb62e435f31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -7,6 +7,7 @@ #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/util/include/default_providers.h" #include "test/providers/provider_test_utils.h" @@ -75,6 +76,28 @@ TEST(LayerNormTest, LayerNorm) { test.Run(); } +TEST(LayerNormTest, LayerNorm_BFloat16Input) { +// prevents test from running on non-BF16-supporting hardware +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f})); + test.AddOutput("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index 09baf8def05f6..31ef62e69bb88 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -433,7 +433,8 @@ static void RunModelWithRandomInput( std::vector token_offset_dims{batch_size, sequence_length}; std::vector cum_seq_len_dims{batch_size + 1}; - float gpu_threshold = is_float16 ? 0.15f : 0.005f; + // TF32 in SM >= 80 is enabled by default, need larger threshold for float when TF32 is enabled. + float gpu_threshold = is_float16 ? 0.15f : (HasCudaEnvironment(800) ? 0.05f : 0.005f); gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); if (enable_cuda) { diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc index fefd5722054de..ea8537f243f5d 100644 --- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array channels_last_values = {-1, 1}; for (const int channels_last : channels_last_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; @@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array has_add_out_values = {true, false}; std::array skip_dims = {2, 4}; @@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { constexpr int channels_last = 1; for (const int skip_dim : skip_dims) { for (const bool has_add_out : has_add_out_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index b174ee4138be3..d7b1de5c930c5 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test { if (invoke_createPlan_explicityly) { onnxruntime::GraphViewer graph_viewer{graph_}; - status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_, - kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context, - MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/ - ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_); + status = SequentialPlanner::CreatePlan( + nullptr, + graph_viewer, + outer_scope_node_args, + execution_providers_, + kernel_create_info_map, + {}, + {}, + state_->GetOrtValueNameIdxMap(), + test_context, +#ifdef ORT_ENABLE_STREAM + MockStreamHandleRegsitry(), +#endif + /* {{kCpuExecutionProvider, 1}}, {},*/ + ORT_TSTR(""), + DefaultLoggingManager().DefaultLogger(), + plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); // AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 2c1cd48d3d02f..8961058628490 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/framework/allocator.h" @@ -15,7 +16,7 @@ TEST(AllocatorTest, CPUAllocatorTest) { EXPECT_EQ(cpu_arena->Info().id, 0); // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) +#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtArenaAllocator); #else EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtDeviceAllocator); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 0d3e4449da939..e9f734057da1c 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -337,6 +337,7 @@ struct StreamMock : public Stream { Status CleanUpOnRunEnd() override { return Status::OK(); } }; +#ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, false); CheckStats(&a, 0, 0, 0, 0); @@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) { EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked"; a.Free(p2); } +#endif TEST(BFCArenaTest, TestExtendStrategy) { int64_t extend_delta_bytes = 0; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index ec572ce9deed8..60752d7456d97 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK()); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable() : nullptr; @@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::vector(6, 1.0f), &v3); std::vector outputs; - ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x1_idx, x2_idx, x3_idx}), + AsSpan({v1, v2, v3}), + AsSpan({t3_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3); OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4); @@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) { CreateMLValue(cpu_allocator, std::vector{2, 2}, std::vector(4, 1.0f), &t_value); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({x_value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor()); ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value)); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 0c2d8bcb2eb93..ed698ab920147 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "asserts.h" #include "core/framework/execution_providers.h" @@ -215,7 +216,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { // if the relevant session option config flag is set // For this test we need to enable the arena-based allocator which is not supported on x86 builds, so // enable this test only on x64 builds -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_MIMALLOC) +#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { AllocatorPtr cpu_allocator = std::make_shared(); // Part 1: Feature turned ON (i.e.) allocate from non-arena memory diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index 38e3f184ebc18..9202543b75a6f 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -6,7 +6,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" - +#include #include namespace onnxruntime { @@ -138,7 +138,7 @@ TEST(TensorTest, EmptyTensorTest) { EXPECT_EQ(location.id, 0); // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) +#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtArenaAllocator); #else EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtDeviceAllocator); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bf02c1741725f..e1fcf835c6043 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,6 +42,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7642,5 +7643,143 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Gather-2 Ops + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); + auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose-2 Ops + auto* transpose_out_2 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7cfbe0a84e3e6..3874901f86387 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -128,6 +128,7 @@ namespace perftest { "\t\t The number of affinities must be equal to intra_op_num_threads - 1\n\n" "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n" "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" + "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -190,7 +191,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqz"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -373,6 +374,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'Z': test_config.run_config.disable_spinning_between_run = true; break; + case 'n': + test_config.run_config.exit_after_session_creation = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 36f08167c2217..43bf54963cabb 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -43,6 +43,13 @@ int real_main(int argc, char* argv[]) { } std::random_device rd; perftest::PerformanceRunner perf_runner(env, test_config, rd); + + // Exit if user enabled -n option so that user can measure session creation time + if (test_config.run_config.exit_after_session_creation) { + perf_runner.LogSessionCreationTime(); + return 0; + } + auto status = perf_runner.Run(); if (!status.IsOK()) { printf("Run failed:%s\n", status.ErrorMessage().c_str()); diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 9f2cbcf6a21f1..37bf80c80e90b 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -115,6 +115,11 @@ void PerformanceResult::DumpToFile(const std::basic_string& path, boo } } +void PerformanceRunner::LogSessionCreationTime() { + std::chrono::duration session_create_duration = session_create_end_ - session_create_start_; + std::cout << "\nSession creation time cost: " << session_create_duration.count() << " s\n"; +} + Status PerformanceRunner::Run() { if (!Initialize()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "failed to initialize."); diff --git a/onnxruntime/test/perftest/performance_runner.h b/onnxruntime/test/perftest/performance_runner.h index da2df9c39f44c..cb1cb661550a7 100644 --- a/onnxruntime/test/perftest/performance_runner.h +++ b/onnxruntime/test/perftest/performance_runner.h @@ -46,6 +46,8 @@ class PerformanceRunner { ~PerformanceRunner(); Status Run(); + void LogSessionCreationTime(); + inline const PerformanceResult& GetResult() const { return performance_result_; } inline void SerializeResult() const { diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 5a49414a49004..74c8eb472cb3e 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -63,6 +63,7 @@ struct RunConfig { std::string intra_op_thread_affinities; bool disable_spinning = false; bool disable_spinning_between_run = false; + bool exit_after_session_creation = false; }; struct PerformanceTestConfig { diff --git a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md index 59fe946b929f2..309b474c016c9 100644 --- a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md +++ b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md @@ -3,13 +3,13 @@ The ETW Sink (ONNXRuntimeTraceLoggingProvider) allows ONNX semi-structured printf style logs to be output via ETW. ETW makes it easy and useful to only enable and listen for events with great performance, and when you need them instead of only at compile time. -Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](docs/FAQ.md?plain=1#L7). +Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](/docs/FAQ.md?plain=1#L7). However, when the provider is enabled a new ETW logger sink will also be added and the severity separately controlled via ETW dynamically. - Provider GUID: 929DD115-1ECB-4CB5-B060-EBD4983C421D -- Keyword: Logs (0x2) keyword per [logging.h](include\onnxruntime\core\common\logging\logging.h) -- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](onnxruntime\core\platform\windows\logging\etw_sink.cc) to [ONNX severity](include\onnxruntime\core\common\logging\severity.h) in an intuitive manner +- Keyword: Logs (0x2) keyword per [logging.h](/include/onnxruntime/core/common/logging/logging.h) +- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](/onnxruntime/core/platform/windows/logging/etw_sink.cc) to [ONNX severity](/include/onnxruntime/core/common/logging/severity.h) in an intuitive manner Notes: - The ETW provider must be enabled prior to session creation, as that as when internal logging setup is complete diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 5e746ed0c62d4..d35e5c78cfd69 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -5,6 +5,7 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "core/util/math.h" #include #include @@ -786,13 +787,20 @@ TEST(MathOpTest, Sqrt_Float) { test.Run(); } -#if defined(USE_DNNL) +#if defined(USE_DNNL) || defined(USE_CUDA) TEST(MathOpTest, Sqrt_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; } +#endif +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware does NOT support BFP16"; + return; + } #endif OpTester test_bf16("Sqrt", 13); // only version 13 support bf16 for sqrt test_bf16.AddInput("X", {2, 3}, @@ -804,6 +812,9 @@ TEST(MathOpTest, Sqrt_bfloat16) { std::vector> execution_providers; #if defined(USE_DNNL) execution_providers.push_back(DefaultDnnlExecutionProvider()); +#endif +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); #endif test_bf16.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/providers/cpu/rnn/GRU.py b/onnxruntime/test/providers/cpu/rnn/GRU.py index 846fc3d06b9a9..144acaf14db61 100644 --- a/onnxruntime/test/providers/cpu/rnn/GRU.py +++ b/onnxruntime/test/providers/cpu/rnn/GRU.py @@ -47,8 +47,8 @@ def __init__(self, **params): if "initial_h" in params else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) ) - LBR = params["linear_before_reset"] if "linear_before_reset" in params else 0 # noqa: N806 - self.direction = params["direction"] if "direction" in params else "forward" + LBR = params.get("linear_before_reset", 0) # noqa: N806 + self.direction = params.get("direction", "forward") if num_directions == 1: if self.direction == "forward": diff --git a/onnxruntime/test/providers/cpu/rnn/LSTM.py b/onnxruntime/test/providers/cpu/rnn/LSTM.py index 74299ea2c75a3..116ec3671bf01 100644 --- a/onnxruntime/test/providers/cpu/rnn/LSTM.py +++ b/onnxruntime/test/providers/cpu/rnn/LSTM.py @@ -65,13 +65,13 @@ def __init__(self, **params): # type: (*Any) -> None else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) ) - f = params["f"] if "f" in params else ActivationFuncs.sigmoid - g = params["g"] if "g" in params else ActivationFuncs.tanh - h = params["h"] if "h" in params else ActivationFuncs.tanh - input_forget = params["input_forget"] if "input_forget" in params else False - clip = params["clip"] if "clip" in params else 9999.0 + f = params.get("f", ActivationFuncs.sigmoid) + g = params.get("g", ActivationFuncs.tanh) + h = params.get("h", ActivationFuncs.tanh) + input_forget = params.get("input_forget", False) + clip = params.get("clip", 9999.0) - self.direction = params["direction"] if "direction" in params else "forward" + self.direction = params.get("direction", "forward") if num_directions == 1: if self.direction == "forward": @@ -266,8 +266,8 @@ def SimpleWeightsNoBiasTwoRows(direction): # type: () -> None # noqa: N802 R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32) # noqa: N806 if direction == "bidirectional": - W = W = np.tile(W, (2, 1)).reshape(2, number_of_gates * hidden_size, input_size) # noqa: N806 - R = R = np.tile(R, (2, 1)).reshape(2, number_of_gates * hidden_size, hidden_size) # noqa: N806 + W = np.tile(W, (2, 1)).reshape(2, number_of_gates * hidden_size, input_size) # noqa: N806 + R = np.tile(R, (2, 1)).reshape(2, number_of_gates * hidden_size, hidden_size) # noqa: N806 lstm = LSTM_Helper(X=input, W=W, R=R, direction=direction) diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc index 13d4546d669e3..b6a760f7041ad 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -52,20 +52,31 @@ struct ConvOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvNhwcBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .bias = true}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.bias = true; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcGroupNoBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcPadding) { - auto op = - ConvOp{.input_dims = {2, 4, 64, 64}, .kernel_shape = {3, 3}, .channels = 4, .padding = {4, 4, 4, 4}}; + auto op = ConvOp{}; + op.input_dims = {2, 4, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 4; + op.padding = {4, 4, 4, 4}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 6514feadf0ff7..786b2cb4cedc4 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvTransposeOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -60,15 +60,21 @@ struct ConvTransposeOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcGroupNoBias) { - auto op = - ConvTransposeOp{.input_dims = {8, 8, 32, 32}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvTransposeOp{}; + op.input_dims = {8, 8, 32, 32}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { - auto op = - ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 8, 80, 80}; + op.kernel_shape = {5, 5}; + op.channels = 16; + op.bias = true; if (HasCudaEnvironment(800)) { MAKE_PROVIDERS_EPS(1e-2) @@ -78,21 +84,23 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { - auto op = ConvTransposeOp{.input_dims = {1, 16, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .padding = {2, 2, 2, 2}, - .output_padding = {}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 16, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.padding = {2, 2, 2, 2}; + op.output_padding = {}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcOutPad) { - auto op = ConvTransposeOp{.input_dims = {1, 32, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .strides = {2, 2}, - .output_padding = {1, 1, 1, 1}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 32, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.strides = {2, 2}; + op.output_padding = {1, 1, 1, 1}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 2c942bb790096..82b6a286409cd 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -16,11 +16,13 @@ #define MAKE_PROVIDERS_EPS(eps) \ std::vector> execution_providers; \ - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; \ + OrtCUDAProviderOptionsV2 nhwc{}; \ + nhwc.prefer_nhwc = true; \ execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \ \ double error_tolerance = eps; \ - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; \ + OrtCUDAProviderOptionsV2 nchw{}; \ + nchw.prefer_nhwc = false; \ auto source_ep = CudaExecutionProviderWithOptions(&nchw); \ auto test = op.get_test(); \ test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); diff --git a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc index 52da8ba557c2d..40f69e3bd5b4f 100644 --- a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc @@ -9,7 +9,7 @@ namespace test { template struct BatchNormOp { - const std::vector input_dims; + std::vector input_dims; std::unique_ptr get_test() { // create rand inputs @@ -40,9 +40,8 @@ struct BatchNormOp { }; TYPED_TEST(CudaNhwcTypedTest, BatchNormNhwc) { - auto op = BatchNormOp{ - .input_dims = {4, 16, 64, 64}, - }; + auto op = BatchNormOp{}; + op.input_dims = {4, 16, 64, 64}; MAKE_PROVIDERS() } diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index e0d59901da80c..426170b9588f1 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -9,9 +9,9 @@ namespace test { template struct PoolOp { - const std::string pooling_type; - const std::vector input_dims; - const std::vector kernel_shape; + std::string pooling_type; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; std::vector strides = {1, 1}; @@ -41,22 +41,21 @@ struct PoolOp { }; TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwc) { - auto op = PoolOp{ - .pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + MAKE_PROVIDERS() } TYPED_TEST(CudaNhwcTypedTest, MaxPoolNhwc) { - auto op = PoolOp{ - .pooling_type = "MaxPool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "MaxPool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; MAKE_PROVIDERS() } @@ -72,21 +71,24 @@ TYPED_TEST(CudaNhwcTypedTest, GlobalMaxPoolNhwc) { test->AddOutput("Y", output_dims, output_data); std::vector> execution_providers; - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; + OrtCUDAProviderOptionsV2 nhwc{}; + nhwc.prefer_nhwc = true; execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); double error_tolerance = 1e-3; - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; + OrtCUDAProviderOptionsV2 nchw{}; + nchw.prefer_nhwc = false; auto source_ep = CudaExecutionProviderWithOptions(&nchw); test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); } TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwcPad) { - auto op = PoolOp{.pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - .padding = {2, 2, 2, 2}}; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.padding = {2, 2, 2, 2}; MAKE_PROVIDERS() } diff --git a/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc index 6cac23f14459e..4917701e5197d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc @@ -17,7 +17,7 @@ TEST(CudaGemmOptions, TestDefaultOptions) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_32F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_32F); #endif } @@ -30,7 +30,7 @@ TEST(CudaGemmOptions, TestCompute16F) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_16F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_16F); #endif } @@ -43,7 +43,7 @@ TEST(CudaGemmOptions, NoReducedPrecision) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_32F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_32F); #endif } @@ -56,7 +56,7 @@ TEST(CudaGemmOptions, Pedantic) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_PEDANTIC_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_32F_PEDANTIC); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_32F); #endif } @@ -69,7 +69,7 @@ TEST(CudaGemmOptions, Compute16F_Pedantic) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_PEDANTIC_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_16F_PEDANTIC); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_16F); #endif } @@ -82,7 +82,7 @@ TEST(CudaGemmOptions, Compute16F_NoReducedPrecision) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_16F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_16F); #endif } diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2f3b0e84a123e..a6422407d79fd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1110,6 +1110,61 @@ TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) { kOnnxDomain, true); } + +static GetTestQDQModelFn BuildQDQConvertAddTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def) { + return [input0_def, input1_def](ModelTestBuilder& builder, std::vector>& output_qparams) { + constexpr bool use_contrib_qdq = true; + + // Input0 -> Quantize(u8) -> Dequantize(u8 to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_u8_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_u8_qparams.scale, + input0_u8_qparams.zero_point, use_contrib_qdq); + + // input0_after_qdq -> Quantize(u16) -> Dequantize(u16 to float) + QuantParams input0_u16_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_convert = AddQDQNodePair(builder, input0_after_qdq, input0_u16_qparams.scale, + input0_u16_qparams.zero_point, use_contrib_qdq); + + // Input1 -> Quantize(u16) -> Dequantize(u16 to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_convert, input1_after_qdq}, {op_output}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Test quantization type conversion (mixed precision) with Add. +// First input is converted from uint8_t to uint16_t. +TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(-20.0f, 20.0f, 8); + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildOpTestCase("Add", {input0_def, input1_def}, {}, {}, kOnnxDomain), + BuildQDQConvertAddTestCase(input0_def, input1_def), + provider_options, + 18, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 68e441c87860e..91b6c71e735a8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -414,6 +414,8 @@ def test_get_and_set_option_with_values(option_name, option_values): str(option_value), ) + test_get_and_set_option_with_values("enable_cuda_graph", ["1", "0"]) + test_get_and_set_option_with_values("arena_extend_strategy", ["kNextPowerOfTwo", "kSameAsRequested"]) test_get_and_set_option_with_values("cudnn_conv_algo_search", ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"]) @@ -426,6 +428,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) + test_get_and_set_option_with_values("use_tf32", ["1", "0"]) + option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" option["gpu_external_empty_cache"] = "0" @@ -553,6 +557,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) + test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) + run_rocm_options_test() def test_invalid_set_providers(self): diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py new file mode 100644 index 0000000000000..2b5d1f36070e5 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh + +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType + + +class TestQuantizerShapeInference(unittest.TestCase): + def test_com_microsoft(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("MatMul", ["X", "W1"], ["T1"]), + oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"), + oh.make_node("MatMul", ["T2", "W3"], ["T3"]), + oh.make_node("MatMul", ["T3", "W4"], ["Y"]), + ], + "name", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])], + [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])], + [ + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"), + ], + ), + opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], + ) + model_shaped = onnx.shape_inference.infer_shapes(model) + shaped_results = set(t.name for t in model_shaped.graph.value_info) + # every result after T1 depends on T2 coming from a node com.microsoft, + # shape_inference cannot go beyond this point + self.assertEqual(shaped_results, {"T1"}) + + # first try: checks it raises an exception + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + {"MatMulConstBOnly": True}, # extra_options, + # {'DefaultTensorType': 1, } + ) + + with self.assertRaises(RuntimeError) as e: + quantizer.quantize_model() + self.assertIn("Unable to find data type for weight_name=", str(e)) + + # second try: checks it works + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + { + "MatMulConstBOnly": True, + "DefaultTensorType": 1, + }, + ) + + model = quantizer.quantize_model() + ops = {n.op_type for n in model.graph.node} + self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py new file mode 100644 index 0000000000000..c425bf956f976 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_subgraph.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import tempfile +import unittest +import urllib.request + +import onnx + +from onnxruntime.quantization import quantize_dynamic + + +class TestDynamicQuantizationSubgraph(unittest.TestCase): + def test_dynamic_quantization_subgraph(self): + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx") + quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx") + urllib.request.urlretrieve( + "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path + ) + + quantize_dynamic( + model_input=onnx_path, + model_output=quantized_onnx_path, + per_channel=True, + op_types_to_quantize=[ + "Conv", + "MatMul", + "Attention", + "LSTM", + "Gather", + "Transpose", + "EmbedLayerNormalization", + ], + extra_options={"EnableSubgraph": True}, + ) + model = onnx.load(quantized_onnx_path) + + # The initializer `shared.weight_merged_0` is attached to the top-level graph, and used in a Gather node in each subgraphs. + # We expect the quantized Gather (after which a DequantizeLinear is attached) initializer to also be attached to the top-level graph. + found_gather_quantized = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_quantized": + found_gather_quantized = True + break + self.assertTrue(found_gather_quantized) + + found_gather_scale = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_scale": + found_gather_scale = True + break + self.assertTrue(found_gather_scale) + + # No initializers related to the Gather node should be attached to the subgraphs. + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + for initializer in attr.g.initializer: + self.assertTrue("shared.weight" not in initializer.name) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index c9db1fbc02931..33ec1bd7728fe 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -361,7 +361,8 @@ def run_configs(self, optional_arguments): # INT8 CPU arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments - self.run_export(arguments) + if "--model_impl" not in arguments: + self.run_export(arguments) @pytest.mark.slow def test_required_args(self): @@ -380,18 +381,24 @@ def test_logits_processor(self): @pytest.mark.slow def test_cross_qk_overall(self): - decoder_input_ids = [ - "--chain_model", - "--collect_cross_qk", - "--output_cross_qk", - "--use_forced_decoder_ids", - "--extra_decoding_ids", - "--output_no_speech_probs", + cross_qk_input_args = [ "--use_vocab_mask", "--use_prefix_vocab_mask", + "--use_forced_decoder_ids", "--use_logits_processor", + "--collect_cross_qk", + "--extra_decoding_ids", ] - self.run_configs(decoder_input_ids) + cross_qk_output_args = [ + "--output_cross_qk", + "--output_no_speech_probs", + ] + self.run_configs(cross_qk_input_args + cross_qk_output_args) + + @pytest.mark.slow + def test_openai_impl_whisper(self): + optional_args = ["--model_impl", "openai"] + self.run_configs(optional_args) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 77ce09d7e793b..7892000ae45a0 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -50,7 +50,7 @@ def run_timestamp(self, provider: str): ort_out = sess.run(None, ort_inputs) ort_out_tensor = torch.from_numpy(ort_out[0]) ort_transcription = processor.batch_decode( - ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True ) print(ort_transcription) expected_transcription = [ @@ -58,7 +58,7 @@ def run_timestamp(self, provider: str): "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "timestamp": (0.0, 5.44), } ], diff --git a/onnxruntime/test/quantization/quantization_test.cc b/onnxruntime/test/quantization/quantization_test.cc index bdfac77b336d4..773f56de5361b 100644 --- a/onnxruntime/test/quantization/quantization_test.cc +++ b/onnxruntime/test/quantization/quantization_test.cc @@ -99,24 +99,22 @@ void EnsureQuantizedTensorParam(const float scale, const T zero_point) { // First, create the scale tensor: auto alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - auto num_bytes = shape.Size() * sizeof(float); - void* data = alloc->Alloc(num_bytes); - float* float_data = static_cast(data); + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(alloc, shape.Size()); + float* float_data = buffer.get(); float_data[0] = scale; Tensor scale_tensor(DataTypeImpl::GetType(), shape, - data, + float_data, alloc->Info(), /*offset=*/0); // Next, create the zero_point tensor: - auto T_num_bytes = shape.Size() * sizeof(T); - void* T_data = alloc->Alloc(T_num_bytes); - T* typed_data = static_cast(T_data); + IAllocatorUniquePtr buffer2 = IAllocator::MakeUniquePtr(alloc, shape.Size()); + T* typed_data = buffer2.get(); typed_data[0] = zero_point; Tensor zero_point_tensor(DataTypeImpl::GetType(), shape, - T_data, + typed_data, alloc->Info(), /*offset=*/0); diff --git a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py index 79d41e41d696c..4c1e3a70de1c7 100644 --- a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py +++ b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py @@ -58,10 +58,8 @@ def _torch_tensor_to_str(torch_tensor): def _build_param_index_to_name_mapping(model, map_result): """Build index to name mapping, which is used to retrieve data from optimizer group.""" - index = 0 - for param in model.named_parameters(): + for index, param in enumerate(model.named_parameters()): map_result[index] = param[0] - index += 1 torch.manual_seed(seed) @@ -119,8 +117,7 @@ def _build_param_index_to_name_mapping(model, map_result): _sync_stream() for group in adamw_optimizer.param_groups: - p_index = 0 - for param in group["params"]: + for p_index, param in enumerate(group["params"]): state = adamw_optimizer.state[param] name = param_index_to_name_mapping[p_index] # Collect flattened optimizer state data. @@ -130,7 +127,6 @@ def _build_param_index_to_name_mapping(model, map_result): else: m1_dict[name].append(_torch_tensor_to_str(state["exp_avg"].view(-1))) m2_dict[name].append(_torch_tensor_to_str(state["exp_avg_sq"].view(-1))) - p_index += 1 adamw_optimizer.step() adamw_optimizer.zero_grad() diff --git a/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py index a3d7946d63214..173225a21a52f 100644 --- a/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py +++ b/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py @@ -58,10 +58,8 @@ def _torch_tensor_to_str(torch_tensor): def _build_param_index_to_name_mapping(model, map_result): """Build index to name mapping, which is used to retrieve data from optimizer group.""" - index = 0 - for param in model.named_parameters(): + for index, param in enumerate(model.named_parameters()): map_result[index] = param[0] - index += 1 torch.manual_seed(seed) diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index a94f7b5b707c7..40b40136af1af 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -208,12 +208,18 @@ std::unique_ptr DefaultRocmExecutionProvider(bool test_tunab } std::unique_ptr DefaultCoreMLExecutionProvider() { -// For any non - macOS system, CoreML will only be used for ort model converter -// Make it unavailable here, you can still manually append CoreML EP to session for model conversion + // To manually test CoreML model generation on a non-macOS platform, comment out the `&& defined(__APPLE__)` below. + // The test will create a model but execution of it will obviously fail. + // To test creating an ML Program, set the environment variable COREML_EP_TEST_MLPROGRAM to any value. #if defined(USE_COREML) && defined(__APPLE__) // We want to run UT on CPU only to get output value without losing precision uint32_t coreml_flags = 0; coreml_flags |= COREML_FLAG_USE_CPU_ONLY; + + if (!Env::Default().GetEnvironmentVar("COREML_EP_TEST_MLPROGRAM").empty()) { + coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + } + return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); #else return nullptr; diff --git a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc index 092ab89d5d760..f30d6ddee253a 100644 --- a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc +++ b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc @@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first))); } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first))); + } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str())); } else { ORT_THROW("Unsupported kwargs data type: ", kv.second.second); } diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 894fe3b052fb2..0b68dc65e41cd 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -140,6 +141,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 55cd2af2d0219..b0d1ed50af126 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -47,7 +47,7 @@ void addObjectMethodsForLazyTensor(py::module& m); #endif bool InitArray(); -bool GetDyanmicExecutionProviderHash( +bool GetDynamicExecutionProviderHash( const std::string& ep_shared_lib_path, const ProviderOptions& provider_options, size_t& hash, @@ -87,13 +87,7 @@ bool GetProviderInstanceHash(const std::string& type, if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) { const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info, provider_options_map); - hash = static_cast(info.device_id) ^ - info.gpu_mem_limit ^ - (static_cast(info.arena_extend_strategy) << 16) ^ - (static_cast(info.cudnn_conv_algo_search) << 18) ^ - (static_cast(info.do_copy_in_default_stream) << 20) ^ - (static_cast(info.has_user_compute_stream) << 22) ^ - std::hash{}(info.tunable_op); + hash = std::hash{}(info); return true; } #endif @@ -102,13 +96,7 @@ bool GetProviderInstanceHash(const std::string& type, if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) { const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info, provider_options_map); - hash = static_cast(info.device_id) ^ - info.gpu_mem_limit ^ - (static_cast(info.arena_extend_strategy) << 16) ^ - (static_cast(info.miopen_conv_exhaustive_search) << 18) ^ - (static_cast(info.do_copy_in_default_stream) << 20) ^ - (static_cast(info.has_user_compute_stream) << 22) ^ - std::hash{}(info.tunable_op); + hash = std::hash{}(info); return true; } #endif @@ -128,7 +116,7 @@ bool GetProviderInstanceHash(const std::string& type, provider_options.insert(option); } } - return GetDyanmicExecutionProviderHash(shared_lib_path_it->second, provider_options, hash); + return GetDynamicExecutionProviderHash(shared_lib_path_it->second, provider_options, hash); } } } diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index b7e55bc733ede..a1c3d7d7e1d4f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -30,7 +30,7 @@ def get_variable_name(self, name: str) -> str: # For some operators such as data load/store, we need an internal variable name inside the kernel function. def get_internal_variable_name(self, name: str) -> str: var_name = self._var_map[name] - var_name = self._var_map[var_name] if var_name in self._var_map else var_name + var_name = self._var_map.get(var_name, var_name) return f'float("{var_name}")' if var_name in _SPECIAL_FLOATS else var_name diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index 95e6703be8783..877eacc0b775f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl def next_power_of_2(n: int) -> int: - assert n <= 2**32, "32-bit only" + """Return the smallest power of 2 greater than or equal to n""" n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 + n |= n >> 32 n += 1 return n diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index fbf1b7c2bac42..4a03465cf2ead 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -39,7 +39,7 @@ def _defined_from_envvar(name, default_value, warn=True): # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and # assign them new values. Importing them directly do not propagate changes. ################################################################################ -ONNX_OPSET_VERSION = 15 +ONNX_OPSET_VERSION = 17 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1" ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions") _FALLBACK_INIT_EXCEPTION = None diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index f10416a9bb0f4..af5f3c9ceb565 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -376,8 +376,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. - index = 0 - for node in exported_model.graph.node: + for index, node in enumerate(exported_model.graph.node): op_name_prefix = node.op_type if node.domain == "com.microsoft" and node.op_type == "PythonOp": output_names = list(node.output) @@ -391,7 +390,6 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model break node.name = f"{op_name_prefix}_id_{index}" - index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 77317242727b4..75512cb8e8c88 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -48,7 +48,7 @@ def _to_gradient_definition(gradient): attr_def.name = key attr_def.value_json = json.dumps(value["value"]) attr_def.dtype = value["dtype"] - attr_def.is_tensor = value["is_tensor"] if "is_tensor" in value else False + attr_def.is_tensor = value.get("is_tensor", False) attributes.append(attr_def) node_def.attributes = attributes node_defs.append(node_def) @@ -241,7 +241,7 @@ def native_group_norm_gradient(): # are available for all versions, though they are not that convienent to use. def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] - if "bilinear" in backward_fn: + if "bicubic" in backward_fn: scales = ["I(2)", *scales] return [ ("Shape", ["I(0)"], ["Shape_X"]), @@ -271,3 +271,8 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) + + +@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") +def upsample_bicubic2d_gradient(): + return _upsample_gradient("upsample_bicubic2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 99e8851b6a697..f81aef5f6b9c4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,3 +808,40 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") + + +@register_symbolic("upsample_bicubic2d") +def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): + return g.op( + "org.pytorch.aten::ATen", + input, + output_size, + align_corners, + scale_factors, + operator_s="upsample_bicubic2d", + overload_name_s="vec", + ) + + +@register_symbolic("layer_norm") +@parse_args("v", "is", "v", "v", "f", "none") +def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + + res, new_running_mean, new_running_var = g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + outputs=3, # force all 3 outputs to be exported in training mode + operator_s="layer_norm", + overload_name_s="vec", + ) + + return res diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cc533e549db92..73c32a2f51e41 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -196,18 +196,20 @@ def backward(ctx, *grad_outputs): # Run and get results backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state - - # Fast version: all backward_outputs are converted first. - # This version only works if backward_outputs is an OrtValueVector. - transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - - self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) - - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + try: + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. + + # Fast version: all backward_outputs are converted first. + # This version only works if backward_outputs is an OrtValueVector. + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + return res + finally: + del ctx.run_info.state return _ORTModuleFunction diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py index dcaa202d46fd8..905eb62768a92 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py @@ -214,8 +214,7 @@ def recursive_wrap(module, save_onnx=False, onnx_prefix=""): if isinstance(sub_module, torch.nn.ModuleList): # We encounter a list of sub-modules. # Let's wrap them one-by-one. - idx = 0 - for item_name, sub_module_item in sub_module._modules.items(): + for idx, (item_name, sub_module_item) in enumerate(sub_module._modules.items()): # Avoid saving too many graphs. new_save_onnx = save_onnx and idx == 0 sub_new_prefix = new_prefix + "_" + item_name @@ -237,7 +236,6 @@ def recursive_wrap(module, save_onnx=False, onnx_prefix=""): ) else: recursive_wrap(sub_module_item, new_save_onnx, sub_new_prefix) - idx += 1 else: if is_supported(sub_module): # Just wrap it as ORTModule when possible. diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index fa72f3b134917..898c242bb3c32 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -23,7 +23,7 @@ cur_file_dir, ] -extra_compile_args = {"cxx": ["-O3"]} +extra_compile_args = {"cxx": ["-O3", "-std=c++17"]} setup( name="torch_interop_utils", ext_modules=[ diff --git a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc index cf510ea43c89f..509937bdd0c3a 100644 --- a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc @@ -135,7 +135,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_Allowed) { } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -206,7 +206,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_LabelNameNotMat } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -277,7 +277,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_ReduceNone) { } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -344,7 +344,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_NoIgnoreIndex) } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index b774fec11cc8d..bab7c09839273 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1523,7 +1523,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs) { builder.AddNode("Identity", {add2_out}, {graph_out}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), @@ -1616,7 +1616,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs_LastAddNotHaveScaleI builder.AddNode("Identity", {add2_out}, {graph_out}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), @@ -1710,7 +1710,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionTwoInputs) { builder.AddNode("Identity", {add1_out}, {graph_output2}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), diff --git a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc index ea05b29c8668b..a1629eb73eeb6 100644 --- a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc @@ -67,7 +67,7 @@ TEST(ShapeOptimizerTests, Shape15CannotFold) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -145,7 +145,7 @@ TEST(ShapeOptimizerTests, Shape15) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -218,7 +218,7 @@ TEST(ShapeOptimizerTests, Shape15TakesGraphInput) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -289,7 +289,7 @@ TEST(ShapeOptimizerTests, Shape15GeneratesGraphOutput) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -366,7 +366,7 @@ TEST(ShapeOptimizerTests, Slice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -446,7 +446,7 @@ TEST(ShapeOptimizerTests, SliceGeneratesGraphOutput) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -530,7 +530,7 @@ TEST(ShapeOptimizerTests, Gather) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -639,7 +639,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedBySlice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> dropout_input_shape; @@ -810,7 +810,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedByGatherSlice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> reshape_input_shape; @@ -976,7 +976,7 @@ TEST(ShapeOptimizerTests, SymbolicDimUsedByGather_ConcreteDimUsedByGather) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> reshape_input_shape; diff --git a/orttraining/orttraining/test/python/orttraining_test_model_transform.py b/orttraining/orttraining/test/python/orttraining_test_model_transform.py index 3b07aa1f4daf0..095830cd54ab8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_model_transform.py +++ b/orttraining/orttraining/test/python/orttraining_test_model_transform.py @@ -2,10 +2,8 @@ def add_name(model): - i = 0 - for node in model.graph.node: + for i, node in enumerate(model.graph.node): node.name = "%s_%d" % (node.op_type, i) - i += 1 def find_single_output_node(model, arg): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 938d33cc9a714..365c2bb8ebe0e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -34,7 +34,7 @@ from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck -DEFAULT_OPSET = 15 +DEFAULT_OPSET = 17 # PyTorch model definitions for tests @@ -1805,6 +1805,34 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +def test_aten_upsample_bicubic(): + class _NeuralNetUpsampleBicubic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(8, 12), mode="bicubic") + + device = "cuda" + pt_model = _NeuralNetUpsampleBicubic().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D): @@ -5252,7 +5280,7 @@ def run_step(model, x): assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13 -@pytest.mark.parametrize("opset_version", [12, 13, 14, 15]) +@pytest.mark.parametrize("opset_version", [12, 13, 14, 15, 17]) def test_opset_version_change(opset_version): original_env = None if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ: @@ -6400,7 +6428,7 @@ def run_step(model, x): reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now." ) def test_bert_result_with_layerwise_recompute(): - original_val = os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ else None + original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None) # Create PyTorch model with dropout disabled. pt_model = _get_bert_for_sequence_classification_model( "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py index 50016515a69e1..043c70263d31e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py @@ -125,8 +125,6 @@ def run_with_ort_on_gpu(model, args, rank, device): try: mp.spawn(test_Distributed_ReduceWithMarkDirtyModel, nprocs=size, args=(size,)) except Exception: - import sys # noqa: F811 - sys.stdout.flush() sys.stderr.flush() raise diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 4f0925c5c855b..2f240406b25b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -79,7 +79,7 @@ def run_step(model, x): for onnx_model in [onnx_graph_inf, onnx_graph_train]: for oimp in onnx_model.opset_import: if oimp.domain == "": - self.assertEqual(oimp.version, 15) + self.assertEqual(oimp.version, 17) # Needs to match latest default ORTModule opset if op_grad_type is not None: if isinstance(op_grad_type, tuple): text = str(onnx_graph_train) diff --git a/orttraining/orttraining/test/python/qat_poc_example/README.md b/orttraining/orttraining/test/python/qat_poc_example/README.md index 6840e98bd9c86..05072b410b730 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/README.md +++ b/orttraining/orttraining/test/python/qat_poc_example/README.md @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t > **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC. -> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True` +> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True` > **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False` diff --git a/orttraining/orttraining/test/python/qat_poc_example/model.py b/orttraining/orttraining/test/python/qat_poc_example/model.py index 91d7ccd7294f5..601362a59e379 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/model.py +++ b/orttraining/orttraining/test/python/qat_poc_example/model.py @@ -5,7 +5,7 @@ import onnx import torch -import onnxruntime.training.onnxblock as onnxblock +from onnxruntime.training import artifacts class MNIST(torch.nn.Module): @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix): 4. The checkpoint file """ - class MNISTWithLoss(onnxblock.TrainingModel): - def __init__(self): - super().__init__() - self.loss = onnxblock.loss.CrossEntropyLoss() - - def build(self, output_name): - return self.loss(output_name) - - mnist_with_loss = MNISTWithLoss() - onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None - - # Build the training and eval graphs - logging.info("Using onnxblock to create the training artifacts.") - with onnxblock.onnx_model(onnx_model) as model_accessor: - _ = mnist_with_loss(onnx_model.graph.output[0].name) - eval_model = model_accessor.eval_model - - # Build the optimizer graph - optimizer = onnxblock.optim.AdamW() - with onnxblock.onnx_model() as accessor: - _ = optimizer(mnist_with_loss.parameters()) - optimizer_model = accessor.model + onnx_model = onnx.load(model_path) + + requires_grad = [ + param.name + for param in onnx_model.graph.initializer + if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point")) + ] + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=artifacts_dir, + prefix=model_prefix, + ) # Create the training artifacts - train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx") - logging.info(f"Saving the training model to {train_model_path}.") - onnx.save(onnx_model, train_model_path) - eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx") - logging.info(f"Saving the eval model to {eval_model_path}.") - onnx.save(eval_model, eval_model_path) - optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx") - logging.info(f"Saving the optimizer model to {optimizer_model_path}.") - onnx.save(optimizer_model, optimizer_model_path) - trainable_params, non_trainable_params = mnist_with_loss.parameters() - checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt") - logging.info(f"Saving the checkpoint to {checkpoint_path}.") - onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path) + train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx") + eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx") + optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx") + checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint") return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path diff --git a/orttraining/orttraining/test/python/qat_poc_example/qat.py b/orttraining/orttraining/test/python/qat_poc_example/qat.py index 51a15475ee911..dcc9e116fda7d 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/qat.py +++ b/orttraining/orttraining/test/python/qat_poc_example/qat.py @@ -46,7 +46,7 @@ ) logging.info("Preparing the training artifacts for QAT.") - training_model_name = "mnist_qat" + training_model_name = "mnist_qat_" artifacts_dir = os.path.join(model_dir, "training_artifacts") utils.makedir(artifacts_dir) training_artifacts = create_training_artifacts( diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py index 9a429d2adc6f1..a25c071c58a48 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/train.py +++ b/orttraining/orttraining/test/python/qat_poc_example/train.py @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader): model.train() cumulative_loss = 0 for data, target in train_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - train_loss = model(forward_inputs) + train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) optimizer.step() model.lazy_reset_grad() - cumulative_loss += train_loss[0] + cumulative_loss += train_loss return cumulative_loss / len(train_loader) @@ -43,12 +39,8 @@ def _eval(model, test_loader): model.eval() cumulative_loss = 0 for data, target in test_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - test_loss = model(forward_inputs) - cumulative_loss += test_loss[0] + test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) + cumulative_loss += test_loss return cumulative_loss / len(test_loader) @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp train_loader, test_loader = _get_dataloaders("data", batch_size) # Load the checkpoint state. - state = orttraining.CheckpointState(qat_checkpoint) + state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint) # Create the training module. model = orttraining.Module(qat_train_model, state, qat_eval_model) diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index f226db76f7ed7..db8e8558ab884 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -25,12 +25,15 @@ class TritonOp final : public OpKernel { attr.first == "onnx_string") { continue; } - // Support int64 and float only for now, skip other types. + // Support int64, float and string only for now, skip other types. if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); } else if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) { + kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}}); } } } diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index dcf733153bdad..8b2bc7e2ef2b3 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -196,6 +196,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); @@ -452,6 +453,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f6c58445c0a5d..fc5d9b65d0f89 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -114,7 +114,8 @@ Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args_.params.data_type)); + args_.params.data_type, + UseTF32())); if (dB) { const TensorShape& db_shape = dB->Shape(); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index 5dc16c68f6210..d23905496c9bb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const } template -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32) { perf_results.resize(1); perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; if (args.params.data_type == CUDNN_DATA_HALF) { perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) { + perf_results[0].mathType = CUDNN_FMA_MATH; } else { perf_results[0].mathType = CUDNN_DEFAULT_MATH; } @@ -256,7 +258,7 @@ Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const std::vector perf_results; ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) + ? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32()) : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); for (auto& algo_perf : perf_results) { if (f(algo_perf) == Status::OK()) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h index a2d4bf3bdc006..3fdb4306bfbbb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -75,7 +75,7 @@ class AlgoIterator { Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f); - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32); private: const ConvArgs& args_; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index 5f7206fc121ec..d3f5a89434a48 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -182,7 +182,8 @@ Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tenso ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); } return Status::OK(); @@ -287,7 +288,8 @@ Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, cons ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); if (dB) { const auto& b_shape = dB->Shape(); diff --git a/orttraining/tools/scripts/gpt2_model_transform.py b/orttraining/tools/scripts/gpt2_model_transform.py index 06f03e06632b4..294af13fe69b7 100644 --- a/orttraining/tools/scripts/gpt2_model_transform.py +++ b/orttraining/tools/scripts/gpt2_model_transform.py @@ -17,10 +17,8 @@ def add_name(model): - i = 0 - for node in model.graph.node: + for i, node in enumerate(model.graph.node): node.name = "%s_%d" % (node.op_type, i) - i += 1 def find_input_node(model, arg): @@ -139,11 +137,9 @@ def process_concat(model): delete_nodes.append(get_node_index(model, n)) # insert new shape to reshape - index = 0 - for reshape_node_index in new_nodes: + for index, reshape_node_index in enumerate(new_nodes): shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64)) const_node = add_const(model, "concat_shape_node_%d" % index, "concat_shape_%d" % index, shape_tensor) - index += 1 reshape_node = model.graph.node[reshape_node_index] reshape_node.input[1] = const_node.output[0] # delete nodes @@ -154,28 +150,22 @@ def process_concat(model): def replace_input_arg(model, arg, new_arg): for node in model.graph.node: - i = 0 - while i < len(node.input): - if node.input[i] == arg: + for i, input_name in enumerate(node.input): + if input_name == arg: node.input[i] = new_arg - i += 1 def find_weight_index(model, name): - index = 0 - for w in model.graph.initializer: + for index, w in enumerate(model.graph.initializer): if w.name == name: return index - index += 1 return None def find_input_index(model, name): - index = 0 - for w in model.graph.input: + for index, w in enumerate(model.graph.input): if w.name == name: return index - index += 1 return None diff --git a/orttraining/tools/scripts/model_transform.py b/orttraining/tools/scripts/model_transform.py index 81e9f7b16be14..f0cf53990eac3 100644 --- a/orttraining/tools/scripts/model_transform.py +++ b/orttraining/tools/scripts/model_transform.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import sys import numpy as np import onnx -from onnx import TensorProto, helper, numpy_helper, shape_inference # noqa: F401 +from onnx import numpy_helper if len(sys.argv) < 2: print("Please give model path...") @@ -15,10 +17,8 @@ def add_name(model): - i = 0 - for node in model.graph.node: + for i, node in enumerate(model.graph.node): node.name = "%s_%d" % (node.op_type, i) - i += 1 def find_input_node(model, arg): @@ -118,11 +118,9 @@ def process_concat(model): for n in fuse_nodes: delete_nodes.append(get_node_index(model, n)) # insert new shape to reshape - index = 0 - for reshape_node_index in new_nodes: + for index, reshape_node_index in enumerate(new_nodes): shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64)) const_node = add_const(model, "concat_shape_node_%d" % index, "concat_shape_%d" % index, shape_tensor) - index += 1 reshape_node = model.graph.node[reshape_node_index] reshape_node.input[1] = const_node.output[0] # delete nodes @@ -199,12 +197,10 @@ def replace_input_arg(model, arg, new_arg): i += 1 -def find_weight_index(model, name): - index = 0 - for w in model.graph.initializer: +def find_weight_index(model, name: str) -> int | None: + for index, w in enumerate(model.graph.initializer): if w.name == name: return index - index += 1 return None diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 25454ce40c263..6836d5df69324 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.11.0 # RUFF -ruff==0.1.4 +ruff==0.2.1 # BLACK-ISORT black==23.10.1 isort==5.12.0 diff --git a/setup.py b/setup.py index 67d34b065ad03..9a5fc29dd5e02 100644 --- a/setup.py +++ b/setup.py @@ -205,18 +205,23 @@ def run(self): rocm_dependencies = [ "libamd_comgr.so.2", "libamdhip64.so.5", + "libamdhip64.so.6", "libdrm.so.2", "libdrm_amdgpu.so.1", "libelf.so.1", "libhipfft.so.0", "libhiprtc.so.5", + "libhiprtc.so.6", "libhsa-runtime64.so.1", "libMIOpen.so.1", "libnuma.so.1", "librccl.so.1", "librocblas.so.3", + "librocblas.so.4", "librocfft.so.0", + "libroctx64.so.4", "librocm_smi64.so.5", + "librocm_smi64.so.6", "libroctracer64.so.4", "libtinfo.so.6", "libmigraphx_c.so.3", @@ -419,6 +424,7 @@ def finalize_options(self): "onnxruntime.transformers.models.gpt2", "onnxruntime.transformers.models.llama", "onnxruntime.transformers.models.longformer", + "onnxruntime.transformers.models.phi2", "onnxruntime.transformers.models.t5", "onnxruntime.transformers.models.stable_diffusion", "onnxruntime.transformers.models.whisper", diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 8ea0481c9b101..f1d3702e3245e 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -117,7 +117,6 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("HIPBLAS_R_16F", "rocblas_datatype_f16_r") s = s.replace("HIPBLAS_R_32F", "rocblas_datatype_f32_r") s = s.replace("ROCBLAS_GEMM_DEFAULT_TENSOR_OP", "rocblas_gemm_algo_standard") - s = s.replace("ROCBLAS_TENSOR_OP_MATH", "0 /* CUBLAS_TENSOR_OP_MATH is deprecated */") # compatible layer s = s.replace("rocblas_gemm_strided_batched_ex", "_compat_rocblas_gemm_strided_batched_ex") @@ -182,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("rocm_device_prop_", "cuda_device_prop_") s = s.replace("rocm_device_arch_", "cuda_device_arch_") + s = s.replace("HipTuningContext", "RocmTuningContext") + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names # And we do this last, undoing or fixing hipify mistakes. if "fft" in src_file_path: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index a7cd12c488d6c..aecb3b355bd9f 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1243,9 +1243,15 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), ] - # TensorRT and OpenVINO providers currently only support - # full_protobuf option. - if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + # VitisAI and OpenVINO providers currently only support + # full_protobuf option. TensorRT provider only requires it if built with oss_parser + if ( + args.use_full_protobuf + or (args.use_tensorrt and args.use_tensorrt_oss_parser) + or args.use_openvino + or args.use_vitisai + or args.gen_doc + ): cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if args.use_tvm and args.llvm_path is not None: @@ -1527,7 +1533,8 @@ def generate_build_tree( ldflags = ["/profile", "/DYNAMICBASE"] # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. if not args.enable_address_sanitizer: - cflags += ["/Qspectre"] + # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs + cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo": @@ -1631,9 +1638,11 @@ def generate_build_tree( [ *temp_cmake_args, f"-DCMAKE_BUILD_TYPE={config}", - f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" - if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) - else "", + ( + f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" + if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) + else "" + ), ], cwd=config_build_dir, cuda_home=cuda_home, @@ -1667,8 +1676,11 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": - # CMake will generate correct build tool args for Xcode - cmd_args += ["--parallel", str(num_parallel_jobs)] + build_tool_args += [ + "-parallelizeTargets", + "-jobs", + str(num_parallel_jobs), + ] else: build_tool_args += [f"-j{num_parallel_jobs}"] @@ -2543,11 +2555,15 @@ def main(): if args.build_nuget and cross_compiling: raise BuildError("Currently nuget package creation is not supported while cross-compiling") - if args.enable_pybind and args.disable_rtti: - raise BuildError("Python bindings use typeid so you can't disable RTTI") + if args.enable_pybind: + if args.disable_rtti: + raise BuildError("Python bindings use typeid so you can't disable RTTI") + + if args.disable_exceptions: + raise BuildError("Python bindings require exceptions to be enabled.") - if args.enable_pybind and args.disable_exceptions: - raise BuildError("Python bindings require exceptions to be enabled.") + if args.minimal_build is not None: + raise BuildError("Python bindings are not supported in a minimal build.") if args.nnapi_min_api: if not args.use_nnapi: diff --git a/tools/ci_build/github/apple/get_simulator_device_info.py b/tools/ci_build/github/apple/get_simulator_device_info.py index 2a36418bac9cb..7de9aa13912e0 100755 --- a/tools/ci_build/github/apple/get_simulator_device_info.py +++ b/tools/ci_build/github/apple/get_simulator_device_info.py @@ -138,13 +138,11 @@ def runtime_id_and_device_pair_key(runtime_id_and_device_pair): def main(): parser = argparse.ArgumentParser(description="Gets simulator info from Xcode and prints it in JSON format.") - _ = parser.parse_args() # no args yet + parser.add_argument("--max-runtime-version", help="The maximum runtime version to allow.") + args = parser.parse_args() info = get_simulator_device_info( - # The macOS-13 hosted agent image has iOS 17 which is currently in beta. Limit it to 16.4 for now. - # See https://github.com/actions/runner-images/issues/8023 - # TODO Remove max_runtime_version limit. - max_runtime_version_str="16.4", + max_runtime_version_str=args.max_runtime_version, ) print(json.dumps(info, indent=2)) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 2b181810b0788..d37266a8e96d8 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index b19a8b11db265..24319184dd0b8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -204,6 +204,7 @@ jobs: --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ $(Repository) \ /bin/bash -c " set -ex; \ diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index e75bb68a8bfeb..eaadc6ad728c0 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -15,6 +15,11 @@ parameters: - 8.6.1.6 - BIN +- name: UseTensorrtOssParser + displayName: Use TensorRT-OSS Parser + type: boolean + default: false + - name: ModelGroups type: object default: @@ -73,7 +78,7 @@ jobs: value: ort-image-$(Build.BuildId) steps: - - ${{ if eq(parameters.TrtVersion, 'BIN') }}: + - ${{ if and(eq(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: - script: 'ls -al $(trtBinsDir)' displayName: 'Show available TensorRT .tar.gz packages' @@ -83,11 +88,19 @@ jobs: - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --tar_cudnn_version=$(tarCudnnVersion) --trt_bins_dir=.' displayName: 'Install TensorRT from binaries and build latest ORT Image' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' - - ${{ else }}: + + # Build ORT with TensorRT built-in parser + - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75' - displayName: 'Build latest ORT Image' + displayName: 'Build latest ORT Image with TensorRT built-in parser' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' - + + # Build ORT with TensorRT OSS parser + - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, true)) }}: + - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --use_tensorrt_oss_parser' + displayName: 'Build latest ORT Image with TensorRT OSS parser' + workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' + - ${{ if eq(parameters.MemTest, true) }}: - script: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh -d $(image) -p $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/ -w /code/ -l false' displayName: 'Run Memory Test' diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 0312b70d2b1d5..8fa5bdbf90931 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 864d1002a90fc..7b03c0e82f4bb 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -4,7 +4,7 @@ parameters: stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: @@ -18,4 +18,3 @@ stages: value: '$(Build.BinariesDirectory)' steps: - template: test.yml - diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 871d7894e5315..dc52e9a22f05b 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -3,7 +3,7 @@ parameters: stages: - stage: Nodejs_Test_MacOS_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml index c823ac788f925..9b3c61b2d3d85 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml @@ -4,7 +4,7 @@ parameters: stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index f244851f8cc37..47b1e0933417e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -13,9 +13,9 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '2.0.0' - opset_version: '15' + opset_version: '17' cuda_version: '11.8' - cmake_cuda_architectures: 60;61;70;75;80;86;90 + cmake_cuda_architectures: 60;61;70;75;80;86 docker_file: Dockerfile.manylinux2_28_training_cuda11_8 agent_pool: Onnxruntime-Linux-GPU upload_wheel: 'yes' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml index 422fb33eec5de..86dce7ae465fc 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '2.1.0' - opset_version: '15' + opset_version: '17' cuda_version: '12.2' cmake_cuda_architectures: 70;75;80;86;90 docker_file: Dockerfile.manylinux2_28_training_cuda12_2 diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 5349b1ca67ab1..6b0ae085fa4db 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -34,6 +34,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -70,5 +75,6 @@ stages: enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index b0509467e1689..9a38513d04a79 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 8bdb395c00dc3..1ba0b02560aca 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -501,12 +501,13 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- stage: Nodejs_Packaging_CPU +- stage: Nodejs_Packaging dependsOn: + - Windows_CI_GPU_DML_Dev + - Windows_CI_GPU_DML_Dev_arm64 - Linux_C_API_Packaging_CPU + - Linux_C_API_Packaging_GPU_TensorRT_x64 - MacOS_C_API_Package_Publish - - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} condition: succeeded() jobs: - job: @@ -533,17 +534,49 @@ stages: workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Testing: force EOL to lf on windows for /js/**' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Win x64)' - inputs: - artifactName: 'onnxruntime-win-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Win ARM64)' - inputs: - artifactName: 'onnxruntime-win-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + ################################################################## + # Node.js binding artifacts preparation + # + # This stage prepares Node.js binding artifacts for publishing. The artifacts support the following platforms: + # - Windows x64 with DML support + # - Windows arm64 with DML support + # - Linux x64 with TensorRT support + # - Linux arm64 (CPU only) + # - macOS x64 (CPU only) + # - macOS arm64 (CPU only) + # + # ORT Node.js binding artifacts contain 2 parts: + # 1. ONNX Runtime native shared libraries and their dependencies + # - Windows (x64, arm64): + # - onnxruntime.dll + # - DirectML.dll + # - Linux (x64, arm64): + # - libonnxruntime.so{.version} + # - libonnxruntime_providers_shared.so + # - libonnxruntime_providers_{provider}.so + # - macOS (x64, arm64): + # - libonnxruntime.dylib + # 2. ONNX Runtime Node.js binding + # - onnxruntime_binding.node + # + # For windows platform, the artifact is named as 'onnxruntime-nodejs-win-x64-dml' for x64, and + # 'onnxruntime-nodejs-win-arm64-dml' for arm64. Each artifact contains both (1) and (2). + # + # For Linux and macOS platforms, (1) and (2) are packed into separate artifacts. + # The following artifacts contain (1): + # - onnxruntime-osx + # - onnxruntime-linux-x64-tensorrt + # - onnxruntime-linux-aarch64 + # The following artifacts contain (2): + # - drop-onnxruntime-nodejs-linux-x64-tensorrt + # - drop-onnxruntime-nodejs-linux-aarch64 + # - drop-onnxruntime-nodejs-osx-x86_64 + # - drop-onnxruntime-nodejs-osx-arm64 + # + # All binary artifacts will eventually be put into folder before packaging 'onnxruntime-node': + # $(Build.SourcesDirectory)\js\node\bin\napi-v3\{os}\{cpu_arch}\ + # + # {os} is one of 'win32', 'darwin', 'linux' and {cpu_arch} is one of 'x64', 'arm64'. - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (OSX)' @@ -554,7 +587,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (Linux x64)' inputs: - artifactName: 'onnxruntime-linux-x64' + artifactName: 'onnxruntime-linux-x64-tensorrt' targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - task: DownloadPipelineArtifact@0 @@ -566,13 +599,13 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-x64' + artifactName: 'drop-onnxruntime-nodejs-win-x64-dml' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/x64/' - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win ARM64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-arm64' + artifactName: 'drop-onnxruntime-nodejs-win-arm64-dml' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/arm64/' - task: DownloadPipelineArtifact@0 @@ -590,7 +623,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Linux x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-linux-x64' + artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/x64/' - task: DownloadPipelineArtifact@0 @@ -631,38 +664,32 @@ stages: # Node.js binding win32/x64 - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64\lib' - Contents: '*.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64' - - task: CopyFiles@2 - displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' + displayName: 'Copy binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\win32\x64' - Contents: '*.node' + Contents: | + *.dll + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64' # Node.js binding win32/arm64 - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-arm64\lib' - Contents: '*.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' - - task: CopyFiles@2 - displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' + displayName: 'Copy binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\win32\arm64' - Contents: '*.node' + Contents: | + *.dll + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' # Node.js binding linux/x64 - task: CopyFiles@2 displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64\lib' - Contents: 'libonnxruntime.so.*' + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-tensorrt\lib' + Contents: | + libonnxruntime.so.* + libonnxruntime_providers_*.so TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index c2ef565a6e9ee..f1418e75bffa2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -5,10 +5,12 @@ parameters: default: 'succeeded' # could be 'ci_only', 'always', 'succeeded' steps: -- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: +- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - task: DeleteFiles@1 inputs: - contents: $(Build.BinariesDirectory)/* + SourceFolder: '$(Build.BinariesDirectory)' + contents: | + **/* displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 146e3e58444c1..5ac5bda8b0964 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -40,6 +40,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -459,3 +464,9 @@ stages: QNN_SDK: 'qnn-v2.18.0.240101_win' PYTHON_VERSION: '3.11' NUMPY_VERSION: '1.25.2' + + - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - template: py-win-x64-qnn.yml + parameters: + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: 'qnn-v2.18.0.240101_win' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 18368e59cad52..4315eae503ebd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -120,17 +120,17 @@ jobs: $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 + # building with build.py so the parallelization parameters are added to the msbuild command + - task: PythonScript@0 displayName: 'Build' inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: x64 - configuration: RelWithDebInfo - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --parallel --build + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - template: win-esrp-dll.yml @@ -188,7 +188,7 @@ jobs: condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) inputs: GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - template: component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml new file mode 100644 index 0000000000000..30f21e933ee36 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -0,0 +1,177 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'Onnxruntime-QNNEP-Windows-2022-CPU' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_x64_qnn_Wheels + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + strategy: + matrix: + Python38_x64: + PythonVersion: '3.8' + Python39_x64: + PythonVersion: '3.9' + Python310_x64: + PythonVersion: '3.10' + Python311_x64: + PythonVersion: '3.11' + Python312_x64: + PythonVersion: '3.12' + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import sys + np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: download-deps.yml + + - task: PythonScript@0 + displayName: 'Update deps.txt' + inputs: + scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py + arguments: --new_dir $(Build.BinariesDirectory)/deps + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Install ONNX' + inputs: + filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' + workingDirectory: '$(Build.BinariesDirectory)' + arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'x64' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 13d4589a67cdc..dc861f7f1ed79 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 6246bb83566e5..534d5c6d6135b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 0c95083d614ed..fafc47b6e9de6 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -7,7 +7,7 @@ ARG PLATFORM=x86_64 ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 -ARG PREPEND_PATH=/usr/local/cuda/binet +ARG PREPEND_PATH=/usr/local/cuda/bin ARG TRT_VERSION=8.6.1.6-1.cuda11.8 #Build manylinux docker image begin diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index dd7c669c37885..e1914d5fe2f06 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -178,7 +178,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.8 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 index a6a75afb0f4c3..fed29689fbe5e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 @@ -161,7 +161,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=2.0.0 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 index d29157daef611..e1caa141ef317 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 @@ -161,7 +161,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=2.1.0 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index 04a6af962b5e6..f1ffba3b3e1c9 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -82,8 +82,9 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" -RUN /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' +ENV CUDA_MODULE_LOADING "LAZY" +ARG PARSER_CONFIG="" +RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' # Switch to root to continue following steps of CI USER root diff --git a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py index dec070e3f5c75..2264742079d15 100644 --- a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py +++ b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py @@ -66,8 +66,8 @@ def diff_sections_total_size(base_binary_path, binary_path, readelf_path="readel results = collections.OrderedDict() for section in sorted(merged_keys): - base_size = base_section_sizes[section] if section in base_section_sizes else 0 - size = section_sizes[section] if section in section_sizes else 0 + base_size = base_section_sizes.get(section, 0) + size = section_sizes.get(section, 0) base_total += base_size total += size diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 4767c74afd28f..496b57b417fbd 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -112,7 +112,7 @@ RUN pip install \ cerberus \ sympy \ h5py \ - datasets==1.9.0 \ + datasets==2.17.0 \ requests \ sacrebleu==1.5.1 \ sacremoses \ @@ -131,7 +131,7 @@ RUN pip install \ # Install migraphx RUN apt update && apt install -y migraphx -ENV ORTMODULE_ONNX_OPSET_VERSION=15 +ENV ORTMODULE_ONNX_OPSET_VERSION=17 ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/op_registration_validator.py b/tools/ci_build/op_registration_validator.py index 8222437f7b42e..5c7edfa88a48b 100644 --- a/tools/ci_build/op_registration_validator.py +++ b/tools/ci_build/op_registration_validator.py @@ -165,7 +165,7 @@ def _validate_last_registration(self, last_r: RegistrationInfo) -> bool: # domain that have newer registrations in a non-contrib op file differently. They should only be considered # deprecated as contrib ops. domain_and_op_str = last_r.domain_and_op_str() - deprecation_version = deprecated_ops.get(domain_and_op_str, None) + deprecation_version = deprecated_ops.get(domain_and_op_str) allow_missing_unversioned_registration = ( deprecation_version is not None and last_r.end_version == deprecation_version - 1 diff --git a/tools/python/fix_long_lines.py b/tools/python/fix_long_lines.py index 383fdc9623551..8a3c249ef672a 100644 --- a/tools/python/fix_long_lines.py +++ b/tools/python/fix_long_lines.py @@ -20,9 +20,8 @@ def _process_files(filenames, clang_exe, tmpdir): bad_lines = [] with open(path, encoding="UTF8") as f: - line_num = 0 - for line in f: - line_num += 1 # clang-format line numbers start at 1 + for i, line in enumerate(f): + line_num = i + 1 # clang-format line numbers start at 1 if len(line) > 120: bad_lines.append(line_num) diff --git a/tools/python/gen_opkernel_doc.py b/tools/python/gen_opkernel_doc.py index 1075ed8192fdd..f6f9f21396859 100644 --- a/tools/python/gen_opkernel_doc.py +++ b/tools/python/gen_opkernel_doc.py @@ -22,11 +22,9 @@ def format_version_range(v): def format_type_constraints(tc): - counter = 0 tcstr = "" firsttcitem = True for tcitem in tc: - counter += 1 if firsttcitem: firsttcitem = False else: @@ -98,7 +96,7 @@ def main(output_path: pathlib.Path, provider_filter: [str]): paramstr += f"*out* {outp.name}:**{outp.typeStr}**" paramstr += "" - paramset = paramdict.get(fullname, None) + paramset = paramdict.get(fullname) if paramset is None: paramdict[fullname] = set() @@ -145,9 +143,8 @@ def main(output_path: pathlib.Path, provider_filter: [str]): else: fout.write("|||") fout.write(format_version_range(version_range) + "|") - tnameindex = 0 - for tname, tcset in sorted(typemap.items()): - tnameindex += 1 + for i, (tname, tcset) in enumerate(sorted(typemap.items())): + tnameindex = i + 1 tclist = [] for tc in sorted(tcset): tclist.append(tc) diff --git a/tools/python/ort_test_dir_utils.py b/tools/python/ort_test_dir_utils.py index cd1f5022af526..3af407b2aeee6 100644 --- a/tools/python/ort_test_dir_utils.py +++ b/tools/python/ort_test_dir_utils.py @@ -115,8 +115,7 @@ def create_test_dir( model_outputs = model.graph.output def save_data(prefix, name_data_map, model_info): - idx = 0 - for name, data in name_data_map.items(): + for idx, (name, data) in enumerate(name_data_map.items()): if isinstance(data, dict): # ignore. map from traditional ML ops pass @@ -130,8 +129,6 @@ def save_data(prefix, name_data_map, model_info): with open(filename, "wb") as f: f.write(tensor.SerializeToString()) - idx += 1 - if not name_input_map: name_input_map = {} diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 7a77839c4a4e7..df4e70b1e51fe 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,8 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # big models + "Big Models", # not currently required, but running ensures we're hitting all mobile platforms "Android CI Pipeline", "iOS CI Pipeline", diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py index 9eccb7c36455f..f8cc34e04afa0 100644 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py @@ -105,7 +105,7 @@ def _node_output_is_supported(name): # some models don't have complete imports. use 1 as a default as that's valid for custom domains and should # result in an error for any others. not sure why ONNX or ORT validation allows this though. - opset = opsets[domain] if domain in opsets else 1 + opset = opsets.get(domain, 1) if ( domain not in required_ops or opset not in required_ops[domain] diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index a89ac561f8860..d04e276347170 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -14,7 +14,7 @@ struct LogicalProcessorInformation { struct CoreCounter { uint32_t PhysicalCores = 0; - uint32_t SocDieCores = 0; + uint32_t Num2CacheCores = 0; }; static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { @@ -75,7 +75,7 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() { read += currentProcessorInfo->Size; } - cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + cores.Num2CacheCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); return cores; } @@ -83,8 +83,27 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); - // We want to use the number of physical cores, but exclude soc cores - return cores.PhysicalCores - cores.SocDieCores; + +#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) + const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + int regs_leaf0[4]; + int regs_leaf7[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf7, 0x7); + + auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && + (kVendorID_Intel[2] == regs_leaf0[3]); + + auto isHybrid = (regs_leaf7[3] & (1 << 15)); + + if (isIntel && isHybrid) { + // We want to use the number of physical cores, but exclude soc cores + // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores + return cores.PhysicalCores - cores.Num2CacheCores; + } +#endif + + return cores.PhysicalCores; } } // namespace WINMLP