agent.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import copy
  2. import json
  3. import traceback
  4. from abc import ABC, abstractmethod
  5. from typing import Dict, Iterator, List, Optional, Tuple, Union
  6. from qwen_agent.llm import get_chat_model
  7. from qwen_agent.llm.base import BaseChatModel
  8. from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM, ContentItem, Message
  9. from qwen_agent.log import logger
  10. from qwen_agent.tools import TOOL_REGISTRY, BaseTool
  11. from qwen_agent.utils.utils import has_chinese_messages, merge_generate_cfgs
  12. class Agent(ABC):
  13. """A base class for Agent.
  14. An agent can receive messages and provide response by LLM or Tools.
  15. Different agents have distinct workflows for processing messages and generating responses in the `_run` method.
  16. """
  17. def __init__(self,
  18. function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
  19. llm: Optional[Union[dict, BaseChatModel]] = None,
  20. system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
  21. name: Optional[str] = None,
  22. description: Optional[str] = None,
  23. **kwargs):
  24. """Initialization the agent.
  25. Args:
  26. function_list: One list of tool name, tool configuration or Tool object,
  27. such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter().
  28. llm: The LLM model configuration or LLM model object.
  29. Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}.
  30. system_message: The specified system message for LLM chat.
  31. name: The name of this agent.
  32. description: The description of this agent, which will be used for multi_agent.
  33. """
  34. if isinstance(llm, dict):
  35. self.llm = get_chat_model(llm)
  36. else:
  37. self.llm = llm
  38. self.extra_generate_cfg: dict = {}
  39. self.function_map = {}
  40. if function_list:
  41. for tool in function_list:
  42. self._init_tool(tool)
  43. self.system_message = system_message or self.SYSTEM_MESSAGE
  44. self.name = name
  45. self.description = description
  46. def run_nonstream(self, messages: List[Union[Dict, Message]], **kwargs) -> Union[List[Message], List[Dict]]:
  47. """Same as self.run, but with stream=False,
  48. meaning it returns the complete response directly
  49. instead of streaming the response incrementally."""
  50. *_, last_responses = self.run(messages, **kwargs)
  51. return last_responses
  52. def run(self, messages: List[Union[Dict, Message]],
  53. **kwargs) -> Union[Iterator[List[Message]], Iterator[List[Dict]]]:
  54. """Return one response generator based on the received messages.
  55. This method performs a uniform type conversion for the inputted messages,
  56. and calls the _run method to generate a reply.
  57. Args:
  58. messages: A list of messages.
  59. Yields:
  60. The response generator.
  61. """
  62. messages = copy.deepcopy(messages)
  63. _return_message_type = 'dict'
  64. new_messages = []
  65. # Only return dict when all input messages are dict
  66. if not messages:
  67. _return_message_type = 'message'
  68. for msg in messages:
  69. if isinstance(msg, dict):
  70. new_messages.append(Message(**msg))
  71. else:
  72. new_messages.append(msg)
  73. _return_message_type = 'message'
  74. if 'lang' not in kwargs:
  75. if has_chinese_messages(new_messages):
  76. kwargs['lang'] = 'zh'
  77. else:
  78. kwargs['lang'] = 'en'
  79. for rsp in self._run(messages=new_messages, **kwargs):
  80. for i in range(len(rsp)):
  81. if not rsp[i].name and self.name:
  82. rsp[i].name = self.name
  83. if _return_message_type == 'message':
  84. yield [Message(**x) if isinstance(x, dict) else x for x in rsp]
  85. else:
  86. yield [x.model_dump() if not isinstance(x, dict) else x for x in rsp]
  87. @abstractmethod
  88. def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]:
  89. """Return one response generator based on the received messages.
  90. The workflow for an agent to generate a reply.
  91. Each agent subclass needs to implement this method.
  92. Args:
  93. messages: A list of messages.
  94. lang: Language, which will be used to select the language of the prompt
  95. during the agent's execution process.
  96. Yields:
  97. The response generator.
  98. """
  99. raise NotImplementedError
  100. def _call_llm(
  101. self,
  102. messages: List[Message],
  103. functions: Optional[List[Dict]] = None,
  104. stream: bool = True,
  105. extra_generate_cfg: Optional[dict] = None,
  106. ) -> Iterator[List[Message]]:
  107. """The interface of calling LLM for the agent.
  108. We prepend the system_message of this agent to the messages, and call LLM.
  109. Args:
  110. messages: A list of messages.
  111. functions: The list of functions provided to LLM.
  112. stream: LLM streaming output or non-streaming output.
  113. For consistency, we default to using streaming output across all agents.
  114. Yields:
  115. The response generator of LLM.
  116. """
  117. messages = copy.deepcopy(messages)
  118. if messages[0][ROLE] != SYSTEM:
  119. messages.insert(0, Message(role=SYSTEM, content=self.system_message))
  120. elif isinstance(messages[0][CONTENT], str):
  121. messages[0][CONTENT] = self.system_message + '\n\n' + messages[0][CONTENT]
  122. else:
  123. assert isinstance(messages[0][CONTENT], list)
  124. messages[0][CONTENT] = [ContentItem(text=self.system_message + '\n\n')] + messages[0][CONTENT]
  125. return self.llm.chat(messages=messages,
  126. functions=functions,
  127. stream=stream,
  128. extra_generate_cfg=merge_generate_cfgs(
  129. base_generate_cfg=self.extra_generate_cfg,
  130. new_generate_cfg=extra_generate_cfg,
  131. ))
  132. def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> Union[str, List[ContentItem]]:
  133. """The interface of calling tools for the agent.
  134. Args:
  135. tool_name: The name of one tool.
  136. tool_args: Model generated or user given tool parameters.
  137. Returns:
  138. The output of tools.
  139. """
  140. if tool_name not in self.function_map:
  141. return f'Tool {tool_name} does not exists.'
  142. tool = self.function_map[tool_name]
  143. try:
  144. tool_result = tool.call(tool_args, **kwargs)
  145. except Exception as ex:
  146. exception_type = type(ex).__name__
  147. exception_message = str(ex)
  148. traceback_info = ''.join(traceback.format_tb(ex.__traceback__))
  149. error_message = f'An error occurred when calling tool `{tool_name}`:\n' \
  150. f'{exception_type}: {exception_message}\n' \
  151. f'Traceback:\n{traceback_info}'
  152. logger.warning(error_message)
  153. return error_message
  154. if isinstance(tool_result, str):
  155. return tool_result
  156. elif isinstance(tool_result, list) and all(isinstance(item, ContentItem) for item in tool_result):
  157. return tool_result # multimodal tool results
  158. else:
  159. return json.dumps(tool_result, ensure_ascii=False, indent=4)
  160. def _init_tool(self, tool: Union[str, Dict, BaseTool]):
  161. if isinstance(tool, BaseTool):
  162. tool_name = tool.name
  163. if tool_name in self.function_map:
  164. logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
  165. self.function_map[tool_name] = tool
  166. else:
  167. if isinstance(tool, dict):
  168. tool_name = tool['name']
  169. tool_cfg = tool
  170. else:
  171. tool_name = tool
  172. tool_cfg = None
  173. if tool_name not in TOOL_REGISTRY:
  174. raise ValueError(f'Tool {tool_name} is not registered.')
  175. if tool_name in self.function_map:
  176. logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
  177. self.function_map[tool_name] = TOOL_REGISTRY[tool_name](tool_cfg)
  178. def _detect_tool(self, message: Message) -> Tuple[bool, str, str, str]:
  179. """A built-in tool call detection for func_call format message.
  180. Args:
  181. message: one message generated by LLM.
  182. Returns:
  183. Need to call tool or not, tool name, tool args, text replies.
  184. """
  185. func_name = None
  186. func_args = None
  187. if message.function_call:
  188. func_call = message.function_call
  189. func_name = func_call.name
  190. func_args = func_call.arguments
  191. text = message.content
  192. if not text:
  193. text = ''
  194. return (func_name is not None), func_name, func_args, text
  195. # The most basic form of an agent is just a LLM, not augmented with any tool or workflow.
  196. class BasicAgent(Agent):
  197. def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]:
  198. return self._call_llm(messages)