@@ -61,11 +61,15 @@ impl SentenceTransformer {
61
61
/// # }
62
62
/// ```
63
63
pub fn from_repo_string ( repo_string : & str , device : & Device ) -> Result < Self > {
64
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-from-repo-string" ) ;
65
+ let _enter = span. enter ( ) ;
64
66
let ( model_repo, default_revision) = utils:: parse_repo_string ( repo_string) ?;
65
67
Self :: from_repo ( model_repo, default_revision, device)
66
68
}
67
69
68
70
pub fn from_repo ( repo_name : & str , revision : & str , device : & Device ) -> Result < Self > {
71
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-from-repo" ) ;
72
+ let _enter = span. enter ( ) ;
69
73
let api = Api :: new ( ) ?. repo ( Repo :: with_revision (
70
74
repo_name. into ( ) ,
71
75
RepoType :: Model ,
@@ -76,6 +80,8 @@ impl SentenceTransformer {
76
80
}
77
81
78
82
pub fn from_api ( api : ApiRepo , device : & Device ) -> Result < Self > {
83
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-from-api" ) ;
84
+ let _enter = span. enter ( ) ;
79
85
let model_path = api. get ( "model.safetensors" ) ?;
80
86
81
87
let config_path = api. get ( "config.json" ) ?;
@@ -91,7 +97,19 @@ impl SentenceTransformer {
91
97
tokenizer_path : & Path ,
92
98
device : & Device ,
93
99
) -> Result < Self > {
94
- let tokenizer = Tokenizer :: from_file ( tokenizer_path) ?;
100
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-from-path" ) ;
101
+ let _enter = span. enter ( ) ;
102
+ let mut tokenizer = Tokenizer :: from_file ( tokenizer_path) ?;
103
+
104
+ if let Some ( pp) = tokenizer. get_padding_mut ( ) {
105
+ pp. strategy = tokenizers:: PaddingStrategy :: BatchLongest
106
+ } else {
107
+ let pp = tokenizers:: PaddingParams {
108
+ strategy : tokenizers:: PaddingStrategy :: BatchLongest ,
109
+ ..Default :: default ( )
110
+ } ;
111
+ tokenizer. with_padding ( Some ( pp) ) ;
112
+ }
95
113
96
114
let model = load_pretrained_model ( model_path, config_path, device) ?;
97
115
@@ -119,6 +137,8 @@ impl SentenceTransformer {
119
137
/// # Ok(())
120
138
/// # }
121
139
pub fn from_folder ( folder_path : & Path , device : & Device ) -> Result < Self > {
140
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-from-folder" ) ;
141
+ let _enter = span. enter ( ) ;
122
142
// Construct PathBuf objects for model, config, and tokenizer json files
123
143
let model_path = folder_path. join ( "model.safetensors" ) ;
124
144
let config_path = folder_path. join ( "config.json" ) ;
@@ -177,6 +197,9 @@ impl SentenceTransformer {
177
197
where
178
198
E : Into < EncodeInput < ' s > > + Send ,
179
199
{
200
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-encode-batch" ) ;
201
+ let _enter = span. enter ( ) ;
202
+
180
203
let ( embeddings, usage) = encode_batch_with_usage (
181
204
self . model . as_ref ( ) ,
182
205
& self . tokenizer ,
@@ -196,6 +219,9 @@ impl SentenceTransformer {
196
219
where
197
220
E : Into < EncodeInput < ' s > > + Send ,
198
221
{
222
+ let span = tracing:: span!( tracing:: Level :: TRACE , "st-encode-batch" ) ;
223
+ let _enter = span. enter ( ) ;
224
+
199
225
encode_batch (
200
226
self . model . as_ref ( ) ,
201
227
& self . tokenizer ,
@@ -204,6 +230,10 @@ impl SentenceTransformer {
204
230
normalize,
205
231
)
206
232
}
233
+
234
+ pub fn get_tokenizer_mut ( & mut self ) -> & mut Tokenizer {
235
+ & mut self . tokenizer
236
+ }
207
237
}
208
238
209
239
#[ cfg( test) ]
0 commit comments