@@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
908
908
special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
909
909
special_vocab .add_to_gguf (self .gguf_writer )
910
910
911
+ def _set_vocab_rwkv_world (self ):
912
+ assert (self .dir_model / "rwkv_vocab_v20230424.txt" ).is_file ()
913
+ vocab_size = self .hparams .get ("vocab_size" , 65536 )
914
+
915
+ tokens : list [bytes ] = ['<s>' .encode ("utf-8" )]
916
+ toktypes : list [int ] = [gguf .TokenType .CONTROL ]
917
+
918
+ with open (self .dir_model / "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as f :
919
+ lines = f .readlines ()
920
+ for line in lines :
921
+ parts = line .split (' ' )
922
+ assert len (parts ) >= 3
923
+ token , token_len = ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
924
+ token = token .encode ("utf-8" ) if isinstance (token , str ) else token
925
+ assert isinstance (token , bytes )
926
+ assert len (token ) == token_len
927
+ token_text : str = repr (token )[2 :- 1 ] # "b'\xff'" -> "\xff"
928
+ tokens .append (token_text .encode ("utf-8" ))
929
+ toktypes .append (gguf .TokenType .NORMAL )
930
+ remainder = vocab_size - len (tokens )
931
+ assert remainder >= 0
932
+ for i in range (len (tokens ), vocab_size ):
933
+ tokens .append (f"[PAD{ i } ]" .encode ("utf-8" ))
934
+ toktypes .append (gguf .TokenType .UNUSED )
935
+
936
+ self .gguf_writer .add_tokenizer_model ("rwkv" )
937
+ self .gguf_writer .add_token_list (tokens )
938
+ self .gguf_writer .add_token_types (toktypes )
939
+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
940
+ special_vocab .chat_template = "rwkv-world"
941
+ # hack: Add '\n\n' as the EOT token to make it chat normally
942
+ special_vocab ._set_special_token ("eot" , 261 )
943
+ special_vocab .add_to_gguf (self .gguf_writer )
944
+
911
945
def _set_vocab_builtin (self , model_name : Literal ["gpt-neox" , "llama-spm" ], vocab_size : int ):
912
946
tokenizer_path = Path (sys .path [0 ]) / "models" / f"ggml-vocab-{ model_name } .gguf"
913
947
logger .warning (f"Using tokenizer from '{ os .path .relpath (tokenizer_path , os .getcwd ())} '" )
@@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
3412
3446
model_arch = gguf .MODEL_ARCH .RWKV6
3413
3447
3414
3448
def set_vocab (self ):
3415
- assert (self .dir_model / "rwkv_vocab_v20230424.txt" ).is_file ()
3416
- vocab_size = self .hparams .get ("vocab_size" , 65536 )
3417
-
3418
- tokens : list [bytes ] = ['<s>' .encode ("utf-8" )]
3419
- toktypes : list [int ] = [gguf .TokenType .CONTROL ]
3420
-
3421
- with open (self .dir_model / "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as f :
3422
- lines = f .readlines ()
3423
- for line in lines :
3424
- parts = line .split (' ' )
3425
- assert len (parts ) >= 3
3426
- token , token_len = ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
3427
- token = token .encode ("utf-8" ) if isinstance (token , str ) else token
3428
- assert isinstance (token , bytes )
3429
- assert len (token ) == token_len
3430
- token_text : str = repr (token )[2 :- 1 ] # "b'\xff'" -> "\xff"
3431
- tokens .append (token_text .encode ("utf-8" ))
3432
- toktypes .append (gguf .TokenType .NORMAL )
3433
- remainder = vocab_size - len (tokens )
3434
- assert remainder >= 0
3435
- for i in range (len (tokens ), vocab_size ):
3436
- tokens .append (f"[PAD{ i } ]" .encode ("utf-8" ))
3437
- toktypes .append (gguf .TokenType .UNUSED )
3438
-
3439
- self .gguf_writer .add_tokenizer_model ("rwkv" )
3440
- self .gguf_writer .add_token_list (tokens )
3441
- self .gguf_writer .add_token_types (toktypes )
3442
- special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
3443
- special_vocab .chat_template = "rwkv-world"
3444
- # hack: Add '\n\n' as the EOT token to make it chat normally
3445
- special_vocab ._set_special_token ("eot" , 261 )
3446
- special_vocab .add_to_gguf (self .gguf_writer )
3449
+ self ._set_vocab_rwkv_world ()
3447
3450
3448
3451
def set_gguf_parameters (self ):
3449
3452
block_count = self .hparams ["num_hidden_layers" ]
@@ -3565,6 +3568,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
3565
3568
yield (new_name , data )
3566
3569
3567
3570
3571
+ @Model .register ("Rwkv7ForCausalLM" , "RWKV7ForCausalLM" )
3572
+ class Rwkv7Model (Model ):
3573
+ model_arch = gguf .MODEL_ARCH .RWKV7
3574
+
3575
+ def set_vocab (self ):
3576
+ self ._set_vocab_rwkv_world ()
3577
+
3578
+ def calc_lora_rank (self , hidden_size , exponent , multiplier ):
3579
+ return max (1 , round (hidden_size ** exponent * multiplier / 32 )) * 32
3580
+
3581
+ def set_gguf_parameters (self ):
3582
+ block_count = self .hparams ["num_hidden_layers" ]
3583
+ try :
3584
+ head_size = self .hparams ["head_size" ]
3585
+ layer_norm_eps = self .hparams ["layer_norm_epsilon" ]
3586
+ except KeyError :
3587
+ head_size = self .hparams ["head_dim" ]
3588
+ layer_norm_eps = self .hparams ["norm_eps" ]
3589
+ hidden_size = self .hparams ["hidden_size" ]
3590
+ intermediate_size = self .hparams ["intermediate_size" ] if self .hparams ["intermediate_size" ] is not None else (hidden_size * 4 )
3591
+
3592
+ # ICLR: In-Context-Learning-Rate
3593
+ try :
3594
+ lora_rank_decay = self .hparams ["lora_rank_decay" ] if self .hparams ["lora_rank_decay" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3595
+ lora_rank_iclr = self .hparams ["lora_rank_iclr" ] if self .hparams ["lora_rank_iclr" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3596
+ lora_rank_value_residual_mix = self .hparams ["lora_rank_value_residual_mix" ] if self .hparams ["lora_rank_value_residual_mix" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3597
+ lora_rank_gate = self .hparams ["lora_rank_gate" ] if self .hparams ["lora_rank_gate" ] is not None else self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3598
+ except KeyError :
3599
+ lora_rank_decay = self .hparams ["decay_low_rank_dim" ] if self .hparams ["decay_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3600
+ lora_rank_iclr = self .hparams ["a_low_rank_dim" ] if self .hparams ["a_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3601
+ lora_rank_value_residual_mix = self .hparams ["v_low_rank_dim" ] if self .hparams ["v_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3602
+ lora_rank_gate = self .hparams ["gate_low_rank_dim" ] if self .hparams ["gate_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3603
+
3604
+ # RWKV isn't context limited
3605
+ self .gguf_writer .add_context_length (1048576 )
3606
+ self .gguf_writer .add_embedding_length (hidden_size )
3607
+ self .gguf_writer .add_block_count (block_count )
3608
+ self .gguf_writer .add_layer_norm_eps (layer_norm_eps )
3609
+ self .gguf_writer .add_wkv_head_size (head_size )
3610
+ self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3611
+ self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3612
+ self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3613
+ self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3614
+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3615
+ self .gguf_writer .add_file_type (self .ftype )
3616
+
3617
+ # required by llama.cpp, unused
3618
+ self .gguf_writer .add_head_count (0 )
3619
+
3620
+ lerp_weights : dict [int , dict [str , Tensor ]] = {}
3621
+ lora_needs_transpose : bool = True
3622
+
3623
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3624
+ # unify tensor names here to make life easier
3625
+ name = name .replace ("blocks" , "layers" ).replace ("ffn" , "feed_forward" )
3626
+ name = name .replace ("self_attn" , "attention" ).replace ("attn" , "attention" )
3627
+ name = name .replace ("time_mixer." , "" )
3628
+ # lora layer names in fla-hub's impl
3629
+ if "_lora.lora" in name :
3630
+ self .lora_needs_transpose = False
3631
+ name = name .replace ("_lora.lora.0.weight" , "1.weight" )
3632
+ name = name .replace ("_lora.lora.2.weight" , "2.weight" )
3633
+ name = name .replace ("_lora.lora.2.bias" , "0.weight" )
3634
+
3635
+ name = name .replace ("feed_forward_norm" , "ln2" )
3636
+ name = name .replace ("g_norm" , "ln_x" )
3637
+
3638
+ if "attention.v" in name and "value" not in self .map_tensor_name (name ) and bid == 0 :
3639
+ # some models have dummy v0/v1/v2 on first layer while others don't
3640
+ # ignore them all since they are not used
3641
+ return
3642
+
3643
+ wkv_has_gate = self .hparams .get ("wkv_has_gate" , True )
3644
+ lerp_list = ["r" , "w" , "k" , "v" , "a" , "g" ] if wkv_has_gate else ["r" , "w" , "k" , "v" , "a" ]
3645
+
3646
+ if bid is not None and "attention.x_" in name :
3647
+ if "attention.x_x" in name :
3648
+ # already concatenated
3649
+ new_name = f"blk.{ bid } .time_mix_lerp_fused.weight"
3650
+ data = data_torch .reshape (len (lerp_list ), 1 , 1 , - 1 )
3651
+ yield (new_name , data )
3652
+ else :
3653
+ try :
3654
+ self .lerp_weights [bid ][name ] = data_torch
3655
+ except KeyError :
3656
+ self .lerp_weights [bid ] = {name : data_torch }
3657
+ if all (f"model.layers.{ bid } .attention.x_{ i } " in self .lerp_weights [bid ].keys () for i in lerp_list ):
3658
+ new_name = f"blk.{ bid } .time_mix_lerp_fused.weight"
3659
+ data = torch .stack ([self .lerp_weights [bid ][f"model.layers.{ bid } .attention.x_{ i } " ] for i in lerp_list ], dim = 0 )
3660
+ yield (new_name , data )
3661
+ return
3662
+ else :
3663
+ data_torch = data_torch .squeeze ()
3664
+ new_name = self .map_tensor_name (name )
3665
+
3666
+ if not (new_name .endswith (".weight" ) or new_name .endswith (".bias" )):
3667
+ new_name += ".weight"
3668
+
3669
+ if self .lora_needs_transpose and any (
3670
+ new_name .endswith (t ) for t in [
3671
+ "time_mix_w1.weight" , "time_mix_w2.weight" ,
3672
+ "time_mix_a1.weight" , "time_mix_a2.weight" ,
3673
+ "time_mix_v1.weight" , "time_mix_v2.weight" ,
3674
+ "time_mix_g1.weight" , "time_mix_g2.weight" ,
3675
+ ]
3676
+ ):
3677
+ data_torch = data_torch .transpose (0 , 1 )
3678
+
3679
+ if 'r_k' in new_name :
3680
+ data_torch = data_torch .flatten ()
3681
+
3682
+ if bid == 0 and "time_mix_a" in new_name :
3683
+ # dummy v0/v1/v2 on first layer
3684
+ # easist way to make llama happy
3685
+ yield (new_name .replace ("time_mix_a" , "time_mix_v" ), data_torch )
3686
+
3687
+ yield (new_name , data_torch )
3688
+
3689
+
3690
+ @Model .register ("RwkvHybridForCausalLM" )
3691
+ class ARwkv7Model (Rwkv7Model ):
3692
+ model_arch = gguf .MODEL_ARCH .ARWKV7
3693
+
3694
+ def set_vocab (self ):
3695
+ try :
3696
+ self ._set_vocab_sentencepiece ()
3697
+ except FileNotFoundError :
3698
+ self ._set_vocab_gpt2 ()
3699
+
3700
+ def set_gguf_parameters (self ):
3701
+ block_count = self .hparams ["num_hidden_layers" ]
3702
+ hidden_size = self .hparams ["hidden_size" ]
3703
+ head_size = self .hparams ["head_size" ]
3704
+ rms_norm_eps = self .hparams ["rms_norm_eps" ]
3705
+ intermediate_size = self .hparams ["intermediate_size" ]
3706
+ wkv_has_gate = self .hparams ["wkv_has_gate" ]
3707
+ assert self .hparams ["wkv_version" ] == 7
3708
+
3709
+ # ICLR: In-Context-Learning-Rate
3710
+ lora_rank_decay = 64
3711
+ lora_rank_iclr = 64
3712
+ lora_rank_value_residual_mix = 32
3713
+ lora_rank_gate = 128 if wkv_has_gate else 0
3714
+
3715
+ # RWKV isn't context limited
3716
+ self .gguf_writer .add_context_length (1048576 )
3717
+ self .gguf_writer .add_embedding_length (hidden_size )
3718
+ self .gguf_writer .add_block_count (block_count )
3719
+ self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
3720
+ self .gguf_writer .add_wkv_head_size (head_size )
3721
+ self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3722
+ self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3723
+ self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3724
+ self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3725
+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3726
+ self .gguf_writer .add_file_type (self .ftype )
3727
+ self .gguf_writer .add_token_shift_count (1 )
3728
+
3729
+ # required by llama.cpp, unused
3730
+ self .gguf_writer .add_head_count (0 )
3731
+
3732
+
3568
3733
@Model .register ("MambaForCausalLM" , "MambaLMHeadModel" , "FalconMambaForCausalLM" )
3569
3734
class MambaModel (Model ):
3570
3735
model_arch = gguf .MODEL_ARCH .MAMBA
0 commit comments