planner.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import traceback
  2. import json
  3. from typing import List, Literal, Optional
  4. from pydantic import BaseModel
  5. from typing import List
  6. import asyncio
  7. from httpx import RemoteProtocolError
  8. from typing import Union
  9. import re
  10. from qwen_agent.messages.context_message import PlanResponseContextManager, ChatResponseStreamChoice, ChatResponseChoice
  11. from qwen_agent.messages.plan_message import PlanInfo
  12. from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient
  13. class Planner:
  14. def __init__(self, llm_name, system_prompt, default_plan, retriever=None,
  15. llm: Optional[Union[LLMClient, LLMAsyncClient]]=None,
  16. stream=False, name='planner', max_retry_times=3):
  17. self.llm = llm
  18. self.llm_name = llm_name
  19. self.stream = stream
  20. self.system_prompt = system_prompt
  21. self.default_plan = default_plan
  22. self.retriever = retriever
  23. self.name = name
  24. self.max_retry_times = max_retry_times
  25. def get_system_prompt(self, user_request):
  26. return self.system_prompt
  27. async def run(self, plan_context: PlanResponseContextManager = None, messages=None):
  28. user_request = plan_context.user_request
  29. _messages = [{
  30. "role": "system",
  31. "content": self.get_system_prompt(user_request)
  32. }]
  33. if messages:
  34. _messages.extend(messages)
  35. if user_request:
  36. _messages.append({
  37. "role": "user",
  38. "content": user_request + '\n 请为Question生成执行计划。\n'
  39. })
  40. # for msg in _messages:
  41. for i, msg in enumerate(_messages):
  42. if not isinstance(msg, dict):
  43. msg = dict(msg)
  44. if msg['type'].value == 1:
  45. msg['role'] = 'user'
  46. msg['content'] = msg['data']
  47. else:
  48. msg['role'] = 'assistant'
  49. msg['content'] = dict(msg['data'])['exec_res'][0]
  50. msg['history'] = True
  51. del msg['data']
  52. del msg['type']
  53. _messages[i] = msg
  54. if 'history' in msg and msg['history']:
  55. print('is history messsage')
  56. else:
  57. yield ChatResponseChoice(role=msg['role'], content=msg['content'])
  58. retry_times = 0
  59. while retry_times < self.max_retry_times:
  60. try:
  61. rsp = await self.llm.chat(model=self.llm_name, messages=_messages, stream=self.stream)
  62. if self.stream:
  63. plan_rsp = ''
  64. async for chunk in rsp:
  65. if chunk:
  66. yield ChatResponseStreamChoice(role='assistant', delta=chunk)
  67. plan_rsp += chunk
  68. yield ChatResponseStreamChoice(role='assistant', finish_reason='stop')
  69. pattern = r'<think>.*?</think>\n*'
  70. plan_rsp = re.sub(pattern, '', plan_rsp, flags=re.DOTALL)
  71. print('plan_rsp:', plan_rsp)
  72. else:
  73. yield ChatResponseChoice(role='assistant', content=rsp)
  74. plan_rsp = rsp
  75. plans, msg = self.parse_plan(plan_rsp)
  76. break
  77. except Exception as e:
  78. traceback.print_exc()
  79. print(f'{type(e)}, {isinstance(e, RemoteProtocolError)}')
  80. if self.stream:
  81. yield ChatResponseStreamChoice(role='assistant', finish_reason='flush')
  82. retry_times += 1
  83. if isinstance(e, RemoteProtocolError):
  84. await asyncio.sleep(2**self.max_retry_times + 2**(self.max_retry_times - retry_times + 1))
  85. # plans, msg = self.parse_plan(self.default_plan)
  86. yield ChatResponseChoice(role='info', content=msg)
  87. self.plans = plans
  88. self.exec_res = msg
  89. def parse_plan(self, plan_rsp) -> List[PlanInfo]:
  90. # plan_msg = ''
  91. # for trunk in response:
  92. # plan_msg += trunk
  93. plan_list = []
  94. try:
  95. plans = plan_rsp.split('Plan:')[-1].strip()
  96. if plans.startswith('```json'):
  97. plans = plans.replace('```json', '')
  98. if plans.endswith('```'):
  99. plans = plans.replace('```', '')
  100. plans = json.loads(plans)
  101. plan_msg = f'对Question进行分析,我认为需要用以下执行计划完成用户的查询需求:{json.dumps(plans, ensure_ascii=False)}\n'
  102. except Exception as e:
  103. plans = json.loads(self.default_plan)
  104. plan_msg = f"\n使用默认执行计划, Plans:{self.default_plan}\n"
  105. for plan in plans:
  106. plan_list.append(PlanInfo(**plan))
  107. return plan_list, plan_msg