plan_memory.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # from langchain.embeddings.openai import OpenAIEmbeddings
  2. from langchain.vectorstores.faiss import FAISS
  3. from langchain.schema import Document
  4. import jsonlines
  5. import json
  6. import os
  7. from qwen_agent.memory.SqlMemory import embeddings
  8. from qwen_agent.utils.util import get_data_from_jsons
  9. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  10. class PlanExampleRetrieval():
  11. def __init__(self, query_type='bidding') -> None:
  12. # self.EMBEDDING_MODEL = "text2vec" # embedding 模型,对应 embedding_model_dict
  13. # self.DEVICE = "cuda:2"
  14. # self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[self.EMBEDDING_MODEL],)
  15. # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
  16. # device=self.DEVICE)
  17. few_shot_docs = []
  18. # embeddings = OpenAIEmbeddings() LianqiaiAgent/qwen_agent/memory/data/ifbunitplan_examples.jsonl
  19. # 修改成自动读取下面的多个json,方便扩展
  20. data = get_data_from_jsons(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data'), 'plan_examples')
  21. for line in data:
  22. if line['query_type'] == query_type:
  23. few_shot_docs.append(Document(page_content=line['query'], metadata={'plan':line['plan']}))
  24. self.vector_db = FAISS.from_documents(few_shot_docs, embeddings)
  25. def get_relevant_documents(self,query,top_k=4):
  26. results=[]
  27. for r in self.vector_db.similarity_search(query, k=top_k):
  28. results.append((r.page_content, r.metadata['plan']))
  29. return results
  30. if __name__=="__main__":
  31. print(os.path.abspath(os.path.dirname(__file__)))
  32. print(get_data_from_jsons(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data'), 'plans'))
  33. # print(os.path.join(os.path.abspath(os.path.dirname(__file__)),'data/sqls.jsonl'))