diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c6f5809ccd7..4ed1981d029 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -722,6 +722,11 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; + + vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16; + vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16; + vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; @@ -3996,6 +4001,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -10111,7 +10123,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { + vk_op_soft_max_push_constants pc { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -10122,7 +10134,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, n_head_log2, nrows_x, src2 != nullptr - }); + }; + + if (ncols <= 16384) { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc)); + } else { + + vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a; + vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a; + vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst); + + uint32_t elems_per_wg = 128 * 4; + uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg); + size_t tmp_size = num_wgs * nrows_x * sizeof(float); + + if (ctx->prealloc_size_x < tmp_size) { + ctx->prealloc_size_x = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_size_y < tmp_size) { + ctx->prealloc_size_y = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size }; + vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size }; + + std::array elements = { num_wgs, nrows_x, 1 }; + + vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32; + vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32; + vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + + ctx->prealloc_x_need_sync = true; + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp new file mode 100644 index 00000000000..39c46639122 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp @@ -0,0 +1,62 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + if (tid == 0) { + max_val = vals[0]; + data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp new file mode 100644 index 00000000000..69524f5f756 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp @@ -0,0 +1,79 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + sum += val; + data_d[i] = D_TYPE(val); + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + sum = vals[0]; + data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp new file mode 100644 index 00000000000..06efd7d9fb4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + sum += data_s[rowx * gl_NumWorkGroups.x + i + tid]; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + sumsh[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + sumsh[tid] += sumsh[tid + s]; + } + barrier(); + } + + max_val = vals[0]; + sum = sumsh[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl new file mode 100644 index 00000000000..6636d1f8dea --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl @@ -0,0 +1,53 @@ +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 128; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(constant_id = 1) const uint num_iters = 4; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; +layout (binding = 4) buffer M {float data_m[];}; +layout (binding = 5) buffer S {float data_s[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +float get_slope(uint rowx) { + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + return slope; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92bae088b20..72c63e81732 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -899,6 +899,13 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e94a53da25..a0e7f663cfb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7627,6 +7627,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) {