Skip to content

Commit 8fcb563

Browse files
Load all MoE experts during warmup (ggml-org#11571)
* llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup * common : use new API to enable warmup mode during model warmup --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent add2a3a commit 8fcb563

File tree

6 files changed

+22
-2
lines changed

6 files changed

+22
-2
lines changed

common/common.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) {
10331033
if (params.warmup) {
10341034
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
10351035

1036+
llama_set_warmup(lctx, true);
1037+
10361038
std::vector<llama_token> tmp;
10371039
llama_token bos = llama_vocab_bos(vocab);
10381040
llama_token eos = llama_vocab_eos(vocab);
@@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10631065
llama_kv_self_clear(lctx);
10641066
llama_synchronize(lctx);
10651067
llama_perf_context_reset(lctx);
1068+
llama_set_warmup(lctx, false);
10661069
}
10671070

10681071
iparams.model.reset(model);

include/llama.h

+4
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,10 @@ extern "C" {
945945
// If set to true, the model will only attend to the past tokens
946946
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
947947

948+
// Set whether the model is in warmup mode or not
949+
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
950+
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
951+
948952
// Set abort callback
949953
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
950954

src/llama-context.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ llama_context::llama_context(
3939
cparams.flash_attn = params.flash_attn;
4040
cparams.no_perf = params.no_perf;
4141
cparams.pooling_type = params.pooling_type;
42+
cparams.warmup = false;
4243

4344
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
4445
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -948,6 +949,12 @@ void llama_context::set_causal_attn(bool value) {
948949
cparams.causal_attn = value;
949950
}
950951

952+
void llama_context::set_warmup(bool value) {
953+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
954+
955+
cparams.warmup = value;
956+
}
957+
951958
void llama_context::set_adapter_lora(
952959
llama_adapter_lora * adapter,
953960
float scale) {
@@ -1594,7 +1601,7 @@ void llama_context::output_reorder() {
15941601
//
15951602

15961603
int32_t llama_context::graph_max_nodes() const {
1597-
return std::max<int32_t>(8192, 5*model.n_tensors());
1604+
return std::max<int32_t>(65536, 5*model.n_tensors());
15981605
}
15991606

16001607
ggml_cgraph * llama_context::graph_init() {
@@ -2372,6 +2379,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
23722379
ctx->set_causal_attn(causal_attn);
23732380
}
23742381

2382+
void llama_set_warmup(llama_context * ctx, bool warmup) {
2383+
ctx->set_warmup(warmup);
2384+
}
2385+
23752386
void llama_synchronize(llama_context * ctx) {
23762387
ctx->synchronize();
23772388
}

src/llama-context.h

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ struct llama_context {
6464

6565
void set_embeddings (bool value);
6666
void set_causal_attn(bool value);
67+
void set_warmup(bool value);
6768

6869
void set_adapter_lora(
6970
llama_adapter_lora * adapter,

src/llama-cparams.h

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct llama_cparams {
2929
bool offload_kqv;
3030
bool flash_attn;
3131
bool no_perf;
32+
bool warmup;
3233

3334
enum llama_pooling_type pooling_type;
3435

src/llama-graph.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
577577
n_embd_head_v (hparams.n_embd_head_v),
578578
n_embd_v_gqa (hparams.n_embd_v_gqa()),
579579
n_expert (hparams.n_expert),
580-
n_expert_used (hparams.n_expert_used),
580+
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
581581
freq_base (cparams.rope_freq_base),
582582
freq_scale (cparams.rope_freq_scale),
583583
ext_factor (cparams.yarn_ext_factor),

0 commit comments

Comments
 (0)