sql_generator.py 8.6 KB


  1. from langchain_openai import ChatOpenAI
  2. from langchain.prompts import ChatPromptTemplate
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_core.documents import Document
  5. from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
  6. from langchain_community.vectorstores import FAISS
  7. from prompt_template import get_prompt
  8. import json
  9. import traceback
  10. from config import get_model_config
  11. from typing import List, Dict, Any
  12. from database import Database, vector_model_config
  13. import re
  14. class SQLGenerator:
  15. def __init__(self, model_type: str = "openai"):
  16. # 获取模型配置
  17. model_config = get_model_config(model_type)
  18. # 初始化LLM
  19. self.llm = ChatOpenAI(
  20. model_name=model_config.model_name,
  21. api_key=model_config.api_key,
  22. base_url=model_config.api_base,
  23. streaming=True
  24. )
  25. # 初始化提示词模板
  26. self.prompt = ChatPromptTemplate.from_template(get_prompt())
  27. # 构建链式调用
  28. self.chain = (
  29. {
  30. "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
  31. "question": lambda x: x["question"]
  32. }
  33. | self.prompt
  34. | self.llm
  35. | StrOutputParser()
  36. )
  37. # 初始化本地m3e-base模型
  38. model_cfg = vector_model_config["m3e-base"]
  39. self.model_path = model_cfg["model_path"]
  40. self.embeddings = HuggingFaceEmbeddings(
  41. model_name=self.model_path,
  42. model_kwargs={'device': model_cfg["device"]},
  43. encode_kwargs={'normalize_embeddings': True}
  44. )
  45. # 加载示例数据并创建向量数据库
  46. self.examples = self._load_examples()
  47. self.vectorstore = self._build_vectorstore()
  48. # 初始化数据库连接
  49. self.db = Database()
  50. def _extract_sql_from_markdown(self, text: str) -> str:
  51. """
  52. 从markdown格式的文本中提取SQL语句
  53. """
  54. # 使用正则表达式匹配markdown代码块中的SQL
  55. sql_pattern = r"```sql\n(.*?)\n```"
  56. match = re.search(sql_pattern, text, re.DOTALL)
  57. if match:
  58. # 提取SQL语句并清理
  59. sql = match.group(1).strip()
  60. # 移除可能的注释
  61. sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
  62. # 移除多余的空行
  63. sql = re.sub(r'\n\s*\n', '\n', sql)
  64. return sql.strip()
  65. return text.strip()
  66. def _load_examples(self):
  67. """加载示例数据"""
  68. with open('examples.json', 'r', encoding='utf-8') as f:
  69. return json.load(f)
  70. def _build_vectorstore(self):
  71. """构建FAISS向量数据库"""
  72. # 准备Document对象列表
  73. documents = []
  74. for example in self.examples:
  75. doc = Document(
  76. page_content=example['query'],
  77. metadata={
  78. 'query_type': example['query_type'],
  79. 'sql_code': example['sql_code']
  80. }
  81. )
  82. documents.append(doc)
  83. # 创建FAISS向量数据库
  84. vectorstore = FAISS.from_documents(
  85. documents=documents,
  86. embedding=self.embeddings
  87. )
  88. return vectorstore
  89. def _get_similar_examples(self, query: str, k: int = 3):
  90. """获取最相似的示例"""
  91. # 使用FAISS搜索相似示例
  92. docs = self.vectorstore.similarity_search_with_score(query, k=k)
  93. # 格式化返回结果
  94. similar_examples = []
  95. for doc, score in docs:
  96. similar_examples.append({
  97. 'query_type': doc.metadata['query_type'],
  98. 'query': doc.page_content,
  99. 'sql_code': doc.metadata['sql_code'],
  100. 'similarity_score': float(score)
  101. })
  102. return similar_examples
  103. def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
  104. """格式化聊天历史"""
  105. if len(similar_examples) == 0:
  106. return ""
  107. chat_history = "基于以下相似示例:\n\n"
  108. for i, example in enumerate(similar_examples, 1):
  109. chat_history += f"示例 {i}:\n"
  110. chat_history += f"问题: {example['query']}\n"
  111. chat_history += f"SQL: {example['sql_code']}\n"
  112. chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
  113. return chat_history
  114. async def generate_sql_stream(self, query_description: str):
  115. """
  116. 流式生成SQL查询语句
  117. """
  118. try:
  119. # 开始生成SQL
  120. yield json.dumps({"type": "start", "content": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
  121. # 获取相似示例
  122. similar_examples = self._get_similar_examples(query_description)
  123. yield json.dumps({
  124. "type": "similar_examples",
  125. "content": similar_examples
  126. }, ensure_ascii=False) + "\n"
  127. # 准备输入数据
  128. chain_input = {
  129. "similar_examples": similar_examples,
  130. "question": query_description
  131. }
  132. print("Chain input:", chain_input)
  133. try:
  134. # 构建完整的提示词
  135. formatted_prompt = self.prompt.format(
  136. chat_history=self._format_chat_history(similar_examples),
  137. question=query_description
  138. )
  139. print("Formatted prompt:", formatted_prompt)
  140. # 流式生成SQL
  141. full_response = ""
  142. async for chunk in self.chain.astream(chain_input):
  143. if chunk:
  144. try:
  145. full_response += chunk
  146. yield json.dumps({
  147. "type": "sql_generation",
  148. "content": chunk
  149. }, ensure_ascii=False) + "\n"
  150. except Exception as chunk_error:
  151. print(f"Error processing chunk: {str(chunk_error)}")
  152. continue
  153. # 使用正则表达式提取SQL代码块
  154. sql_pattern = r"```sql\n(.*?)\n```"
  155. sql_match = re.search(sql_pattern, full_response, re.DOTALL)
  156. if sql_match:
  157. # 提取并清理SQL
  158. sql_content = sql_match.group(1).strip()
  159. # 移除注释
  160. sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE)
  161. # 移除多余空行
  162. sql_content = re.sub(r'\n\s*\n', '\n', sql_content)
  163. # 确保SQL语句完整
  164. if not sql_content.strip().endswith(';'):
  165. sql_content = sql_content.strip() + ';'
  166. yield json.dumps({
  167. "type": "sql_result",
  168. "content": sql_content
  169. }, ensure_ascii=False) + "\n"
  170. else:
  171. # 如果没有找到SQL代码块,返回完整响应
  172. yield json.dumps({
  173. "type": "sql_result",
  174. "content": full_response
  175. }, ensure_ascii=False) + "\n"
  176. yield json.dumps({"type": "end", "content": "SQL生成完成"}, ensure_ascii=False) + "\n"
  177. except Exception as e:
  178. print(f"Error details: {traceback.format_exc()}")
  179. yield json.dumps({
  180. "type": "error",
  181. "content": f"生成SQL时发生错误: {str(e)}"
  182. }, ensure_ascii=False) + "\n"
  183. except Exception as e:
  184. traceback.print_exc()
  185. yield json.dumps({
  186. "type": "error",
  187. "content": str(e)
  188. }, ensure_ascii=False) + "\n"
  189. async def execute_sql(self, sql: str) -> dict:
  190. """
  191. 执行SQL查询并返回结果
  192. """
  193. print(sql)
  194. try:
  195. # 确保SQL是完整的
  196. if not sql.strip().endswith(';'):
  197. sql = sql.strip() + ';'
  198. result = await self.db.execute_query(sql)
  199. return {
  200. "status": "success",
  201. "data": result
  202. }
  203. except Exception as e:
  204. traceback.print_exc()
  205. return {
  206. "status": "error",
  207. "message": str(e)
  208. }
  209. async def close(self):
  210. """关闭数据库连接"""
  211. await self.db.close()