diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/paged_attention_opt.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/paged_attention_opt.cl index 85617f887d5571..0070a30d4a1cf4 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/paged_attention_opt.cl +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/paged_attention_opt.cl @@ -61,6 +61,9 @@ KERNEL(pa_sdpa_opt)( #endif #if HAS_ALIBI const __global ALIBI_INPUT_TYPE* alibi_slopes, +#endif +#if HAS_SINK_INPUT + const __global SINK_DATA_T* sink_ptr, #endif __global OUTPUT_TYPE* output, #if PAGED_ATTENTION_SCORES_OUTPUT @@ -341,7 +344,13 @@ KERNEL(pa_sdpa_opt)( // Final max value after reduction across of all SG and WI unroll_for (uint q_idx = 0; q_idx < QUERIES_PER_WI; q_idx++) { + #ifdef HAS_SINK_INPUT + const uint head_idx = get_global_id(1); + const SOFTMAX_ACCUMULATOR_TYPE qk_max_tmp = sub_group_reduce_max(GET_VECTOR_ELEMENT(qk_max, q_idx)); + GET_VECTOR_ELEMENT(qk_max, q_idx) = qk_max_tmp > sink_ptr[head_idx] ? qk_max_tmp : sink_ptr[head_idx]; + #else GET_VECTOR_ELEMENT(qk_max, q_idx) = sub_group_reduce_max(GET_VECTOR_ELEMENT(qk_max, q_idx)); + #endif } SOFTMAX_ACCUMULATOR_VEC_TYPE exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; @@ -388,6 +397,10 @@ KERNEL(pa_sdpa_opt)( // Final sum of all exp_sum values unroll_for (uint q_idx = 0; q_idx < QUERIES_PER_WI; q_idx++) { GET_VECTOR_ELEMENT(exp_sum, q_idx) = sub_group_reduce_add(GET_VECTOR_ELEMENT(exp_sum, q_idx)); + #ifdef HAS_SINK_INPUT + const uint head_idx = get_global_id(1); + GET_VECTOR_ELEMENT(exp_sum, head_idx) += (native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(sink_ptr[head_idx] - GET_VECTOR_ELEMENT(qk_max, q_idx)))); + #endif } for (uint qk_idx = 0; qk_idx < qk_iters_num; qk_idx++) { diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp index 9b51f99cc6bcda..e9ce67de6dcfad 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp @@ -308,6 +308,11 @@ class PagedAttentionGeneratorBase : public KernelGenerator { jit.add(make_type_jit_constants("ALIBI_INPUT", params.input_layouts[alibi_input_idx].data_type)); } + if (desc->has_sink_input) { + const auto& sink_layout = params.input_layouts[PagedAttentionInputIdx::SINKS]; + jit.make("SINK_DATA_T", to_ocl_type(sink_layout.data_type)); + jit.make("HAS_SINK_INPUT", 1); + } if (params.output_layouts.size() > 1) { jit.make("PAGED_ATTENTION_SCORES_OUTPUT", 1); if (desc->has_score_aggregation) { @@ -396,7 +401,7 @@ class PagedAttentionGeneratorSingleToken : public PagedAttentionGeneratorBase { const auto has_alibi = params.get_input_layout(PagedAttentionInputIdx::ALIBI).count() > 0; const auto has_scale_input = !desc->scale_val.has_value(); const auto has_scores_output = params.output_layouts.size() > 1; - + const auto has_sink_input = desc->has_sink_input; if (params.is_dynamic()) { args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0}); } @@ -416,6 +421,10 @@ class PagedAttentionGeneratorSingleToken : public PagedAttentionGeneratorBase { args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::ALIBI}); // alibi } + if (has_sink_input) { + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SINKS}); // sink + } + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); add_intermediate_inputs(args, has_scores_output, false, desc->has_score_aggregation); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/sdpa_gen_micro.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/sdpa_gen_micro.cpp index 660d280d884ff3..4c98d9f757cd12 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/sdpa_gen_micro.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/sdpa_gen_micro.cpp @@ -952,6 +952,12 @@ JitConstants SDPAMicroGenerator::get_jit_constants(const kernel_impl_params& par jit.add(make_layout_jit_constants("INPUT" + to_code_string(5), params.input_layouts[tensor_id], in_offsets_map.at(tensor_id))); } + if (desc->has_sink_input) { + const auto& sink_layout = params.input_layouts[PagedAttentionInputIdx::SINKS]; + jit.make("SINK_DATA_T", to_ocl_type(sink_layout.data_type)); + jit.make("HAS_SINK_INPUT", 1); + } + jit.add(make_layout_jit_constants("OUTPUT", params.output_layouts[0], out_offsets_map.at(0))); if (has_scores_output) { jit.add(make_layout_jit_constants("OUTPUT" + to_code_string(1), params.output_layouts[1], out_offsets_map.at(1))); @@ -1235,6 +1241,7 @@ Arguments SDPAMicroGenerator::get_arguments_desc(const kernel_impl_params& param auto data_inputs_num = micro_get_input_num(params, config); if (config.is_paged_attention) { + const auto desc = params.typed_desc(); if (m_is_prefill) { args.push_back({ArgumentDescriptor::Types::INPUT, 1}); // Key args.push_back({ArgumentDescriptor::Types::INPUT, 0}); // Q @@ -1254,7 +1261,11 @@ Arguments SDPAMicroGenerator::get_arguments_desc(const kernel_impl_params& param } if (!config.has_const_scale_val) args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SCALE}); // scale - args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // blocked_indexes_start_and_gws_mapping + + if (desc->has_sink_input) + args.push_back({ArgumentDescriptor::Types::INPUT, PagedAttentionInputIdx::SINKS}); // sink + + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); // blocked_indexes_start_and_gws_mapping } else { args.push_back({ArgumentDescriptor::Types::INPUT, ScaledDotProductAttentionInputIdx::KEY}); // K args.push_back({ArgumentDescriptor::Types::INPUT, ScaledDotProductAttentionInputIdx::QUERY}); // Q diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa_micro.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa_micro.cl index 6d99b7a9ec80e6..1097aab40f37ab 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa_micro.cl +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa_micro.cl @@ -530,6 +530,9 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG tile_elementwise(S_tile, scale); #endif // LOG_2_MUL_SCALE tile_binary(S_tile, mask_tile_float, binary_add); +#elif IS_CAUSAL && HAS_SINK_INPUT +#define scale(x) ((x)* scale) + tile_elementwise(S_tile, scale); #endif /* Apply k mask */ @@ -543,7 +546,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG #else #define greater_than(offset_k, offset_q) (offset_k > offset_q) #endif - + int col_offset = wg_j0 + sg_j0_kq; #if IS_PAGED_ATTENTION && IS_PREFILL == 0 col_offset += k - q;