sql_generator.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from langchain_openai import ChatOpenAI
  2. from langchain.prompts import ChatPromptTemplate
  3. from langchain.chains import LLMChain
  4. from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
  5. from langchain_community.vectorstores import FAISS
  6. from langchain_core.output_parsers import StrOutputParser
  7. from langchain_core.runnables import RunnablePassthrough
  8. from langchain_core.documents import Document
  9. import os
  10. from prompt_template import get_prompt
  11. from fastapi.responses import StreamingResponse
  12. import json
  13. import asyncio
  14. import numpy as np
  15. from config import get_model_config
  16. import pandas as pd
  17. from typing import List, Dict, Any
  18. import traceback
  19. class SQLGenerator:
  20. def __init__(self, model_type: str = "openai"):
  21. # 获取模型配置
  22. model_config = get_model_config(model_type)
  23. # 初始化LLM
  24. self.llm = ChatOpenAI(
  25. model_name=model_config.model_name,
  26. temperature=model_config.temperature,
  27. api_key=model_config.api_key,
  28. base_url=model_config.api_base,
  29. max_tokens=model_config.max_tokens,
  30. streaming=True
  31. )
  32. # 初始化提示词模板
  33. self.prompt = ChatPromptTemplate.from_template(get_prompt())
  34. # 构建链式调用
  35. self.chain = (
  36. {
  37. "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
  38. "question": lambda x: x["question"]
  39. }
  40. | self.prompt
  41. | self.llm
  42. | StrOutputParser()
  43. )
  44. # 初始化本地m3e-base模型
  45. self.model_path = r"E:\项目临时\AI大模型\m3e-base"
  46. self.embeddings = HuggingFaceEmbeddings(
  47. model_name=self.model_path,
  48. model_kwargs={'device': 'cpu'},
  49. encode_kwargs={'normalize_embeddings': True}
  50. )
  51. # 加载示例数据并创建向量数据库
  52. self.examples = self._load_examples()
  53. self.vectorstore = self._build_vectorstore()
  54. def _load_examples(self):
  55. """加载示例数据"""
  56. with open('examples.json', 'r', encoding='utf-8') as f:
  57. return json.load(f)
  58. def _build_vectorstore(self):
  59. """构建FAISS向量数据库"""
  60. # 准备Document对象列表
  61. documents = []
  62. for example in self.examples:
  63. doc = Document(
  64. page_content=example['query'],
  65. metadata={
  66. 'query_type': example['query_type'],
  67. 'plan': example['plan'],
  68. 'sql_code': example['sql_code']
  69. }
  70. )
  71. documents.append(doc)
  72. # 创建FAISS向量数据库
  73. vectorstore = FAISS.from_documents(
  74. documents=documents,
  75. embedding=self.embeddings
  76. )
  77. return vectorstore
  78. def _get_similar_examples(self, query: str, k: int = 3):
  79. """获取最相似的示例"""
  80. # 使用FAISS搜索相似示例
  81. docs = self.vectorstore.similarity_search_with_score(query, k=k)
  82. # 格式化返回结果
  83. similar_examples = []
  84. for doc, score in docs:
  85. similar_examples.append({
  86. 'query_type': doc.metadata['query_type'],
  87. 'query': doc.page_content,
  88. 'plan': doc.metadata['plan'],
  89. 'sql_code': doc.metadata['sql_code'],
  90. 'similarity_score': float(score)
  91. })
  92. return similar_examples
  93. def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
  94. """格式化聊天历史"""
  95. if len(similar_examples) == 0:
  96. return ""
  97. chat_history = "基于以下相似示例:\n\n"
  98. for i, example in enumerate(similar_examples, 1):
  99. chat_history += f"示例 {i}:\n"
  100. chat_history += f"问题: {example['query']}\n"
  101. chat_history += f"计划: {example['plan']}\n"
  102. chat_history += f"SQL: {example['sql_code']}\n"
  103. chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
  104. print(chat_history+" !!!!!")
  105. return chat_history
  106. async def generate_sql_stream(self, query_description: str):
  107. """
  108. 流式生成SQL查询语句
  109. """
  110. try:
  111. # 开始生成SQL
  112. yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}) + "\n"
  113. # 获取相似示例
  114. similar_examples = self._get_similar_examples(query_description)
  115. yield json.dumps({
  116. "type": "similar_examples",
  117. "content": similar_examples
  118. }, ensure_ascii=False) + "\n"
  119. # 准备输入数据
  120. chain_input = {
  121. "similar_examples": similar_examples,
  122. "question": query_description
  123. }
  124. print("Chain input:", chain_input)
  125. try:
  126. # 流式生成SQL
  127. async for chunk in self.chain.astream(chain_input):
  128. if chunk:
  129. yield json.dumps({
  130. "type": "sql_generation",
  131. "content": chunk
  132. }) + "\n"
  133. yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
  134. except Exception as e:
  135. print(f"Error during streaming: {str(e)}")
  136. print(f"Error type: {type(e)}")
  137. print(f"Error details: {traceback.format_exc()}")
  138. yield json.dumps({
  139. "type": "error",
  140. "message": f"生成SQL时发生错误: {str(e)}"
  141. }) + "\n"
  142. except Exception as e:
  143. traceback.print_exc()
  144. yield json.dumps({
  145. "type": "error",
  146. "message": str(e)
  147. }) + "\n"
  148. async def execute_sql(self, sql: str, db_connection) -> dict:
  149. """
  150. 执行SQL查询并返回结果
  151. """
  152. try:
  153. result = db_connection.execute(sql)
  154. columns = result.keys()
  155. data = [dict(zip(columns, row)) for row in result.fetchall()]
  156. return {
  157. "status": "success",
  158. "data": data,
  159. "columns": columns
  160. }
  161. except Exception as e:
  162. traceback.print_exc()
  163. return {
  164. "status": "error",
  165. "message": str(e)
  166. }