Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,11 @@ struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_soft_max_back_f32;

vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;

vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
Expand Down Expand Up @@ -3996,6 +4001,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);

ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);

ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
Expand Down Expand Up @@ -10111,7 +10123,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
vk_op_soft_max_push_constants pc {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
Expand All @@ -10122,7 +10134,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
n_head_log2,
nrows_x,
src2 != nullptr
});
};

if (ncols <= 16384) {
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
} else {

vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);

uint32_t elems_per_wg = 128 * 4;
uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
size_t tmp_size = num_wgs * nrows_x * sizeof(float);

if (ctx->prealloc_size_x < tmp_size) {
ctx->prealloc_size_x = tmp_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_size_y < tmp_size) {
ctx->prealloc_size_y = tmp_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}

vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };

std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };

vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;

ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);

ctx->prealloc_x_need_sync = true;
ctx->prealloc_y_need_sync = true;
}
}

static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down
62 changes: 62 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#version 450

#include "soft_max_large_common.glsl"

void main() {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.y;
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;

const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;

uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}

if (rowx >= p.nrows_x) {
return;
}

float slope = get_slope(rowx);

// Find max
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];

[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;

FLOAT_TYPE a = FLOAT_TYPE(0);
if (col < p.KX) {
a = data_a[rowx * p.KX + col];
}

FLOAT_TYPE b = FLOAT_TYPE(0);
if (p.KY > 0 && col < p.KX) {
b = data_b[rowy_start + col];
}

FLOAT_TYPE v = a * p.scale + slope * b;

if (col < p.KX) {
max_val = max(max_val, v);
}
}

// reduce across the workgroup
vals[tid] = max_val;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(vals[tid], vals[tid + s]);
}
barrier();
}

if (tid == 0) {
max_val = vals[0];
data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
}
}
79 changes: 79 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#version 450

#include "soft_max_large_common.glsl"

void main() {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.y;
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;

const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;

uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}

if (rowx >= p.nrows_x) {
return;
}

float slope = get_slope(rowx);

// Find max
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];

[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
if (i + tid < gl_NumWorkGroups.x) {
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
}
}

// reduce across the workgroup
vals[tid] = max_val;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(max_val, vals[tid + s]);
}
barrier();
}

max_val = vals[0];
barrier();

FLOAT_TYPE sum = FLOAT_TYPE(0.0f);

// Compute sum{exp(x - max)}
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;

if (col >= p.KX) {
break;
}

// compute exp(a*scale+b*slope), add it to sum
const uint i = rowx * p.KX + col;
FLOAT_TYPE val;
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
sum += val;
data_d[i] = D_TYPE(val);
}

// reduce across the workgroup
vals[tid] = sum;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] += vals[tid + s];
}
barrier();
}

if (tid == 0) {
sum = vals[0];
data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;
}
}
65 changes: 65 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#version 450

#include "soft_max_large_common.glsl"

shared FLOAT_TYPE sumsh[BLOCK_SIZE];

void main() {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.y;
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;

const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;

uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}

if (rowx >= p.nrows_x) {
return;
}

FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);

[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
if (i + tid < gl_NumWorkGroups.x) {
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
}
}

// reduce across the workgroup
vals[tid] = max_val;
sumsh[tid] = sum;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(max_val, vals[tid + s]);
sumsh[tid] += sumsh[tid + s];
}
barrier();
}

max_val = vals[0];
sum = sumsh[0];

if (p.has_sinks != 0) {
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
}

FLOAT_TYPE rcpdivisor = 1.0/sum;

[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;

if (col >= p.KX) {
continue;
}

data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
}
}
53 changes: 53 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#extension GL_EXT_control_flow_attributes : enable

layout (push_constant) uniform parameter
{
uint KX;
uint KY;
uint ne00;
uint ne01;
uint ne02;
uint ne12;
uint ne13;
uint nb11;
uint nb12;
uint nb13;
float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
uint nrows_x;
uint has_sinks;
} p;

#include "types.glsl"

layout(constant_id = 0) const uint BLOCK_SIZE = 128;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 1) const uint num_iters = 4;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {float data_c[];};
layout (binding = 3) buffer D {D_TYPE data_d[];};
layout (binding = 4) buffer M {float data_m[];};
layout (binding = 5) buffer S {float data_s[];};

shared FLOAT_TYPE vals[BLOCK_SIZE];

float get_slope(uint rowx) {
float slope = 1.0f;

// ALiBi
if (p.max_bias > 0.0f) {
const uint h = (rowx / p.ne01) % p.ne02; // head index

const float base = h < p.n_head_log2 ? p.m0 : p.m1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;

slope = pow(base, exp);
}

return slope;
}
7 changes: 7 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,13 @@ void process_shaders() {
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));

string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));

string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
Expand Down
3 changes: 3 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7627,6 +7627,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));

test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));

for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
Expand Down
Loading