plan_continue_executor.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import List
  2. from qwen_agent.planning.plan_executor import PlanExecutor
  3. from qwen_agent.messages.context_message import PlanResponseContextManager
  4. from qwen_agent.messages.context_message import ChatResponseChoice, ChatResponseStreamChoice, SystemSignal
  5. from qwen_agent.sub_agent.BaseSubAgent import BaseSubAgent
  6. from agent_config import ActionDict
  7. class PlanContinueExecutor(PlanExecutor):
  8. async def continue_run_agent(self, agent: BaseSubAgent, messages: List[str],
  9. plan_context: PlanResponseContextManager,
  10. clarification_data) -> List[str]:
  11. rsp = agent.continue_run(plan_context=plan_context,
  12. user_input=clarification_data, messages=messages)
  13. async for msg in rsp:
  14. if isinstance(msg, SystemSignal):
  15. yield msg
  16. return
  17. plan_context.add_message(agent.name, msg)
  18. # if msg.content:
  19. if (isinstance(msg, ChatResponseStreamChoice) and msg.delta) or \
  20. (isinstance(msg, ChatResponseChoice) and msg.content):
  21. yield plan_context.response_json()
  22. # set results
  23. exec_res = agent.exec_res
  24. plan_context.set_last_plan_execute(exec_res)
  25. if hasattr(agent, 'sql_code'):
  26. plan_context.executing_agent.sql_code = agent.sql_code
  27. yield plan_context.response_json()
  28. async def run(self, info, messages=[]):
  29. # recovery context
  30. plan_context = PlanResponseContextManager.model_validate_json(info)
  31. clarification_data = plan_context.clarification_data
  32. plan_context.system_signal = None
  33. plan_context.clarification_data = None
  34. # messages = plan_context.plan_response.history
  35. # continue run last agent
  36. for plan in plan_context.plans:
  37. if plan.executed: continue
  38. llm_name = self.llm_dict[plan.action_name]
  39. agent: BaseSubAgent = ActionDict[plan.action_name](
  40. llm=self.llm, llm_name=llm_name, stream=self.stream, name=plan.action_name
  41. )
  42. async for rsp in self.continue_run_agent(agent, messages, plan_context, clarification_data):
  43. yield rsp
  44. if isinstance(rsp, SystemSignal):
  45. plan_context.system_signal = rsp
  46. yield plan_context.model_dump_json(exclude_none=True)
  47. return
  48. plan.executed = True
  49. break
  50. # continue run left agents
  51. for idx, plan in enumerate(plan_context.plans):
  52. if plan.executed or plan.action_name not in ActionDict:
  53. continue
  54. llm_name = self.llm_dict[plan.action_name]
  55. agent: BaseSubAgent = ActionDict[plan.action_name](
  56. llm=self.llm, llm_name=llm_name, stream=self.stream, name=plan.action_name
  57. )
  58. if plan.action_name == 'generate_chart':
  59. plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction)
  60. elif plan.action_name == 'TenderResultSqlAgent' or plan.action_name == 'show_case':
  61. plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction, add_to_context=True, add_to_final=True)
  62. else:
  63. plan_context.add_executing_agent_info(agent.name, llm_name, plan.instruction, add_to_context=True)
  64. async for rsp in self.run_agent(agent, messages, plan_context=plan_context):
  65. if isinstance(rsp, SystemSignal):
  66. plan_context.system_signal = rsp
  67. yield plan_context.model_dump_json(exclude_none=True)
  68. return
  69. yield rsp
  70. plan.executed = True
  71. # if isinstance(agent, SqlAgent) and agent.is_success and hasattr(agent, 'empty_data') and agent.empty_data == False:
  72. # plan_context.plans.insert(idx+1, PlanInfo(action_name='TenderResultSqlAgent', instruction="已经获取到相关数据,我需要你再给我一些具体的招投标记录展示给用户,不要超过50条"))