Skip to content

[WebNN EP] Automatically use ml-tensor for outputs #24282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,20 @@ export class WebNNBackend {
* Maps from session id to list of graph inputs.
*/
private sessionGraphInputs: Map<number, string[]> = new Map();
/**
* Maps from session id to list of graph outputs.
*/
private sessionGraphOutputs: Map<number, string[]> = new Map();
/**
* Temporary graph inputs for the current session.
* These inputs will be registered when the session is created.
*/
private temporaryGraphInputs: string[] = [];
/**
* Temporary graph outputs for the current session.
* These outputs will be registered when the session is created.
*/
private temporaryGraphOutputs: string[] = [];
/**
* Temporary tensors for the current session.
*/
Expand Down Expand Up @@ -167,10 +176,15 @@ export class WebNNBackend {
this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
this.temporaryGraphInputs = [];
}
if (this.temporaryGraphOutputs.length > 0) {
this.sessionGraphOutputs.set(sessionId, this.temporaryGraphOutputs);
this.temporaryGraphOutputs = [];
}
}

public onReleaseSession(sessionId: number): void {
this.sessionGraphInputs.delete(sessionId);
this.sessionGraphOutputs.delete(sessionId);
const mlContext = this.mlContextBySessionId.get(sessionId)!;
if (!mlContext) {
// Current session is not a WebNN session.
Expand Down Expand Up @@ -363,6 +377,10 @@ export class WebNNBackend {
this.temporaryGraphInputs.push(inputName);
}

public registerGraphOutput(outputName: string): void {
this.temporaryGraphOutputs.push(outputName);
}

public isGraphInput(sessionId: number, inputName: string): boolean {
const inputNames = this.sessionGraphInputs.get(sessionId);
if (!inputNames) {
Expand All @@ -371,6 +389,14 @@ export class WebNNBackend {
return inputNames.includes(inputName);
}

public isGraphOutput(sessionId: number, outputName: string): boolean {
const outputNames = this.sessionGraphOutputs.get(sessionId);
if (!outputNames) {
return false;
}
return outputNames.includes(outputName);
}

public isInt64Supported(sessionId: number): boolean {
const context = this.mlContextBySessionId.get(sessionId);
return !!context?.opSupportLimits().input.dataTypes.includes('int64');
Expand Down
50 changes: 44 additions & 6 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,13 @@ export const initEp = async (env: Env, epName: string): Promise<void> => {
/**
* valid data locations for input/output tensors.
*/
type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-tensor';
type SupportedTensorDataLocationForInputOutput =
| 'cpu'
| 'cpu-pinned'
| 'gpu-buffer'
| 'ml-tensor'
// Use 'ml-tensor' during inference, but output a tensor located on the CPU.
| 'ml-tensor-cpu-output';

type IOBindingState = {
/**
Expand Down Expand Up @@ -424,6 +430,11 @@ export const createSession = async (
typeof options?.preferredOutputLocation === 'string'
? options.preferredOutputLocation
: (options?.preferredOutputLocation?.[nameString] ?? 'cpu');
const isGraphOutput = wasm.webnnIsGraphOutput;
if (location === 'cpu' && isGraphOutput && isGraphOutput(sessionHandle, nameString)) {
outputPreferredLocations.push('ml-tensor-cpu-output');
continue;
}
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
Expand All @@ -438,7 +449,10 @@ export const createSession = async (

// use IO binding only when at least one output is preferred to be on GPU.
let bindingState: IOBindingState | null = null;
if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor')) {
if (
!BUILD_DEFS.DISABLE_JSEP &&
outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor' || l === 'ml-tensor-cpu-output')
) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError("Can't create IO binding.");
Expand All @@ -447,7 +461,10 @@ export const createSession = async (
bindingState = {
handle: ioBindingHandle,
outputPreferredLocations,
outputPreferredLocationsEncoded: outputPreferredLocations.map((l) => dataLocationStringToEnum(l)),
outputPreferredLocationsEncoded: outputPreferredLocations
// 'ml-tensor-cpu-output' is treated as 'ml-tensor' for the purpose of IO binding.
.map((l) => (l === 'ml-tensor-cpu-output' ? 'ml-tensor' : l))
.map((l) => dataLocationStringToEnum(l)),
};
}

Expand Down Expand Up @@ -599,10 +616,11 @@ export const prepareInputOutputTensor = async (
}
} else {
const isGraphInput = wasm.webnnIsGraphInput;
if (dataType !== 'string' && isGraphInput) {
const isGraphOutput = wasm.webnnIsGraphOutput;
if (dataType !== 'string' && isGraphInput && isGraphOutput) {
const tensorName = wasm.UTF8ToString(tensorNameUTF8Encoded);
// Promote the tensor to 'ml-tensor' if it is a graph input.
if (isGraphInput(sessionId, tensorName)) {
if (isGraphInput(sessionId, tensorName) || isGraphOutput(sessionId, tensorName)) {
const dataTypeEnum = tensorDataTypeStringToEnum(dataType);
dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!;
actualLocation = 'ml-tensor';
Expand Down Expand Up @@ -810,6 +828,7 @@ export const run = async (
}

const output: TensorMetadata[] = [];
const outputPromises: Array<Promise<[number, Tensor.DataType]>> = [];

for (let i = 0; i < outputCount; i++) {
const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*'));
Expand Down Expand Up @@ -958,6 +977,20 @@ export const run = async (
},
'ml-tensor',
]);
} else if (preferredLocation === 'ml-tensor-cpu-output' && size > 0) {
const data = wasm.webnnCreateMLTensorDownloader!(dataOffset, type as Tensor.MLTensorDataTypes)();
const index = output.length;
// Delay the data download and releasing the tensor until we can wait for all output tensors to be downloaded.
keepOutputTensor = true;
outputPromises.push(
(async () => {
const result: [number, Tensor.DataType] = [index, await data];
wasm.webnnReleaseTensorId!(dataOffset);
wasm._OrtReleaseTensor(tensor);
return result;
})(),
);
output.push([type, dims, [], 'cpu']);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
Expand All @@ -975,7 +1008,6 @@ export const run = async (
if (!keepOutputTensor) {
wasm._OrtReleaseTensor(tensor);
}
wasm.webnnOnRunEnd?.(sessionHandle);
}
}

Expand All @@ -992,8 +1024,14 @@ export const run = async (
false,
]);
}
// Wait for all output tensor data to be downloaded.
for (const [index, data] of await Promise.all(outputPromises)) {
output[index][2] = data;
}
return output;
} finally {
wasm.webnnOnRunEnd?.(sessionHandle);

wasm.stackRestore(beforeRunStack);

if (BUILD_DEFS.USE_WEBGPU_EP) {
Expand Down
13 changes: 13 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,19 @@ export declare namespace JSEP {
* @returns whether the input is a WebNN graph input.
*/
webnnIsGraphInput: (sessionId: number, inputName: string) => boolean;
/**
* [exported from pre-jsep.js] Register a WebNN graph output.
* @param outputName - specify the output name.
*/
webnnRegisterGraphOutput: (outputName: string) => void;
/**
* [exported from pre-jsep.js] Check if a graph output is a WebNN graph output.
* @param sessionId - specify the session ID.
* @param outputName - specify the output name.
* @returns whether the output is a WebNN graph output.
*/
webnnIsGraphOutput: (sessionId: number, outputName: string) => boolean;

/**
* [exported from pre-jsep.js] Create a temporary MLTensor for a session.
* @param sessionId - specify the session ID.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
emscripten::val::module_property("webnnRegisterGraphInput")(name);
input_names_.push_back(name);
} else {
emscripten::val::module_property("webnnRegisterGraphOutput")(name);
output_names_.push_back(name);
}

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ Module["jsepInit"] = (name, params) => {
Module["webnnRegisterGraphInput"] =
backend["registerGraphInput"].bind(backend);
Module["webnnIsGraphInput"] = backend["isGraphInput"].bind(backend);
Module["webnnRegisterGraphOutput"] =
backend["registerGraphOutput"].bind(backend);
Module["webnnIsGraphOutput"] = backend["isGraphOutput"].bind(backend);

Module["webnnCreateTemporaryTensor"] =
backend["createTemporaryTensor"].bind(backend);
Expand Down
Loading