@@ -216,43 +216,36 @@ struct InterfaceImpl : DeviceInterface {
216216 return true ;
217217 }
218218
219- // Compact attention mask: write total_length into a [batch_beam_size, 1] tensor on WebGPU.
220- // This avoids GPU->CPU->GPU round-trips by only doing a single CPU->GPU copy.
219+ // Compact attention mask: write total_length into a [1, 1] tensor on WebGPU.
220+ // Only supports batch_beam_size == 1 to avoid per-token heap allocation.
221+ // For batch_beam_size > 1, returns false to fall back to the CPU path.
221222 bool UpdateCompactAttentionMask (void * mask_data, int batch_beam_size, int total_length, ONNXTensorElementDataType type) override {
222- if (!ort_allocator_) {
223- throw std::runtime_error ( " WebGPU allocator not initialized " ) ;
223+ if (!ort_allocator_ || batch_beam_size != 1 ) {
224+ return false ;
224225 }
225226
226- // Prepare the values on CPU using properly aligned buffers and perform a single CPU->GPU copy
227+ // Single scalar on the stack — no heap allocation
227228 if (type == Ort::TypeToTensorType<int32_t >) {
228- const size_t elem_size = sizeof (int32_t );
229- const size_t byte_count = static_cast <size_t >(batch_beam_size) * elem_size;
230- std::vector<int32_t > cpu_buffer (batch_beam_size);
231- for (int i = 0 ; i < batch_beam_size; ++i) {
232- cpu_buffer[i] = static_cast <int32_t >(total_length);
233- }
229+ int32_t value = static_cast <int32_t >(total_length);
230+ const size_t byte_count = sizeof (int32_t );
234231
235232 int64_t shape_val = static_cast <int64_t >(byte_count);
236233 std::span<const int64_t > shape{&shape_val, 1 };
237234 auto cpu_mem_info = OrtMemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
238- auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, cpu_buffer. data () , byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
235+ auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, &value , byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
239236 auto dst_tensor = OrtValue::CreateTensor (*ort_memory_info_, mask_data, byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
240237
241238 const std::vector<const OrtValue*> src_ptrs = {src_tensor.get ()};
242239 const std::vector<OrtValue*> dst_ptrs = {dst_tensor.get ()};
243240 GetOrtEnv ().CopyTensors (src_ptrs, dst_ptrs, nullptr );
244241 } else {
245- const size_t elem_size = sizeof (int64_t );
246- const size_t byte_count = static_cast <size_t >(batch_beam_size) * elem_size;
247- std::vector<int64_t > cpu_buffer (batch_beam_size);
248- for (int i = 0 ; i < batch_beam_size; ++i) {
249- cpu_buffer[i] = static_cast <int64_t >(total_length);
250- }
242+ int64_t value = static_cast <int64_t >(total_length);
243+ const size_t byte_count = sizeof (int64_t );
251244
252245 int64_t shape_val = static_cast <int64_t >(byte_count);
253246 std::span<const int64_t > shape{&shape_val, 1 };
254247 auto cpu_mem_info = OrtMemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
255- auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, cpu_buffer. data () , byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
248+ auto src_tensor = OrtValue::CreateTensor (*cpu_mem_info, &value , byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
256249 auto dst_tensor = OrtValue::CreateTensor (*ort_memory_info_, mask_data, byte_count, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
257250
258251 const std::vector<const OrtValue*> src_ptrs = {src_tensor.get ()};
0 commit comments