react_chat.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import json
  2. from typing import Dict, Iterator, List, Literal, Optional, Tuple, Union
  3. from qwen_agent.agents.fncall_agent import FnCallAgent
  4. from qwen_agent.llm import BaseChatModel
  5. from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, Message
  6. from qwen_agent.settings import MAX_LLM_CALL_PER_RUN
  7. from qwen_agent.tools import BaseTool
  8. from qwen_agent.utils.utils import format_as_text_message, merge_generate_cfgs
  9. TOOL_DESC = (
  10. '{name_for_model}: Call this tool to interact with the {name_for_human} API. '
  11. 'What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} {args_format}')
  12. PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools:
  13. {tool_descs}
  14. Use the following format:
  15. Question: the input question you must answer
  16. Thought: you should always think about what to do
  17. Action: the action to take, should be one of [{tool_names}]
  18. Action Input: the input to the action
  19. Observation: the result of the action
  20. ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
  21. Thought: I now know the final answer
  22. Final Answer: the final answer to the original input question
  23. Begin!
  24. Question: {query}
  25. Thought: """
  26. class ReActChat(FnCallAgent):
  27. """This agent use ReAct format to call tools"""
  28. def __init__(self,
  29. function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
  30. llm: Optional[Union[Dict, BaseChatModel]] = None,
  31. system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
  32. name: Optional[str] = None,
  33. description: Optional[str] = None,
  34. files: Optional[List[str]] = None,
  35. **kwargs):
  36. super().__init__(function_list=function_list,
  37. llm=llm,
  38. system_message=system_message,
  39. name=name,
  40. description=description,
  41. files=files,
  42. **kwargs)
  43. self.extra_generate_cfg = merge_generate_cfgs(
  44. base_generate_cfg=self.extra_generate_cfg,
  45. new_generate_cfg={'stop': ['Observation:', 'Observation:\n']},
  46. )
  47. def _run(self, messages: List[Message], lang: Literal['en', 'zh'] = 'en', **kwargs) -> Iterator[List[Message]]:
  48. text_messages = self._prepend_react_prompt(messages, lang=lang)
  49. num_llm_calls_available = MAX_LLM_CALL_PER_RUN
  50. response: str = 'Thought: '
  51. while num_llm_calls_available > 0:
  52. num_llm_calls_available -= 1
  53. # Display the streaming response
  54. output = []
  55. for output in self._call_llm(messages=text_messages):
  56. if output:
  57. yield [Message(role=ASSISTANT, content=response + output[-1].content)]
  58. # Accumulate the current response
  59. if output:
  60. response += output[-1].content
  61. has_action, action, action_input, thought = self._detect_tool(output[-1].content)
  62. if not has_action:
  63. break
  64. # Add the tool result
  65. observation = self._call_tool(action, action_input, messages=messages, **kwargs)
  66. observation = f'\nObservation: {observation}\nThought: '
  67. response += observation
  68. yield [Message(role=ASSISTANT, content=response)]
  69. if (not text_messages[-1].content.endswith('\nThought: ')) and (not thought.startswith('\n')):
  70. # Add the '\n' between '\nQuestion:' and the first 'Thought:'
  71. text_messages[-1].content += '\n'
  72. if action_input.startswith('```'):
  73. # Add a newline for proper markdown rendering of code
  74. action_input = '\n' + action_input
  75. text_messages[-1].content += thought + f'\nAction: {action}\nAction Input: {action_input}' + observation
  76. def _prepend_react_prompt(self, messages: List[Message], lang: Literal['en', 'zh']) -> List[Message]:
  77. tool_descs = []
  78. for f in self.function_map.values():
  79. function = f.function
  80. name = function.get('name', None)
  81. name_for_human = function.get('name_for_human', name)
  82. name_for_model = function.get('name_for_model', name)
  83. assert name_for_human and name_for_model
  84. args_format = function.get('args_format', '')
  85. tool_descs.append(
  86. TOOL_DESC.format(name_for_human=name_for_human,
  87. name_for_model=name_for_model,
  88. description_for_model=function['description'],
  89. parameters=json.dumps(function['parameters'], ensure_ascii=False),
  90. args_format=args_format).rstrip())
  91. tool_descs = '\n\n'.join(tool_descs)
  92. tool_names = ','.join(tool.name for tool in self.function_map.values())
  93. text_messages = [format_as_text_message(m, add_upload_info=True, lang=lang) for m in messages]
  94. text_messages[-1].content = PROMPT_REACT.format(
  95. tool_descs=tool_descs,
  96. tool_names=tool_names,
  97. query=text_messages[-1].content,
  98. )
  99. return text_messages
  100. def _detect_tool(self, text: str) -> Tuple[bool, str, str, str]:
  101. special_func_token = '\nAction:'
  102. special_args_token = '\nAction Input:'
  103. special_obs_token = '\nObservation:'
  104. func_name, func_args = None, None
  105. i = text.rfind(special_func_token)
  106. j = text.rfind(special_args_token)
  107. k = text.rfind(special_obs_token)
  108. if 0 <= i < j: # If the text has `Action` and `Action input`,
  109. if k < j: # but does not contain `Observation`,
  110. # then it is likely that `Observation` is ommited by the LLM,
  111. # because the output text may have discarded the stop word.
  112. text = text.rstrip() + special_obs_token # Add it back.
  113. k = text.rfind(special_obs_token)
  114. func_name = text[i + len(special_func_token):j].strip()
  115. func_args = text[j + len(special_args_token):k].strip()
  116. text = text[:i] # Return the response before tool call, i.e., `Thought`
  117. return (func_name is not None), func_name, func_args, text