Skip to content

Commit 90e88eb

Browse files
authored
Error handling (#28)
* Error handling * bump version * Changed publish workflow trigger to `release` * Added test workflow before release
1 parent 8aed0e9 commit 90e88eb

File tree

10 files changed

+157
-53
lines changed

10 files changed

+157
-53
lines changed

.github/workflows/publish.yml

+42-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,49 @@
11
name: Publish Package
22
on:
3-
push:
4-
branches:
5-
- master
3+
release:
4+
types: [ published ]
5+
66
jobs:
7+
build-and-test:
8+
strategy:
9+
fail-fast: false
10+
matrix:
11+
os: [ ubuntu-latest, macos-latest, windows-latest ]
12+
13+
name: Test multiple workspaces on ${{ matrix.os }}
14+
runs-on: ${{ matrix.os }}
15+
16+
env:
17+
CARGO_TERM_COLOR: always
18+
19+
steps:
20+
- uses: actions/checkout@v4
21+
22+
- uses: Swatinem/rust-cache@v2
23+
with:
24+
workspaces: |
25+
crates/glowrs
26+
crates/glowrs-server
27+
28+
- name: Build
29+
run: cargo build --verbose
30+
31+
- name: Check formatting
32+
run: cargo fmt -- --check
33+
34+
- name: Check clippy
35+
run: cargo clippy -- -D warnings
36+
37+
- name: Publish dry-run
38+
run: cargo publish -p glowrs --dry-run
39+
40+
- name: Run tests
41+
run: cargo test --verbose
42+
43+
744
publish:
845
runs-on: ubuntu-20.04
46+
needs: [ build-and-test ]
947
steps:
1048
- uses: actions/checkout@v4
1149

@@ -15,7 +53,7 @@ jobs:
1553
LAST_PUBLISHED_VERSION=$(cargo search glowrs --limit 1 | awk '{print $3}' | tr -d '"')
1654
LOCAL_VERSION=$(grep -e '^version\s*=\s*"' Cargo.toml | head -1 | cut -d '"' -f2)
1755
if [ "$LAST_PUBLISHED_VERSION" == "$LOCAL_VERSION" ]; then
18-
echo "::set-output name=skip-publish::true"
56+
exit 1 # Force a failure if the versions match
1957
fi
2058
2159
- uses: Swatinem/rust-cache@v2
@@ -26,4 +64,3 @@ jobs:
2664
2765
- name: Publish glowrs
2866
run: cargo publish -p glowrs --token ${{ secrets.CRATES_TOKEN }}
29-
if: steps.check-version.outputs.skip-publish != 'true'

.github/workflows/rust.yml

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
name: Build & Test
22

33
on:
4+
push:
5+
branches: [ "master" ]
6+
paths:
7+
- "crates/**"
8+
- ".github/workflows/rust.yml"
9+
- "tests/**"
10+
- "Cargo.toml"
11+
- ".cargo/**"
412
pull_request:
513
branches: [ "master" ]
614
paths:

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ candle-transformers = { opt-level = 3 }
1919

2020
[workspace.package]
2121
license = "Apache-2.0"
22-
version = "0.2.2"
22+
version = "0.3.0"
2323

crates/glowrs/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ tracing = "0.1.37"
2222
uuid = { version = "1.6.1", features = ["v4"] }
2323
tokenizers = "0.19.1"
2424
hf-hub = { version = "0.3.2", features = ["tokio"] }
25-
anyhow = "1.0.79"
2625
thiserror = "1.0.56"
2726
once_cell = "1.19.0"
2827

crates/glowrs/src/error.rs

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use thiserror::Error;
2+
3+
#[derive(Error, Debug)]
4+
pub enum Error {
5+
#[error("Invalid model name: {0}")]
6+
InvalidModelName(&'static str),
7+
#[error("Model load error: {0}")]
8+
ModelLoad(&'static str),
9+
#[error("Invalid model architecture: {0}")]
10+
InvalidModelConfig(&'static str),
11+
#[error("Candle error: {0}")]
12+
Candle(#[from] candle_core::Error),
13+
#[error("Tokenization error: {0}")]
14+
Tokenization(#[from] tokenizers::Error),
15+
#[error("Serde JSON error: {0}")]
16+
Serde(#[from] serde_json::Error),
17+
#[error("IO error: {0}")]
18+
IO(#[from] std::io::Error),
19+
#[error("HF Hub error: {0}")]
20+
HFHub(#[from] hf_hub::api::sync::ApiError),
21+
}
22+
23+
pub(crate) type Result<T> = std::result::Result<T, Error>;
24+
25+
#[cfg(test)]
26+
mod test {
27+
use super::*;
28+
29+
#[test]
30+
fn test_error_display() {
31+
let error = Error::InvalidModelName("test");
32+
assert_eq!(error.to_string(), "Invalid model name: test");
33+
34+
let error = Error::ModelLoad("test");
35+
assert_eq!(error.to_string(), "Model load error: test");
36+
37+
let error = Error::InvalidModelConfig("test");
38+
assert_eq!(error.to_string(), "Invalid model architecture: test");
39+
40+
let error = Error::Candle(candle_core::Error::UnexpectedNumberOfDims {
41+
shape: (32, 32).into(),
42+
expected: 3,
43+
got: 2,
44+
});
45+
assert_eq!(
46+
error.to_string(),
47+
"Candle error: unexpected rank, expected: 3, got: 2 ([32, 32])"
48+
);
49+
50+
let error = Error::IO(std::io::Error::new(std::io::ErrorKind::Other, "test"));
51+
assert_eq!(error.to_string(), "IO error: test");
52+
53+
let error = Error::HFHub(hf_hub::api::sync::ApiError::MissingHeader("test"));
54+
assert_eq!(error.to_string(), "HF Hub error: Header test is missing");
55+
}
56+
}

crates/glowrs/src/lib.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#![doc = include_str!("../README.md")]
22

3+
mod error;
34
pub mod model;
5+
pub use error::Error;
6+
pub(crate) use error::Result;
7+
48
pub use model::pooling::PoolingStrategy;
59
pub use model::sentence_transformer::SentenceTransformer;
610
use serde::Serialize;

crates/glowrs/src/model/embedder.rs

+16-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use anyhow::{Context, Error, Result};
21
use candle_core::{DType, Module, Tensor};
32
use candle_nn::VarBuilder;
43
use candle_transformers::models::{
@@ -16,10 +15,10 @@ pub use candle_transformers::models::{
1615
use serde::Deserialize;
1716

1817
use crate::model::device::DEVICE;
18+
use crate::model::pooling::{pool_embeddings, PoolingStrategy};
1919
use crate::model::utils::normalize_l2;
20-
use crate::Usage;
20+
use crate::{Error, Result, Usage};
2121

22-
use crate::model::pooling::{pool_embeddings, PoolingStrategy};
2322
#[cfg(test)]
2423
use candle_nn::VarMap;
2524

@@ -35,16 +34,17 @@ struct BaseModelConfig {
3534
}
3635

3736
pub(crate) fn parse_config(config_str: &str) -> Result<ModelConfig> {
37+
use Error::*;
3838
let base_config: BaseModelConfig = serde_json::from_str(config_str)?;
3939

4040
let config = match base_config.architectures {
4141
Some(arch) => {
4242
if arch.is_empty() {
43-
return Err(Error::msg("No architectures found"));
43+
return Err(InvalidModelConfig("No architectures found"));
4444
}
4545

4646
if arch.len() > 1 {
47-
return Err(Error::msg("Multiple architectures not supported"));
47+
return Err(InvalidModelConfig("Multiple architectures not supported"));
4848
}
4949

5050
match arch.first().map(String::as_str) {
@@ -60,10 +60,10 @@ pub(crate) fn parse_config(config_str: &str) -> Result<ModelConfig> {
6060
let config: DistilBertConfig = serde_json::from_str(config_str)?;
6161
ModelConfig::DistilBert(config)
6262
}
63-
_ => return Err(Error::msg("Invalid model architecture")),
63+
_ => return Err(InvalidModelConfig("Invalid model architecture")),
6464
}
6565
}
66-
None => return Err(Error::msg("Model architecture not found")),
66+
None => return Err(InvalidModelConfig("Model architecture not found")),
6767
};
6868

6969
Ok(config)
@@ -149,13 +149,15 @@ pub(crate) fn encode_batch_with_usage<'s, E>(
149149
where
150150
E: Into<EncodeInput<'s>> + Send,
151151
{
152-
let tokens = tokenizer
153-
.encode_batch(sentences, true)
154-
.map_err(Error::msg)
155-
.context("Failed to encode batch.")?;
152+
let tokens = tokenizer.encode_batch(sentences, true)?;
156153

157154
let prompt_tokens = tokens.len() as u32;
158155

156+
let usage = Usage {
157+
prompt_tokens,
158+
total_tokens: prompt_tokens,
159+
};
160+
159161
let token_ids = tokens
160162
.iter()
161163
.map(|tokens| {
@@ -170,20 +172,16 @@ where
170172
let embeddings = model.encode(&token_ids)?;
171173
tracing::trace!("generated embeddings {:?}", embeddings.shape());
172174

173-
// Apply some avg-pooling by taking the mean model value for all tokens (including padding)
174-
let (_n_sentence, out_tokens, _hidden_size) = embeddings.dims3()?;
175+
// Apply pooling
175176
let pooled_embeddings = pool_embeddings(&embeddings, pooling_strategy)?;
177+
178+
// Normalize embeddings (if required)
176179
let embeddings = if normalize {
177180
normalize_l2(&pooled_embeddings)?
178181
} else {
179182
pooled_embeddings
180183
};
181184

182-
// TODO: Incorrect usage calculation - fix
183-
let usage = Usage {
184-
prompt_tokens,
185-
total_tokens: prompt_tokens + (out_tokens as u32),
186-
};
187185
Ok((embeddings, usage))
188186
}
189187

crates/glowrs/src/model/pooling.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use anyhow::Result;
1+
use crate::Result;
22
use candle_core::Tensor;
33
use serde::Deserialize;
44

@@ -12,9 +12,9 @@ pub enum PoolingStrategy {
1212

1313
pub fn pool_embeddings(embeddings: &Tensor, strategy: &PoolingStrategy) -> Result<Tensor> {
1414
match strategy {
15-
PoolingStrategy::Mean => Ok(mean_pooling(embeddings)?),
16-
PoolingStrategy::Max => Ok(max_pooling(embeddings)?),
17-
PoolingStrategy::Sum => Ok(sum_pooling(embeddings)?),
15+
PoolingStrategy::Mean => mean_pooling(embeddings),
16+
PoolingStrategy::Max => max_pooling(embeddings),
17+
PoolingStrategy::Sum => sum_pooling(embeddings),
1818
}
1919
}
2020

crates/glowrs/src/model/sentence_transformer.rs

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use anyhow::{Context, Error, Result};
21
use candle_core::Tensor;
32
use hf_hub::api::sync::{Api, ApiRepo};
43
use hf_hub::{Repo, RepoType};
@@ -10,7 +9,7 @@ use crate::model::embedder::{
109
encode_batch, encode_batch_with_usage, load_pretrained_model, EmbedderModel,
1110
};
1211
use crate::model::utils;
13-
use crate::Usage;
12+
use crate::{Error, Result, Usage};
1413

1514
#[cfg(test)]
1615
use crate::model::embedder::{load_zeros_model, parse_config};
@@ -57,8 +56,9 @@ impl SentenceTransformer {
5756
///
5857
/// ```rust
5958
/// # use glowrs::SentenceTransformer;
59+
/// # use std::error::Error;
6060
///
61-
/// # fn main() -> anyhow::Result<()> {
61+
/// # fn main() -> Result<(), Box<dyn Error>> {
6262
/// let encoder = SentenceTransformer::from_repo_string("sentence-transformers/all-MiniLM-L6-v2")?;
6363
///
6464
/// # Ok(())
@@ -84,26 +84,19 @@ impl SentenceTransformer {
8484
}
8585

8686
pub fn from_api(api: ApiRepo) -> Result<Self> {
87-
let model_path = api
88-
.get("model.safetensors")
89-
.context("Model repository is not available or doesn't contain `model.safetensors`.")?;
87+
let model_path = api.get("model.safetensors")?;
9088

91-
let config_path = api
92-
.get("config.json")
93-
.context("Model repository doesn't contain `config.json`.")?;
89+
let config_path = api.get("config.json")?;
9490

95-
let tokenizer_path = api
96-
.get("tokenizer.json")
97-
.context("Model repository doesn't contain `tokenizer.json`.")?;
91+
let tokenizer_path = api.get("tokenizer.json")?;
9892

9993
Self::from_path(&model_path, &config_path, &tokenizer_path)
10094
}
10195

10296
pub fn from_path(model_path: &Path, config_path: &Path, tokenizer_path: &Path) -> Result<Self> {
103-
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(Error::msg)?;
97+
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
10498

105-
let model = load_pretrained_model(model_path, config_path)
106-
.context("Something went wrong while loading the model.")?;
99+
let model = load_pretrained_model(model_path, config_path)?;
107100

108101
Ok(Self::new(model, tokenizer))
109102
}
@@ -119,7 +112,9 @@ impl SentenceTransformer {
119112
/// use glowrs::SentenceTransformer;
120113
/// use std::path::Path;
121114
///
122-
/// # fn main() -> anyhow::Result<()> {
115+
/// # type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
116+
///
117+
/// # fn main() -> Result<()> {
123118
/// let path = Path::new("path/to/folder");
124119
///
125120
/// let encoder = SentenceTransformer::from_folder(path)?;
@@ -133,7 +128,9 @@ impl SentenceTransformer {
133128
let tokenizer_path = folder_path.join("tokenizer.json");
134129

135130
if !model_path.exists() || !config_path.exists() || !tokenizer_path.exists() {
136-
Err(anyhow::anyhow!("model.safetensors, config.json, or tokenizer.json does not exist in the given directory"))
131+
Err(Error::ModelLoad(
132+
"model.safetensors, config.json, or tokenizer.json does not exist in the given directory"
133+
))
137134
} else {
138135
Self::from_path(&model_path, &config_path, &tokenizer_path)
139136
}
@@ -147,7 +144,9 @@ impl SentenceTransformer {
147144
/// # use glowrs::SentenceTransformer;
148145
/// # use glowrs::PoolingStrategy;
149146
///
150-
/// # fn main() -> anyhow::Result<()> {
147+
/// # type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
148+
///
149+
/// # fn main() -> Result<()> {
151150
/// let encoder = SentenceTransformer::from_repo_string("sentence-transformers/all-MiniLM-L6-v2")?
152151
/// .with_pooling_strategy(PoolingStrategy::Sum);
153152
///
@@ -161,7 +160,7 @@ impl SentenceTransformer {
161160

162161
#[cfg(test)]
163162
pub(crate) fn test_from_config_json(config_path: &Path, tokenizer_path: &Path) -> Result<Self> {
164-
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(Error::msg)?;
163+
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
165164

166165
let config_str = std::fs::read_to_string(config_path)?;
167166

0 commit comments

Comments
 (0)