Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions sdk/python/ragflow_sdk/modules/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,31 @@
# limitations under the License.
#

from typing import Any, Optional, TYPE_CHECKING
from .base import Base
from .session import Session

if TYPE_CHECKING:
from ..ragflow import RAGFlow

__all__ = 'Agent',

class Agent(Base):
def __init__(self, rag, res_dict):
__slots__ = (
'id',
'avatar',
'canvas_type',
'description',
'dsl',
)

id: Optional[str]
avatar: Optional[str]
canvas_type: Optional[str]
description: Optional[str]
dsl: Optional["Agent.Dsl"]

def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = None
self.avatar = None
self.canvas_type = None
Expand All @@ -28,7 +47,25 @@ def __init__(self, rag, res_dict):
super().__init__(rag, res_dict)

class Dsl(Base):
def __init__(self, rag, res_dict):
__slots__ = (
'answer',
'components',
'graph',
'history',
'messages',
'path',
'reference',
)
# TODO: Proper typing including TypedDict for the dicts. Where is the specification of the DSL?
answer: list[Any]
components: dict[str, Any]
graph: dict[str, Any]
history: list[Any]
messages: list[Any]
path: list[Any]
reference: list[Any]

def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.answer = []
self.components = {
"begin": {
Expand Down Expand Up @@ -65,8 +102,8 @@ def __init__(self, rag, res_dict):
self.reference = []
super().__init__(rag, res_dict)


def create_session(self, **kwargs) -> Session:
# TODO: Proper typing of kwargs. Where are these parameters defined?
def create_session(self, **kwargs: dict[str, Any]) -> Session:
res = self.post(f"/agents/{self.id}/sessions", json=kwargs)
res = res.json()
if res.get("code") == 0:
Expand All @@ -75,7 +112,7 @@ def create_session(self, **kwargs) -> Session:


def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str = None) -> list[Session]:
id: Optional[str] = None) -> list[Session]:
res = self.get(f"/agents/{self.id}/sessions",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id})
res = res.json()
Expand All @@ -87,8 +124,8 @@ def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "crea
return result_list
raise Exception(res.get("message"))

def delete_sessions(self, ids: list[str] | None = None):
def delete_sessions(self, ids: list[str] | None = None) -> None:
res = self.rm(f"/agents/{self.id}/sessions", {"ids": ids})
res = res.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
raise Exception(res.get("message"))
32 changes: 21 additions & 11 deletions sdk/python/ragflow_sdk/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,31 @@
# limitations under the License.
#

from typing import Any, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from requests import Response
from requests.sessions import _Files, _Params
from ..ragflow import RAGFlow

class Base:
def __init__(self, rag, res_dict):
__slots__ = 'rag',

rag: "RAGFlow"

def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.rag = rag
self._update_from_dict(rag, res_dict)

def _update_from_dict(self, rag, res_dict):
def _update_from_dict(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
for k, v in res_dict.items():
if isinstance(v, dict):
self.__dict__[k] = Base(rag, v)
setattr(self, k, Base(rag, v))
else:
self.__dict__[k] = v
setattr(self, k, v)

def to_json(self):
pr = {}
def to_json(self) -> dict[str, Any]:
pr: dict[str, Any] = {}
for name in dir(self):
value = getattr(self, name)
if not name.startswith("__") and not callable(value) and name != "rag":
Expand All @@ -38,21 +48,21 @@ def to_json(self):
pr[name] = value
return pr

def post(self, path, json=None, stream=False, files=None):
def post(self, path: str, json: Any=None, stream: bool=False, files: Optional["_Files"]=None) -> "Response":
res = self.rag.post(path, json, stream=stream, files=files)
return res

def get(self, path, params=None):
def get(self, path: str, params: Optional["_Params"]=None) -> "Response":
res = self.rag.get(path, params)
return res

def rm(self, path, json):
def rm(self, path: str, json: Any) -> "Response":
res = self.rag.delete(path, json)
return res

def put(self, path, json):
def put(self, path: str, json: Any) -> "Response":
res = self.rag.put(path, json)
return res

def __str__(self):
def __str__(self) -> str:
return str(self.to_json())
102 changes: 96 additions & 6 deletions sdk/python/ragflow_sdk/modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,62 @@
#


from typing import Any, NotRequired, Optional, TYPE_CHECKING, TypedDict
from .base import Base
from .session import Session

if TYPE_CHECKING:
from ..ragflow import RAGFlow

__all__ = 'Chat',

class Variable(TypedDict):
key: str
optional: NotRequired[bool]

LLMUpdateMessage = TypedDict('LLMUpdateMessage', {
"model_name": NotRequired[str],
"temperature": NotRequired[float],
"top_p": NotRequired[float],
"presence_penalty": NotRequired[float],
"frequency penalty": NotRequired[float],
})

class PromptUpdateMessage(TypedDict):
similarity_threshold: NotRequired[float]
keywords_similarity_weight: NotRequired[float]
top_n: NotRequired[int]
variables: NotRequired[list[Variable]]
rerank_model: NotRequired[str]
empty_response: NotRequired[str]
opener: NotRequired[str]
show_quote: NotRequired[bool]
prompt: NotRequired[str]

class UpdateMessage(TypedDict):
name: NotRequired[str]
avatar: NotRequired[str]
dataset_ids: NotRequired[list[str]]
llm: NotRequired[LLMUpdateMessage]
prompt: NotRequired[PromptUpdateMessage]


class Chat(Base):
def __init__(self, rag, res_dict):
__slots__ = (
'id',
'name',
'avatar',
'llm',
'prompt',
)

id: str
name: str
avatar: str
llm: "Chat.LLM"
prompt: "Chat.Prompt"

def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = ""
self.name = "assistant"
self.avatar = "path/to/avatar"
Expand All @@ -29,7 +79,23 @@ def __init__(self, rag, res_dict):
super().__init__(rag, res_dict)

class LLM(Base):
def __init__(self, rag, res_dict):
__slots__ = (
'model_name',
'temperature',
'top_p',
'presence_penalty',
'frequency_penalty',
'max_tokens',
)

model_name: Optional[str]
temperature: float
top_p: float
presence_penalty: float
frequency_penalty: float
max_tokens: int

def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.model_name = None
self.temperature = 0.1
self.top_p = 0.3
Expand All @@ -39,7 +105,31 @@ def __init__(self, rag, res_dict):
super().__init__(rag, res_dict)

class Prompt(Base):
def __init__(self, rag, res_dict):
__slots__ = (
'similarity_threshold',
'keywords_similarity_weight',
'top_n',
'top_k',
'variables',
'rerank_model',
'empty_response',
'opener',
'show_quote',
'prompt',
)

similarity_threshold: float
keywords_similarity_weight: float
top_n: int
top_k: int
variables: list[Variable]
rerank_model: str
empty_response: Optional[str]
opener: str
show_quote: bool
prompt: str

def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.7
self.top_n = 8
Expand All @@ -57,7 +147,7 @@ def __init__(self, rag, res_dict):
)
super().__init__(rag, res_dict)

def update(self, update_message: dict):
def update(self, update_message: UpdateMessage) -> None:
res = self.put(f"/chats/{self.id}", update_message)
res = res.json()
if res.get("code") != 0:
Expand All @@ -70,7 +160,7 @@ def create_session(self, name: str = "New session") -> Session:
return Session(self.rag, res["data"])
raise Exception(res["message"])

def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None) -> list[Session]:
def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: Optional[str] = None, name: Optional[str] = None) -> list[Session]:
res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
res = res.json()
if res.get("code") == 0:
Expand All @@ -80,7 +170,7 @@ def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "crea
return result_list
raise Exception(res["message"])

def delete_sessions(self, ids: list[str] | None = None):
def delete_sessions(self, ids: list[str] | None = None) -> None:
res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids})
res = res.json()
if res.get("code") != 0:
Expand Down
Loading