Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ KERNEL(pa_sdpa_opt)(
#endif
#if HAS_ALIBI
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
#if HAS_SINK_INPUT
const __global SINK_DATA_T* sink_ptr,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
Expand Down Expand Up @@ -341,7 +344,13 @@ KERNEL(pa_sdpa_opt)(

// Final max value after reduction across of all SG and WI
unroll_for (uint q_idx = 0; q_idx < QUERIES_PER_WI; q_idx++) {
#ifdef HAS_SINK_INPUT
const uint head_idx = get_global_id(1);
const SOFTMAX_ACCUMULATOR_TYPE qk_max_tmp = sub_group_reduce_max(GET_VECTOR_ELEMENT(qk_max, q_idx));
GET_VECTOR_ELEMENT(qk_max, q_idx) = qk_max_tmp > sink_ptr[head_idx] ? qk_max_tmp : sink_ptr[head_idx];
#else
GET_VECTOR_ELEMENT(qk_max, q_idx) = sub_group_reduce_max(GET_VECTOR_ELEMENT(qk_max, q_idx));
#endif
}

SOFTMAX_ACCUMULATOR_VEC_TYPE exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
Expand Down Expand Up @@ -388,6 +397,10 @@ KERNEL(pa_sdpa_opt)(
// Final sum of all exp_sum values
unroll_for (uint q_idx = 0; q_idx < QUERIES_PER_WI; q_idx++) {
GET_VECTOR_ELEMENT(exp_sum, q_idx) = sub_group_reduce_add(GET_VECTOR_ELEMENT(exp_sum, q_idx));
#ifdef HAS_SINK_INPUT
const uint head_idx = get_global_id(1);
GET_VECTOR_ELEMENT(exp_sum, head_idx) += (native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(sink_ptr[head_idx] - GET_VECTOR_ELEMENT(qk_max, q_idx))));
#endif
}

for (uint qk_idx = 0; qk_idx < qk_iters_num; qk_idx++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ class PagedAttentionGeneratorBase : public KernelGenerator {
jit.add(make_type_jit_constants("ALIBI_INPUT", params.input_layouts[alibi_input_idx].data_type));
}

if (desc->has_sink_input) {
const auto& sink_layout = params.input_layouts[PagedAttentionInputIdx::SINKS];
jit.make("SINK_DATA_T", to_ocl_type(sink_layout.data_type));
jit.make("HAS_SINK_INPUT", 1);
}
if (params.output_layouts.size() > 1) {
jit.make("PAGED_ATTENTION_SCORES_OUTPUT", 1);
if (desc->has_score_aggregation) {
Expand Down Expand Up @@ -396,7 +401,7 @@ class PagedAttentionGeneratorSingleToken : public PagedAttentionGeneratorBase {
const auto has_alibi = params.get_input_layout(PagedAttentionInputIdx::ALIBI).count() > 0;
const auto has_scale_input = !desc->scale_val.has_value();
const auto has_scores_output = params.output_layouts.size() > 1;

const auto has_sink_input = desc->has_sink_input;
if (params.is_dynamic()) {
args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0});
}
Expand All @@ -416,6 +421,10 @@ class PagedAttentionGeneratorSingleToken : public PagedAttentionGeneratorBase {
args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::ALIBI}); // alibi
}

if (has_sink_input) {
args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SINKS}); // sink
}

args.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
add_intermediate_inputs(args, has_scores_output, false, desc->has_score_aggregation);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ JitConstants SDPAMicroGenerator::get_jit_constants(const kernel_impl_params& par
jit.add(make_layout_jit_constants("INPUT" + to_code_string(5), params.input_layouts[tensor_id], in_offsets_map.at(tensor_id)));
}

if (desc->has_sink_input) {
const auto& sink_layout = params.input_layouts[PagedAttentionInputIdx::SINKS];
jit.make("SINK_DATA_T", to_ocl_type(sink_layout.data_type));
jit.make("HAS_SINK_INPUT", 1);
}

jit.add(make_layout_jit_constants("OUTPUT", params.output_layouts[0], out_offsets_map.at(0)));
if (has_scores_output) {
jit.add(make_layout_jit_constants("OUTPUT" + to_code_string(1), params.output_layouts[1], out_offsets_map.at(1)));
Expand Down Expand Up @@ -1235,6 +1241,7 @@ Arguments SDPAMicroGenerator::get_arguments_desc(const kernel_impl_params& param
auto data_inputs_num = micro_get_input_num(params, config);

if (config.is_paged_attention) {
const auto desc = params.typed_desc<paged_attention>();
if (m_is_prefill) {
args.push_back({ArgumentDescriptor::Types::INPUT, 1}); // Key
args.push_back({ArgumentDescriptor::Types::INPUT, 0}); // Q
Expand All @@ -1254,7 +1261,11 @@ Arguments SDPAMicroGenerator::get_arguments_desc(const kernel_impl_params& param
}
if (!config.has_const_scale_val)
args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SCALE}); // scale
args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // blocked_indexes_start_and_gws_mapping

if (desc->has_sink_input)
args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SINKS}); // sink

args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // blocked_indexes_start_and_gws_mapping
} else {
args.push_back({ArgumentDescriptor::Types::INPUT, ScaledDotProductAttentionInputIdx::KEY}); // K
args.push_back({ArgumentDescriptor::Types::INPUT, ScaledDotProductAttentionInputIdx::QUERY}); // Q
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
tile_elementwise(S_tile, scale);
#endif // LOG_2_MUL_SCALE
tile_binary(S_tile, mask_tile_float, binary_add);
#elif IS_CAUSAL && HAS_SINK_INPUT
#define scale(x) ((x)* scale)
tile_elementwise(S_tile, scale);
#endif

/* Apply k mask */
Expand All @@ -543,7 +546,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
#else
#define greater_than(offset_k, offset_q) (offset_k > offset_q)
#endif

int col_offset = wg_j0 + sg_j0_kq;
#if IS_PAGED_ATTENTION && IS_PREFILL == 0
col_offset += k - q;
Expand Down
Loading