group_chat.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import copy
  2. import random
  3. from typing import Dict, Iterator, List, Optional, Union
  4. from qwen_agent import Agent, MultiAgentHub
  5. from qwen_agent.agents.assistant import Assistant
  6. from qwen_agent.agents.group_chat_auto_router import GroupChatAutoRouter
  7. from qwen_agent.agents.user_agent import PENDING_USER_INPUT, UserAgent
  8. from qwen_agent.llm import BaseChatModel
  9. from qwen_agent.llm.schema import Message
  10. from qwen_agent.log import logger
  11. from qwen_agent.tools import BaseTool
  12. class GroupChat(Agent, MultiAgentHub):
  13. """This is an agent for multi-agent management.
  14. This agent can accept a list of agents, manage their speaking order, and output the response of each agent.
  15. """
  16. _VALID_AGENT_SELECTION_METHODS = ['manual', 'round_robin', 'random', 'auto']
  17. def __init__(self,
  18. agents: Union[List[Agent], Dict],
  19. agent_selection_method: Optional[str] = 'auto',
  20. function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
  21. llm: Optional[Union[Dict, BaseChatModel]] = None,
  22. **kwargs):
  23. """Initialization the agent.
  24. Args:
  25. agents: A list of agents of agent configurations. One configuration example is:
  26. {
  27. 'background': 'An interest group',
  28. 'agents': [{
  29. 'name': 'Tang Xiao',
  30. 'description': 'A hardworking worker, addicted to work every day, gradually losing weight.',
  31. 'is_human': True # mark this as a real person
  32. }, {
  33. 'name': 'Tou Da',
  34. 'description': 'A sports student',
  35. 'instructions': 'You are a sports student who loves sports.',
  36. 'knowledge_files': ['http://example.html'],
  37. 'selected_tools': ['image_gen']
  38. }]
  39. }
  40. agent_selection_method: The method of select speaker:
  41. (1) auto: Using one host agent to choose the speaker according to the context.
  42. (2) round_robin: Speak in order.
  43. (3) random: Random speech.
  44. function_list: The tools for inputting to the host.
  45. llm: The LLM for inputting to the host.
  46. """
  47. super().__init__(**kwargs)
  48. assert agent_selection_method in self._VALID_AGENT_SELECTION_METHODS, f'You must choose agent_selection_method from {", ".join(self._VALID_AGENT_SELECTION_METHODS)}'
  49. self.agent_selection_method = agent_selection_method
  50. if isinstance(agents, dict):
  51. self._agents = self._init_agents_from_config(agents, llm=llm)
  52. else:
  53. self._agents = agents
  54. if self.agent_selection_method == 'auto':
  55. assert llm is not None, 'Need to provide LLM to the host in auto mode'
  56. self.host = GroupChatAutoRouter(function_list=function_list, llm=llm, agents=self.agents, name='host')
  57. def _run(self,
  58. messages: List[Message] = None,
  59. lang: str = 'zh',
  60. max_round: Optional[int] = 3,
  61. need_batch_response: bool = True,
  62. mentioned_agents_name: List[str] = None,
  63. **kwargs) -> Iterator[List[Message]]:
  64. messages = copy.deepcopy(messages)
  65. for message in messages:
  66. if message.role == 'assistant':
  67. assert message.name, 'In group chat, each agent must be given a name'
  68. # Name will be used for router
  69. # Todo: Dealing with situations where there are no real players
  70. if not message.name:
  71. message.name = message.role
  72. if need_batch_response:
  73. return self._gen_batch_response(messages=messages,
  74. lang=lang,
  75. max_round=max_round,
  76. mentioned_agents_name=mentioned_agents_name,
  77. **kwargs)
  78. else:
  79. return self._gen_one_response(messages=messages,
  80. lang=lang,
  81. mentioned_agents_name=mentioned_agents_name,
  82. **kwargs)
  83. def _gen_batch_response(self,
  84. messages: List[Message] = None,
  85. lang: str = 'zh',
  86. max_round: Optional[int] = 3,
  87. mentioned_agents_name: List[str] = None,
  88. **kwargs) -> Iterator[List[Message]]:
  89. # Record all mentioned agents: reply in order
  90. mentioned_agents_name = mentioned_agents_name or []
  91. messages = copy.deepcopy(messages)
  92. response = []
  93. for i in range(max_round):
  94. if isinstance(messages[-1].content, list):
  95. content = '\n'.join([x.text if x.text else '' for x in messages[-1].content]).strip()
  96. else:
  97. content = messages[-1].content.strip()
  98. if '@' in content:
  99. for x in content.split('@'):
  100. for agent in self.agents:
  101. if x.startswith(agent.name):
  102. if agent not in mentioned_agents_name:
  103. mentioned_agents_name.append(agent.name)
  104. break
  105. rsp = []
  106. for rsp in self._gen_one_response(messages=messages,
  107. lang=lang,
  108. mentioned_agents_name=mentioned_agents_name,
  109. **kwargs):
  110. yield response + rsp
  111. if not rsp:
  112. # The topic ends
  113. break
  114. if mentioned_agents_name:
  115. assert rsp[-1].name == mentioned_agents_name[0]
  116. mentioned_agents_name.pop(0)
  117. response += rsp
  118. if rsp[-1].content == PENDING_USER_INPUT:
  119. # Terminate group chat and wait for user input
  120. break
  121. messages.extend(rsp)
  122. yield response
  123. def _gen_one_response(self,
  124. messages: List[Message] = None,
  125. lang: str = 'zh',
  126. mentioned_agents_name: List[str] = None,
  127. **kwargs) -> Iterator[List[Message]]:
  128. selected_agent = self._select_agent(messages, mentioned_agents_name, lang)
  129. if selected_agent:
  130. logger.info(f'selected_agent_name: {selected_agent.name}')
  131. new_messages = self._manage_messages(messages, selected_agent.name)
  132. for rsp in selected_agent.run(messages=new_messages, **kwargs):
  133. yield rsp
  134. else:
  135. yield []
  136. def _select_agent(self,
  137. messages: List[Message],
  138. mentioned_agents_name: List[str] = None,
  139. lang: str = 'zh') -> Union[Agent, None]:
  140. agents_map = {x.name: x for x in self.agents}
  141. if mentioned_agents_name:
  142. # Manually select agent
  143. return agents_map[mentioned_agents_name[0]]
  144. if self.agent_selection_method == 'auto':
  145. *_, last = self.host.run(messages=messages, lang=lang)
  146. auto_selected_agent = None
  147. if isinstance(last[-1]['content'], str):
  148. auto_selected_agent = last[-1]['content']
  149. else:
  150. assert isinstance(last[-1]['content'], list)
  151. if 'text' in last[-1]['content'][0]:
  152. auto_selected_agent = last[-1]['content'][0]['text']
  153. if auto_selected_agent in agents_map.keys():
  154. return agents_map[auto_selected_agent]
  155. elif auto_selected_agent == '[STOP]':
  156. return None
  157. if self.agent_selection_method == 'random':
  158. agent = random.choice(list(self.agents))
  159. return agent
  160. if self.agent_selection_method == 'manual':
  161. for i in range(3):
  162. agent_key = input('Please enter the selected agent name: ')
  163. if agent_key in agents_map.keys():
  164. return agents_map[agent_key]
  165. else:
  166. logger.warning(f'Please select one agent from {str(list(agents_map.keys()))}')
  167. # round_robin
  168. if messages:
  169. agents_list = [x.name for x in self.agents]
  170. try:
  171. last_agent_index = agents_list.index(messages[-1]['name'])
  172. except ValueError:
  173. last_agent_index = -1
  174. else:
  175. last_agent_index = -1
  176. return self.agents[(last_agent_index + 1) % len(self.agents)]
  177. def _manage_messages(self, messages: List[Message], name: str) -> List[Message]:
  178. new_messages = []
  179. new_msg = None
  180. i = 0
  181. while i < len(messages):
  182. msg = messages[i]
  183. if msg.name == name:
  184. if new_msg:
  185. # Have 'user' before 'assistant'
  186. new_messages.append(new_msg)
  187. if not msg.function_call and ( # noqa
  188. (not new_messages) or (new_messages[-1].name == name)): # noqa
  189. new_messages.append(Message('user', f'{name}: '))
  190. new_msg = copy.deepcopy(msg)
  191. new_msg.role = 'assistant'
  192. new_messages.append(new_msg)
  193. new_msg = None
  194. if msg.function_call:
  195. # Append the function call msg
  196. assert messages[i + 1].role == 'function'
  197. new_messages.append(copy.deepcopy(messages[i + 1]))
  198. i += 1
  199. else:
  200. if isinstance(msg.content, list):
  201. content = '\n'.join([x.text if x.text else '' for x in msg.content]).strip()
  202. else:
  203. content = msg.content.strip()
  204. if content.strip():
  205. if not new_msg:
  206. new_msg = Message('user', f'{msg.name}: {content.strip()}')
  207. else:
  208. new_msg.content += f'\n{msg.name}: {content.strip()}'
  209. if msg.function_call:
  210. # Skip the function call msg
  211. assert messages[i + 1].role == 'function'
  212. assert messages[i + 2].role == 'assistant' and messages[i + 2].name == msg.name
  213. i += 1
  214. i += 1
  215. if new_msg:
  216. new_messages.append(new_msg)
  217. if new_messages and new_messages[-1].role == 'user':
  218. new_messages[-1].content += f'\n{name}: '
  219. else:
  220. new_messages.append(Message('user', f'{name}: '))
  221. return new_messages
  222. def _init_agents_from_config(self, cfgs: Dict, llm: Optional[Union[Dict, BaseChatModel]] = None) -> List[Agent]:
  223. def _build_system_from_role_config(config: Dict):
  224. role_chat_prompt = """你是{name}。{description}\n\n{instructions}"""
  225. name = config.get('name', '').strip()
  226. description = config.get('description', '').lstrip('\n').rstrip()
  227. instructions = config.get('instructions', '').lstrip('\n').rstrip()
  228. if len(instructions) >= len(description):
  229. description = '' # redundant, as we already have instructions
  230. else:
  231. description = f'你的简介是:{description}'
  232. prompt = role_chat_prompt.format(name=name, description=description, instructions=instructions)
  233. knowledge_files = config.get('knowledge_files', [])
  234. selected_tools = config.get('selected_tools', [])
  235. return prompt, knowledge_files, selected_tools
  236. agents = []
  237. groupchat_background = '你在一个群聊中,'
  238. if cfgs.get('background', ''):
  239. groupchat_background += f'群聊背景为:{cfgs["background"]}'
  240. for cfg in cfgs['agents']:
  241. system, knowledge_files, selected_tools = _build_system_from_role_config(cfg)
  242. if 'is_human' in cfg and cfg['is_human']:
  243. # Append human agent
  244. agents.append(UserAgent(name=cfg['name'], description=cfg['description']))
  245. else:
  246. # Create npc agent by config
  247. other_agents = []
  248. for x in cfgs['agents']:
  249. if x['name'] != cfg['name']:
  250. other_agents.append(x['name'])
  251. agents.append(
  252. Assistant(llm=llm,
  253. system_message=groupchat_background + system +
  254. f'\n\n群里其他成员包括:{", ".join(other_agents)},如果你想和别人对话,可以@成员名字。\n' +
  255. '\n\n讲话时请直接输出内容,不要输出你的名字。\n\n其他群友的发言历史以如下格式展示:\n角色名: 说话内容',
  256. files=knowledge_files,
  257. function_list=selected_tools,
  258. name=cfg['name'],
  259. description=cfg['description']))
  260. return agents