import traceback import json from typing import List, Literal, Optional from pydantic import BaseModel from typing import List import asyncio from httpx import RemoteProtocolError from typing import Union import re from qwen_agent.messages.context_message import PlanResponseContextManager, ChatResponseStreamChoice, ChatResponseChoice from qwen_agent.messages.plan_message import PlanInfo from qwen_agent.llm.llm_client import LLMClient, LLMAsyncClient class Planner: def __init__(self, llm_name, system_prompt, default_plan, retriever=None, llm: Optional[Union[LLMClient, LLMAsyncClient]]=None, stream=False, name='planner', max_retry_times=3): self.llm = llm self.llm_name = llm_name self.stream = stream self.system_prompt = system_prompt self.default_plan = default_plan self.retriever = retriever self.name = name self.max_retry_times = max_retry_times def get_system_prompt(self, user_request): return self.system_prompt async def run(self, plan_context: PlanResponseContextManager = None, messages=None): user_request = plan_context.user_request _messages = [{ "role": "system", "content": self.get_system_prompt(user_request) }] if messages: _messages.extend(messages) if user_request: _messages.append({ "role": "user", "content": user_request + '\n 请为Question生成执行计划。\n' }) # for msg in _messages: for i, msg in enumerate(_messages): if not isinstance(msg, dict): msg = dict(msg) if msg['type'].value == 1: msg['role'] = 'user' msg['content'] = msg['data'] else: msg['role'] = 'assistant' msg['content'] = dict(msg['data'])['exec_res'][0] msg['history'] = True del msg['data'] del msg['type'] _messages[i] = msg if 'history' in msg and msg['history']: print('is history messsage') else: yield ChatResponseChoice(role=msg['role'], content=msg['content']) retry_times = 0 while retry_times < self.max_retry_times: try: rsp = await self.llm.chat(model=self.llm_name, messages=_messages, stream=self.stream) if self.stream: plan_rsp = '' async for chunk in rsp: if chunk: yield ChatResponseStreamChoice(role='assistant', delta=chunk) plan_rsp += chunk yield ChatResponseStreamChoice(role='assistant', finish_reason='stop') pattern = r'.*?\n*' plan_rsp = re.sub(pattern, '', plan_rsp, flags=re.DOTALL) print('plan_rsp:', plan_rsp) else: yield ChatResponseChoice(role='assistant', content=rsp) plan_rsp = rsp plans, msg = self.parse_plan(plan_rsp) break except Exception as e: traceback.print_exc() print(f'{type(e)}, {isinstance(e, RemoteProtocolError)}') if self.stream: yield ChatResponseStreamChoice(role='assistant', finish_reason='flush') retry_times += 1 if isinstance(e, RemoteProtocolError): await asyncio.sleep(2**self.max_retry_times + 2**(self.max_retry_times - retry_times + 1)) # plans, msg = self.parse_plan(self.default_plan) yield ChatResponseChoice(role='info', content=msg) self.plans = plans self.exec_res = msg def parse_plan(self, plan_rsp) -> List[PlanInfo]: # plan_msg = '' # for trunk in response: # plan_msg += trunk plan_list = [] try: plans = plan_rsp.split('Plan:')[-1].strip() if plans.startswith('```json'): plans = plans.replace('```json', '') if plans.endswith('```'): plans = plans.replace('```', '') plans = json.loads(plans) plan_msg = f'对Question进行分析,我认为需要用以下执行计划完成用户的查询需求:{json.dumps(plans, ensure_ascii=False)}\n' except Exception as e: plans = json.loads(self.default_plan) plan_msg = f"\n使用默认执行计划, Plans:{self.default_plan}\n" for plan in plans: plan_list.append(PlanInfo(**plan)) return plan_list, plan_msg