Skip to content

Commit ce68bb1

Browse files
committed
Address PR review: fix heap alloc, D2H copy, input name, EP check
1 parent ac3fbfe commit ce68bb1

4 files changed

Lines changed: 23 additions & 24 deletions

File tree

src/models/position_inputs.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,10 @@ void DefaultPositionInputs::CreateAndInitializeCompactAttentionMask(DeviceSpan<i
455455
void DefaultPositionInputs::UpdateCompactAttentionMask() {
456456
// In compact mode, attention_mask has shape [batch_size, 1] containing total seq len per batch.
457457
// Each decode step adds one token, so increment each value by 1.
458+
// Use CpuSpan() as the source of truth — avoids per-token device-to-host readback.
459+
// This is safe because all non-fast-path updates go through CpuSpan() + CopyCpuToDevice().
458460
auto byte_span = attention_mask_->GetByteSpan();
459-
// CopyDeviceToCpu() ensures the CPU buffer reflects the current device contents before reading.
460-
// This is needed because the WebGPU fast path writes directly to GPU memory, so CpuSpan() alone
461-
// would read stale/uninitialized data.
462-
auto cpu_data = byte_span.CopyDeviceToCpu();
461+
auto cpu_data = byte_span.CpuSpan();
463462
if (type_ == Ort::TypeToTensorType<int32_t>) {
464463
auto* data = reinterpret_cast<int32_t*>(cpu_data.data());
465464
for (int64_t i = 0; i < attention_mask_shape_[0]; i++)

src/python/py/models/builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def check_extra_options(kv_pairs, execution_provider):
105105
)
106106
kv_pairs["enable_webgpu_graph"] = False
107107

108+
if kv_pairs.get("compact_attention_mask", False) and execution_provider != "webgpu":
109+
print(
110+
"WARNING: compact_attention_mask is currently only supported with WebGPU execution provider. Disabling compact_attention_mask."
111+
)
112+
kv_pairs["compact_attention_mask"] = False
113+
kv_pairs["enable_webgpu_graph"] = False
114+
108115

109116
def parse_extra_options(kv_items, execution_provider):
110117
"""

src/python/py/models/builders/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4463,7 +4463,7 @@ def make_attention_mask_compact_reformatting_for_gqa(self, attn_mask_basename):
44634463

44644464
# Cast from INT64 to INT32
44654465
cast_name = f"{attn_mask_basename}/Cast"
4466-
self.make_cast(cast_name, "attention_mask", dtype=ir.DataType.INT32, shape=["batch_size", 1])
4466+
self.make_cast(cast_name, self.input_names["attention_mask"], dtype=ir.DataType.INT32, shape=["batch_size", 1])
44674467

44684468
# Reshape from [batch_size, 1] to [batch_size]
44694469
reshape_name = f"{attn_mask_basename}/Reshape"

src/webgpu/interface.cpp

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -216,43 +216,36 @@ struct InterfaceImpl : DeviceInterface {
216216
return true;
217217
}
218218

219-
// Compact attention mask: write total_length into a [batch_beam_size, 1] tensor on WebGPU.
220-
// This avoids GPU->CPU->GPU round-trips by only doing a single CPU->GPU copy.
219+
// Compact attention mask: write total_length into a [1, 1] tensor on WebGPU.
220+
// Only supports batch_beam_size == 1 to avoid per-token heap allocation.
221+
// For batch_beam_size > 1, returns false to fall back to the CPU path.
221222
bool UpdateCompactAttentionMask(void* mask_data, int batch_beam_size, int total_length, ONNXTensorElementDataType type) override {
222-
if (!ort_allocator_) {
223-
throw std::runtime_error("WebGPU allocator not initialized");
223+
if (!ort_allocator_ || batch_beam_size != 1) {
224+
return false;
224225
}
225226

226-
// Prepare the values on CPU using properly aligned buffers and perform a single CPU->GPU copy
227+
// Single scalar on the stack — no heap allocation
227228
if (type == Ort::TypeToTensorType<int32_t>) {
228-
const size_t elem_size = sizeof(int32_t);
229-
const size_t byte_count = static_cast<size_t>(batch_beam_size) * elem_size;
230-
std::vector<int32_t> cpu_buffer(batch_beam_size);
231-
for (int i = 0; i < batch_beam_size; ++i) {
232-
cpu_buffer[i] = static_cast<int32_t>(total_length);
233-
}
229+
int32_t value = static_cast<int32_t>(total_length);
230+
const size_t byte_count = sizeof(int32_t);
234231

235232
int64_t shape_val = static_cast<int64_t>(byte_count);
236233
std::span<const int64_t> shape{&shape_val, 1};
237234
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
238-
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_buffer.data(), byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
235+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, &value, byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
239236
auto dst_tensor = OrtValue::CreateTensor(*ort_memory_info_, mask_data, byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
240237

241238
const std::vector<const OrtValue*> src_ptrs = {src_tensor.get()};
242239
const std::vector<OrtValue*> dst_ptrs = {dst_tensor.get()};
243240
GetOrtEnv().CopyTensors(src_ptrs, dst_ptrs, nullptr);
244241
} else {
245-
const size_t elem_size = sizeof(int64_t);
246-
const size_t byte_count = static_cast<size_t>(batch_beam_size) * elem_size;
247-
std::vector<int64_t> cpu_buffer(batch_beam_size);
248-
for (int i = 0; i < batch_beam_size; ++i) {
249-
cpu_buffer[i] = static_cast<int64_t>(total_length);
250-
}
242+
int64_t value = static_cast<int64_t>(total_length);
243+
const size_t byte_count = sizeof(int64_t);
251244

252245
int64_t shape_val = static_cast<int64_t>(byte_count);
253246
std::span<const int64_t> shape{&shape_val, 1};
254247
auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
255-
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, cpu_buffer.data(), byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
248+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, &value, byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
256249
auto dst_tensor = OrtValue::CreateTensor(*ort_memory_info_, mask_data, byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
257250

258251
const std::vector<const OrtValue*> src_ptrs = {src_tensor.get()};

0 commit comments

Comments
 (0)