123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- import traceback
- import json
- from typing import List, Literal, Optional
- from pydantic import BaseModel
- from typing import List
- import asyncio
- from httpx import RemoteProtocolError
- from typing import Union
- import re
- from qwen_agent.messages.context_message import PlanResponseContextManager, ChatResponseStreamChoice, ChatResponseChoice
- from qwen_agent.messages.plan_message import PlanInfo
- from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient
- class Planner:
- def __init__(self, llm_name, system_prompt, default_plan, retriever=None,
- llm: Optional[Union[LLMClient, LLMAsyncClient]]=None,
- stream=False, name='planner', max_retry_times=3):
- self.llm = llm
- self.llm_name = llm_name
- self.stream = stream
- self.system_prompt = system_prompt
- self.default_plan = default_plan
- self.retriever = retriever
- self.name = name
- self.max_retry_times = max_retry_times
- def get_system_prompt(self, user_request):
- return self.system_prompt
- async def run(self, plan_context: PlanResponseContextManager = None, messages=None):
- user_request = plan_context.user_request
- _messages = [{
- "role": "system",
- "content": self.get_system_prompt(user_request)
- }]
- if messages:
- _messages.extend(messages)
- if user_request:
- _messages.append({
- "role": "user",
- "content": user_request + '\n 请为Question生成执行计划。\n'
- })
- # for msg in _messages:
- for i, msg in enumerate(_messages):
- if not isinstance(msg, dict):
- msg = dict(msg)
- if msg['type'].value == 1:
- msg['role'] = 'user'
- msg['content'] = msg['data']
- else:
- msg['role'] = 'assistant'
- msg['content'] = dict(msg['data'])['exec_res'][0]
- msg['history'] = True
- del msg['data']
- del msg['type']
- _messages[i] = msg
- if 'history' in msg and msg['history']:
- print('is history messsage')
- else:
- yield ChatResponseChoice(role=msg['role'], content=msg['content'])
- retry_times = 0
- while retry_times < self.max_retry_times:
- try:
- rsp = await self.llm.chat(model=self.llm_name, messages=_messages, stream=self.stream)
- if self.stream:
- plan_rsp = ''
- async for chunk in rsp:
- if chunk:
- yield ChatResponseStreamChoice(role='assistant', delta=chunk)
- plan_rsp += chunk
- yield ChatResponseStreamChoice(role='assistant', finish_reason='stop')
- pattern = r'<think>.*?</think>\n*'
- plan_rsp = re.sub(pattern, '', plan_rsp, flags=re.DOTALL)
- print('plan_rsp:', plan_rsp)
- else:
- yield ChatResponseChoice(role='assistant', content=rsp)
- plan_rsp = rsp
- plans, msg = self.parse_plan(plan_rsp)
- break
- except Exception as e:
- traceback.print_exc()
- print(f'{type(e)}, {isinstance(e, RemoteProtocolError)}')
- if self.stream:
- yield ChatResponseStreamChoice(role='assistant', finish_reason='flush')
- retry_times += 1
- if isinstance(e, RemoteProtocolError):
- await asyncio.sleep(2**self.max_retry_times + 2**(self.max_retry_times - retry_times + 1))
-
- # plans, msg = self.parse_plan(self.default_plan)
- yield ChatResponseChoice(role='info', content=msg)
- self.plans = plans
- self.exec_res = msg
- def parse_plan(self, plan_rsp) -> List[PlanInfo]:
- # plan_msg = ''
- # for trunk in response:
- # plan_msg += trunk
- plan_list = []
- try:
- plans = plan_rsp.split('Plan:')[-1].strip()
- if plans.startswith('```json'):
- plans = plans.replace('```json', '')
- if plans.endswith('```'):
- plans = plans.replace('```', '')
- plans = json.loads(plans)
- plan_msg = f'对Question进行分析,我认为需要用以下执行计划完成用户的查询需求:{json.dumps(plans, ensure_ascii=False)}\n'
- except Exception as e:
- plans = json.loads(self.default_plan)
- plan_msg = f"\n使用默认执行计划, Plans:{self.default_plan}\n"
- for plan in plans:
- plan_list.append(PlanInfo(**plan))
- return plan_list, plan_msg
|