Skip to content

Commit ce3d60b

Browse files
Replaced gsl::naroow with gsl::narrow_cast
1 parent 3d4e022 commit ce3d60b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,19 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W
111111
}
112112

113113
Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) {
114-
const auto half_rotary_embedding_dim = gsl::narrow<uint32_t>(cos_cache->Shape()[1]);
114+
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
115115
const auto head_size = params.head_size_;
116116
const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_;
117117
const TensorShape global_shape({params.batch_size_, params.sequence_length_, hidden_size / head_size, static_cast<int64_t>(head_size - half_rotary_embedding_dim)});
118118
const auto rank = global_shape.NumDimensions();
119119
std::vector<uint32_t> global_dims(rank);
120120
std::vector<uint32_t> global_strides(rank);
121121
for (size_t j = 0; j < rank; ++j) {
122-
global_dims[j] = gsl::narrow<uint32_t>(global_shape[j]);
123-
global_strides[j] = gsl::narrow<uint32_t>(global_shape.SizeFromDimension(j + 1));
122+
global_dims[j] = gsl::narrow_cast<uint32_t>(global_shape[j]);
123+
global_strides[j] = gsl::narrow_cast<uint32_t>(global_shape.SizeFromDimension(j + 1));
124124
}
125-
const auto input_output_strides = std::vector<uint32_t>({gsl::narrow<uint32_t>(input->Shape().SizeFromDimension(1)), gsl::narrow<uint32_t>(hidden_size), gsl::narrow<uint32_t>(head_size), 1});
126-
const auto output_size = gsl::narrow<const uint32_t>(global_shape.Size());
125+
const auto input_output_strides = std::vector<uint32_t>({gsl::narrow_cast<uint32_t>(input->Shape().SizeFromDimension(1)), gsl::narrow_cast<uint32_t>(hidden_size), gsl::narrow_cast<uint32_t>(head_size), 1});
126+
const auto output_size = gsl::narrow_cast<const uint32_t>(global_shape.Size());
127127

128128
RotaryEmbeddingProgram program(params.rotary_interleaved_);
129129
program

0 commit comments

Comments
 (0)