BaseSubAgent.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from qwen_agent.actions.base import Action
  2. from qwen_agent.tools.tools import tools_list
  3. from qwen_agent.tools.tools import call_plugin
  4. import copy
  5. import re
  6. from pydantic import BaseModel, Field
  7. from openai import OpenAI
  8. import traceback
  9. import asyncio
  10. from httpx import RemoteProtocolError
  11. from typing import Dict, List, Literal, Optional, Union
  12. import json
  13. from qwen_agent.llm.llm_client import LLMClient
  14. from qwen_agent.messages.context_message import ChatResponseChoice, ChatResponseStreamChoice
  15. class BaseSubAgent():
  16. def __init__(self, llm: Optional[LLMClient]=None, llm_name=None, stream=False, name='base_agent', max_retry_round=3):
  17. self.llm = llm
  18. if llm_name:
  19. self.llm_name = llm_name
  20. else:
  21. self.llm_name = llm.model
  22. # self.stream = False
  23. self.stream = stream
  24. self.is_success = False
  25. self.empty_data = False
  26. self.tool_list = []
  27. self.max_retry_round = max_retry_round
  28. self.REACT_INSTRUCTION = """
  29. #API描述
  30. {tools_text}
  31. 请依据API的描述,制定计划完成用户需求,按照如下格式返回:
  32. Thought: 生成计划的原因。
  33. Action: 当前需要使用的API,必须包含在[{tools_name_text}] 中。注意这里只需要放API的名字(name_for_model),不需要额外的信息
  34. Action Input: 当前API的输入参数。注意这里只需要放JSON格式的API的参数(parameters),不需要额外的信息
  35. Final: 以上是思考的结果。
  36. """
  37. self.TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
  38. self.SubAgent_PROMPT = ''
  39. self.name = name
  40. # self._client = OpenAI(
  41. # api_key="none", base_url=model_server
  42. # )
  43. # def chat(self, messages, functions):
  44. # response = self.llm.chat(model=self.llm_name, messages=messages, functions=functions, stream=False)
  45. # return response.choices[0].message
  46. # def chat_stream(self, messages):
  47. # response = self.llm.chat(model=self.llm_name, messages=messages, stream=True)
  48. # for chunk in response:
  49. # yield chunk
  50. # async def async_chat(self, messages, functions):
  51. # response = await self.llm.chat(model=self.llm_name, messages=messages, functions=functions, stream=False)
  52. # return response.choices[0].message
  53. # async def async_chat_stream(self, messages):
  54. # response = await self.llm.chat(model=self.llm_name, messages=messages, stream=True)
  55. # async for chunk in response:
  56. # yield chunk
  57. def gen_system_prompt(self,functions):
  58. if functions:
  59. tools_text = []
  60. tools_name_text = []
  61. for func_info in functions:
  62. name = func_info.get("name", "")
  63. name_m = func_info.get("name_for_model", name)
  64. name_h = func_info.get("name_for_human", name)
  65. desc = func_info.get("description", "")
  66. desc_m = func_info.get("description_for_model", desc)
  67. tool = self.TOOL_DESC.format(
  68. name_for_model=name_m,
  69. name_for_human=name_h,
  70. # Hint: You can add the following format requirements in description:
  71. # "Format the arguments as a JSON object."
  72. # "Enclose the code within triple backticks (`) at the beginning and end of the code."
  73. description_for_model=desc_m,
  74. parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
  75. )
  76. tools_text.append(tool)
  77. tools_name_text.append(name_m)
  78. tools_text = "\n\n".join(tools_text)
  79. tools_name_text = ", ".join(tools_name_text)
  80. system = self.SubAgent_PROMPT + "\n\n" + self.REACT_INSTRUCTION.format(
  81. tools_text=tools_text,
  82. tools_name_text=tools_name_text,
  83. )
  84. system = system.lstrip("\n").rstrip()
  85. return system
  86. def parse_response_func(self,response):
  87. func_name, func_args = "", ""
  88. i = response.find("Action:")
  89. j = response.find("\nAction Input:")
  90. k = response.find("\nFinal:")
  91. if 0 <= i < j: # If the text has `Action` and `Action input`,
  92. func_name = response[i + len("Action:") : j].strip()
  93. if k > 0:
  94. func_args = response[j + len("\nAction Input:") : k].strip()
  95. else:
  96. func_args = response[j + len("\nAction Input:") :].strip()
  97. if func_name:
  98. choice_data = {'role':"assistant","content":response[:i],
  99. "function_call":{"name": func_name, "arguments": func_args}
  100. }
  101. return choice_data
  102. return {'function_call':None,'content':response}
  103. def parse_response(self,rsp,prefix='python'):
  104. # print('parse_response_rsp:',rsp)
  105. if isinstance(rsp,str):
  106. rsp = self.parse_response_func(rsp)
  107. if rsp is not None and rsp['function_call'] is None:
  108. rsp['function_call'] = {}
  109. rsp['function_call']['name'] = None
  110. rsp['function_call']['arguments']=None
  111. triple_match = re.search(r'```[^\n]*\n(.+?)```',rsp['content'] , re.DOTALL)
  112. if triple_match:
  113. text = triple_match.group(1)
  114. if text.startswith(prefix):
  115. text=text.replace(prefix,'')
  116. rsp['function_call']['name'] = self.tool_list[0].get('name_for_model')
  117. rsp['function_call']['arguments'] = triple_match.group(1)
  118. return rsp
  119. return rsp
  120. async def _core(self, query,messages=[]):
  121. bot_msg,func_msg = None,None
  122. is_success = False
  123. local_message = copy.deepcopy(messages)
  124. local_message.insert(0,{'role': 'system','content':self.gen_system_prompt(self.tool_list)})
  125. local_message.append({'role': 'user','content':query})
  126. print(f"local_message:{local_message}")
  127. for msg in local_message:
  128. if 'history' in msg and msg['history']:
  129. print('is history messsage')
  130. else:
  131. yield ChatResponseChoice(role=msg['role'], content=msg['content'])
  132. observation = ''
  133. plugin_args = None
  134. self.retry_cnt = self.max_retry_round
  135. while not is_success and self.retry_cnt>0:
  136. # print('local_message:',local_message)
  137. try:
  138. if not self.stream:
  139. # rsp = self.chat(local_message,self.tool_list)
  140. rsp = await self.llm.chat(model=self.llm_name, messages=local_message, stream=self.stream)
  141. rsp = self.parse_response(rsp)
  142. else:
  143. self.stream_rsp = ''
  144. rsp = await self.llm.chat(model=self.llm_name, messages=local_message, stream=self.stream)
  145. async for r in rsp:
  146. if r:
  147. self.stream_rsp += r
  148. yield ChatResponseStreamChoice(role='assistant',delta=f"{r}")
  149. yield ChatResponseStreamChoice(role='assistant',finish_reason='stop')
  150. rsp = self.parse_response(self.stream_rsp)
  151. print('openai_rsp:',rsp)
  152. except Exception as e:
  153. traceback.print_exc()
  154. print(f'{type(e)}, {isinstance(e, RemoteProtocolError)}')
  155. if self.stream:
  156. yield ChatResponseStreamChoice(role='assistant', finish_reason='flush')
  157. if isinstance(e, RemoteProtocolError):
  158. await asyncio.sleep(2 ** self.max_retry_round + 2 ** (self.max_retry_round - self.retry_cnt + 1))
  159. self.retry_cnt -= 1
  160. if self.retry_cnt >= 0:
  161. print('retry')
  162. continue
  163. yield ChatResponseChoice(
  164. role='function',
  165. content=f"""```{rsp['function_call']['arguments']}```"""
  166. )
  167. # if rsp['function_call'] and rsp['function_call']['name'] in [l['name_for_model'] for l in self.tool_list]:
  168. if rsp['function_call']:
  169. bot_msg = {
  170. 'role': 'assistant',
  171. 'content': rsp['content'],
  172. 'function_call': {
  173. 'name': rsp['function_call']['name'],
  174. 'arguments': rsp['function_call']['arguments'],
  175. }
  176. }
  177. # yield ChatResponseChoice(**bot_msg)
  178. # res_params = await call_plugin(rsp['function_call']['name'], rsp['function_call']['arguments'])
  179. res_params = await self.run_function(rsp['function_call']['arguments'])
  180. observation, plugin_args, is_success = res_params
  181. func_msg = {
  182. 'role': 'function',
  183. 'name': rsp['function_call']['name'],
  184. 'content': observation,
  185. }
  186. yield ChatResponseChoice(**func_msg)
  187. if not is_success:
  188. self.retry_cnt -= 1
  189. user_msg = {
  190. 'role': 'user',
  191. 'content': f"""CODE:\n```\n{rsp['function_call']['arguments']}\n```\n
  192. TrackBack:\n{observation}\n
  193. 请根据以上报错信息,修改[Action Input],按照如下的格式返回:
  194. Thought: 修改的原因
  195. Action: {rsp['function_call']['name']}
  196. Action Input: 当前Action的输入参数。""",
  197. }
  198. local_message.append(bot_msg)
  199. local_message.append(user_msg)
  200. yield ChatResponseChoice(**user_msg)
  201. else:
  202. local_message.append(bot_msg)
  203. local_message.append(func_msg)
  204. # yield ChatResponseChoice(**func_msg)
  205. else:
  206. bot_msg = {
  207. 'role': 'assistant',
  208. 'content': f"{rsp}",
  209. }
  210. user_msg = {
  211. 'role': 'user',
  212. 'content': f"""
  213. [Action]中是未包含需要的API,请重新调整[Action]和[Action Input],按照如下的格式返回:
  214. Thought: 修改的原因
  215. Action: 当前需要使用的API
  216. Action Input: 当前API的输入参数。注意这里只需要放JSON格式的API的参数(parameters),不需要额外的信息。""",
  217. }
  218. local_message.append(bot_msg)
  219. local_message.append(user_msg)
  220. yield ChatResponseChoice(**user_msg)
  221. self.retry_cnt -= 1
  222. self.is_success = is_success
  223. self.plugin_args = plugin_args
  224. # return local_message, is_success
  225. async def continue_run(self,plan_context,user_input):
  226. pass
  227. async def run(self, plan_context,messages=[]):
  228. query = plan_context.get_context()
  229. # local_message, is_success = self._core(query,messages)
  230. async for msg in self._core(query,messages,self.stream):
  231. yield msg
  232. self.exec_res = msg.content
  233. # return local_message[-1]['content'], local_message,
  234. def parse_parameter(self,plugin_args):
  235. print("plugin_args", plugin_args)
  236. if plugin_args and plugin_args.startswith('```'):
  237. triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
  238. if triple_match:
  239. plugin_args = triple_match.group(1)
  240. else:
  241. triple_match = re.search(r'```[^\n]*\n(.+?)', plugin_args, re.DOTALL)
  242. plugin_args = triple_match.group(1)
  243. print("plugin_args_clean:", plugin_args)
  244. return plugin_args
  245. async def run_function(self,plugin_args):
  246. raise NotImplementedError
  247. def continue_run(self, plan_context, user_input, messages):
  248. raise NotImplementedError