Add MIGraphX execution provider support#2069
Open
aditya-dl wants to merge 6 commits intomicrosoft:mainfrom
Open
Add MIGraphX execution provider support#2069aditya-dl wants to merge 6 commits intomicrosoft:mainfrom
aditya-dl wants to merge 6 commits intomicrosoft:mainfrom
Conversation
Register MIGraphX as an execution provider following the new
per-provider session_options architecture. Creates
src/migraphx/session_options.{h,cpp} with AppendExecutionProvider
that tries V2 plugin path then falls back to V1 legacy API.
Adds provider name normalization ("migraphx" -> "MIGraphX") and
registers in the dispatch table and cmake build.
The MIGraphX static padding in State::Run() hardcoded int64_t* for data access regardless of actual element type, only zeroed the first batch row, and used a flat copy loop ignoring per-batch-row strides. Fix dispatches on elem_type for int32/int64, zeros the entire padded tensor, and copies each batch row with correct source/dest strides.
When prompt_gen_ pads position_ids_shape_ to max_length, the function used shape[1] (max_length) to index into next_tokens which only has new_length elements per row, causing out-of-bounds reads for batch_size>1. Add seq_length parameter to separate actual token count from padded tensor shape. Zero-fill the full tensor upfront and populate only the real token positions per row.
Contributor
|
@aditya-dl please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
The prompt_gen_ flag was set unconditionally in AppendTokens(), causing all EPs (CUDA, DML, WebGPU, etc.) to pad inputs to max_length during prompt processing. This is only needed for MIGraphX to avoid recompilation on varying input shapes. Add NeedsStaticInputShapes() config query that returns true only when MIGraphX is the configured provider. Cache the result as use_static_input_shapes on GeneratorParams (following the same pattern as use_graph_capture and use_multi_profile). Set prompt_gen_ from this flag instead of unconditional true.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds MIGraphX (AMD GPU) execution provider support to ONNX Runtime GenAI.
prompt_gen_flag) so MIGraphX avoids recompilation on varying prompt lengthsHow it works
MIGraphX requires static input dimensions to avoid expensive recompilation. During prompt processing,
input_ids,position_ids, andlogitsare padded tomax_length. During token generation, shapes are [batch, 1] and graph capture is enabled. The attention mask is padded via the existingInitializeStaticMaskpath (triggered by graph capture).MIGraphX loads as a shared provider via the V1
AppendExecutionProviderpath. When the plugin EP (V2) becomes available, the same code will work through the V2 path automatically.Known limitations