base.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import json
  2. import os
  3. from abc import ABC, abstractmethod
  4. from typing import Dict, List, Optional, Union
  5. from qwen_agent.llm.schema import ContentItem
  6. from qwen_agent.settings import DEFAULT_WORKSPACE
  7. from qwen_agent.utils.utils import has_chinese_chars, json_loads, logger, print_traceback, save_url_to_local_work_dir
  8. TOOL_REGISTRY = {}
  9. def register_tool(name, allow_overwrite=False):
  10. def decorator(cls):
  11. if name in TOOL_REGISTRY:
  12. if allow_overwrite:
  13. logger.warning(f'Tool `{name}` already exists! Overwriting with class {cls}.')
  14. else:
  15. raise ValueError(f'Tool `{name}` already exists! Please ensure that the tool name is unique.')
  16. if cls.name and (cls.name != name):
  17. raise ValueError(f'{cls.__name__}.name="{cls.name}" conflicts with @register_tool(name="{name}").')
  18. cls.name = name
  19. TOOL_REGISTRY[name] = cls
  20. return cls
  21. return decorator
  22. def is_tool_schema(obj: dict) -> bool:
  23. """
  24. Check if obj is a valid JSON schema describing a tool compatible with OpenAI's tool calling.
  25. Example valid schema:
  26. {
  27. "name": "get_current_weather",
  28. "description": "Get the current weather in a given location",
  29. "parameters": {
  30. "type": "object",
  31. "properties": {
  32. "location": {
  33. "type": "string",
  34. "description": "The city and state, e.g. San Francisco, CA"
  35. },
  36. "unit": {
  37. "type": "string",
  38. "enum": ["celsius", "fahrenheit"]
  39. }
  40. },
  41. "required": ["location"]
  42. }
  43. }
  44. """
  45. import jsonschema
  46. try:
  47. assert set(obj.keys()) == {'name', 'description', 'parameters'}
  48. assert isinstance(obj['name'], str)
  49. assert obj['name'].strip()
  50. assert isinstance(obj['description'], str)
  51. assert isinstance(obj['parameters'], dict)
  52. assert set(obj['parameters'].keys()) == {'type', 'properties', 'required'}
  53. assert obj['parameters']['type'] == 'object'
  54. assert isinstance(obj['parameters']['properties'], dict)
  55. assert isinstance(obj['parameters']['required'], list)
  56. assert set(obj['parameters']['required']).issubset(set(obj['parameters']['properties'].keys()))
  57. except AssertionError:
  58. return False
  59. try:
  60. jsonschema.validate(instance={}, schema=obj['parameters'])
  61. except jsonschema.exceptions.SchemaError:
  62. return False
  63. except jsonschema.exceptions.ValidationError:
  64. pass
  65. return True
  66. class BaseTool(ABC):
  67. name: str = ''
  68. description: str = ''
  69. parameters: Union[List[dict], dict] = []
  70. def __init__(self, cfg: Optional[dict] = None):
  71. self.cfg = cfg or {}
  72. if not self.name:
  73. raise ValueError(
  74. f'You must set {self.__class__.__name__}.name, either by @register_tool(name=...) or explicitly setting {self.__class__.__name__}.name'
  75. )
  76. if isinstance(self.parameters, dict):
  77. if not is_tool_schema({'name': self.name, 'description': self.description, 'parameters': self.parameters}):
  78. raise ValueError(
  79. 'The parameters, when provided as a dict, must confirm to a valid openai-compatible JSON schema.')
  80. @abstractmethod
  81. def call(self, params: Union[str, dict], **kwargs) -> Union[str, list, dict, List[ContentItem]]:
  82. """The interface for calling tools.
  83. Each tool needs to implement this function, which is the workflow of the tool.
  84. Args:
  85. params: The parameters of func_call.
  86. kwargs: Additional parameters for calling tools.
  87. Returns:
  88. The result returned by the tool, implemented in the subclass.
  89. """
  90. raise NotImplementedError
  91. def _verify_json_format_args(self, params: Union[str, dict], strict_json: bool = False) -> dict:
  92. """Verify the parameters of the function call"""
  93. if isinstance(params, str):
  94. try:
  95. if strict_json:
  96. params_json: dict = json.loads(params)
  97. else:
  98. params_json: dict = json_loads(params)
  99. except json.decoder.JSONDecodeError:
  100. raise ValueError('Parameters must be formatted as a valid JSON!')
  101. else:
  102. params_json: dict = params
  103. if isinstance(self.parameters, list):
  104. for param in self.parameters:
  105. if 'required' in param and param['required']:
  106. if param['name'] not in params_json:
  107. raise ValueError('Parameters %s is required!' % param['name'])
  108. elif isinstance(self.parameters, dict):
  109. import jsonschema
  110. jsonschema.validate(instance=params_json, schema=self.parameters)
  111. else:
  112. raise ValueError
  113. return params_json
  114. @property
  115. def function(self) -> dict: # Bad naming. It should be `function_info`.
  116. return {
  117. 'name_for_human': self.name_for_human,
  118. 'name': self.name,
  119. 'description': self.description,
  120. 'parameters': self.parameters,
  121. 'args_format': self.args_format
  122. }
  123. @property
  124. def name_for_human(self) -> str:
  125. return self.cfg.get('name_for_human', self.name)
  126. @property
  127. def args_format(self) -> str:
  128. fmt = self.cfg.get('args_format')
  129. if fmt is None:
  130. if has_chinese_chars([self.name_for_human, self.name, self.description, self.parameters]):
  131. fmt = '此工具的输入应为JSON对象。'
  132. else:
  133. fmt = 'Format the arguments as a JSON object.'
  134. return fmt
  135. @property
  136. def file_access(self) -> bool:
  137. return False
  138. class BaseToolWithFileAccess(BaseTool, ABC):
  139. def __init__(self, cfg: Optional[Dict] = None):
  140. super().__init__(cfg)
  141. assert self.name
  142. default_work_dir = os.path.join(DEFAULT_WORKSPACE, 'tools', self.name)
  143. self.work_dir: str = self.cfg.get('work_dir', default_work_dir)
  144. @property
  145. def file_access(self) -> bool:
  146. return True
  147. def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> str:
  148. # Copy remote files to the working directory:
  149. if files:
  150. os.makedirs(self.work_dir, exist_ok=True)
  151. for file in files:
  152. try:
  153. save_url_to_local_work_dir(file, self.work_dir)
  154. except Exception:
  155. print_traceback()
  156. # Then do something with the files:
  157. # ...