memory.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import List
  2. from qwen_agent.schema import RefMaterial
  3. from qwen_agent.tools.similarity_search import SimilaritySearch
  4. from qwen_agent.utils.util import count_tokens
  5. class Memory:
  6. def __init__(self, open_ss, ss_type):
  7. self.open_ss = open_ss
  8. self.ss_type = ss_type
  9. def get(self, query: str, records: list, llm=None, stream=False, max_token=4000) -> List[RefMaterial]:
  10. if not self.open_ss:
  11. _ref_list = self.get_top(records)
  12. else:
  13. search_agent = SimilaritySearch(type=self.ss_type, llm=llm, stream=stream)
  14. _ref_list = []
  15. for record in records:
  16. now_ref_list = search_agent.run(record, query)
  17. if now_ref_list['text']:
  18. _ref_list.append(now_ref_list)
  19. if not _ref_list:
  20. _ref_list = self.get_top(records)
  21. # token number
  22. new_ref_list = []
  23. single_max_token = int(max_token/len(_ref_list))
  24. for _ref in _ref_list:
  25. tmp = {
  26. 'url': _ref['url'],
  27. 'text': []
  28. }
  29. now_token = 0
  30. print(len(_ref['text']))
  31. for x in _ref['text']:
  32. # lenx = len(x)
  33. lenx = count_tokens(x)
  34. if (now_token + lenx) <= single_max_token:
  35. tmp['text'].append(x)
  36. now_token += lenx
  37. else:
  38. use_rate = (single_max_token-now_token)/lenx
  39. tmp['text'].append(x[:int(len(x)*use_rate)])
  40. break
  41. new_ref_list.append(tmp)
  42. return new_ref_list
  43. def get_top(self, records: list, k=6):
  44. _ref_list = []
  45. for record in records:
  46. raw = record['raw']
  47. k = min(len(raw), k)
  48. _ref_list.append(RefMaterial(url=record['url'], text=[x['page_content'] for x in raw[:k]]).to_dict())
  49. return _ref_list