Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MlasTranspose multi-threads support. #24261

Merged

Conversation

msy-kato
Copy link
Contributor

@msy-kato msy-kato commented Apr 1, 2025

Description

MlasTranspose was running single-thread and was not performing well enough on a multi-threaded CPU. Therefore, I modified it to run with multi-thread to improve performance.

The MlasTranspose was previously running in a single-threaded, which resulted in suboptimal performance on multi-threaded CPUs. To address this, I have modified it to utilize multi-threading.

Motivation and Context

We encountered this issue while running the multilingual-e5-large, which was converted to ONNX format and executed on a multi-core CPU (Xeon 6338). Below are the performance metrics before and after the modification:

INTER_NUM_THREADS INTRA_NUM_THREADS INPUT_LENGTH BATCH_SIZE Duration time[sec]
BEFORE 1 16 512 4 1.24
AFTER 1 16 512 4 1.09

Condition

  • FP32
  • CPUExecutionProvider

This change resulted in a performance improvement of approximately 14%. MlasTranspose stand-alone performance improvements are as follows

INTRA_NUM_THREADS BEFORE AFTER
MlasTranspose [msec] 16 182.55 [ms] 11.60 [ms]

MlasTranspose is x15~16 faster.

@msy-kato msy-kato requested a review from a team as a code owner April 1, 2025 00:01
@snnn
Copy link
Member

snnn commented Apr 1, 2025

/azp run Big Models, Linux CPU Minimal Build E2E CI Pipeline, Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@msy-kato
Copy link
Contributor Author

msy-kato commented Apr 1, 2025

@microsoft-github-policy-service agree company="Fujitsu Ltd."

@snnn
Copy link
Member

snnn commented Apr 1, 2025

##[error]D:\a_work\onnxruntime\onnxruntime\onnxruntime\core\mlas\lib\transpose.cpp(986,5): error C2664: 'void MlasExecuteThreaded(MLAS_THREADED_ROUTINE (__cdecl *),void *,ptrdiff_t,MLAS_THREADPOOL *)': cannot convert argument 1 from 'void (__stdcall *)(void *,ptrdiff_t)' to 'MLAS_THREADED_ROUTINE (__cdecl *)' [D:\a_work\onnxruntime\onnxruntime\build\RelWithDebInfo\onnxruntime_mlas.vcxproj]

@snnn
Copy link
Member

snnn commented Apr 1, 2025

Lint for onnxruntime/test/mlas/unittest/test_transpose.cpp:

Warning (CLANGFORMAT) format
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.

You can run `lintrunner -a` to apply this patch.

[42](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:43)  42 |   }
[43](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:44)  43 | 
44  44 |   static const std::string GetTypeString() {
[44](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:45)     |-    if(std::is_same<ElementType, float>::value) return std::string("FP32");
[45](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:46)     |-    if(std::is_same<ElementType, uint32_t>::value) return std::string("U32");
[46](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:47)     |-    if(std::is_same<ElementType, uint16_t>::value) return std::string("U16");
[47](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:48)     |-    if(std::is_same<ElementType, uint8_t>::value) return std::string("U8");
    45 |+    if (std::is_same<ElementType, float>::value) return std::string("FP32");
    46 |+    if (std::is_same<ElementType, uint32_t>::value) return std::string("U32");
    47 |+    if (std::is_same<ElementType, uint16_t>::value) return std::string("U16");
    [48](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:49) |+    if (std::is_same<ElementType, uint8_t>::value) return std::string("U8");
[49](https://github.com/microsoft/onnxruntime/actions/runs/14186340445/job/39791671501?pr=24261#step:7:50)  49 |     return std::string("unknown");
50  50 |   }
51  51 |

@msy-kato
Copy link
Contributor Author

msy-kato commented Apr 1, 2025

OK, I'll apply the patch and push again.

@jywu-msft jywu-msft requested a review from amarin16 April 2, 2025 00:38
@jywu-msft
Copy link
Member

@amarin16 please test this out and review. thanks!

@snnn snnn closed this Apr 3, 2025
@snnn snnn reopened this Apr 3, 2025
@amarin16
Copy link
Collaborator

amarin16 commented Apr 3, 2025

The code changes look good to me. Waiting for the pipelines to pass

@amarin16
Copy link
Collaborator

amarin16 commented Apr 3, 2025

@msy-kato Could you please provide some details about how you ran the performance tests? Did you use onnxruntime-genai?

@msy-kato
Copy link
Contributor Author

msy-kato commented Apr 4, 2025

Thanks for the review.

Could you please provide some details about how you ran the performance tests?

Sure! I converted HF model by torch.onnx.export and run model directry by onnxruntime.InferenceSession.
This is my snippet I use.

  • convert.py
import torch
from transformers import AutoTokenizer, AutoModel

model = AutoModel.from_pretrained("intfloat/multilingual-e5-large")
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
input_texts = [' '.join(['Hello']) * 32] * 2
inputs = dict(tokenizer(input_texts, return_tensors="pt"))
torch.onnx.export(
    model,
    inputs,
    "model.onnx",
    input_names=list(inputs.keys()),
    output_names=['last_hidden_state', 'pooler_output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'max_input_length'},
        'attention_mask': {0: 'batch_size', 1: 'max_input_length'},
    }
)
  • run.py
import onnxruntime
import time
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
input_texts = [' '.join(['Hello']) * 510] * 4
options = onnxruntime.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 16
ort_session = onnxruntime.InferenceSession("model.onnx", sess_options=options)
batch_dict = dict(tokenizer(input_texts, max_length=512, return_tensors="pt"))
batch_dict = {name: tensor.numpy() for name, tensor in batch_dict.items()}

# warmup
_ = ort_session.run(['last_hidden_state'], batch_dict)

start_time = time.time()
for i in range(10):
    _ = ort_session.run(['last_hidden_state'], batch_dict)
end_time = time.time()
print('step duration(avg) = {:.7f} sec/step'.format((end_time - start_time) / 10))

commands

$ python3 convert.py
$ numactl -C 0-15 python3 run.py

@msy-kato
Copy link
Contributor Author

msy-kato commented Apr 8, 2025

The code changes look good to me. Waiting for the pipelines to pass

@amarin16 Thank you for approving my PR. I noticed that the CI/CD pipeline hasn't completed yet. Could you advise if there's anything I can do?

@amarin16
Copy link
Collaborator

amarin16 commented Apr 8, 2025

Could try closing the PR and re-opening it

@amarin16 amarin16 closed this Apr 8, 2025
@amarin16 amarin16 reopened this Apr 8, 2025
@snnn
Copy link
Member

snnn commented Apr 8, 2025

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@msy-kato msy-kato force-pushed the feature-mlastranspose-multithread-v2 branch from 1a22f09 to 09aade9 Compare April 10, 2025 01:51
@msy-kato msy-kato closed this Apr 10, 2025
@msy-kato msy-kato reopened this Apr 10, 2025
@amarin16
Copy link
Collaborator

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@amarin16 amarin16 merged commit 7a03764 into microsoft:main Apr 11, 2025
69 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants