fncall_agent.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import copy
  2. from typing import Dict, Iterator, List, Literal, Optional, Union
  3. from qwen_agent import Agent
  4. from qwen_agent.llm import BaseChatModel
  5. from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, FUNCTION, Message
  6. from qwen_agent.memory import Memory
  7. from qwen_agent.settings import MAX_LLM_CALL_PER_RUN
  8. from qwen_agent.tools import BaseTool
  9. from qwen_agent.utils.utils import extract_files_from_messages
  10. class FnCallAgent(Agent):
  11. """This is a widely applicable function call agent integrated with llm and tool use ability."""
  12. def __init__(self,
  13. function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
  14. llm: Optional[Union[Dict, BaseChatModel]] = None,
  15. system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
  16. name: Optional[str] = None,
  17. description: Optional[str] = None,
  18. files: Optional[List[str]] = None,
  19. **kwargs):
  20. """Initialization the agent.
  21. Args:
  22. function_list: One list of tool name, tool configuration or Tool object,
  23. such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter().
  24. llm: The LLM model configuration or LLM model object.
  25. Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}.
  26. system_message: The specified system message for LLM chat.
  27. name: The name of this agent.
  28. description: The description of this agent, which will be used for multi_agent.
  29. files: A file url list. The initialized files for the agent.
  30. """
  31. super().__init__(function_list=function_list,
  32. llm=llm,
  33. system_message=system_message,
  34. name=name,
  35. description=description)
  36. if not hasattr(self, 'mem'):
  37. # Default to use Memory to manage files
  38. self.mem = Memory(llm=self.llm, files=files, **kwargs)
  39. def _run(self, messages: List[Message], lang: Literal['en', 'zh'] = 'en', **kwargs) -> Iterator[List[Message]]:
  40. messages = copy.deepcopy(messages)
  41. num_llm_calls_available = MAX_LLM_CALL_PER_RUN
  42. response = []
  43. while True and num_llm_calls_available > 0:
  44. num_llm_calls_available -= 1
  45. output_stream = self._call_llm(messages=messages,
  46. functions=[func.function for func in self.function_map.values()],
  47. extra_generate_cfg={'lang': lang})
  48. output: List[Message] = []
  49. for output in output_stream:
  50. if output:
  51. yield response + output
  52. if output:
  53. response.extend(output)
  54. messages.extend(output)
  55. used_any_tool = False
  56. for out in output:
  57. use_tool, tool_name, tool_args, _ = self._detect_tool(out)
  58. if use_tool:
  59. tool_result = self._call_tool(tool_name, tool_args, messages=messages, **kwargs)
  60. fn_msg = Message(
  61. role=FUNCTION,
  62. name=tool_name,
  63. content=tool_result,
  64. )
  65. messages.append(fn_msg)
  66. response.append(fn_msg)
  67. yield response
  68. used_any_tool = True
  69. if not used_any_tool:
  70. break
  71. yield response
  72. def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str:
  73. if tool_name not in self.function_map:
  74. return f'Tool {tool_name} does not exists.'
  75. # Temporary plan: Check if it is necessary to transfer files to the tool
  76. # Todo: This should be changed to parameter passing, and the file URL should be determined by the model
  77. if self.function_map[tool_name].file_access:
  78. assert 'messages' in kwargs
  79. files = extract_files_from_messages(kwargs['messages'], include_images=True) + self.mem.system_files
  80. return super()._call_tool(tool_name, tool_args, files=files, **kwargs)
  81. else:
  82. return super()._call_tool(tool_name, tool_args, **kwargs)