18
18
from tools .semantic_citation_tool import SemanticCitationTool
19
19
from tools .summarization_tool import TextSummarizationTool
20
20
from tools .task_planning_tool import TaskPlanningTool
21
- from tools .utils import FaissSearch , ReportCallbackHandler , build_index
21
+ from tools .utils import FaissSearch , ReportCallbackHandler , build_index , setup_logging
22
22
23
23
from erniebot_agent .chat_models import ERNIEBot
24
24
from erniebot_agent .extensions .langchain .embeddings import ErnieEmbeddings
25
25
from erniebot_agent .memory import SystemMessage
26
26
from erniebot_agent .retrieval import BaizhongSearch
27
- from erniebot_agent .utils .logging import setup_logging
28
27
29
28
parser = argparse .ArgumentParser ()
30
29
parser .add_argument ("--api_type" , type = str , default = "aistudio" )
79
78
args = parser .parse_args ()
80
79
os .environ ["api_type" ] = args .api_type
81
80
access_token = os .environ .get ("EB_AGENT_ACCESS_TOKEN" , None )
82
- os .environ ["EB_AGENT_LOGGING_FILE" ] = args .log_path
81
+ # os.environ["EB_AGENT_LOGGING_FILE"] = args.log_path
83
82
# sh = logging.StreamHandler()
84
83
# logging.basicConfig(filename=args.log_path, level=logging.INFO)
85
- logger = setup_logging (use_fileformatter = False )
84
+ logger = setup_logging (args . log_path )
86
85
87
86
88
87
def get_logs (path = args .log_path ):
@@ -148,8 +147,8 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
148
147
system_message = SystemMessage ("你是一个报告生成助手。你可以根据用户的指定内容生成一份报告手稿" ),
149
148
dir_path = dir_path ,
150
149
report_type = args .report_type ,
151
- retriever_abstract_tool = retriever_sets ["abstract" ],
152
- retriever_tool = retriever_sets ["full_text" ],
150
+ retriever_abstract_db = retriever_sets ["abstract" ],
151
+ retriever_fulltext_db = retriever_sets ["full_text" ],
153
152
intent_detection_tool = tool_sets ["intent_detection" ],
154
153
task_planning_tool = tool_sets ["task_planning" ],
155
154
report_writing_tool = tool_sets ["report_writing" ],
@@ -176,22 +175,20 @@ def get_agents(retriever_sets, tool_sets, llm, llm_long, dir_path, target_path):
176
175
name = "polish" ,
177
176
llm = llm ,
178
177
llm_long = llm_long ,
179
- faiss_name_citation = args .index_name_citation ,
178
+ citation_index_name = args .index_name_citation ,
180
179
embeddings = retriever_sets ["embeddings" ],
181
180
dir_path = target_path ,
182
181
report_type = args .report_type ,
183
182
citation_tool = tool_sets ["semantic_citation" ],
184
183
callbacks = ReportCallbackHandler (logger = logger ),
185
184
)
186
- team_actor = ResearchTeam (
187
- ranker_actor = ranker_actor ,
188
- research_actor = research_actor ,
189
- editor_actor = editor_actor ,
190
- reviser_actor = reviser_actor ,
191
- polish_actor = polish_actor ,
192
- use_reflection = True ,
193
- )
194
- return team_actor
185
+ return {
186
+ "research_actor" : research_actor ,
187
+ "editor_actor" : editor_actor ,
188
+ "reviser_actor" : reviser_actor ,
189
+ "ranker_actor" : ranker_actor ,
190
+ "polish_actor" : polish_actor ,
191
+ }
195
192
196
193
197
194
def generate_report (query , history = []):
@@ -203,7 +200,8 @@ def generate_report(query, history=[]):
203
200
llm_long = ERNIEBot (model = "ernie-longtext" )
204
201
retriever_sets = get_retrievers ()
205
202
tool_sets = get_tools (llm , llm_long )
206
- team_actor = get_agents (retriever_sets , tool_sets , llm , llm_long , dir_path , target_path )
203
+ agent_sets = get_agents (retriever_sets , tool_sets , llm , llm_long , dir_path , target_path )
204
+ team_actor = ResearchTeam (** agent_sets , use_reflection = True )
207
205
report , path = asyncio .run (team_actor .run (query , args .iterations ))
208
206
return report , path
209
207
0 commit comments