diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index c2a855bedca22..4de02983d068d 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -79,11 +79,20 @@ export class WebNNBackend { * Maps from session id to list of graph inputs. */ private sessionGraphInputs: Map = new Map(); + /** + * Maps from session id to list of graph outputs. + */ + private sessionGraphOutputs: Map = new Map(); /** * Temporary graph inputs for the current session. * These inputs will be registered when the session is created. */ private temporaryGraphInputs: string[] = []; + /** + * Temporary graph outputs for the current session. + * These outputs will be registered when the session is created. + */ + private temporaryGraphOutputs: string[] = []; /** * Temporary tensors for the current session. */ @@ -167,10 +176,15 @@ export class WebNNBackend { this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs); this.temporaryGraphInputs = []; } + if (this.temporaryGraphOutputs.length > 0) { + this.sessionGraphOutputs.set(sessionId, this.temporaryGraphOutputs); + this.temporaryGraphOutputs = []; + } } public onReleaseSession(sessionId: number): void { this.sessionGraphInputs.delete(sessionId); + this.sessionGraphOutputs.delete(sessionId); const mlContext = this.mlContextBySessionId.get(sessionId)!; if (!mlContext) { // Current session is not a WebNN session. @@ -363,6 +377,10 @@ export class WebNNBackend { this.temporaryGraphInputs.push(inputName); } + public registerGraphOutput(outputName: string): void { + this.temporaryGraphOutputs.push(outputName); + } + public isGraphInput(sessionId: number, inputName: string): boolean { const inputNames = this.sessionGraphInputs.get(sessionId); if (!inputNames) { @@ -371,6 +389,14 @@ export class WebNNBackend { return inputNames.includes(inputName); } + public isGraphOutput(sessionId: number, outputName: string): boolean { + const outputNames = this.sessionGraphOutputs.get(sessionId); + if (!outputNames) { + return false; + } + return outputNames.includes(outputName); + } + public isInt64Supported(sessionId: number): boolean { const context = this.mlContextBySessionId.get(sessionId); return !!context?.opSupportLimits().input.dataTypes.includes('int64'); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8dd643293937b..227c89a53afc6 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -172,7 +172,13 @@ export const initEp = async (env: Env, epName: string): Promise => { /** * valid data locations for input/output tensors. */ -type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-tensor'; +type SupportedTensorDataLocationForInputOutput = + | 'cpu' + | 'cpu-pinned' + | 'gpu-buffer' + | 'ml-tensor' + // Use 'ml-tensor' during inference, but output a tensor located on the CPU. + | 'ml-tensor-cpu-output'; type IOBindingState = { /** @@ -424,6 +430,11 @@ export const createSession = async ( typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : (options?.preferredOutputLocation?.[nameString] ?? 'cpu'); + const isGraphOutput = wasm.webnnIsGraphOutput; + if (location === 'cpu' && isGraphOutput && isGraphOutput(sessionHandle, nameString)) { + outputPreferredLocations.push('ml-tensor-cpu-output'); + continue; + } if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') { throw new Error(`Not supported preferred output location: ${location}.`); } @@ -438,7 +449,10 @@ export const createSession = async ( // use IO binding only when at least one output is preferred to be on GPU. let bindingState: IOBindingState | null = null; - if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor')) { + if ( + !BUILD_DEFS.DISABLE_JSEP && + outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor' || l === 'ml-tensor-cpu-output') + ) { ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); if (ioBindingHandle === 0) { checkLastError("Can't create IO binding."); @@ -447,7 +461,10 @@ export const createSession = async ( bindingState = { handle: ioBindingHandle, outputPreferredLocations, - outputPreferredLocationsEncoded: outputPreferredLocations.map((l) => dataLocationStringToEnum(l)), + outputPreferredLocationsEncoded: outputPreferredLocations + // 'ml-tensor-cpu-output' is treated as 'ml-tensor' for the purpose of IO binding. + .map((l) => (l === 'ml-tensor-cpu-output' ? 'ml-tensor' : l)) + .map((l) => dataLocationStringToEnum(l)), }; } @@ -599,10 +616,11 @@ export const prepareInputOutputTensor = async ( } } else { const isGraphInput = wasm.webnnIsGraphInput; - if (dataType !== 'string' && isGraphInput) { + const isGraphOutput = wasm.webnnIsGraphOutput; + if (dataType !== 'string' && isGraphInput && isGraphOutput) { const tensorName = wasm.UTF8ToString(tensorNameUTF8Encoded); // Promote the tensor to 'ml-tensor' if it is a graph input. - if (isGraphInput(sessionId, tensorName)) { + if (isGraphInput(sessionId, tensorName) || isGraphOutput(sessionId, tensorName)) { const dataTypeEnum = tensorDataTypeStringToEnum(dataType); dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!; actualLocation = 'ml-tensor'; @@ -810,6 +828,7 @@ export const run = async ( } const output: TensorMetadata[] = []; + const outputPromises: Array> = []; for (let i = 0; i < outputCount; i++) { const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*')); @@ -958,6 +977,20 @@ export const run = async ( }, 'ml-tensor', ]); + } else if (preferredLocation === 'ml-tensor-cpu-output' && size > 0) { + const data = wasm.webnnCreateMLTensorDownloader!(dataOffset, type as Tensor.MLTensorDataTypes)(); + const index = output.length; + // Delay the data download and releasing the tensor until we can wait for all output tensors to be downloaded. + keepOutputTensor = true; + outputPromises.push( + (async () => { + const result: [number, Tensor.DataType] = [index, await data]; + wasm.webnnReleaseTensorId!(dataOffset); + wasm._OrtReleaseTensor(tensor); + return result; + })(), + ); + output.push([type, dims, [], 'cpu']); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); @@ -975,7 +1008,6 @@ export const run = async ( if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } - wasm.webnnOnRunEnd?.(sessionHandle); } } @@ -992,8 +1024,14 @@ export const run = async ( false, ]); } + // Wait for all output tensor data to be downloaded. + for (const [index, data] of await Promise.all(outputPromises)) { + output[index][2] = data; + } return output; } finally { + wasm.webnnOnRunEnd?.(sessionHandle); + wasm.stackRestore(beforeRunStack); if (BUILD_DEFS.USE_WEBGPU_EP) { diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index b2ca8480f1546..22af02b2790f4 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -287,6 +287,19 @@ export declare namespace JSEP { * @returns whether the input is a WebNN graph input. */ webnnIsGraphInput: (sessionId: number, inputName: string) => boolean; + /** + * [exported from pre-jsep.js] Register a WebNN graph output. + * @param outputName - specify the output name. + */ + webnnRegisterGraphOutput: (outputName: string) => void; + /** + * [exported from pre-jsep.js] Check if a graph output is a WebNN graph output. + * @param sessionId - specify the session ID. + * @param outputName - specify the output name. + * @returns whether the output is a WebNN graph output. + */ + webnnIsGraphOutput: (sessionId: number, outputName: string) => boolean; + /** * [exported from pre-jsep.js] Create a temporary MLTensor for a session. * @param sessionId - specify the session ID. diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 399cc5faf6273..37df96e9ebaf7 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -290,6 +290,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i emscripten::val::module_property("webnnRegisterGraphInput")(name); input_names_.push_back(name); } else { + emscripten::val::module_property("webnnRegisterGraphOutput")(name); output_names_.push_back(name); } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index cca8da0525fbe..4dca86d287dae 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -151,6 +151,9 @@ Module["jsepInit"] = (name, params) => { Module["webnnRegisterGraphInput"] = backend["registerGraphInput"].bind(backend); Module["webnnIsGraphInput"] = backend["isGraphInput"].bind(backend); + Module["webnnRegisterGraphOutput"] = + backend["registerGraphOutput"].bind(backend); + Module["webnnIsGraphOutput"] = backend["isGraphOutput"].bind(backend); Module["webnnCreateTemporaryTensor"] = backend["createTemporaryTensor"].bind(backend);