Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ struct vk_device_struct {
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_diag[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
Expand Down Expand Up @@ -3917,6 +3918,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -8400,6 +8404,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_DIAG:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
Expand Down Expand Up @@ -9091,6 +9101,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
Expand Down Expand Up @@ -9778,6 +9789,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
}

static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));

ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
}

static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
Expand Down Expand Up @@ -11857,6 +11874,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TRI:
ggml_vk_tri(ctx, compute_ctx, src0, node);

break;
case GGML_OP_DIAG:
ggml_vk_diag(ctx, compute_ctx, src0, node);

break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
Expand Down Expand Up @@ -14001,6 +14022,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
case GGML_OP_TRI:
case GGML_OP_DIAG:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_ARGSORT:
Expand Down Expand Up @@ -14591,6 +14613,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_DIAG) {
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Expand Down
29 changes: 29 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/diag.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#version 450

#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;

if (i10 == i11) {
const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
} else {
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
}
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,8 @@ void process_shaders() {

string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
Expand Down
Loading