1
+ import os
2
+ import sentencepiece as spm
3
+ import tiktoken
4
+ from tiktoken .load import load_tiktoken_bpe
5
+ from pathlib import Path
6
+ from typing import Dict
7
+
8
+ class TokenizerInterface :
9
+ def __init__ (self , model_path ):
10
+ self .model_path = model_path
11
+
12
+ def encode (self , text ):
13
+ raise NotImplementedError ("This method should be overridden by subclasses." )
14
+
15
+ def decode (self , tokens ):
16
+ raise NotImplementedError ("This method should be overridden by subclasses." )
17
+
18
+ def bos_id (self ):
19
+ raise NotImplementedError ("This method should be overridden by subclasses." )
20
+
21
+ def eos_id (self ):
22
+ raise NotImplementedError ("This method should be overridden by subclasses." )
23
+
24
+ class SentencePieceWrapper (TokenizerInterface ):
25
+ def __init__ (self , model_path ):
26
+ super ().__init__ (model_path )
27
+ self .processor = spm .SentencePieceProcessor (str (model_path ))
28
+
29
+ def encode (self , text ):
30
+ return self .processor .EncodeAsIds (text )
31
+
32
+ def decode (self , tokens ):
33
+ return self .processor .DecodeIds (tokens )
34
+
35
+ def bos_id (self ):
36
+ return self .processor .bos_id ()
37
+
38
+ def eos_id (self ):
39
+ return self .processor .eos_id ()
40
+
41
+ class TiktokenWrapper (TokenizerInterface ):
42
+ """
43
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
44
+ """
45
+
46
+ special_tokens : Dict [str , int ]
47
+
48
+ num_reserved_special_tokens = 256
49
+
50
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
51
+
52
+ def __init__ (self , model_path ):
53
+ super ().__init__ (model_path )
54
+ assert os .path .isfile (model_path ), str (model_path )
55
+ mergeable_ranks = load_tiktoken_bpe (str (model_path ))
56
+ num_base_tokens = len (mergeable_ranks )
57
+ special_tokens = [
58
+ "<|begin_of_text|>" ,
59
+ "<|end_of_text|>" ,
60
+ "<|reserved_special_token_0|>" ,
61
+ "<|reserved_special_token_1|>" ,
62
+ "<|reserved_special_token_2|>" ,
63
+ "<|reserved_special_token_3|>" ,
64
+ "<|start_header_id|>" ,
65
+ "<|end_header_id|>" ,
66
+ "<|reserved_special_token_4|>" ,
67
+ "<|eot_id|>" , # end of turn
68
+ ] + [
69
+ f"<|reserved_special_token_{ i } |>"
70
+ for i in range (5 , self .num_reserved_special_tokens - 5 )
71
+ ]
72
+ self .special_tokens = {
73
+ token : num_base_tokens + i for i , token in enumerate (special_tokens )
74
+ }
75
+ self .model = tiktoken .Encoding (
76
+ name = Path (model_path ).name ,
77
+ pat_str = self .pat_str ,
78
+ mergeable_ranks = mergeable_ranks ,
79
+ special_tokens = self .special_tokens ,
80
+ )
81
+ # BOS / EOS token IDs
82
+ self ._bos_id : int = self .special_tokens ["<|begin_of_text|>" ]
83
+ self ._eos_id : int = self .special_tokens ["<|end_of_text|>" ]
84
+
85
+ def encode (self , text ):
86
+ return self .model .encode (text )
87
+
88
+ def decode (self , tokens ):
89
+ return self .model .decode (tokens )
90
+
91
+ def bos_id (self ):
92
+ return self ._bos_id
93
+
94
+ def eos_id (self ):
95
+ return self ._eos_id
96
+
97
+ def get_tokenizer (tokenizer_model_path , model_name ):
98
+ """
99
+ Factory function to get the appropriate tokenizer based on the model name.
100
+
101
+ Args:
102
+ - tokenizer_model_path (str): The file path to the tokenizer model.
103
+ - model_name (str): The name of the model, used to determine the tokenizer type.
104
+
105
+ Returns:
106
+ - TokenizerInterface: An instance of a tokenizer.
107
+ """
108
+ if "Llama-3" in str (model_name ):
109
+ return TiktokenWrapper (tokenizer_model_path )
110
+ else :
111
+ return SentencePieceWrapper (tokenizer_model_path )
0 commit comments