@@ -111,19 +111,19 @@ Status GeneratePositionIDs(onnxruntime::webgpu::ComputeContext& context, const W
111
111
}
112
112
113
113
Status RunRotaryEmbedding (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* input, const Tensor* pos_ids, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* output, bool is_query_input) {
114
- const auto half_rotary_embedding_dim = gsl::narrow <uint32_t >(cos_cache->Shape ()[1 ]);
114
+ const auto half_rotary_embedding_dim = gsl::narrow_cast <uint32_t >(cos_cache->Shape ()[1 ]);
115
115
const auto head_size = params.head_size_ ;
116
116
const auto hidden_size = is_query_input ? params.hidden_size_ : params.kv_hidden_size_ ;
117
117
const TensorShape global_shape ({params.batch_size_ , params.sequence_length_ , hidden_size / head_size, static_cast <int64_t >(head_size - half_rotary_embedding_dim)});
118
118
const auto rank = global_shape.NumDimensions ();
119
119
std::vector<uint32_t > global_dims (rank);
120
120
std::vector<uint32_t > global_strides (rank);
121
121
for (size_t j = 0 ; j < rank; ++j) {
122
- global_dims[j] = gsl::narrow <uint32_t >(global_shape[j]);
123
- global_strides[j] = gsl::narrow <uint32_t >(global_shape.SizeFromDimension (j + 1 ));
122
+ global_dims[j] = gsl::narrow_cast <uint32_t >(global_shape[j]);
123
+ global_strides[j] = gsl::narrow_cast <uint32_t >(global_shape.SizeFromDimension (j + 1 ));
124
124
}
125
- const auto input_output_strides = std::vector<uint32_t >({gsl::narrow <uint32_t >(input->Shape ().SizeFromDimension (1 )), gsl::narrow <uint32_t >(hidden_size), gsl::narrow <uint32_t >(head_size), 1 });
126
- const auto output_size = gsl::narrow <const uint32_t >(global_shape.Size ());
125
+ const auto input_output_strides = std::vector<uint32_t >({gsl::narrow_cast <uint32_t >(input->Shape ().SizeFromDimension (1 )), gsl::narrow_cast <uint32_t >(hidden_size), gsl::narrow_cast <uint32_t >(head_size), 1 });
126
+ const auto output_size = gsl::narrow_cast <const uint32_t >(global_shape.Size ());
127
127
128
128
RotaryEmbeddingProgram program (params.rotary_interleaved_ );
129
129
program
0 commit comments