12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- from typing import List
- from qwen_agent.planning.plan_executor import PlanExecutor
- from qwen_agent.messages.context_message import PlanResponseContextManager
- from qwen_agent.messages.context_message import ChatResponseChoice, ChatResponseStreamChoice, SystemSignal
- from qwen_agent.sub_agent.BaseSubAgent import BaseSubAgent
- from agent_config import ActionDict
- class PlanContinueExecutor(PlanExecutor):
- async def continue_run_agent(self, agent: BaseSubAgent, messages: List[str],
- plan_context: PlanResponseContextManager,
- clarification_data) -> List[str]:
- rsp = agent.continue_run(plan_context=plan_context,
- user_input=clarification_data, messages=messages)
- async for msg in rsp:
- if isinstance(msg, SystemSignal):
- yield msg
- return
- plan_context.add_message(agent.name, msg)
- # if msg.content:
- if (isinstance(msg, ChatResponseStreamChoice) and msg.delta) or \
- (isinstance(msg, ChatResponseChoice) and msg.content):
- yield plan_context.response_json()
- # set results
- exec_res = agent.exec_res
- plan_context.set_last_plan_execute(exec_res)
- if hasattr(agent, 'sql_code'):
- plan_context.executing_agent.sql_code = agent.sql_code
- yield plan_context.response_json()
- async def run(self, info, messages=[]):
- # recovery context
- plan_context = PlanResponseContextManager.model_validate_json(info)
- clarification_data = plan_context.clarification_data
- plan_context.system_signal = None
- plan_context.clarification_data = None
- # messages = plan_context.plan_response.history
- # continue run last agent
- for plan in plan_context.plans:
- if plan.executed: continue
- llm_name = self.llm_dict[plan.action_name]
- agent: BaseSubAgent = ActionDict[plan.action_name](
- llm=self.llm, llm_name=llm_name, stream=self.stream, name=plan.action_name
- )
- async for rsp in self.continue_run_agent(agent, messages, plan_context, clarification_data):
- yield rsp
- if isinstance(rsp, SystemSignal):
- plan_context.system_signal = rsp
- yield plan_context.model_dump_json(exclude_none=True)
- return
- plan.executed = True
- break
- # continue run left agents
- for idx, plan in enumerate(plan_context.plans):
- if plan.executed or plan.action_name not in ActionDict:
- continue
- llm_name = self.llm_dict[plan.action_name]
- agent: BaseSubAgent = ActionDict[plan.action_name](
- llm=self.llm, llm_name=llm_name, stream=self.stream, name=plan.action_name
- )
- if plan.action_name == 'generate_chart':
- plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction)
- elif plan.action_name == 'TenderResultSqlAgent' or plan.action_name == 'show_case':
- plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction, add_to_context=True, add_to_final=True)
- else:
- plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction, add_to_context=True)
- async for rsp in self.run_agent(agent, messages, plan_context=plan_context):
- if isinstance(rsp, SystemSignal):
- plan_context.system_signal = rsp
- yield plan_context.model_dump_json(exclude_none=True)
- return
- yield rsp
- plan.executed = True
- # if isinstance(agent, SqlAgent) and agent.is_success and hasattr(agent, 'empty_data') and agent.empty_data == False:
- # plan_context.plans.insert(idx+1, PlanInfo(action_name='TenderResultSqlAgent', instruction="已经获取到相关数据,我需要你再给我一些具体的招投标记录展示给用户,不要超过50条"))
|