Skip to content

Commit 11dd2bf

Browse files
committed
Add temporary rms_norm_eps parameter
1 parent 8cd64d4 commit 11dd2bf

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

llama_cpp/llama.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
rope_freq_base: float = 10000.0,
225225
rope_freq_scale: float = 1.0,
226226
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
227-
rms_eps_norm: Optional[float] = None, # (TEMPORARY)
227+
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
228228
verbose: bool = True,
229229
):
230230
"""Load a llama.cpp model from `model_path`.
@@ -287,8 +287,8 @@ def __init__(
287287
if n_gqa is not None:
288288
self.params.n_gqa = n_gqa
289289

290-
if rms_eps_norm is not None:
291-
self.params.rms_eps_norm = rms_eps_norm
290+
if rms_norm_eps is not None:
291+
self.params.rms_norm_eps = rms_norm_eps
292292

293293
self.last_n_tokens_size = last_n_tokens_size
294294
self.n_batch = min(n_ctx, n_batch)
@@ -1533,7 +1533,7 @@ def __getstate__(self):
15331533
tensor_split=self.tensor_split,
15341534
### TEMPORARY ###
15351535
n_gqa=self.params.n_gqa,
1536-
rms_eps_norm=self.params.rms_eps_norm,
1536+
rms_norm_eps=self.params.rms_norm_eps,
15371537
### TEMPORARY ###
15381538
### DEPRECATED ###
15391539
n_parts=self.n_parts,
@@ -1559,11 +1559,11 @@ def __setstate__(self, state):
15591559
lora_base=state["lora_base"],
15601560
lora_path=state["lora_path"],
15611561
tensor_split=state["tensor_split"],
1562-
n_gqa=state["n_gqa"],
1563-
### TEMPORARY ###
1564-
rms_eps_norm=state["rms_eps_norm"],
15651562
verbose=state["verbose"],
15661563
### TEMPORARY ###
1564+
n_gqa=state["n_gqa"],
1565+
rms_norm_eps=state["rms_norm_eps"],
1566+
### TEMPORARY ###
15671567
### DEPRECATED ###
15681568
n_parts=state["n_parts"],
15691569
### DEPRECATED ###

llama_cpp/server/app.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ class Settings(BaseSettings):
9595
default=True,
9696
description="Whether to interrupt requests when a new request is received.",
9797
)
98+
n_gqa: Optional[int] = Field(
99+
default=None,
100+
description="TEMPORARY: Set to 8 for Llama2 70B",
101+
)
102+
rms_norm_eps: Optional[float] = Field(
103+
default=None,
104+
description="TEMPORARY",
105+
)
98106

99107

100108
class ErrorResponse(TypedDict):
@@ -320,6 +328,8 @@ def create_app(settings: Optional[Settings] = None):
320328
last_n_tokens_size=settings.last_n_tokens_size,
321329
vocab_only=settings.vocab_only,
322330
verbose=settings.verbose,
331+
n_gqa=settings.n_gqa,
332+
rms_norm_eps=settings.rms_norm_eps,
323333
)
324334
if settings.cache:
325335
if settings.cache_type == "disk":

0 commit comments

Comments
 (0)