assistant.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import copy
  2. import datetime
  3. from typing import Dict, Iterator, List, Literal, Optional, Union
  4. import json5
  5. from qwen_agent.agents.fncall_agent import FnCallAgent
  6. from qwen_agent.llm import BaseChatModel
  7. from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM, Message
  8. from qwen_agent.log import logger
  9. from qwen_agent.tools import BaseTool
  10. from qwen_agent.utils.utils import get_basename_from_url, print_traceback
  11. KNOWLEDGE_TEMPLATE_ZH = """# 知识库
  12. {knowledge}"""
  13. KNOWLEDGE_TEMPLATE_EN = """# Knowledge Base
  14. {knowledge}"""
  15. KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN}
  16. KNOWLEDGE_SNIPPET_ZH = """## 来自 {source} 的内容:
  17. ```
  18. {content}
  19. ```"""
  20. KNOWLEDGE_SNIPPET_EN = """## The content from {source}:
  21. ```
  22. {content}
  23. ```"""
  24. KNOWLEDGE_SNIPPET = {'zh': KNOWLEDGE_SNIPPET_ZH, 'en': KNOWLEDGE_SNIPPET_EN}
  25. def format_knowledge_to_source_and_content(result: Union[str, List[dict]]) -> List[dict]:
  26. knowledge = []
  27. if isinstance(result, str):
  28. result = f'{result}'.strip()
  29. try:
  30. docs = json5.loads(result)
  31. except Exception:
  32. print_traceback()
  33. knowledge.append({'source': '上传的文档', 'content': result})
  34. return knowledge
  35. else:
  36. docs = result
  37. try:
  38. _tmp_knowledge = []
  39. assert isinstance(docs, list)
  40. for doc in docs:
  41. url, snippets = doc['url'], doc['text']
  42. assert isinstance(snippets, list)
  43. _tmp_knowledge.append({
  44. 'source': f'[文件]({get_basename_from_url(url)})',
  45. 'content': '\n\n...\n\n'.join(snippets)
  46. })
  47. knowledge.extend(_tmp_knowledge)
  48. except Exception:
  49. print_traceback()
  50. knowledge.append({'source': '上传的文档', 'content': result})
  51. return knowledge
  52. class Assistant(FnCallAgent):
  53. """This is a widely applicable agent integrated with RAG capabilities and function call ability."""
  54. def __init__(self,
  55. function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
  56. llm: Optional[Union[Dict, BaseChatModel]] = None,
  57. system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
  58. name: Optional[str] = None,
  59. description: Optional[str] = None,
  60. files: Optional[List[str]] = None,
  61. rag_cfg: Optional[Dict] = None):
  62. super().__init__(function_list=function_list,
  63. llm=llm,
  64. system_message=system_message,
  65. name=name,
  66. description=description,
  67. files=files,
  68. rag_cfg=rag_cfg)
  69. def _run(self,
  70. messages: List[Message],
  71. lang: Literal['en', 'zh'] = 'en',
  72. knowledge: str = '',
  73. **kwargs) -> Iterator[List[Message]]:
  74. """Q&A with RAG and tool use abilities.
  75. Args:
  76. knowledge: If an external knowledge string is provided,
  77. it will be used directly without retrieving information from files in messages.
  78. """
  79. new_messages = self._prepend_knowledge_prompt(messages=messages, lang=lang, knowledge=knowledge, **kwargs)
  80. return super()._run(messages=new_messages, lang=lang, **kwargs)
  81. def _prepend_knowledge_prompt(self,
  82. messages: List[Message],
  83. lang: Literal['en', 'zh'] = 'en',
  84. knowledge: str = '',
  85. **kwargs) -> List[Message]:
  86. messages = copy.deepcopy(messages)
  87. if not knowledge:
  88. # Retrieval knowledge from files
  89. *_, last = self.mem.run(messages=messages, lang=lang, **kwargs)
  90. knowledge = last[-1][CONTENT]
  91. logger.debug(f'Retrieved knowledge of type `{type(knowledge).__name__}`:\n{knowledge}')
  92. if knowledge:
  93. knowledge = format_knowledge_to_source_and_content(knowledge)
  94. logger.debug(f'Formatted knowledge into type `{type(knowledge).__name__}`:\n{knowledge}')
  95. else:
  96. knowledge = []
  97. snippets = []
  98. for k in knowledge:
  99. snippets.append(KNOWLEDGE_SNIPPET[lang].format(source=k['source'], content=k['content']))
  100. knowledge_prompt = ''
  101. if snippets:
  102. knowledge_prompt = KNOWLEDGE_TEMPLATE[lang].format(knowledge='\n\n'.join(snippets))
  103. if knowledge_prompt:
  104. if messages[0][ROLE] == SYSTEM:
  105. messages[0][CONTENT] += '\n\n' + knowledge_prompt
  106. else:
  107. messages = [Message(role=SYSTEM, content=knowledge_prompt)] + messages
  108. return messages
  109. def get_current_date_str(
  110. lang: Literal['en', 'zh'] = 'en',
  111. hours_from_utc: Optional[int] = None,
  112. ) -> str:
  113. if hours_from_utc is None:
  114. cur_time = datetime.datetime.now()
  115. else:
  116. cur_time = datetime.datetime.utcnow() + datetime.timedelta(hours=hours_from_utc)
  117. if lang == 'en':
  118. date_str = 'Current date: ' + cur_time.strftime('%A, %B %d, %Y')
  119. elif lang == 'zh':
  120. cur_time = cur_time.timetuple()
  121. date_str = f'当前时间:{cur_time.tm_year}年{cur_time.tm_mon}月{cur_time.tm_mday}日,星期'
  122. date_str += ['一', '二', '三', '四', '五', '六', '日'][cur_time.tm_wday]
  123. date_str += '。'
  124. else:
  125. raise NotImplementedError
  126. return date_str