memory.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import json
  2. from importlib import import_module
  3. from typing import Dict, Iterator, List, Optional, Union
  4. import json5
  5. from qwen_agent import Agent
  6. from qwen_agent.llm import BaseChatModel
  7. from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, USER, Message
  8. from qwen_agent.log import logger
  9. from qwen_agent.settings import (DEFAULT_MAX_REF_TOKEN, DEFAULT_PARSER_PAGE_SIZE, DEFAULT_RAG_KEYGEN_STRATEGY,
  10. DEFAULT_RAG_SEARCHERS)
  11. from qwen_agent.tools import BaseTool
  12. from qwen_agent.tools.simple_doc_parser import PARSER_SUPPORTED_FILE_TYPES
  13. from qwen_agent.utils.utils import extract_files_from_messages, extract_text_from_message, get_file_type
  14. class Memory(Agent):
  15. """Memory is special agent for file management.
  16. By default, this memory can use retrieval tool for RAG.
  17. """
  18. def __init__(self,
  19. function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
  20. llm: Optional[Union[Dict, BaseChatModel]] = None,
  21. system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
  22. files: Optional[List[str]] = None,
  23. rag_cfg: Optional[Dict] = None):
  24. """Initialization the memory.
  25. Args:
  26. rag_cfg: The config for RAG. One example is:
  27. {
  28. 'max_ref_token': 4000,
  29. 'parser_page_size': 500,
  30. 'rag_keygen_strategy': 'SplitQueryThenGenKeyword',
  31. 'rag_searchers': ['keyword_search', 'front_page_search']
  32. }
  33. And the above is the default settings.
  34. """
  35. self.cfg = rag_cfg or {}
  36. self.max_ref_token: int = self.cfg.get('max_ref_token', DEFAULT_MAX_REF_TOKEN)
  37. self.parser_page_size: int = self.cfg.get('parser_page_size', DEFAULT_PARSER_PAGE_SIZE)
  38. self.rag_searchers = self.cfg.get('rag_searchers', DEFAULT_RAG_SEARCHERS)
  39. self.rag_keygen_strategy = self.cfg.get('rag_keygen_strategy', DEFAULT_RAG_KEYGEN_STRATEGY)
  40. function_list = function_list or []
  41. super().__init__(function_list=[{
  42. 'name': 'retrieval',
  43. 'max_ref_token': self.max_ref_token,
  44. 'parser_page_size': self.parser_page_size,
  45. 'rag_searchers': self.rag_searchers,
  46. }, {
  47. 'name': 'doc_parser',
  48. 'max_ref_token': self.max_ref_token,
  49. 'parser_page_size': self.parser_page_size,
  50. }] + function_list,
  51. llm=llm,
  52. system_message=system_message)
  53. self.system_files = files or []
  54. def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]:
  55. """This agent is responsible for processing the input files in the message.
  56. This method stores the files in the knowledge base, and retrievals the relevant parts
  57. based on the query and returning them.
  58. The currently supported file types include: .pdf, .docx, .pptx, .txt, .csv, .tsv, .xlsx, .xls and html.
  59. Args:
  60. messages: A list of messages.
  61. lang: Language.
  62. Yields:
  63. The message of retrieved documents.
  64. """
  65. # process files in messages
  66. rag_files = self.get_rag_files(messages)
  67. if not rag_files:
  68. yield [Message(role=ASSISTANT, content='', name='memory')]
  69. else:
  70. query = ''
  71. # Only retrieval content according to the last user query if exists
  72. if messages and messages[-1].role == USER:
  73. query = extract_text_from_message(messages[-1], add_upload_info=False)
  74. # Keyword generation
  75. if query and self.rag_keygen_strategy.lower() != 'none':
  76. module_name = 'qwen_agent.agents.keygen_strategies'
  77. module = import_module(module_name)
  78. cls = getattr(module, self.rag_keygen_strategy)
  79. keygen = cls(llm=self.llm)
  80. response = keygen.run([Message(USER, query)], files=rag_files)
  81. last = None
  82. for last in response:
  83. continue
  84. if last:
  85. keyword = last[-1].content.strip()
  86. else:
  87. keyword = ''
  88. if keyword.startswith('```json'):
  89. keyword = keyword[len('```json'):]
  90. if keyword.endswith('```'):
  91. keyword = keyword[:-3]
  92. try:
  93. keyword_dict = json5.loads(keyword)
  94. if 'text' not in keyword_dict:
  95. keyword_dict['text'] = query
  96. query = json.dumps(keyword_dict, ensure_ascii=False)
  97. logger.info(query)
  98. except Exception:
  99. query = query
  100. content = self.function_map['retrieval'].call(
  101. {
  102. 'query': query,
  103. 'files': rag_files
  104. },
  105. **kwargs,
  106. )
  107. if not isinstance(content, str):
  108. content = json.dumps(content, ensure_ascii=False, indent=4)
  109. yield [Message(role=ASSISTANT, content=content, name='memory')]
  110. def get_rag_files(self, messages: List[Message]):
  111. session_files = extract_files_from_messages(messages, include_images=False)
  112. files = self.system_files + session_files
  113. rag_files = []
  114. for file in files:
  115. f_type = get_file_type(file)
  116. if f_type in PARSER_SUPPORTED_FILE_TYPES and file not in rag_files:
  117. rag_files.append(file)
  118. return rag_files