Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions env.example
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ LLM_MODEL=gpt-4o
LLM_BINDING_HOST=https://api.openai.com/v1
LLM_BINDING_API_KEY=your_api_key

#####################################
### Query-specific LLM Configuration
### Use a more powerful model for queries while keeping extraction economical
#####################################
QUERY_BINDING=openai
QUERY_MODEL=gpt-5
QUERY_BINDING_HOST=https://api.openai.com/v1
QUERY_BINDING_API_KEY=your_api_key

### Optional for Azure
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
# AZURE_OPENAI_DEPLOYMENT=gpt-4o
Expand Down
11 changes: 11 additions & 0 deletions lightrag/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,17 @@ def parse_args() -> argparse.Namespace:

# Inject model configuration
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
args.query_binding = get_env_value("QUERY_BINDING", args.llm_binding)
args.query_model = get_env_value("QUERY_MODEL", args.llm_model)
args.query_binding_host = get_env_value(
"QUERY_BINDING_HOST",
get_default_host(args.query_binding)
if args.query_binding != args.llm_binding
else args.llm_binding_host,
)
args.query_binding_api_key = get_env_value(
"QUERY_BINDING_API_KEY", args.llm_binding_api_key
)
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)

Expand Down
79 changes: 79 additions & 0 deletions lightrag/api/lightrag_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ async def lifespan(app: FastAPI):
# Store background tasks
app.state.background_tasks = set()

# Store query LLM function in app state
app.state.query_llm_func = query_llm_func
app.state.query_llm_kwargs = query_llm_kwargs

try:
# Initialize database connections
await rag.initialize_storages()
Expand Down Expand Up @@ -497,6 +501,11 @@ async def optimized_azure_openai_model_complete(

return optimized_azure_openai_model_complete

llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
)

def create_llm_model_func(binding: str):
"""
Create LLM model function based on binding type.
Expand Down Expand Up @@ -524,6 +533,76 @@ def create_llm_model_func(binding: str):
except ImportError as e:
raise Exception(f"Failed to import {binding} LLM binding: {e}")

def create_query_llm_func(args, config_cache, llm_timeout):
"""
Create query-specific LLM function if QUERY_BINDING or QUERY_MODEL is different.
Returns tuple of (query_llm_func, query_llm_kwargs).
"""
# Check if query-specific LLM is configured
# Only skip if BOTH binding AND model are the same
if not hasattr(args, "query_binding") or (
args.query_binding == args.llm_binding
and args.query_model == args.llm_model
):
logger.info("Using same LLM for both extraction and query")
return None, {}

logger.info(
f"Creating separate query LLM: {args.query_binding}/{args.query_model}"
)

# Create a temporary args object for query LLM
class QueryArgs:
pass

query_args = QueryArgs()
query_args.llm_binding = args.query_binding
query_args.llm_model = args.query_model
query_args.llm_binding_host = args.query_binding_host
query_args.llm_binding_api_key = args.query_binding_api_key

# Create query-specific LLM function based on binding type
query_llm_func = None
query_llm_kwargs = {}

try:
if args.query_binding == "openai":
query_llm_func = create_optimized_openai_llm_func(
config_cache, query_args, llm_timeout
)
elif args.query_binding == "azure_openai":
query_llm_func = create_optimized_azure_openai_llm_func(
config_cache, query_args, llm_timeout
)
elif args.query_binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete

query_llm_func = ollama_model_complete
query_llm_kwargs = create_llm_model_kwargs(
args.query_binding, query_args, llm_timeout
)
elif args.query_binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete

query_llm_func = lollms_model_complete
query_llm_kwargs = create_llm_model_kwargs(
args.query_binding, query_args, llm_timeout
)
elif args.query_binding == "aws_bedrock":
query_llm_func = bedrock_model_complete
else:
raise ValueError(f"Unsupported query binding: {args.query_binding}")

logger.info(f"Query LLM configured: {args.query_model}")
return query_llm_func, query_llm_kwargs

except ImportError as e:
raise Exception(f"Failed to import {args.query_binding} LLM binding: {e}")

query_llm_func, query_llm_kwargs = create_query_llm_func(
args, config_cache, llm_timeout
)

def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
"""
Create LLM model kwargs based on binding type.
Expand Down
8 changes: 8 additions & 0 deletions lightrag/api/routers/ollama_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,14 @@ async def stream_generator():
else:
first_chunk_time = time.time_ns()

# Check if query-specific LLM is configured
if (
hasattr(raw_request.app.state, "query_llm_func")
and raw_request.app.state.query_llm_func
):
query_param.model_func = raw_request.app.state.query_llm_func
logger.debug("Using query-specific LLM for Ollama chat")

# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
Expand Down
30 changes: 26 additions & 4 deletions lightrag/api/routers/query_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from typing import Any, Dict, List, Literal, Optional

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from lightrag.base import QueryParam
from lightrag.api.utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -267,7 +267,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
},
},
)
async def query_text(request: QueryRequest):
async def query_text(request: QueryRequest, req: Request):
"""
Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored.

Expand Down Expand Up @@ -343,6 +343,13 @@ async def query_text(request: QueryRequest):
# Force stream=False for /query endpoint regardless of include_references setting
param.stream = False

# Use query-specific LLM if available
if (
hasattr(req.app.state, "query_llm_func")
and req.app.state.query_llm_func
):
param.model_func = req.app.state.query_llm_func

# Unified approach: always use aquery_llm for both cases
result = await rag.aquery_llm(request.query, param=param)

Expand Down Expand Up @@ -438,7 +445,7 @@ async def query_text(request: QueryRequest):
},
},
)
async def query_text_stream(request: QueryRequest):
async def query_text_stream(request: QueryRequest, req: Request):
"""
Advanced RAG query endpoint with flexible streaming response.

Expand Down Expand Up @@ -560,6 +567,13 @@ async def query_text_stream(request: QueryRequest):
stream_mode = request.stream if request.stream is not None else True
param = request.to_query_params(stream_mode)

# Use query-specific LLM if available
if (
hasattr(req.app.state, "query_llm_func")
and req.app.state.query_llm_func
):
param.model_func = req.app.state.query_llm_func

from fastapi.responses import StreamingResponse

# Unified approach: always use aquery_llm for all cases
Expand Down Expand Up @@ -907,7 +921,7 @@ async def stream_generator():
},
},
)
async def query_data(request: QueryRequest):
async def query_data(request: QueryRequest, req: Request):
"""
Advanced data retrieval endpoint for structured RAG analysis.

Expand Down Expand Up @@ -1002,6 +1016,14 @@ async def query_data(request: QueryRequest):
"""
try:
param = request.to_query_params(False) # No streaming for data endpoint

# Use query-specific LLM if available (for keyword extraction)
if (
hasattr(req.app.state, "query_llm_func")
and req.app.state.query_llm_func
):
param.model_func = req.app.state.query_llm_func

response = await rag.aquery_data(request.query, param=param)

# aquery_data returns the new format with status, message, data, and metadata
Expand Down