Skip to content

Commit 7dfad38

Browse files
authored
llama: Add support for RWKV v7 architecture (ggml-org#12412)
* ggml: Add op l2_norm Signed-off-by: Molly Sophia <[email protected]> * ggml: Add op rwkv_wkv7 Signed-off-by: Molly Sophia <[email protected]> * llama: Add support for RWKV7 and ARWKV7 models Signed-off-by: Molly Sophia <[email protected]> * llama: fix inference with RWKV6Qwen2 Signed-off-by: Molly Sophia <[email protected]> * llama: add more (a)rwkv7 variants in size Signed-off-by: Molly Sophia <[email protected]> * Apply code-format changes Signed-off-by: Molly Sophia <[email protected]> * fix MUSA build Signed-off-by: Molly Sophia <[email protected]> * llama: fix shape error with rwkv using llama-parallel Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]>
1 parent 60c9029 commit 7dfad38

35 files changed

+2949
-439
lines changed

convert_hf_to_gguf.py

Lines changed: 197 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
908908
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
909909
special_vocab.add_to_gguf(self.gguf_writer)
910910

911+
def _set_vocab_rwkv_world(self):
912+
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
913+
vocab_size = self.hparams.get("vocab_size", 65536)
914+
915+
tokens: list[bytes] = ['<s>'.encode("utf-8")]
916+
toktypes: list[int] = [gguf.TokenType.CONTROL]
917+
918+
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
919+
lines = f.readlines()
920+
for line in lines:
921+
parts = line.split(' ')
922+
assert len(parts) >= 3
923+
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
924+
token = token.encode("utf-8") if isinstance(token, str) else token
925+
assert isinstance(token, bytes)
926+
assert len(token) == token_len
927+
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
928+
tokens.append(token_text.encode("utf-8"))
929+
toktypes.append(gguf.TokenType.NORMAL)
930+
remainder = vocab_size - len(tokens)
931+
assert remainder >= 0
932+
for i in range(len(tokens), vocab_size):
933+
tokens.append(f"[PAD{i}]".encode("utf-8"))
934+
toktypes.append(gguf.TokenType.UNUSED)
935+
936+
self.gguf_writer.add_tokenizer_model("rwkv")
937+
self.gguf_writer.add_token_list(tokens)
938+
self.gguf_writer.add_token_types(toktypes)
939+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
940+
special_vocab.chat_template = "rwkv-world"
941+
# hack: Add '\n\n' as the EOT token to make it chat normally
942+
special_vocab._set_special_token("eot", 261)
943+
special_vocab.add_to_gguf(self.gguf_writer)
944+
911945
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
912946
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
913947
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
@@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
34123446
model_arch = gguf.MODEL_ARCH.RWKV6
34133447

34143448
def set_vocab(self):
3415-
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
3416-
vocab_size = self.hparams.get("vocab_size", 65536)
3417-
3418-
tokens: list[bytes] = ['<s>'.encode("utf-8")]
3419-
toktypes: list[int] = [gguf.TokenType.CONTROL]
3420-
3421-
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
3422-
lines = f.readlines()
3423-
for line in lines:
3424-
parts = line.split(' ')
3425-
assert len(parts) >= 3
3426-
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
3427-
token = token.encode("utf-8") if isinstance(token, str) else token
3428-
assert isinstance(token, bytes)
3429-
assert len(token) == token_len
3430-
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
3431-
tokens.append(token_text.encode("utf-8"))
3432-
toktypes.append(gguf.TokenType.NORMAL)
3433-
remainder = vocab_size - len(tokens)
3434-
assert remainder >= 0
3435-
for i in range(len(tokens), vocab_size):
3436-
tokens.append(f"[PAD{i}]".encode("utf-8"))
3437-
toktypes.append(gguf.TokenType.UNUSED)
3438-
3439-
self.gguf_writer.add_tokenizer_model("rwkv")
3440-
self.gguf_writer.add_token_list(tokens)
3441-
self.gguf_writer.add_token_types(toktypes)
3442-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
3443-
special_vocab.chat_template = "rwkv-world"
3444-
# hack: Add '\n\n' as the EOT token to make it chat normally
3445-
special_vocab._set_special_token("eot", 261)
3446-
special_vocab.add_to_gguf(self.gguf_writer)
3449+
self._set_vocab_rwkv_world()
34473450

34483451
def set_gguf_parameters(self):
34493452
block_count = self.hparams["num_hidden_layers"]
@@ -3565,6 +3568,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35653568
yield (new_name, data)
35663569

35673570

3571+
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
3572+
class Rwkv7Model(Model):
3573+
model_arch = gguf.MODEL_ARCH.RWKV7
3574+
3575+
def set_vocab(self):
3576+
self._set_vocab_rwkv_world()
3577+
3578+
def calc_lora_rank(self, hidden_size, exponent, multiplier):
3579+
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
3580+
3581+
def set_gguf_parameters(self):
3582+
block_count = self.hparams["num_hidden_layers"]
3583+
try:
3584+
head_size = self.hparams["head_size"]
3585+
layer_norm_eps = self.hparams["layer_norm_epsilon"]
3586+
except KeyError:
3587+
head_size = self.hparams["head_dim"]
3588+
layer_norm_eps = self.hparams["norm_eps"]
3589+
hidden_size = self.hparams["hidden_size"]
3590+
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
3591+
3592+
# ICLR: In-Context-Learning-Rate
3593+
try:
3594+
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3595+
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3596+
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
3597+
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
3598+
except KeyError:
3599+
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3600+
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3601+
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
3602+
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
3603+
3604+
# RWKV isn't context limited
3605+
self.gguf_writer.add_context_length(1048576)
3606+
self.gguf_writer.add_embedding_length(hidden_size)
3607+
self.gguf_writer.add_block_count(block_count)
3608+
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
3609+
self.gguf_writer.add_wkv_head_size(head_size)
3610+
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
3611+
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
3612+
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
3613+
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
3614+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3615+
self.gguf_writer.add_file_type(self.ftype)
3616+
3617+
# required by llama.cpp, unused
3618+
self.gguf_writer.add_head_count(0)
3619+
3620+
lerp_weights: dict[int, dict[str, Tensor]] = {}
3621+
lora_needs_transpose: bool = True
3622+
3623+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3624+
# unify tensor names here to make life easier
3625+
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
3626+
name = name.replace("self_attn", "attention").replace("attn", "attention")
3627+
name = name.replace("time_mixer.", "")
3628+
# lora layer names in fla-hub's impl
3629+
if "_lora.lora" in name:
3630+
self.lora_needs_transpose = False
3631+
name = name.replace("_lora.lora.0.weight", "1.weight")
3632+
name = name.replace("_lora.lora.2.weight", "2.weight")
3633+
name = name.replace("_lora.lora.2.bias", "0.weight")
3634+
3635+
name = name.replace("feed_forward_norm", "ln2")
3636+
name = name.replace("g_norm", "ln_x")
3637+
3638+
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
3639+
# some models have dummy v0/v1/v2 on first layer while others don't
3640+
# ignore them all since they are not used
3641+
return
3642+
3643+
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
3644+
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
3645+
3646+
if bid is not None and "attention.x_" in name:
3647+
if "attention.x_x" in name:
3648+
# already concatenated
3649+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3650+
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
3651+
yield (new_name, data)
3652+
else:
3653+
try:
3654+
self.lerp_weights[bid][name] = data_torch
3655+
except KeyError:
3656+
self.lerp_weights[bid] = {name: data_torch}
3657+
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
3658+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3659+
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
3660+
yield (new_name, data)
3661+
return
3662+
else:
3663+
data_torch = data_torch.squeeze()
3664+
new_name = self.map_tensor_name(name)
3665+
3666+
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
3667+
new_name += ".weight"
3668+
3669+
if self.lora_needs_transpose and any(
3670+
new_name.endswith(t) for t in [
3671+
"time_mix_w1.weight", "time_mix_w2.weight",
3672+
"time_mix_a1.weight", "time_mix_a2.weight",
3673+
"time_mix_v1.weight", "time_mix_v2.weight",
3674+
"time_mix_g1.weight", "time_mix_g2.weight",
3675+
]
3676+
):
3677+
data_torch = data_torch.transpose(0, 1)
3678+
3679+
if 'r_k' in new_name:
3680+
data_torch = data_torch.flatten()
3681+
3682+
if bid == 0 and "time_mix_a" in new_name:
3683+
# dummy v0/v1/v2 on first layer
3684+
# easist way to make llama happy
3685+
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
3686+
3687+
yield (new_name, data_torch)
3688+
3689+
3690+
@Model.register("RwkvHybridForCausalLM")
3691+
class ARwkv7Model(Rwkv7Model):
3692+
model_arch = gguf.MODEL_ARCH.ARWKV7
3693+
3694+
def set_vocab(self):
3695+
try:
3696+
self._set_vocab_sentencepiece()
3697+
except FileNotFoundError:
3698+
self._set_vocab_gpt2()
3699+
3700+
def set_gguf_parameters(self):
3701+
block_count = self.hparams["num_hidden_layers"]
3702+
hidden_size = self.hparams["hidden_size"]
3703+
head_size = self.hparams["head_size"]
3704+
rms_norm_eps = self.hparams["rms_norm_eps"]
3705+
intermediate_size = self.hparams["intermediate_size"]
3706+
wkv_has_gate = self.hparams["wkv_has_gate"]
3707+
assert self.hparams["wkv_version"] == 7
3708+
3709+
# ICLR: In-Context-Learning-Rate
3710+
lora_rank_decay = 64
3711+
lora_rank_iclr = 64
3712+
lora_rank_value_residual_mix = 32
3713+
lora_rank_gate = 128 if wkv_has_gate else 0
3714+
3715+
# RWKV isn't context limited
3716+
self.gguf_writer.add_context_length(1048576)
3717+
self.gguf_writer.add_embedding_length(hidden_size)
3718+
self.gguf_writer.add_block_count(block_count)
3719+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
3720+
self.gguf_writer.add_wkv_head_size(head_size)
3721+
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
3722+
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
3723+
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
3724+
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
3725+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3726+
self.gguf_writer.add_file_type(self.ftype)
3727+
self.gguf_writer.add_token_shift_count(1)
3728+
3729+
# required by llama.cpp, unused
3730+
self.gguf_writer.add_head_count(0)
3731+
3732+
35683733
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
35693734
class MambaModel(Model):
35703735
model_arch = gguf.MODEL_ARCH.MAMBA

ggml/include/ggml.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ extern "C" {
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
457+
GGML_OP_L2_NORM,
457458

458459
GGML_OP_MUL_MAT,
459460
GGML_OP_MUL_MAT_ID,
@@ -502,6 +503,7 @@ extern "C" {
502503
GGML_OP_ADD_REL_POS,
503504
GGML_OP_RWKV_WKV6,
504505
GGML_OP_GATED_LINEAR_ATTN,
506+
GGML_OP_RWKV_WKV7,
505507

506508
GGML_OP_UNARY,
507509

@@ -1095,6 +1097,18 @@ extern "C" {
10951097
int n_groups,
10961098
float eps);
10971099

1100+
// l2 normalize along rows
1101+
// used in rwkv v7
1102+
GGML_API struct ggml_tensor * ggml_l2_norm(
1103+
struct ggml_context * ctx,
1104+
struct ggml_tensor * a,
1105+
float eps);
1106+
1107+
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
1108+
struct ggml_context * ctx,
1109+
struct ggml_tensor * a,
1110+
float eps);
1111+
10981112
// a - x
10991113
// b - dy
11001114
GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -1890,6 +1904,16 @@ extern "C" {
18901904
struct ggml_tensor * state,
18911905
float scale);
18921906

1907+
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
1908+
struct ggml_context * ctx,
1909+
struct ggml_tensor * r,
1910+
struct ggml_tensor * w,
1911+
struct ggml_tensor * k,
1912+
struct ggml_tensor * v,
1913+
struct ggml_tensor * a,
1914+
struct ggml_tensor * b,
1915+
struct ggml_tensor * state);
1916+
18931917
// custom operators
18941918

18951919
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

0 commit comments

Comments
 (0)