similarity_search_jaccard.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from qwen_agent.schema import RefMaterial
  2. from qwen_agent.utils.util import get_split_word
  3. class SSJaccard:
  4. def __init__(self, llm=None, stream=False):
  5. self.llm = llm
  6. self.stream = stream
  7. def run(self, line, query):
  8. """
  9. Input: one line
  10. Output: the relative text
  11. """
  12. wordlist = get_split_word(query)
  13. content = line['query']
  14. if isinstance(content, str):
  15. content = content.split('\n')
  16. res = []
  17. sims = []
  18. for i, page in enumerate(content):
  19. sim = self.filter_section(page, wordlist)
  20. sims.append([i, sim])
  21. sims.sort(key=lambda x: x[1], reverse=True)
  22. # print('sims: ', sims)
  23. max_sims = sims[0][1]
  24. if max_sims != 0:
  25. for i, x in enumerate(sims):
  26. if x[1] < max_sims and i > 3:
  27. break
  28. page = content[x[0]]
  29. text = ''
  30. if isinstance(page, str):
  31. text = content[x[0]]
  32. elif isinstance(page, dict):
  33. text = page['page_content']
  34. res.append(text)
  35. # for x in res:
  36. # print("=========")
  37. # print(x)
  38. return RefMaterial(url=line['url'], text=res).to_dict()
  39. def filter_section(self, page, wordlist):
  40. if isinstance(page, str):
  41. text = page
  42. elif isinstance(page, dict):
  43. text = page['page_content']
  44. else:
  45. print(type(page))
  46. raise TypeError
  47. pagelist = get_split_word(text)
  48. sim = self.jaccard_similarity(wordlist, pagelist)
  49. return sim
  50. def jaccard_similarity(self, list1, list2):
  51. s1 = set(list1)
  52. s2 = set(list2)
  53. return len(s1.intersection(s2)) # avoid text length impact
  54. # return len(s1.intersection(s2)) / len(s1.union(s2)) # jaccard similarity