extract_doc_vocabulary.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import json
  2. import os
  3. from typing import Dict, Optional, Union
  4. import json5
  5. from qwen_agent.settings import DEFAULT_WORKSPACE
  6. from qwen_agent.tools.base import BaseTool, register_tool
  7. from qwen_agent.tools.search_tools.keyword_search import WORDS_TO_IGNORE, string_tokenizer
  8. from qwen_agent.tools.simple_doc_parser import SimpleDocParser
  9. from qwen_agent.tools.storage import KeyNotExistsError, Storage
  10. @register_tool('extract_doc_vocabulary')
  11. class ExtractDocVocabulary(BaseTool):
  12. description = '提取文档的词表。'
  13. parameters = [{
  14. 'name': 'files',
  15. 'type': 'array',
  16. 'items': {
  17. 'type': 'string'
  18. },
  19. 'description': '文件路径列表,支持本地文件路径或可下载的http(s)链接。',
  20. 'required': True
  21. }]
  22. def __init__(self, cfg: Optional[Dict] = None):
  23. super().__init__(cfg)
  24. self.simple_doc_parse = SimpleDocParser()
  25. self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name))
  26. self.db = Storage({'storage_root_path': self.data_root})
  27. def call(self, params: Union[str, dict], **kwargs) -> str:
  28. params = self._verify_json_format_args(params)
  29. files = params.get('files', [])
  30. document_id = str(files)
  31. if isinstance(files, str):
  32. files = json5.loads(files)
  33. docs = []
  34. for file in files:
  35. _doc = self.simple_doc_parse.call(params={'url': file}, **kwargs)
  36. docs.append(_doc)
  37. try:
  38. all_voc = self.db.call({'operate': 'get', 'key': document_id})
  39. except KeyNotExistsError:
  40. try:
  41. from sklearn.feature_extraction.text import TfidfVectorizer
  42. except ModuleNotFoundError:
  43. raise ModuleNotFoundError('Please install sklearn by: `pip install scikit-learn`')
  44. vectorizer = TfidfVectorizer(tokenizer=string_tokenizer, stop_words=WORDS_TO_IGNORE)
  45. tfidf_matrix = vectorizer.fit_transform(docs)
  46. sorted_items = sorted(zip(vectorizer.get_feature_names_out(),
  47. tfidf_matrix.toarray().flatten()),
  48. key=lambda x: x[1],
  49. reverse=True)
  50. all_voc = ', '.join([term for term, score in sorted_items])
  51. if document_id:
  52. self.db.call({'operate': 'put', 'key': document_id, 'value': json.dumps(all_voc, ensure_ascii=False)})
  53. return all_voc