planner.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import traceback
  2. import json
  3. from typing import List, Literal, Optional
  4. from pydantic import BaseModel
  5. from typing import List
  6. import asyncio
  7. from httpx import RemoteProtocolError
  8. from typing import Union
  9. from qwen_agent.messages.context_message import PlanResponseContextManager, ChatResponseStreamChoice, ChatResponseChoice
  10. from qwen_agent.messages.plan_message import PlanInfo
  11. from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient
  12. class Planner:
  13. def __init__(self, llm_name, system_prompt, default_plan, retriever=None,
  14. llm: Optional[Union[LLMClient, LLMAsyncClient]]=None,
  15. stream=False, name='planner', max_retry_times=3):
  16. self.llm = llm
  17. self.llm_name = llm_name
  18. self.stream = stream
  19. self.system_prompt = system_prompt
  20. self.default_plan = default_plan
  21. self.retriever = retriever
  22. self.name = name
  23. self.max_retry_times = max_retry_times
  24. def get_system_prompt(self, user_request):
  25. return self.system_prompt
  26. async def run(self, plan_context: PlanResponseContextManager = None, messages=None):
  27. user_request = plan_context.user_request
  28. _messages = [{
  29. "role": "system",
  30. "content": self.get_system_prompt(user_request)
  31. }]
  32. if messages:
  33. _messages.extend(messages)
  34. if user_request:
  35. _messages.append({
  36. "role": "user",
  37. "content": user_request + '\n 请为Question生成执行计划。\n'
  38. })
  39. # for msg in _messages:
  40. for i, msg in enumerate(_messages):
  41. if not isinstance(msg, dict):
  42. msg = dict(msg)
  43. if msg['type'].value == 1:
  44. msg['role'] = 'user'
  45. msg['content'] = msg['data']
  46. else:
  47. msg['role'] = 'assistant'
  48. msg['content'] = dict(msg['data'])['exec_res'][0]
  49. msg['history'] = True
  50. del msg['data']
  51. del msg['type']
  52. _messages[i] = msg
  53. if 'history' in msg and msg['history']:
  54. print('is history messsage')
  55. else:
  56. yield ChatResponseChoice(role=msg['role'], content=msg['content'])
  57. retry_times = 0
  58. while retry_times < self.max_retry_times:
  59. try:
  60. rsp = await self.llm.chat(model=self.llm_name, messages=_messages, stream=self.stream)
  61. if self.stream:
  62. plan_rsp = ''
  63. async for chunk in rsp:
  64. if chunk:
  65. yield ChatResponseStreamChoice(role='assistant', delta=chunk)
  66. plan_rsp += chunk
  67. yield ChatResponseStreamChoice(role='assistant', finish_reason='stop')
  68. print('plan_rsp:', plan_rsp)
  69. else:
  70. yield ChatResponseChoice(role='assistant', content=rsp)
  71. plan_rsp = rsp
  72. plans, msg = self.parse_plan(plan_rsp)
  73. break
  74. except Exception as e:
  75. traceback.print_exc()
  76. print(f'{type(e)}, {isinstance(e, RemoteProtocolError)}')
  77. if self.stream:
  78. yield ChatResponseStreamChoice(role='assistant', finish_reason='flush')
  79. retry_times += 1
  80. if isinstance(e, RemoteProtocolError):
  81. await asyncio.sleep(2**self.max_retry_times + 2**(self.max_retry_times - retry_times + 1))
  82. # plans, msg = self.parse_plan(self.default_plan)
  83. yield ChatResponseChoice(role='info', content=msg)
  84. self.plans = plans
  85. self.exec_res = msg
  86. def parse_plan(self, plan_rsp) -> List[PlanInfo]:
  87. # plan_msg = ''
  88. # for trunk in response:
  89. # plan_msg += trunk
  90. plan_list = []
  91. try:
  92. plans = plan_rsp.split('Plan:')[-1].strip()
  93. if plans.startswith('```json'):
  94. plans = plans.replace('```json', '')
  95. if plans.endswith('```'):
  96. plans = plans.replace('```', '')
  97. plans = json.loads(plans)
  98. plan_msg = f'对Question进行分析,我认为需要用以下执行计划完成用户的查询需求:{json.dumps(plans, ensure_ascii=False)}\n'
  99. except Exception as e:
  100. plans = json.loads(self.default_plan)
  101. plan_msg = f"\n使用默认执行计划, Plans:{self.default_plan}\n"
  102. for plan in plans:
  103. plan_list.append(PlanInfo(**plan))
  104. return plan_list, plan_msg