Skip to content

Commit d95345d

Browse files
committed
fix: keep eval-all llm cuda usage within memory
1 parent 919c4c1 commit d95345d

7 files changed

Lines changed: 68 additions & 6 deletions

File tree

CONFIGURATION.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,9 @@ Uses the **Gemma 3** family of models for semantic reranking. Similar to the Qwe
203203
| `SIFT_BLOBS_CACHE` | Specific override for the blob store. |
204204
| `SIFT_MANIFESTS_CACHE` | Specific override for the project manifests. |
205205
| `SIFT_MODELS_CACHE` | Specific override for downloaded ML models. |
206+
| `SIFT_DENSE_DEVICE` | Dense embedding device override: `cpu` or `cuda`. |
207+
| `SIFT_LLM_DEVICE` | Default device override for Candle-backed LLM paths: `cpu` or `cuda`. |
208+
| `SIFT_QWEN_DEVICE` | Qwen-specific device override: `cpu` or `cuda`. |
209+
| `SIFT_JINA_DEVICE` | Jina-specific device override: `cpu` or `cuda`. |
210+
| `SIFT_GEMMA_DEVICE` | Gemma-specific device override: `cpu` or `cuda`. |
206211
| `HF_TOKEN` | Hugging Face API token for downloading gated models (e.g., Jina Reranker v3 and Gemma 3). |

EVALUATIONS.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,22 @@ just sift --cuda eval agentic \
119119
```
120120

121121
This `--cuda` switch is handled by the `just` recipe, not by the `sift` CLI itself.
122-
By default, that recipe keeps the dense embedder on CPU (`SIFT_DENSE_DEVICE=cpu`) so local GPUs can be reserved for Qwen/Jina/Gemma during evals. If you want dense embeddings on CUDA too, override it explicitly with `SIFT_DENSE_DEVICE=cuda just sift --cuda ...`.
122+
By default, that recipe keeps the dense embedder on CPU (`SIFT_DENSE_DEVICE=cpu`) so local GPUs can be reserved for LLM-backed eval paths. For `eval all`, it also keeps the heavier Jina and Gemma rerankers on CPU by default (`SIFT_JINA_DEVICE=cpu`, `SIFT_GEMMA_DEVICE=cpu`) to avoid CUDA OOM across back-to-back strategy runs.
123+
124+
If you want to override those defaults, you can set:
125+
126+
- `SIFT_DENSE_DEVICE=cuda`
127+
- `SIFT_LLM_DEVICE=cuda|cpu`
128+
- `SIFT_QWEN_DEVICE=cuda|cpu`
129+
- `SIFT_JINA_DEVICE=cuda|cpu`
130+
- `SIFT_GEMMA_DEVICE=cuda|cpu`
131+
132+
Example:
133+
134+
```bash
135+
SIFT_JINA_DEVICE=cuda SIFT_GEMMA_DEVICE=cuda \
136+
just sift --cuda eval all --dataset scifact
137+
```
123138

124139
---
125140

justfile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,23 @@ sift *args:
7373
cargo_args=(--release); \
7474
env_args=(); \
7575
sift_args=(); \
76+
cuda_enabled=0; \
7677
for arg in "$@"; do \
7778
if [ "$arg" = "--cuda" ]; then \
79+
cuda_enabled=1; \
7880
cargo_args+=(--features cuda); \
7981
env_args+=("SIFT_DENSE_DEVICE=${SIFT_DENSE_DEVICE:-cpu}"); \
8082
else \
8183
sift_args+=("$arg"); \
8284
fi; \
8385
done; \
86+
if [ "$cuda_enabled" -eq 1 ] \
87+
&& [ "${#sift_args[@]}" -ge 2 ] \
88+
&& [ "${sift_args[0]}" = "eval" ] \
89+
&& [ "${sift_args[1]}" = "all" ]; then \
90+
env_args+=("SIFT_JINA_DEVICE=${SIFT_JINA_DEVICE:-cpu}"); \
91+
env_args+=("SIFT_GEMMA_DEVICE=${SIFT_GEMMA_DEVICE:-cpu}"); \
92+
fi; \
8493
env "${env_args[@]}" cargo run "${cargo_args[@]}" -- "${sift_args[@]}" \
8594
' -- {{args}}
8695

src/search/adapters/gemma.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl GemmaReranker {
7474
let tokenizer = Tokenizer::from_file(&tokenizer_path)
7575
.map_err(|m| anyhow!("failed to load tokenizer: {}", m))?;
7676

77-
let device = super::llm_utils::get_device()?;
77+
let device = super::llm_utils::get_device_for("GEMMA")?;
7878
let vb = load_mmaped_safetensors_with_repair(
7979
&spec.model_id,
8080
&spec.revision,

src/search/adapters/jina.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl JinaReranker {
9595
let tokenizer = Tokenizer::from_file(&tokenizer_path)
9696
.map_err(|m| anyhow!("failed to load tokenizer: {}", m))?;
9797

98-
let device = super::llm_utils::get_device()?;
98+
let device = super::llm_utils::get_device_for("JINA")?;
9999
let vb = load_mmaped_safetensors_with_repair(
100100
&spec.model_id,
101101
&spec.revision,

src/search/adapters/llm_utils.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,48 @@ pub fn load_mmaped_safetensors_with_repair(
172172
}
173173

174174
pub fn get_device() -> Result<Device> {
175+
get_device_for("LLM")
176+
}
177+
178+
pub fn get_device_for(kind: &str) -> Result<Device> {
179+
let specific_env = format!("SIFT_{}_DEVICE", kind);
180+
let requested_device = match std::env::var(&specific_env) {
181+
Ok(value) => Some((specific_env.clone(), value)),
182+
Err(_) => std::env::var("SIFT_LLM_DEVICE")
183+
.ok()
184+
.map(|value| ("SIFT_LLM_DEVICE".to_string(), value)),
185+
};
186+
187+
if let Some((source, value)) = requested_device {
188+
match value.to_ascii_lowercase().as_str() {
189+
"cpu" => {
190+
tracing::info!("Using CPU for {} via {}", kind, source);
191+
return Ok(Device::Cpu);
192+
}
193+
"cuda" => {}
194+
other => {
195+
bail!(
196+
"unsupported device override '{}' in {} (expected 'cpu' or 'cuda')",
197+
other,
198+
source
199+
);
200+
}
201+
}
202+
}
203+
175204
#[cfg(feature = "cuda")]
176205
{
177206
match Device::new_cuda(0) {
178207
Ok(d) => {
179-
tracing::info!("Using CUDA device 0");
208+
tracing::info!("Using CUDA device 0 for {}", kind);
180209
Ok(d)
181210
}
182211
Err(e) => {
183-
tracing::warn!("Failed to initialize CUDA, falling back to CPU: {:?}", e);
212+
tracing::warn!(
213+
"Failed to initialize CUDA for {}, falling back to CPU: {:?}",
214+
kind,
215+
e
216+
);
184217
Ok(Device::Cpu)
185218
}
186219
}

src/search/adapters/qwen.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ impl QwenReranker {
7676
let tokenizer = Tokenizer::from_file(&tokenizer_path)
7777
.map_err(|m| anyhow!("failed to load tokenizer: {}", m))?;
7878

79-
let device = super::llm_utils::get_device()?;
79+
let device = super::llm_utils::get_device_for("QWEN")?;
8080
let vb = load_mmaped_safetensors_with_repair(
8181
&spec.model_id,
8282
&spec.revision,

0 commit comments

Comments
 (0)