plan_memory.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from langchain.vectorstores.faiss import FAISS
  2. from langchain.schema import Document
  3. import os
  4. from qwen_agent.memory.SqlMemory import embeddings
  5. from qwen_agent.utils.util import get_data_from_jsons
  6. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  7. class PlanExampleRetrieval():
  8. def __init__(self, query_type='bidding') -> None:
  9. few_shot_docs = []
  10. # 修改成自动读取下面的多个json,方便扩展
  11. data = get_data_from_jsons(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data'), 'plan_examples')
  12. # 将memory/data/plans下的对应的query_type的所有plans作为知识库
  13. for line in data:
  14. if line['query_type'] == query_type:
  15. few_shot_docs.append(Document(page_content=line['query'], metadata={'plan': line['plan']}))
  16. # 将这些rags进行向量化,保存到FAISS数据库中
  17. self.vector_db = FAISS.from_documents(few_shot_docs, embeddings)
  18. def get_relevant_documents(self, query, top_k=4):
  19. results = []
  20. for r in self.vector_db.similarity_search(query, k=top_k):
  21. results.append((r.page_content, r.metadata['plan']))
  22. return results
  23. if __name__ == "__main__":
  24. print(os.path.abspath(os.path.dirname(__file__)))
  25. # for data in get_data_from_jsons(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data'), 'plan'):
  26. # print(data)
  27. plan_retrieval = PlanExampleRetrieval("land_site_selection")
  28. results = plan_retrieval.get_relevant_documents("萧山区推荐几块工业用地", )
  29. for r in results:
  30. print(r)