function_calling.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import copy
  2. from abc import ABC
  3. from typing import Dict, Iterator, List, Literal, Optional, Union
  4. from qwen_agent.llm.base import BaseChatModel
  5. from qwen_agent.llm.schema import ASSISTANT, FUNCTION, USER, ContentItem, Message
  6. class BaseFnCallModel(BaseChatModel, ABC):
  7. def __init__(self, cfg: Optional[Dict] = None):
  8. super().__init__(cfg)
  9. fncall_prompt_type = self.generate_cfg.get('fncall_prompt_type', 'qwen')
  10. if fncall_prompt_type == 'qwen':
  11. from qwen_agent.llm.fncall_prompts.qwen_fncall_prompt import FN_STOP_WORDS, QwenFnCallPrompt
  12. self.fncall_prompt = QwenFnCallPrompt()
  13. stop = self.generate_cfg.get('stop', [])
  14. self.generate_cfg['stop'] = stop + [x for x in FN_STOP_WORDS if x not in stop]
  15. else:
  16. raise NotImplementedError
  17. def _preprocess_messages(
  18. self,
  19. messages: List[Message],
  20. lang: Literal['en', 'zh'],
  21. generate_cfg: dict,
  22. functions: Optional[List[Dict]] = None,
  23. ) -> List[Message]:
  24. messages = super()._preprocess_messages(messages, lang=lang, generate_cfg=generate_cfg)
  25. if (not functions) or (generate_cfg.get('function_choice', 'auto') == 'none'):
  26. messages = self._remove_fncall_messages(messages, lang=lang)
  27. else:
  28. validate_num_fncall_results(
  29. messages=messages,
  30. support_multimodal_input=self.support_multimodal_input,
  31. )
  32. messages = self.fncall_prompt.preprocess_fncall_messages(
  33. messages=messages,
  34. functions=functions,
  35. lang=lang,
  36. parallel_function_calls=generate_cfg.get('parallel_function_calls', False),
  37. function_choice=generate_cfg.get('function_choice', 'auto'),
  38. )
  39. return messages
  40. def _postprocess_messages(
  41. self,
  42. messages: List[Message],
  43. fncall_mode: bool,
  44. generate_cfg: dict,
  45. ) -> List[Message]:
  46. messages = super()._postprocess_messages(messages, fncall_mode=fncall_mode, generate_cfg=generate_cfg)
  47. if fncall_mode:
  48. messages = self.fncall_prompt.postprocess_fncall_messages(
  49. messages=messages,
  50. parallel_function_calls=generate_cfg.get('parallel_function_calls', False),
  51. function_choice=generate_cfg.get('function_choice', 'auto'),
  52. )
  53. return messages
  54. def _remove_fncall_messages(self, messages: List[Message], lang: Literal['en', 'zh']) -> List[Message]:
  55. # Change function calls into user messages so that the model won't try
  56. # to generate function calls when given functions and function_choice="none".
  57. new_messages = []
  58. for msg in messages:
  59. if (msg.role == FUNCTION) or msg.function_call:
  60. if (not new_messages) or (new_messages[-1].role != USER):
  61. new_messages.append(Message(role=USER, content=[]))
  62. if msg.function_call:
  63. tool_name = msg.function_call.name
  64. tool_args = msg.function_call.arguments
  65. if lang == 'zh':
  66. tool_text = f'\n\n工具"{tool_name}"被调用时使用了以下参数:\n{tool_args}'
  67. else:
  68. tool_text = f'\n\nThe tool "{tool_name}" was called with these arguments:\n{tool_args}'
  69. else:
  70. assert msg.role == FUNCTION
  71. if msg.content:
  72. assert len(msg.content) == 1
  73. assert isinstance(msg.content[0], ContentItem)
  74. assert isinstance(msg.content[0].text, str)
  75. tool_result = msg.content[0].text
  76. else:
  77. tool_result = 'No result.'
  78. if lang == 'zh':
  79. tool_text = f'\n\n该工具返回了以下结果:\n{tool_result}'
  80. else:
  81. tool_text = f'\n\nThe tool has returned the following result:\n{tool_result}'
  82. new_messages[-1].content.append(ContentItem(text=tool_text))
  83. else:
  84. if (msg.role == USER) and new_messages and (new_messages[-1].role == USER):
  85. # Separate two user messages with an assistant message to make the bot focus on the latter:
  86. new_messages.append(Message(role=ASSISTANT, content=[ContentItem(text='...')]))
  87. new_messages.append(msg)
  88. return new_messages
  89. def _chat_with_functions(
  90. self,
  91. messages: List[Message],
  92. functions: List[Dict],
  93. stream: bool,
  94. delta_stream: bool,
  95. generate_cfg: dict,
  96. lang: Literal['en', 'zh'],
  97. ) -> Union[List[Message], Iterator[List[Message]]]:
  98. if delta_stream:
  99. raise NotImplementedError('Please use stream=True with delta_stream=False, because delta_stream=True'
  100. ' is not implemented for function calling due to some technical reasons.')
  101. generate_cfg = copy.deepcopy(generate_cfg)
  102. for k in ['parallel_function_calls', 'function_choice']:
  103. if k in generate_cfg:
  104. del generate_cfg[k]
  105. return self._continue_assistant_response(messages, generate_cfg=generate_cfg, stream=stream)
  106. def _continue_assistant_response(
  107. self,
  108. messages: List[Message],
  109. generate_cfg: dict,
  110. stream: bool,
  111. ) -> Iterator[List[Message]]:
  112. messages = simulate_response_completion_with_chat(messages)
  113. return self._chat(messages, stream=stream, delta_stream=False, generate_cfg=generate_cfg)
  114. def simulate_response_completion_with_chat(messages: List[Message]) -> List[Message]:
  115. if messages and (messages[-1].role == ASSISTANT):
  116. assert (len(messages) > 1) and (messages[-2].role == USER)
  117. assert messages[-1].function_call is None
  118. usr = messages[-2].content
  119. bot = messages[-1].content
  120. sep = '\n\n'
  121. if isinstance(usr, str) and isinstance(bot, str):
  122. usr = usr + sep + bot
  123. elif isinstance(usr, list) and isinstance(bot, list):
  124. usr = usr + [ContentItem(text=sep)] + bot
  125. else:
  126. raise NotImplementedError
  127. text_to_complete = copy.deepcopy(messages[-2])
  128. text_to_complete.content = usr
  129. messages = messages[:-2] + [text_to_complete]
  130. return messages
  131. def validate_num_fncall_results(messages: List[Message], support_multimodal_input: bool):
  132. fn_results = []
  133. i = len(messages) - 1
  134. while messages[i].role == FUNCTION:
  135. fn_results = [messages[i].name] + fn_results
  136. content = messages[i].content
  137. if isinstance(content, list):
  138. for item in content:
  139. if item.file:
  140. raise ValueError('Tool call results with content type="file" are not supported.')
  141. if item.image and (not support_multimodal_input):
  142. raise ValueError('The current model service does not accept images as tool results.')
  143. i -= 1
  144. fn_calls = []
  145. while messages[i].function_call:
  146. fn_calls = [messages[i].function_call.name] + fn_calls
  147. i -= 1
  148. if len(fn_calls) != len(fn_results):
  149. raise ValueError(f'Expecting {len(fn_calls)} function results (i.e., messages with role="function") '
  150. f'but received {len(fn_results)} function results. '
  151. 'The number of function results must match that of the function_call messages.')
  152. for fc_name, fr_name in zip(fn_calls, fn_results):
  153. if fr_name and (fc_name != fr_name):
  154. raise ValueError('The function results (i.e., the messages with role="function" ) must be '
  155. 'put in the same order as the function_call messages. And the function names must match.'
  156. f'The function results are currently {fn_results}. But {fn_calls} are expected.')