Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions examples/onebot_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Example of integrating memU with a OneBot v11 (NapCat) QQ bot for long-term memory."""

import sys
import types
import asyncio
import logging
from openai import AsyncOpenAI

# Mock the core Rust module if not compiled in the environment
mock_core = types.ModuleType("memu._core")
mock_core.hello_from_bin = lambda: "Hello from mocked bin!"
sys.modules["memu._core"] = mock_core

from memu.app import MemoryService
from memu.app.settings import MemorizeConfig, CategoryConfig
from memu.integrations.onebot import OneBotAdapter, OneBotConfig

# Configure logging to match memU standards
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("memu.examples.onebot")

# ==============================================================================
# LLM Configuration
# ==============================================================================
LLM_PROFILES = {
"default": {
"base_url": "", # Your LLM endpoint
"api_key": "", # Your LLM API key
"chat_model": "", # Your LLM chat model
"client_backend": "sdk"
},
"embedding": {
"base_url": "", # Your LLM endpoint for embeddings
"api_key": "", # Your LLM API key for embeddings
"embed_model": "", # Your LLM embedding model
"client_backend": "sdk"
}
}

# Initialize a chat client for generating conversational responses
chat_client = AsyncOpenAI(
base_url=LLM_PROFILES["default"]["base_url"],
api_key=LLM_PROFILES["default"]["api_key"]
)

async def on_qq_message(event: dict, adapter: OneBotAdapter, memory_service: MemoryService):
"""Callback function triggered when a QQ message is received via OneBot."""
user_id = str(event.get("user_id"))
group_id = event.get("group_id")
text = event.get("clean_text", "")
message_type = event.get("message_type")

# Ignore empty messages or messages sent by the bot itself
if not text or user_id == str(adapter.get_self_id()):
return

logger.info(f"Received message from User [{user_id}]: {text}")

try:
current_user = {"user_id": user_id}

# Retrieve long-term memories related to the current message
logger.info("Retrieving related memories from memU...")
retrieved_result = await memory_service.retrieve(
queries=[{"role": "user", "content": text}],
where=current_user
)

# Handle the result based on memU's return structure
items = (
retrieved_result.get("items", [])
if isinstance(retrieved_result, dict)
else getattr(retrieved_result, "items", [])
)

memory_context = ""
if items:
memory_context = "\n".join([
f"- {m.get('summary', '')}" if isinstance(m, dict) else f"- {m.summary}"
for m in items
])
logger.info(f"Successfully retrieved {len(items)} related memory item(s).")
else:
logger.info("No related long-term memory found for the current topic.")

# Generate a response using the LLM, augmented by retrieved memories
logger.info("Generating response using LLM with memory context...")
system_prompt = (
"You are an AI assistant. There is no short-term chat history between us. "
"You MUST rely exclusively on the clues provided in the [Long-Term Memory Database] below to answer. "
"If there is no relevant information in the memory, simply state that you don't know or don't remember. "
"If the memory contradicts the user's current statement (e.g., they used to like apples but now say oranges), "
"point out the change in a friendly conversational tone.\n\n"
f"[Long-Term Memory Database]:\n{memory_context if memory_context else 'Empty'}\n"
)

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text}
]

response = await chat_client.chat.completions.create(
model=LLM_PROFILES["default"]["chat_model"],
messages=messages,
temperature=0.7
)
reply_text = response.choices[0].message.content

# Send the generated response back to the QQ group or private chat
if message_type == "group":
logger.info(f"Sending reply to group [{group_id}].")
await adapter.send_group_msg_ack(int(group_id), f"[CQ:at,qq={user_id}] {reply_text}")
else:
logger.info(f"Sending private reply to user [{user_id}].")
adapter.send_private_msg(int(user_id), reply_text)

# Persist the current message into the memory database
logger.info("Committing current interaction to memU storage...")
await memory_service.create_memory_item(
memory_type="event",
memory_content=text,
memory_categories=["QQConversations"],
user=current_user
)

logger.info("Interaction cycle completed successfully.")

except Exception as e:
logger.error(f"Error occurred during interaction cycle: {e}", exc_info=True)


async def main():
"""Main entry point for starting the OneBot memory integration example."""
logger.info("Initializing memU MemoryService...")
memory_service = MemoryService(
llm_profiles=LLM_PROFILES,
memorize_config=MemorizeConfig(
memory_categories=[
CategoryConfig(name="QQConversations", description="Records of chats from QQ")
]
),
)

async def msg_handler(event: dict, adapter: OneBotAdapter):
await on_qq_message(event, adapter, memory_service)

# Initialize the OneBot adapter with standard configuration
logger.info("Initializing OneBot v11 Adapter...")
config = OneBotConfig(
ws_url="ws://127.0.0.1:3001",
access_token="" # Replace with environment variable in production
)

adapter = OneBotAdapter(config=config, on_message=msg_handler)

logger.info("Starting OneBot engine and entering event loop...")
await adapter.connect()


if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application shutdown requested by user.")
223 changes: 223 additions & 0 deletions src/memu/integrations/onebot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import json
import asyncio
import logging
import re
import time
import uuid
from typing import Callable, Awaitable, Any, Dict, List, Optional

logger = logging.getLogger(__name__)

class OneBotConfig:
"""Configuration for the OneBot (NapCat) adapter."""

def __init__(
self,
ws_url: str,
access_token: Optional[str] = None,
):
self.ws_url = ws_url
self.access_token = access_token


class OneBotAdapter:
"""Adapter for connecting to a OneBot v11 implementation (e.g., NapCat) via WebSocket."""

def __init__(
self,
config: OneBotConfig,
on_message: Optional[Callable[[Dict[str, Any], "OneBotAdapter"], Awaitable[None]]] = None
):
self.config = config
self.on_message = on_message

self._ws = None
self._self_id: Optional[int] = None
self._is_alive = False

self._reconnect_attempts = 0
self._max_reconnect_delay = 60.0
self._last_message_at = 0.0
self._pending_messages: List[Dict[str, Any]] = []
self._pending_requests: Dict[str, asyncio.Future] = {}

def get_self_id(self) -> Optional[int]:
"""Get the connected bot's user ID."""
return self._self_id

def is_connected(self) -> bool:
"""Check if the WebSocket connection is currently active."""
return self._is_alive and self._ws is not None

async def connect(self):
"""Main loop for managing WebSocket connection and auto-reconnection."""
import websockets
kwargs = {}
if self.config.access_token:
kwargs["additional_headers"] = {"Authorization": f"Bearer {self.config.access_token}"}

while True:
try:
logger.info(f"Connecting to OneBot server at {self.config.ws_url}...")

# Disable underlying ping to fully simulate Node.js ws library behavior.
async with websockets.connect(self.config.ws_url, ping_interval=None, **kwargs) as ws:
self._ws = ws
self._is_alive = True
self._reconnect_attempts = 0
self._last_message_at = time.time()
logger.info("Connected to OneBot server successfully.")

if self._pending_messages:
to_flush = self._pending_messages[:]
self._pending_messages.clear()
sent = 0
for item in to_flush:
try:
await self._ws.send(json.dumps(item))
sent += 1
except Exception:
pass
if sent > 0:
logger.info(f"Flushed {sent}/{len(to_flush)} queued outbound message(s).")

# Wrap get_login_info in try-except and delay to prevent initialization shock.
async def fetch_bot_info():
try:
await asyncio.sleep(1.0)
info = await self.get_login_info()
logger.info(f"Bot logged in successfully. Basic info: {info}")
except Exception as e:
logger.warning(f"Failed to fetch Bot account info: {e}")

asyncio.create_task(fetch_bot_info())

# Start application layer heartbeat detection task.
heartbeat_task = asyncio.create_task(self._heartbeat_loop())

try:
async for message in ws:
self._is_alive = True
self._last_message_at = time.time()
self._handle_raw_message(message)

# Log specific disconnection status codes for easier troubleshooting.
logger.warning(f"WebSocket loop exited cleanly. Close code: {ws.close_code}, reason: {ws.close_reason}")

except websockets.exceptions.ConnectionClosed as cc:
logger.warning(f"WebSocket connection closed by server: Code={cc.code}, Reason={cc.reason}")
finally:
heartbeat_task.cancel()

except Exception as e:
logger.error(f"WebSocket Error: {e}")

# Execute cleanup and backoff reconnection logic.
self._cleanup()
delay = min(1.0 * (2 ** self._reconnect_attempts), self._max_reconnect_delay)
logger.info(f"Reconnecting in {delay}s (Attempt {self._reconnect_attempts + 1})...")
await asyncio.sleep(delay)
self._reconnect_attempts += 1

def _cleanup(self):
"""Clean up connection state and discard all pending Futures."""
self._is_alive = False
self._ws = None
for future in self._pending_requests.values():
if not future.done():
future.set_exception(Exception("Connection closed"))
self._pending_requests.clear()

async def _heartbeat_loop(self):
"""Heartbeat checker: forcefully disconnect to trigger a reconnect if the connection goes idle."""
while True:
await asyncio.sleep(45)
stale_ms = (time.time() - self._last_message_at) * 1000
if stale_ms > 180000: # No inbound traffic for over 3 minutes.
logger.warning(f"No inbound traffic for {int(stale_ms/1000)}s, forcing reconnect...")
if self._ws:
await self._ws.close()
break

def _handle_raw_message(self, message: str):
"""Parse and route inbound messages."""
try:
event = json.loads(message)
except json.JSONDecodeError:
return

# Process API responses (Echo callback mechanism)
if "echo" in event and event["echo"] in self._pending_requests:
future = self._pending_requests.pop(event["echo"])
if not future.done():
if event.get("status") == "ok":
future.set_result(event.get("data"))
else:
future.set_exception(Exception(event.get("msg", "API request failed")))
return

# Process heartbeat meta events
if event.get("post_type") == "meta_event" and event.get("meta_event_type") == "heartbeat":
return

# Record the bot's own User ID upon receiving login info
if event.get("status") == "ok" and "data" in event and "user_id" in event["data"]:
self._self_id = event["data"]["user_id"]

# Trigger business message hook
if event.get("post_type") == "message":
# Clean CQ codes to extract pure text for LLM embedding and processing
clean_text = re.sub(r'\[CQ:.*?\]', '', event.get("raw_message", "")).strip()
event['clean_text'] = clean_text
if self.on_message:
asyncio.create_task(self.on_message(event, self))

# ================= API Implementations =================

async def _send_with_response(self, action: str, params: Dict[str, Any], timeout_ms: int = 5000) -> Any:
"""Send an API request with an Echo identifier and await its Promise (Future) response."""
if not self.is_connected():
raise Exception("WebSocket not open")

echo = str(uuid.uuid4())
req = {"action": action, "params": params, "echo": echo}

future = asyncio.get_event_loop().create_future()
self._pending_requests[echo] = future

await self._ws.send(json.dumps(req))

try:
return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
except asyncio.TimeoutError:
self._pending_requests.pop(echo, None)
raise Exception(f"Request timeout for action: {action}")

def _send_safe(self, action: str, params: Dict[str, Any]):
"""Safely send a request without awaiting its response. Queues the request if disconnected."""
req = {"action": action, "params": params}
if self.is_connected():
asyncio.create_task(self._ws.send(json.dumps(req)))
else:
if len(self._pending_messages) < 200:
self._pending_messages.append(req)

async def get_login_info(self) -> Any:
"""Fetch the login information of the current bot."""
return await self._send_with_response("get_login_info", {})

def send_private_msg(self, user_id: int, message: str):
"""Send a private message silently (no ack)."""
self._send_safe("send_private_msg", {"user_id": user_id, "message": message})

async def send_private_msg_ack(self, user_id: int, message: str) -> Any:
"""Send a private message and wait for the acknowledgment."""
return await self._send_with_response("send_private_msg", {"user_id": user_id, "message": message}, 15000)

def send_group_msg(self, group_id: int, message: str):
"""Send a group message silently (no ack)."""
self._send_safe("send_group_msg", {"group_id": group_id, "message": message})

async def send_group_msg_ack(self, group_id: int, message: str) -> Any:
"""Send a group message and wait for the acknowledgment."""
return await self._send_with_response("send_group_msg", {"group_id": group_id, "message": message}, 15000)