Skip to content

Commit 7aa68a5

Browse files
authored
Aligned tokenizer strategy w.r.t. sentence-transformers (#32)
* Added simple example * Fixed tokenizer strategy; added example; added comparison test w.r.t. sentence-transformers * Fixed pooling test * Reduced printing in similarity test
1 parent 385b78c commit 7aa68a5

File tree

11 files changed

+201
-20
lines changed

11 files changed

+201
-20
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
/target
22
Cargo.lock
3-
.idea
3+
.idea
4+
5+
traces/
6+
**/trace-*.json

crates/glowrs/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ tokenizers = "0.19.1"
2424
hf-hub = { version = "0.3.2", features = ["tokio"] }
2525
thiserror = "1.0.56"
2626
once_cell = "1.19.0"
27+
clap = { version = "4.5.4", features = ["derive"] }
2728

2829
[features]
2930
default = []
@@ -35,4 +36,6 @@ cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
3536
dirs = "5.0.1"
3637
tempfile = "3.10.1"
3738
approx = "0.5.1"
39+
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
40+
tracing-chrome = "0.7.2"
3841

crates/glowrs/examples/simple.rs

+29-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1+
use clap::Parser;
12
use glowrs::{Device, Error, PoolingStrategy, SentenceTransformer};
23
use std::process::ExitCode;
4+
use tracing_subscriber::prelude::*;
5+
6+
#[derive(Debug, Parser)]
7+
pub struct App {
8+
#[clap(short, long, default_value = "jinaai/jina-embeddings-v2-small-en")]
9+
pub model_repo: String,
10+
11+
#[clap(short, long, default_value = "debug")]
12+
pub log_level: String,
13+
}
314

415
fn main() -> Result<ExitCode, Error> {
5-
let sentences = vec![
16+
let app = App::parse();
17+
18+
let sentences = [
619
"The cat sits outside",
720
"A man is playing guitar",
821
"I love pasta",
@@ -11,14 +24,25 @@ fn main() -> Result<ExitCode, Error> {
1124
"A woman watches TV",
1225
"The new movie is so great",
1326
"Do you like pizza?",
14-
"The cat sits",
1527
];
28+
29+
tracing_subscriber::registry()
30+
.with(
31+
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
32+
eprintln!("No environment variables found that can initialize tracing_subscriber::EnvFilter. Using defaults.");
33+
// axum logs rejections from built-in extractors with the `axum::rejection`
34+
// target, at `TRACE` level. `axum::rejection=trace` enables showing those events
35+
format!("glowrs={},tower_http=debug,axum::rejection=trace", app.log_level).into()
36+
}),
37+
)
38+
.with(tracing_subscriber::fmt::layer()).init();
39+
40+
println!("Using model {}", app.model_repo);
1641
let device = Device::Cpu;
17-
let encoder =
18-
SentenceTransformer::from_repo_string("Snowflake/snowflake-arctic-embed-xs", &device)?;
42+
let encoder = SentenceTransformer::from_repo_string(&app.model_repo, &device)?;
1943

2044
let pooling_strategy = PoolingStrategy::Mean;
21-
let embeddings = encoder.encode_batch(sentences.clone(), false, pooling_strategy)?;
45+
let embeddings = encoder.encode_batch(sentences.into(), false, pooling_strategy)?;
2246
println!("Embeddings: {:?}", embeddings);
2347

2448
let (n_sentences, _) = embeddings.dims2()?;

crates/glowrs/src/error.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pub enum Error {
88
ModelLoad(&'static str),
99
#[error("Invalid model architecture: {0}")]
1010
InvalidModelConfig(&'static str),
11+
#[error("Inference error: {0}")]
12+
InferenceError(&'static str),
1113
#[error("Candle error: {0}")]
1214
Candle(#[from] candle_core::Error),
1315
#[error("Tokenization error: {0}")]
@@ -20,7 +22,7 @@ pub enum Error {
2022
HFHub(#[from] hf_hub::api::sync::ApiError),
2123
}
2224

23-
pub(crate) type Result<T> = std::result::Result<T, Error>;
25+
pub type Result<T> = std::result::Result<T, Error>;
2426

2527
#[cfg(test)]
2628
mod test {

crates/glowrs/src/lib.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ pub mod model;
66

77
pub use exports::*;
88

9-
pub use crate::error::Error;
10-
pub(crate) use error::Result;
9+
pub use crate::error::{Error, Result};
1110

1211
pub use model::pooling::PoolingStrategy;
1312
pub use model::sentence_transformer::SentenceTransformer;

crates/glowrs/src/model/embedder.rs

+21-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use candle_transformers::models::{
44
bert::Config as BertConfig, distilbert::Config as DistilBertConfig,
55
jina_bert::Config as JinaBertConfig,
66
};
7+
use serde::Deserialize;
78
use std::ops::Deref;
89
use std::path::Path;
910
use tokenizers::{EncodeInput, Tokenizer};
@@ -12,7 +13,6 @@ use tokenizers::{EncodeInput, Tokenizer};
1213
pub use candle_transformers::models::{
1314
bert::BertModel, distilbert::DistilBertModel, jina_bert::BertModel as JinaBertModel,
1415
};
15-
use serde::Deserialize;
1616

1717
use crate::model::pooling::{pool_embeddings, PoolingStrategy};
1818
use crate::model::utils::normalize_l2;
@@ -129,8 +129,15 @@ impl EmbedderModel for JinaBertModel {
129129
impl EmbedderModel for DistilBertModel {
130130
#[inline]
131131
fn encode(&self, token_ids: &Tensor) -> Result<Tensor> {
132-
let attention_mask = token_ids.ones_like()?;
133-
Ok(self.forward(token_ids, &attention_mask)?)
132+
let size = token_ids.dim(0)?;
133+
134+
let mask: Vec<_> = (0..size)
135+
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
136+
.collect();
137+
138+
let mask = Tensor::from_slice(&mask, (size, size), token_ids.device())?;
139+
140+
Ok(self.forward(token_ids, &mask)?)
134141
}
135142

136143
fn get_device(&self) -> &Device {
@@ -179,18 +186,28 @@ where
179186
.iter()
180187
.map(|tokens| {
181188
let tokens = tokens.get_ids().to_vec();
189+
182190
Tensor::new(tokens.as_slice(), model.get_device())
183191
})
184192
.collect::<candle_core::Result<Vec<_>>>()?;
185193

186194
let token_ids = Tensor::stack(&token_ids, 0)?;
187195

196+
let pad_id: u32;
197+
if let Some(pp) = tokenizer.get_padding() {
198+
pad_id = pp.pad_id;
199+
} else {
200+
pad_id = 0;
201+
}
202+
203+
let pad_mask = token_ids.ne(pad_id)?;
204+
188205
tracing::trace!("running inference on batch {:?}", token_ids.shape());
189206
let embeddings = model.encode(&token_ids)?;
190207
tracing::trace!("generated embeddings {:?}", embeddings.shape());
191208

192209
// Apply pooling
193-
let pooled_embeddings = pool_embeddings(&embeddings, pooling_strategy)?;
210+
let pooled_embeddings = pool_embeddings(&embeddings, &pad_mask, pooling_strategy)?;
194211

195212
// Normalize embeddings (if required)
196213
let embeddings = if normalize {

crates/glowrs/src/model/pooling.rs

+11-6
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@ pub enum PoolingStrategy {
1010
Sum,
1111
}
1212

13-
pub fn pool_embeddings(embeddings: &Tensor, strategy: PoolingStrategy) -> Result<Tensor> {
13+
pub fn pool_embeddings(
14+
embeddings: &Tensor,
15+
pad_mask: &Tensor,
16+
strategy: PoolingStrategy,
17+
) -> Result<Tensor> {
1418
match strategy {
15-
PoolingStrategy::Mean => mean_pooling(embeddings),
19+
PoolingStrategy::Mean => mean_pooling(embeddings, pad_mask),
1620
PoolingStrategy::Max => max_pooling(embeddings),
1721
PoolingStrategy::Sum => sum_pooling(embeddings),
1822
}
1923
}
2024

21-
pub fn mean_pooling(embeddings: &Tensor) -> Result<Tensor> {
22-
let (_, out_tokens, _) = embeddings.dims3()?;
25+
pub fn mean_pooling(embeddings: &Tensor, pad_mask: &Tensor) -> Result<Tensor> {
26+
let out_tokens = pad_mask.sum(1)?.to_vec1::<u8>()?.iter().sum::<u8>() as f64;
2327

24-
Ok((embeddings.sum(1)? / (out_tokens as f64))?)
28+
Ok((embeddings.sum(1)? / (out_tokens))?)
2529
}
2630

2731
pub fn max_pooling(embeddings: &Tensor) -> Result<Tensor> {
@@ -43,7 +47,8 @@ mod test {
4347
) -> Result<()> {
4448
// 1 sentence, 20 tokens, 32 dimensions
4549
let v = Tensor::ones(&[1, 20, 32], DType::F32, &Device::Cpu)?;
46-
let v_pool = pool_embeddings(&v, strategy)?;
50+
let pad_mask = Tensor::ones(&[1, 20], DType::U8, &Device::Cpu)?;
51+
let v_pool = pool_embeddings(&v, &pad_mask, strategy)?;
4752
let (sent, dim) = v_pool.dims2()?;
4853
assert_eq!(sent, 1);
4954
assert_eq!(dim, 32);

crates/glowrs/src/model/sentence_transformer.rs

+31-1
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,15 @@ impl SentenceTransformer {
6161
/// # }
6262
/// ```
6363
pub fn from_repo_string(repo_string: &str, device: &Device) -> Result<Self> {
64+
let span = tracing::span!(tracing::Level::TRACE, "st-from-repo-string");
65+
let _enter = span.enter();
6466
let (model_repo, default_revision) = utils::parse_repo_string(repo_string)?;
6567
Self::from_repo(model_repo, default_revision, device)
6668
}
6769

6870
pub fn from_repo(repo_name: &str, revision: &str, device: &Device) -> Result<Self> {
71+
let span = tracing::span!(tracing::Level::TRACE, "st-from-repo");
72+
let _enter = span.enter();
6973
let api = Api::new()?.repo(Repo::with_revision(
7074
repo_name.into(),
7175
RepoType::Model,
@@ -76,6 +80,8 @@ impl SentenceTransformer {
7680
}
7781

7882
pub fn from_api(api: ApiRepo, device: &Device) -> Result<Self> {
83+
let span = tracing::span!(tracing::Level::TRACE, "st-from-api");
84+
let _enter = span.enter();
7985
let model_path = api.get("model.safetensors")?;
8086

8187
let config_path = api.get("config.json")?;
@@ -91,7 +97,19 @@ impl SentenceTransformer {
9197
tokenizer_path: &Path,
9298
device: &Device,
9399
) -> Result<Self> {
94-
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
100+
let span = tracing::span!(tracing::Level::TRACE, "st-from-path");
101+
let _enter = span.enter();
102+
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;
103+
104+
if let Some(pp) = tokenizer.get_padding_mut() {
105+
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
106+
} else {
107+
let pp = tokenizers::PaddingParams {
108+
strategy: tokenizers::PaddingStrategy::BatchLongest,
109+
..Default::default()
110+
};
111+
tokenizer.with_padding(Some(pp));
112+
}
95113

96114
let model = load_pretrained_model(model_path, config_path, device)?;
97115

@@ -119,6 +137,8 @@ impl SentenceTransformer {
119137
/// # Ok(())
120138
/// # }
121139
pub fn from_folder(folder_path: &Path, device: &Device) -> Result<Self> {
140+
let span = tracing::span!(tracing::Level::TRACE, "st-from-folder");
141+
let _enter = span.enter();
122142
// Construct PathBuf objects for model, config, and tokenizer json files
123143
let model_path = folder_path.join("model.safetensors");
124144
let config_path = folder_path.join("config.json");
@@ -177,6 +197,9 @@ impl SentenceTransformer {
177197
where
178198
E: Into<EncodeInput<'s>> + Send,
179199
{
200+
let span = tracing::span!(tracing::Level::TRACE, "st-encode-batch");
201+
let _enter = span.enter();
202+
180203
let (embeddings, usage) = encode_batch_with_usage(
181204
self.model.as_ref(),
182205
&self.tokenizer,
@@ -196,6 +219,9 @@ impl SentenceTransformer {
196219
where
197220
E: Into<EncodeInput<'s>> + Send,
198221
{
222+
let span = tracing::span!(tracing::Level::TRACE, "st-encode-batch");
223+
let _enter = span.enter();
224+
199225
encode_batch(
200226
self.model.as_ref(),
201227
&self.tokenizer,
@@ -204,6 +230,10 @@ impl SentenceTransformer {
204230
normalize,
205231
)
206232
}
233+
234+
pub fn get_tokenizer_mut(&mut self) -> &mut Tokenizer {
235+
&mut self.tokenizer
236+
}
207237
}
208238

209239
#[cfg(test)]

crates/glowrs/tests/fixtures/embeddings/examples.json

+1
Large diffs are not rendered by default.
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use candle_core::Tensor;
2+
use serde::Deserialize;
3+
use std::process::ExitCode;
4+
5+
use glowrs::model::utils::normalize_l2;
6+
use glowrs::{PoolingStrategy, Result};
7+
8+
#[derive(Deserialize)]
9+
struct EmbeddingsExample {
10+
sentence: String,
11+
embedding: Vec<f32>,
12+
}
13+
14+
#[derive(Deserialize)]
15+
struct EmbeddingsFixture {
16+
model: String,
17+
examples: Vec<EmbeddingsExample>,
18+
}
19+
20+
#[derive(Deserialize)]
21+
struct Examples {
22+
fixtures: Vec<EmbeddingsFixture>,
23+
}
24+
25+
#[test]
26+
fn test_similarity_sentence_transformers() -> Result<ExitCode> {
27+
use approx::assert_relative_eq;
28+
let examples: Examples =
29+
serde_json::from_str(include_str!("./fixtures/embeddings/examples.json"))?;
30+
let device = glowrs::Device::Cpu;
31+
for fixture in examples.fixtures {
32+
let encoder = glowrs::SentenceTransformer::from_repo_string(&fixture.model, &device)?;
33+
println!("Loaded model: {}", &fixture.model);
34+
for example in fixture.examples {
35+
let embedding =
36+
encoder.encode_batch(vec![example.sentence], false, PoolingStrategy::Mean)?;
37+
let embedding = normalize_l2(&embedding)?;
38+
39+
let expected_dim = example.embedding.len();
40+
let expected = Tensor::from_vec(example.embedding, (1, expected_dim), &device)?;
41+
let expected = normalize_l2(&expected)?;
42+
43+
assert_eq!(embedding.dims(), expected.dims());
44+
45+
let sim = embedding.matmul(&expected.t()?)?.squeeze(1)?;
46+
47+
let sim = sim.to_vec1::<f32>()?;
48+
let sim = sim.first().expect("Expected a value");
49+
assert_relative_eq!(*sim, 1.0, epsilon = 1e-3);
50+
}
51+
println!("Passed all examples for model: {}", &fixture.model)
52+
}
53+
54+
Ok(ExitCode::SUCCESS)
55+
}

tests/generate-fixtures.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import json
2+
from sentence_transformers import SentenceTransformer
3+
4+
SENTENCES = [
5+
"The cat sits outside",
6+
"A man is playing guitar",
7+
"I love pasta",
8+
"The new movie is awesome",
9+
"The cat plays in the garden",
10+
"A woman watches TV",
11+
"The new movie is so great",
12+
"Do you like pizza?",
13+
"The cat sits",
14+
]
15+
16+
MODELS = [
17+
"jinaai/jina-embeddings-v2-small-en",
18+
"sentence-transformers/all-MiniLM-L6-v2",
19+
"sentence-transformers/multi-qa-distilbert-cos-v1",
20+
]
21+
22+
23+
def generate_examples(model: str) -> list:
24+
model = SentenceTransformer(model, trust_remote_code=True)
25+
embeddings = model.encode(SENTENCES, normalize_embeddings=False, batch_size=len(SENTENCES))
26+
return [
27+
{"sentence": sentence, "embedding": embedding.tolist()} for sentence, embedding in zip(SENTENCES, embeddings)
28+
]
29+
30+
31+
if __name__ == "__main__":
32+
out = {
33+
"fixtures": [
34+
{
35+
"model": m,
36+
"examples": generate_examples(m)
37+
38+
} for m in MODELS]
39+
}
40+
41+
with open("crates/glowrs/tests/fixtures/embeddings/examples.json", "w") as f:
42+
json.dump(out, f)

0 commit comments

Comments
 (0)