Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/memu/app/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any:
chat_model=cfg.chat_model,
embed_model=cfg.embed_model,
embed_batch_size=cfg.embed_batch_size,
proxy=cfg.proxy,
)
elif backend == "httpx":
return HTTPLLMClient(
Expand All @@ -116,6 +117,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any:
provider=cfg.provider,
endpoint_overrides=cfg.endpoint_overrides,
embed_model=cfg.embed_model,
proxy=cfg.proxy,
)
elif backend == "lazyllm_backend":
from memu.llm.lazyllm_client import LazyLLMClient
Expand All @@ -129,6 +131,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any:
embed_model=cfg.embed_model,
vlm_model=cfg.lazyllm_source.vlm_model,
stt_model=cfg.lazyllm_source.stt_model,
proxy=cfg.proxy,
)
else:
msg = f"Unknown llm_client_backend '{cfg.client_backend}'"
Expand Down
4 changes: 4 additions & 0 deletions src/memu/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class LLMConfig(BaseModel):
default=1,
description="Maximum batch size for embedding API calls (used by SDK client backends).",
)
proxy: str | None = Field(
default=None,
description="HTTP proxy URL for LLM requests (e.g., 'http://proxy.example.com:8080').",
)

@model_validator(mode="after")
def set_provider_defaults(self) -> "LLMConfig":
Expand Down
4 changes: 2 additions & 2 deletions src/memu/embedding/backends/doubao.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DoubaoEmbeddingBackend(EmbeddingBackend):

def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
"""Build payload for standard text embeddings."""
return {"model": embed_model, "input": inputs, "encoding_format": "float"}
return {"model": embed_model, "inputs": inputs, "encoding_format": "float"}

def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
"""Parse embedding response."""
Expand All @@ -64,7 +64,7 @@ def build_multimodal_embedding_payload(
return {
"model": embed_model,
"encoding_format": encoding_format,
"input": [inp.to_dict() for inp in inputs],
"inputs": [inp.to_dict() for inp in inputs],
}

def parse_multimodal_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
Expand Down
7 changes: 5 additions & 2 deletions src/memu/llm/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
endpoint_overrides: dict[str, str] | None = None,
timeout: int = 60,
embed_model: str | None = None,
proxy: str | None = None,
):
# Ensure base_url ends with "/" so httpx doesn't discard the path
# component when joining with endpoint paths.
Expand All @@ -101,7 +102,9 @@ def __init__(
self.backend = self._load_backend(self.provider)
self.embedding_backend = self._load_embedding_backend(self.provider)
overrides = endpoint_overrides or {}
raw_summary_ep = overrides.get("chat") or overrides.get("summary") or self.backend.summary_endpoint
raw_summary_ep = (
overrides.get("chat") or overrides.get("summary") or self.backend.summary_endpoint
)
raw_embedding_ep = (
overrides.get("embeddings")
or overrides.get("embedding")
Expand All @@ -114,7 +117,7 @@ def __init__(
self.embedding_endpoint = raw_embedding_ep.lstrip("/")
self.timeout = timeout
self.embed_model = embed_model or chat_model
self.proxy = _load_proxy()
self.proxy = proxy or _load_proxy()

async def chat(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/memu/llm/lazyllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
vlm_model: str | None = None,
embed_model: str | None = None,
stt_model: str | None = None,
proxy: str | None = None,
):
self.llm_source = llm_source or self.DEFAULT_SOURCE
self.vlm_source = vlm_source or self.DEFAULT_SOURCE
Expand All @@ -31,6 +32,13 @@ def __init__(
self.vlm_model = vlm_model
self.embed_model = embed_model
self.stt_model = stt_model
self.proxy = proxy

# Set proxy for LazyLLM if provided
if proxy:
import os
os.environ["HTTP_PROXY"] = proxy
os.environ["HTTPS_PROXY"] = proxy

async def _call_async(self, client: Any, *args: Any, **kwargs: Any) -> Any:
"""
Expand Down
14 changes: 13 additions & 1 deletion src/memu/llm/openai_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Literal, cast

import httpx
from openai import AsyncOpenAI
from openai.types import CreateEmbeddingResponse
from openai.types.chat import (
Expand All @@ -28,13 +29,24 @@ def __init__(
chat_model: str,
embed_model: str,
embed_batch_size: int = 1,
proxy: str | None = None,
):
self.base_url = base_url.rstrip("/")
self.api_key = api_key or ""
self.chat_model = chat_model
self.embed_model = embed_model
self.embed_batch_size = embed_batch_size
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)

# Create httpx client with proxy if provided
http_client = None
if proxy:
http_client = httpx.AsyncClient(proxy=proxy)

self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url,
http_client=http_client
)

async def chat(
self,
Expand Down