qwen_react.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import json
  2. import os
  3. from prompt.react import ReAct
  4. QWEN_TOOLS_LIST = [
  5. {
  6. 'name_for_human': '代码解释器',
  7. 'name_for_model': 'code_interpreter',
  8. 'description_for_model': '代码解释器,可用于执行Python代码。',
  9. 'parameters': [{'name': 'code', 'type': 'string', 'description': '待执行的代码'}],
  10. 'args_format': 'code'
  11. },
  12. ]
  13. TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
  14. class QwenReAct(ReAct):
  15. def __init__(self, query, lang='en', upload_file_paths=[]):
  16. super().__init__(query, lang, upload_file_paths)
  17. self.upload_file_paths = [f'{os.path.basename(fname)}' for fname in upload_file_paths]
  18. self.list_of_plugin_info = QWEN_TOOLS_LIST
  19. self.fname_template = {
  20. 'zh': '[上传文件{fname_str}]',
  21. 'en': '[Upload file {fname_str}]',
  22. 'en_multi': '[Upload file {fname_str}]'
  23. }
  24. def build_prompt(self):
  25. im_start = '<|im_start|>'
  26. im_end = '<|im_end|>'
  27. prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
  28. query = super().build_prompt()
  29. query = query.lstrip('\n').rstrip()
  30. prompt += f'\n{im_start}user\n{query}{im_end}'
  31. if f'{im_start}assistant' not in query:
  32. prompt += f'\n{im_start}assistant\n{im_end}'
  33. assert prompt.endswith(f'\n{im_start}assistant\n{im_end}')
  34. prompt = prompt[: -len(f'{im_end}')]
  35. self.prompt = prompt
  36. return prompt
  37. def _build_tools_text(self):
  38. # tool info
  39. tools_text = []
  40. for plugin_info in self.list_of_plugin_info:
  41. tool = TOOL_DESC.format(
  42. name_for_model=plugin_info['name_for_model'],
  43. name_for_human=plugin_info['name_for_human'],
  44. description_for_model=plugin_info['description_for_model'],
  45. parameters=json.dumps(plugin_info['parameters'], ensure_ascii=False),
  46. )
  47. if plugin_info.get('args_format', 'json') == 'json':
  48. tool += ' Format the arguments as a JSON object.'
  49. elif plugin_info['args_format'] == 'code':
  50. tool += ' Enclose the code within triple backticks (`) at the beginning and end of the code.'
  51. else:
  52. raise NotImplementedError
  53. tools_text.append(tool)
  54. tools_text = '\n\n'.join(tools_text)
  55. return tools_text
  56. def _build_tools_name_text(self):
  57. return ', '.join([plugin_info['name_for_model'] for plugin_info in self.list_of_plugin_info])