1234567891011121314151617181920212223242526272829303132333435363738394041 |
- from langchain.vectorstores.faiss import FAISS
- from langchain.schema import Document
- import os
- from qwen_agent.memory.SqlMemory import embeddings
- from qwen_agent.utils.util import get_data_from_jsons
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- class PlanExampleRetrieval():
- def __init__(self, query_type='bidding') -> None:
- few_shot_docs = []
- # 修改成自动读取下面的多个json,方便扩展
- data = get_data_from_jsons(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data'), 'plan_examples')
- # 将memory/data/plans下的对应的query_type的所有plans作为知识库
- for line in data:
- if line['query_type'] == query_type:
- few_shot_docs.append(Document(page_content=line['query'], metadata={'plan': line['plan']}))
- # 将这些rags进行向量化,保存到FAISS数据库中
- self.vector_db = FAISS.from_documents(few_shot_docs, embeddings)
- def get_relevant_documents(self, query, top_k=4):
- results = []
- for r in self.vector_db.similarity_search(query, k=top_k):
- results.append((r.page_content, r.metadata['plan']))
- return results
- if __name__ == "__main__":
- print(os.path.abspath(os.path.dirname(__file__)))
- # for data in get_data_from_jsons(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data'), 'plan'):
- # print(data)
- plan_retrieval = PlanExampleRetrieval("land_site_selection")
- results = plan_retrieval.get_relevant_documents("萧山区推荐几块工业用地", )
- for r in results:
- print(r)
|