1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- from re import S
- from typing import List
- from qwen_agent.planning.plan_dispatcher import PlanDispatcher
- from qwen_agent.planning.planner import PlanInfo, Planner
- from qwen_agent.messages.context_message import PlanResponseContextManager
- from qwen_agent.messages.context_message import ChatResponseChoice, ChatResponseStreamChoice, SystemSignal
- from qwen_agent.sub_agent.BaseSubAgent import BaseSubAgent
- from agent_config import ActionDict
- class PlanExecutor:
- def __init__(self, enable_critic=False, llm=None, stream=False, llm_dict=None, max_plan_critic_num=2):
- self.llm = llm
- self.stream = stream
- self.enable_critic = enable_critic
- self.max_plan_critic_num = max_plan_critic_num
- self.llm_dict = llm_dict
- async def run_agent(self, agent, messages: List[str],
- plan_context: PlanResponseContextManager) -> List[str]:
- rsp = agent.run(plan_context=plan_context, messages=messages)
- async for msg in rsp:
- if isinstance(msg, SystemSignal):
- yield msg
- return
- plan_context.add_message(agent.name, msg)
- # if msg.content:
- if (isinstance(msg, ChatResponseStreamChoice) and msg.delta) or \
- (isinstance(msg, ChatResponseChoice) and msg.content):
- yield plan_context.response_json()
- # set results
- exec_res = agent.exec_res
- if exec_res:
- if hasattr(agent, 'sql_code'):
- plan_context.executing_agent.sql_code = agent.sql_code
- plan_context.set_last_plan_execute(exec_res)
- yield plan_context.response_json()
- async def run(self, user_request, messages=None):
- if not messages:
- messages = []
- plan_context = PlanResponseContextManager()
- plan_context.init_chat(user_request, messages)
- # execute plan dispatcher
- dispatcher_llm_name = self.llm_dict.get('plan_dispatcher') or self.llm_dict['planner']
- plan_dispatcher = PlanDispatcher(
- llm_dict=self.llm_dict, llm_name=dispatcher_llm_name,
- llm=self.llm, stream=self.stream
- )
- plan_context.add_executing_agent_info(plan_dispatcher.name, plan_dispatcher.llm_name)
- async for rsp in self.run_agent(plan_dispatcher, messages, plan_context):
- yield rsp
- # execute planner
- planner: Planner = plan_dispatcher.planner
- plan_context.add_executing_agent_info(planner.name, planner.llm_name)
- async for rsp in self.run_agent(planner, messages, plan_context):
- yield rsp
- if not hasattr(planner, 'plans'):
- return
- # return
- plan_context.plan_msg = planner.exec_res
- plan_context.plans: List[PlanInfo] = planner.plans
- print("self.llmdict")
- print(self.llm_dict)
- for idx, plan in enumerate(plan_context.plans):
- # print(plan.model_dump_json())
- if plan.action_name not in ActionDict:
- continue
- llm_name = self.llm_dict[plan.action_name]
- agent: BaseSubAgent = ActionDict[plan.action_name](llm=self.llm, llm_name=llm_name, stream=self.stream,
- name=plan.action_name)
- if plan.action_name == 'generate_chart':
- plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction)
- elif plan.action_name in ['LandSupplySqlAgent', 'LandUseSqlAgent', 'LandApprovalSqlAgent', 'SpatialAnalysisAgent','LandFindSqlAgent', 'LandSiteSelectionSqlAgent', 'GisSurroundingFacilitiesQueryAgent', 'KfqEvalSqlAgent']:
- plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction, add_to_context=True,
- add_to_final=True)
- else:
- plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction, add_to_context=True)
- if agent.name == 'summary_agent':
- print("summary")
- async for rsp in self.run_agent(agent, messages, plan_context=plan_context):
- if isinstance(rsp, SystemSignal):
- plan_context.system_signal = rsp
- yield plan_context.model_dump_json(exclude_none=True)
- return
- yield rsp
- plan.executed = True
|