Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions docling_serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
)
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from prometheus_client.core import REGISTRY
from prometheus_fastapi_instrumentator import Instrumentator
from scalar_fastapi import get_scalar_api_reference

from docling.datamodel.base_models import DocumentStream
Expand Down Expand Up @@ -70,7 +72,8 @@
from docling_serve.helper_functions import FormDepends
from docling_serve.orchestrator_factory import get_async_orchestrator
from docling_serve.response_preparation import prepare_response
from docling_serve.settings import docling_serve_settings
from docling_serve.rq_metrics_collector import RQCollector, get_redis_connection
from docling_serve.settings import AsyncEngine, docling_serve_settings
from docling_serve.storage import get_scratch
from docling_serve.websocket_notifier import WebsocketNotifier

Expand Down Expand Up @@ -108,33 +111,50 @@ def format(self, record):


# Context manager to initialize and clean up the lifespan of the FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
scratch_dir = get_scratch()
def create_lifespan_handler(instrumentator: Instrumentator):
"""
Create a FastAPI lifespan handler for the application

orchestrator = get_async_orchestrator()
notifier = WebsocketNotifier(orchestrator)
orchestrator.bind_notifier(notifier)
@param instrumentator: A prometheus instrumentator used to expose metrics
"""

# Warm up processing cache
if docling_serve_settings.load_models_at_boot:
await orchestrator.warm_up_caches()
@asynccontextmanager
async def lifespan(app: FastAPI):
scratch_dir = get_scratch()

# Start the background queue processor
queue_task = asyncio.create_task(orchestrator.process_queue())
orchestrator = get_async_orchestrator()
notifier = WebsocketNotifier(orchestrator)
orchestrator.bind_notifier(notifier)

yield
# Warm up processing cache
if docling_serve_settings.load_models_at_boot:
await orchestrator.warm_up_caches()

# Cancel the background queue processor on shutdown
queue_task.cancel()
try:
await queue_task
except asyncio.CancelledError:
_log.info("Queue processor cancelled.")
# Start the background queue processor
queue_task = asyncio.create_task(orchestrator.process_queue())

if docling_serve_settings.eng_kind == AsyncEngine.RQ:
connection = get_redis_connection(
url=docling_serve_settings.eng_rq_redis_url
)
REGISTRY.register(RQCollector(connection))

instrumentator.expose(app)

yield

# Remove scratch directory in case it was a tempfile
if docling_serve_settings.scratch_path is not None:
shutil.rmtree(scratch_dir, ignore_errors=True)
# Cancel the background queue processor on shutdown
queue_task.cancel()
try:
await queue_task
except asyncio.CancelledError:
_log.info("Queue processor cancelled.")

# Remove scratch directory in case it was a tempfile
if docling_serve_settings.scratch_path is not None:
shutil.rmtree(scratch_dir, ignore_errors=True)

return lifespan


##################################
Expand All @@ -159,11 +179,13 @@ def create_app(): # noqa: C901
_log.info("Found static assets.")

require_auth = APIKeyAuth(docling_serve_settings.api_key)

instrumentator = Instrumentator()
app = FastAPI(
title="Docling Serve",
docs_url=None if offline_docs_assets else "/swagger",
redoc_url=None if offline_docs_assets else "/docs",
lifespan=lifespan,
lifespan=create_lifespan_handler(instrumentator),
version=version,
)

Expand All @@ -179,6 +201,8 @@ def create_app(): # noqa: C901
allow_headers=headers,
)

instrumentator.instrument(app).expose(app)

# Mount the Gradio app
if docling_serve_settings.enable_ui:
try:
Expand Down
132 changes: 132 additions & 0 deletions docling_serve/rq_metrics_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Heavily based on https://github.com/mdawar/rq-exporter, thank you <3
import logging

from prometheus_client import Summary
from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily
from prometheus_client.registry import Collector
from redis import Redis
from rq import Queue, Worker
from rq.job import JobStatus

logger = logging.getLogger(__name__)


def get_redis_connection(url: str):
return Redis.from_url(url)


def get_workers_stats(connection):
"""Get the RQ workers stats."""

workers = Worker.all(connection)

return [
{
"name": w.name,
"queues": w.queue_names(),
"state": w.get_state(),
"successful_job_count": w.successful_job_count,
"failed_job_count": w.failed_job_count,
"total_working_time": w.total_working_time,
}
for w in workers
]


def get_queue_jobs(connection, queue_name):
"""Get the jobs by status of a Queue."""

queue = Queue(connection=connection, name=queue_name)

return {
JobStatus.QUEUED: queue.count,
JobStatus.STARTED: queue.started_job_registry.count,
JobStatus.FINISHED: queue.finished_job_registry.count,
JobStatus.FAILED: queue.failed_job_registry.count,
JobStatus.DEFERRED: queue.deferred_job_registry.count,
JobStatus.SCHEDULED: queue.scheduled_job_registry.count,
}


def get_jobs_by_queue(connection):
"""Get the current jobs by queue"""

queues = Queue.all(connection)

return {q.name: get_queue_jobs(connection, q.name) for q in queues}


class RQCollector(Collector):
"""RQ stats collector."""

def __init__(self, connection=None):
self.connection = connection

# RQ data collection count and time in seconds
self.summary = Summary(
"rq_request_processing_seconds", "Time spent collecting RQ data"
)

def collect(self):
"""Collect RQ Metrics."""
logger.debug("Collecting the RQ metrics...")

with self.summary.time():
rq_workers = GaugeMetricFamily(
"rq_workers",
"RQ workers",
labels=["name", "state", "queues"],
)
rq_workers_success = CounterMetricFamily(
"rq_workers_success",
"RQ workers success count",
labels=["name", "queues"],
)
rq_workers_failed = CounterMetricFamily(
"rq_workers_failed",
"RQ workers fail count",
labels=["name", "queues"],
)
rq_workers_working_time = CounterMetricFamily(
"rq_workers_working_time",
"RQ workers spent seconds",
labels=["name", "queues"],
)
rq_jobs = GaugeMetricFamily(
"rq_jobs",
"RQ jobs by state",
labels=["queue", "status"],
)

workers = get_workers_stats(self.connection)
for worker in workers:
label_queues = ",".join(worker["queues"])
rq_workers.add_metric(
[worker["name"], worker["state"], label_queues],
1,
)
rq_workers_success.add_metric(
[worker["name"], label_queues],
worker["successful_job_count"],
)
rq_workers_failed.add_metric(
[worker["name"], label_queues],
worker["failed_job_count"],
)
rq_workers_working_time.add_metric(
[worker["name"], label_queues],
worker["total_working_time"],
)

yield rq_workers
yield rq_workers_success
yield rq_workers_failed
yield rq_workers_working_time

for queue_name, jobs in get_jobs_by_queue(self.connection).items():
for status, count in jobs.items():
rq_jobs.add_metric([queue_name, status], count)

yield rq_jobs

logger.debug("RQ metrics collection finished")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [
"websockets~=14.0",
"scalar-fastapi>=1.0.3",
"docling-mcp>=1.0.0",
"prometheus-fastapi-instrumentator>=7.1.0",
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading