Skip to content

Commit 9194c25

Browse files
nvidiaclaude
andcommitted
feat: add circuit breaker to prevent API retry storms
Cherry-picked from upstream PR volcengine#772. Adds a thread-safe CircuitBreaker that trips after consecutive failures (or immediately on permanent errors like 403/401) and blocks further API calls until a cooldown elapses. Integrated into both SemanticProcessor and TextEmbeddingHandler: - Permanent errors (403/401): drop message, trip breaker immediately - Transient errors (429/5xx/timeout): re-enqueue for retry, trip after threshold - HALF_OPEN probe: allow one request after cooldown to test recovery Replaces the old is_429_error() check with proper error classification. Ref: volcengine#772 Ref: volcengine#729 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4039315 commit 9194c25

3 files changed

Lines changed: 249 additions & 8 deletions

File tree

openviking/storage/collection_schemas.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@
1616
from typing import Any, Dict, List, Optional
1717

1818
from openviking.models.embedder.base import EmbedResult
19-
from openviking.models.embedder.volcengine_embedders import is_429_error
2019
from openviking.server.identity import RequestContext, Role
2120
from openviking.storage.errors import CollectionNotFoundError
2221
from openviking.storage.queuefs.embedding_msg import EmbeddingMsg
2322
from openviking.storage.queuefs.named_queue import DequeueHandlerBase
2423
from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend
2524
from openviking.telemetry import bind_telemetry, resolve_telemetry
25+
from openviking.utils.circuit_breaker import (
26+
CircuitBreaker,
27+
CircuitBreakerOpen,
28+
classify_api_error,
29+
)
2630
from openviking_cli.session.user_id import UserIdentifier
2731
from openviking_cli.utils import get_logger
2832
from openviking_cli.utils.config.open_viking_config import OpenVikingConfig
@@ -162,6 +166,7 @@ def __init__(self, vikingdb: VikingVectorIndexBackend):
162166
self._collection_name = config.storage.vectordb.name
163167
self._vector_dim = config.embedding.dimension
164168
self._initialize_embedder(config)
169+
self._circuit_breaker = CircuitBreaker()
165170

166171
def _initialize_embedder(self, config: "OpenVikingConfig"):
167172
"""Initialize the embedder instance from config."""
@@ -236,6 +241,23 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str,
236241
self.report_success()
237242
return data
238243

244+
# Circuit breaker: if API is known-broken, re-enqueue and wait
245+
try:
246+
self._circuit_breaker.check()
247+
except CircuitBreakerOpen:
248+
logger.warning(
249+
f"Circuit breaker is open, re-enqueueing embedding: {embedding_msg.id}"
250+
)
251+
if self._vikingdb.has_queue_manager:
252+
wait = self._circuit_breaker.retry_after
253+
if wait > 0:
254+
await asyncio.sleep(wait)
255+
await self._vikingdb.enqueue_embedding_msg(embedding_msg)
256+
self.report_success()
257+
return None
258+
self.report_error("Circuit breaker open and no queue manager", data)
259+
return None
260+
239261
# Initialize embedder if not already initialized
240262
if not self._embedder:
241263
from openviking_cli.utils.config import get_openviking_config
@@ -253,13 +275,23 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str,
253275
)
254276
except Exception as embed_err:
255277
error_msg = f"Failed to generate embedding: {embed_err}"
256-
logger.error(error_msg)
278+
error_class = classify_api_error(embed_err)
279+
280+
if error_class == "permanent":
281+
logger.critical(error_msg)
282+
self._circuit_breaker.record_failure(embed_err)
283+
self._merge_request_stats(embedding_msg.telemetry_id, error_count=1)
284+
self.report_error(error_msg, data)
285+
return None
257286

258-
if is_429_error(embed_err) and self._vikingdb.has_queue_manager:
287+
# Transient or unknown — re-enqueue for retry
288+
logger.warning(error_msg)
289+
self._circuit_breaker.record_failure(embed_err)
290+
if self._vikingdb.has_queue_manager:
259291
try:
260292
await self._vikingdb.enqueue_embedding_msg(embedding_msg)
261293
logger.info(
262-
f"Re-enqueued embedding message after rate limit: {embedding_msg.id}"
294+
f"Re-enqueued embedding message after transient error: {embedding_msg.id}"
263295
)
264296
self.report_success()
265297
return None
@@ -342,6 +374,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str,
342374

343375
self._merge_request_stats(embedding_msg.telemetry_id, processed=1)
344376
self.report_success()
377+
self._circuit_breaker.record_success()
345378
return inserted_data
346379

347380
except Exception as e:

openviking/storage/queuefs/semantic_processor.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from openviking.storage.queuefs.semantic_msg import SemanticMsg
2929
from openviking.storage.viking_fs import get_viking_fs
3030
from openviking.telemetry import bind_telemetry, resolve_telemetry
31+
from openviking.utils.circuit_breaker import (
32+
CircuitBreaker,
33+
CircuitBreakerOpen,
34+
classify_api_error,
35+
)
3136
from openviking_cli.session.user_id import UserIdentifier
3237
from openviking_cli.utils import VikingURI
3338
from openviking_cli.utils.config import get_openviking_config
@@ -82,6 +87,7 @@ def __init__(self, max_concurrent_llm: int = 100):
8287
self._dag_executor: Optional[SemanticDagExecutor] = None
8388
self._current_ctx = RequestContext(user=UserIdentifier.the_default_user(), role=Role.ROOT)
8489
self._current_msg: Optional[SemanticMsg] = None
90+
self._circuit_breaker = CircuitBreaker()
8591

8692
@classmethod
8793
def _cache_dag_stats(cls, telemetry_id: str, uri: str, stats: DagStats) -> None:
@@ -204,6 +210,24 @@ async def _check_file_content_changed(
204210
except Exception:
205211
return True
206212

213+
async def _reenqueue_semantic_msg(self, msg: SemanticMsg) -> None:
214+
"""Re-enqueue a semantic message for later processing."""
215+
import asyncio
216+
217+
from openviking.storage.queuefs import get_queue_manager
218+
219+
wait = self._circuit_breaker.retry_after
220+
if wait > 0:
221+
await asyncio.sleep(wait)
222+
223+
queue_manager = get_queue_manager()
224+
if queue_manager is not None:
225+
semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC)
226+
await semantic_queue.enqueue(msg)
227+
logger.info(f"Re-enqueued semantic message: {msg.uri}")
228+
else:
229+
logger.warning(f"No queue manager available, cannot re-enqueue: {msg.uri}")
230+
207231
async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
208232
"""Process dequeued SemanticMsg, recursively process all subdirectories."""
209233
msg: Optional[SemanticMsg] = None
@@ -219,6 +243,17 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str,
219243

220244
assert data is not None
221245
msg = SemanticMsg.from_dict(data)
246+
247+
# Circuit breaker: if API is known-broken, re-enqueue and wait
248+
try:
249+
self._circuit_breaker.check()
250+
except CircuitBreakerOpen:
251+
logger.warning(
252+
f"Circuit breaker is open, re-enqueueing semantic message: {msg.uri}"
253+
)
254+
await self._reenqueue_semantic_msg(msg)
255+
self.report_success()
256+
return None
222257
collector = resolve_telemetry(msg.telemetry_id)
223258
telemetry_ctx = bind_telemetry(collector) if collector is not None else nullcontext()
224259
with telemetry_ctx:
@@ -276,13 +311,37 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str,
276311
self._merge_request_stats(msg.telemetry_id, processed=1)
277312
logger.info(f"Completed semantic generation for: {msg.uri}")
278313
self.report_success()
314+
self._circuit_breaker.record_success()
279315
return None
280316

281317
except Exception as e:
282-
logger.error(f"Failed to process semantic message: {e}", exc_info=True)
283-
if msg is not None:
284-
self._merge_request_stats(msg.telemetry_id, error_count=1)
285-
self.report_error(str(e), data)
318+
error_class = classify_api_error(e)
319+
if error_class == "permanent":
320+
logger.critical(
321+
f"Permanent API error processing semantic message, dropping: {e}",
322+
exc_info=True,
323+
)
324+
self._circuit_breaker.record_failure(e)
325+
if msg is not None:
326+
self._merge_request_stats(msg.telemetry_id, error_count=1)
327+
self.report_error(str(e), data)
328+
else:
329+
logger.warning(
330+
f"Transient API error processing semantic message, re-enqueueing: {e}",
331+
exc_info=True,
332+
)
333+
self._circuit_breaker.record_failure(e)
334+
if msg is not None:
335+
try:
336+
await self._reenqueue_semantic_msg(msg)
337+
except Exception as requeue_err:
338+
logger.error(f"Failed to re-enqueue semantic message: {requeue_err}")
339+
self._merge_request_stats(msg.telemetry_id, error_count=1)
340+
self.report_error(str(e), data)
341+
return None
342+
self.report_success()
343+
else:
344+
self.report_error(str(e), data)
286345
return None
287346
finally:
288347
# Safety net: release lifecycle lock if still held (e.g. on exception
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Circuit breaker and error classification for API call protection."""
4+
5+
from __future__ import annotations
6+
7+
import threading
8+
import time
9+
10+
from openviking_cli.utils.logger import get_logger
11+
12+
logger = get_logger(__name__)
13+
14+
# --- Error classification ---
15+
16+
_PERMANENT_PATTERNS = ("403", "401", "Forbidden", "Unauthorized", "AccountOverdue")
17+
_TRANSIENT_PATTERNS = (
18+
"429",
19+
"500",
20+
"502",
21+
"503",
22+
"504",
23+
"TooManyRequests",
24+
"RateLimit",
25+
"timeout",
26+
"Timeout",
27+
"ConnectionError",
28+
"Connection refused",
29+
"Connection reset",
30+
)
31+
32+
33+
def classify_api_error(error: Exception) -> str:
34+
"""Classify an API error as permanent, transient, or unknown.
35+
36+
Checks both str(error) and str(error.__cause__) for known patterns.
37+
38+
Returns:
39+
"permanent" — 403/401, never retry.
40+
"transient" — 429/5xx/timeout, safe to retry.
41+
"unknown" — unrecognized, treated as transient by callers.
42+
"""
43+
texts = [str(error)]
44+
if error.__cause__ is not None:
45+
texts.append(str(error.__cause__))
46+
47+
for text in texts:
48+
for pattern in _PERMANENT_PATTERNS:
49+
if pattern in text:
50+
return "permanent"
51+
52+
for text in texts:
53+
for pattern in _TRANSIENT_PATTERNS:
54+
if pattern in text:
55+
return "transient"
56+
57+
return "unknown"
58+
59+
60+
# --- Circuit breaker ---
61+
62+
_STATE_CLOSED = "CLOSED"
63+
_STATE_OPEN = "OPEN"
64+
_STATE_HALF_OPEN = "HALF_OPEN"
65+
66+
67+
class CircuitBreakerOpen(Exception):
68+
"""Raised when the circuit breaker is open and blocking requests."""
69+
70+
71+
class CircuitBreaker:
72+
"""Thread-safe circuit breaker for API call protection.
73+
74+
Trips after ``failure_threshold`` consecutive failures (or immediately for
75+
permanent errors like 403/401). After ``reset_timeout`` seconds, allows one
76+
probe request (HALF_OPEN). If the probe succeeds, the breaker closes; if it
77+
fails, the breaker reopens.
78+
"""
79+
80+
def __init__(self, failure_threshold: int = 5, reset_timeout: float = 300):
81+
self._failure_threshold = failure_threshold
82+
self._reset_timeout = reset_timeout
83+
self._lock = threading.Lock()
84+
self._state = _STATE_CLOSED
85+
self._failure_count = 0
86+
self._last_failure_time: float = 0
87+
88+
def check(self) -> None:
89+
"""Allow the request through, or raise ``CircuitBreakerOpen``."""
90+
with self._lock:
91+
if self._state == _STATE_CLOSED:
92+
return
93+
if self._state == _STATE_HALF_OPEN:
94+
return # allow probe request
95+
# OPEN — check if timeout elapsed
96+
elapsed = time.monotonic() - self._last_failure_time
97+
if elapsed >= self._reset_timeout:
98+
self._state = _STATE_HALF_OPEN
99+
logger.info("Circuit breaker transitioning OPEN -> HALF_OPEN (timeout elapsed)")
100+
return
101+
raise CircuitBreakerOpen(
102+
f"Circuit breaker is OPEN, retry after {self._reset_timeout - elapsed:.0f}s"
103+
)
104+
105+
@property
106+
def retry_after(self) -> float:
107+
"""Seconds until the breaker may transition to HALF_OPEN, capped at 30s.
108+
109+
Returns 0 if the breaker is CLOSED or HALF_OPEN.
110+
"""
111+
with self._lock:
112+
if self._state != _STATE_OPEN:
113+
return 0
114+
remaining = self._reset_timeout - (time.monotonic() - self._last_failure_time)
115+
return min(max(remaining, 0), 30)
116+
117+
def record_success(self) -> None:
118+
"""Record a successful API call. Resets failure count."""
119+
with self._lock:
120+
if self._state == _STATE_HALF_OPEN:
121+
logger.info("Circuit breaker transitioning HALF_OPEN -> CLOSED (probe succeeded)")
122+
self._failure_count = 0
123+
self._state = _STATE_CLOSED
124+
125+
def record_failure(self, error: Exception) -> None:
126+
"""Record a failed API call. May trip the breaker."""
127+
error_class = classify_api_error(error)
128+
with self._lock:
129+
self._failure_count += 1
130+
self._last_failure_time = time.monotonic()
131+
132+
if self._state == _STATE_HALF_OPEN:
133+
self._state = _STATE_OPEN
134+
logger.info(
135+
f"Circuit breaker transitioning HALF_OPEN -> OPEN (probe failed: {error})"
136+
)
137+
return
138+
139+
if error_class == "permanent":
140+
self._state = _STATE_OPEN
141+
logger.info(f"Circuit breaker tripped immediately on permanent error: {error}")
142+
return
143+
144+
if self._failure_count >= self._failure_threshold:
145+
self._state = _STATE_OPEN
146+
logger.info(
147+
f"Circuit breaker tripped after {self._failure_count} consecutive "
148+
f"failures: {error}"
149+
)

0 commit comments

Comments
 (0)