Skip to content

Commit 8a2afb7

Browse files
authored
llama : allow custom list of swa_layers (#13726)
1 parent 9ecf3e6 commit 8a2afb7

File tree

3 files changed

+54
-23
lines changed

3 files changed

+54
-23
lines changed

src/llama-hparams.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,26 @@
22

33
#include "ggml.h"
44

5+
llama_hparams::llama_hparams() {
6+
swa_layers.fill(false);
7+
}
8+
9+
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
10+
for (uint32_t il = 0; il < n_layer; ++il) {
11+
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
12+
}
13+
}
14+
15+
bool llama_hparams::is_swa_any() const {
16+
for (uint32_t il = 0; il < n_layer; ++il) {
17+
if (swa_layers[il]) {
18+
return true;
19+
}
20+
}
21+
22+
return false;
23+
}
24+
525
uint32_t llama_hparams::n_head(uint32_t il) const {
626
if (il < n_layer) {
727
return n_head_arr[il];
@@ -72,7 +92,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
7292

7393
bool llama_hparams::is_swa(uint32_t il) const {
7494
if (il < n_layer) {
75-
return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1));
95+
return swa_layers[il];
7696
}
7797

7898
GGML_ABORT("fatal error");

src/llama-hparams.h

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,12 @@ struct llama_hparams {
102102

103103
// Sliding Window Attention (SWA)
104104
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105-
106-
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
107-
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
108-
// by default n == 1, all layers are dense
109-
// note that if n_swa_pattern == 0, all layers are SWA
110-
// example: n_swa_pattern = 3
111-
// il == 0: swa
112-
// il == 1: swa
113-
// il == 2: dense
114-
// il == 3: swa
115-
// il == 4: swa
116-
// il == 5: dense
117-
// il == 6: swa
118-
// etc ...
105+
// the size of the sliding window (0 - no SWA)
106+
uint32_t n_swa = 0;
107+
// if swa_layers[il] == true, then layer il is SWA
108+
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109+
// by default, all layers are dense
110+
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
119111

120112
// for State Space Models
121113
uint32_t ssm_d_conv = 0;
@@ -153,6 +145,25 @@ struct llama_hparams {
153145
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
154146
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
155147

148+
llama_hparams();
149+
150+
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
151+
// note that if n_pattern == 0, all layers are SWA
152+
// if n_pattern == 1, all layers are dense
153+
// example: n_pattern = 3
154+
// il == 0: swa
155+
// il == 1: swa
156+
// il == 2: dense
157+
// il == 3: swa
158+
// il == 4: swa
159+
// il == 5: dense
160+
// il == 6: swa
161+
// etc ...
162+
void set_swa_pattern(uint32_t n_pattern);
163+
164+
// return true if one of the layers is SWA
165+
bool is_swa_any() const;
166+
156167
uint32_t n_head(uint32_t il = 0) const;
157168

158169
uint32_t n_head_kv(uint32_t il = 0) const;

src/llama-model.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
574574

575575
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
576576
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
577-
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
577+
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
578578

579579
switch (hparams.n_expert) {
580580
case 16: type = LLM_TYPE_17B_16E; break;
@@ -863,7 +863,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
863863
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
864864

865865
hparams.n_swa = 0;
866-
hparams.n_swa_pattern = 1;
866+
hparams.set_swa_pattern(1);
867867
}
868868
} break;
869869
case LLM_ARCH_PHIMOE:
@@ -935,7 +935,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
935935
{
936936
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
937937
hparams.n_swa = 4096; // default value of gemma 2
938-
hparams.n_swa_pattern = 2;
938+
hparams.set_swa_pattern(2);
939939
hparams.attn_soft_cap = true;
940940

941941
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -953,7 +953,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
953953
case LLM_ARCH_GEMMA3:
954954
{
955955
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
956-
hparams.n_swa_pattern = 6;
956+
hparams.set_swa_pattern(6);
957957

958958
hparams.rope_freq_base_train_swa = 10000.0f;
959959
hparams.rope_freq_scale_train_swa = 1.0f;
@@ -1038,7 +1038,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10381038
case LLM_ARCH_COHERE2:
10391039
{
10401040
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1041-
hparams.n_swa_pattern = 4;
1041+
hparams.set_swa_pattern(4);
10421042

10431043
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
10441044
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@@ -4320,7 +4320,7 @@ void llama_model::print_info() const {
43204320
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
43214321
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
43224322
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
4323-
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
4323+
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
43244324
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
43254325
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
43264326
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
@@ -13216,7 +13216,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321613216
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1321713217

1321813218
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13219-
GGML_ASSERT(hparams.n_swa_pattern != 1);
13219+
GGML_ASSERT(hparams.is_swa_any());
1322013220

1322113221
res = new llama_kv_cache_unified_iswa(
1322213222
*this,
@@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1323013230
cparams.n_batch,
1323113231
padding);
1323213232
} else {
13233-
GGML_ASSERT(hparams.n_swa_pattern == 1);
13233+
GGML_ASSERT(!hparams.is_swa_any());
1323413234

1323513235
res = new llama_kv_cache_unified(
1323613236
*this,

0 commit comments

Comments
 (0)