retrieval.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Dict, Optional, Union
  2. import json5
  3. from qwen_agent.settings import DEFAULT_MAX_REF_TOKEN, DEFAULT_PARSER_PAGE_SIZE, DEFAULT_RAG_SEARCHERS
  4. from qwen_agent.tools.base import TOOL_REGISTRY, BaseTool, register_tool
  5. from qwen_agent.tools.doc_parser import DocParser, Record
  6. from qwen_agent.tools.simple_doc_parser import PARSER_SUPPORTED_FILE_TYPES
  7. def _check_deps_for_rag():
  8. try:
  9. import charset_normalizer # noqa
  10. import jieba # noqa
  11. import pdfminer # noqa
  12. import pdfplumber # noqa
  13. import rank_bm25 # noqa
  14. import snowballstemmer # noqa
  15. from bs4 import BeautifulSoup # noqa
  16. from docx import Document # noqa
  17. from pptx import Presentation # noqa
  18. except ImportError as e:
  19. raise ImportError('The dependencies for RAG support are not installed. '
  20. 'Please install the required dependencies by running: pip install qwen-agent[rag]') from e
  21. @register_tool('retrieval')
  22. class Retrieval(BaseTool):
  23. description = f'从给定文件列表中检索出和问题相关的内容,支持文件类型包括:{"/".join(PARSER_SUPPORTED_FILE_TYPES)}'
  24. parameters = [{
  25. 'name': 'query',
  26. 'type': 'string',
  27. 'description': '在这里列出关键词,用逗号分隔,目的是方便在文档中匹配到相关的内容,由于文档可能多语言,关键词最好中英文都有。',
  28. 'required': True
  29. }, {
  30. 'name': 'files',
  31. 'type': 'array',
  32. 'items': {
  33. 'type': 'string'
  34. },
  35. 'description': '待解析的文件路径列表,支持本地文件路径或可下载的http(s)链接。',
  36. 'required': True
  37. }]
  38. def __init__(self, cfg: Optional[Dict] = None):
  39. super().__init__(cfg)
  40. self.max_ref_token: int = self.cfg.get('max_ref_token', DEFAULT_MAX_REF_TOKEN)
  41. self.parser_page_size: int = self.cfg.get('parser_page_size', DEFAULT_PARSER_PAGE_SIZE)
  42. self.doc_parse = DocParser({'max_ref_token': self.max_ref_token, 'parser_page_size': self.parser_page_size})
  43. self.rag_searchers = self.cfg.get('rag_searchers', DEFAULT_RAG_SEARCHERS)
  44. if len(self.rag_searchers) == 1:
  45. self.search = TOOL_REGISTRY[self.rag_searchers[0]]({'max_ref_token': self.max_ref_token})
  46. else:
  47. from qwen_agent.tools.search_tools.hybrid_search import HybridSearch
  48. self.search = HybridSearch({'max_ref_token': self.max_ref_token, 'rag_searchers': self.rag_searchers})
  49. def call(self, params: Union[str, dict], **kwargs) -> list:
  50. """RAG tool.
  51. Step1: Parse and save files
  52. Step2: Retrieval related content according to query
  53. Args:
  54. params: The files and query.
  55. Returns:
  56. The parsed file list or retrieved file list.
  57. """
  58. # TODO: It this a good place to check the RAG deps?
  59. _check_deps_for_rag()
  60. params = self._verify_json_format_args(params)
  61. files = params.get('files', [])
  62. if isinstance(files, str):
  63. files = json5.loads(files)
  64. records = []
  65. for file in files:
  66. _record = self.doc_parse.call(params={'url': file}, **kwargs)
  67. records.append(_record)
  68. query = params.get('query', '')
  69. if records:
  70. return self.search.call(params={'query': query}, docs=[Record(**rec) for rec in records], **kwargs)
  71. else:
  72. return []