sql_generator.py 8.6 KB


  1. from langchain_openai import ChatOpenAI
  2. from langchain.prompts import ChatPromptTemplate
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_core.runnables import RunnablePassthrough
  5. from langchain_core.documents import Document
  6. from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
  7. from langchain_community.vectorstores import FAISS
  8. import os
  9. from prompt_template import get_prompt
  10. import json
  11. import traceback
  12. from config import get_model_config
  13. from typing import List, Dict, Any
  14. from database import Database
  15. import re
  16. class SQLGenerator:
  17. def __init__(self, model_type: str = "openai"):
  18. # 获取模型配置
  19. model_config = get_model_config(model_type)
  20. # 初始化LLM
  21. self.llm = ChatOpenAI(
  22. model_name=model_config.model_name,
  23. api_key=model_config.api_key,
  24. base_url=model_config.api_base,
  25. streaming=True
  26. )
  27. # 初始化提示词模板
  28. self.prompt = ChatPromptTemplate.from_template(get_prompt())
  29. # 构建链式调用
  30. self.chain = (
  31. {
  32. "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
  33. "question": lambda x: x["question"]
  34. }
  35. | self.prompt
  36. | self.llm
  37. | StrOutputParser()
  38. )
  39. # 初始化本地m3e-base模型
  40. self.model_path = r"E:\项目临时\AI大模型\m3e-base"
  41. self.embeddings = HuggingFaceEmbeddings(
  42. model_name=self.model_path,
  43. model_kwargs={'device': 'cpu'},
  44. encode_kwargs={'normalize_embeddings': True}
  45. )
  46. # 加载示例数据并创建向量数据库
  47. self.examples = self._load_examples()
  48. self.vectorstore = self._build_vectorstore()
  49. # 初始化数据库连接
  50. self.db = Database()
  51. def _extract_sql_from_markdown(self, text: str) -> str:
  52. """
  53. 从markdown格式的文本中提取SQL语句
  54. """
  55. # 使用正则表达式匹配markdown代码块中的SQL
  56. sql_pattern = r"```sql\n(.*?)\n```"
  57. match = re.search(sql_pattern, text, re.DOTALL)
  58. if match:
  59. # 提取SQL语句并清理
  60. sql = match.group(1).strip()
  61. # 移除可能的注释
  62. sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
  63. # 移除多余的空行
  64. sql = re.sub(r'\n\s*\n', '\n', sql)
  65. return sql.strip()
  66. return text.strip()
  67. def _load_examples(self):
  68. """加载示例数据"""
  69. with open('examples.json', 'r', encoding='utf-8') as f:
  70. return json.load(f)
  71. def _build_vectorstore(self):
  72. """构建FAISS向量数据库"""
  73. # 准备Document对象列表
  74. documents = []
  75. for example in self.examples:
  76. doc = Document(
  77. page_content=example['query'],
  78. metadata={
  79. 'query_type': example['query_type'],
  80. 'sql_code': example['sql_code']
  81. }
  82. )
  83. documents.append(doc)
  84. # 创建FAISS向量数据库
  85. vectorstore = FAISS.from_documents(
  86. documents=documents,
  87. embedding=self.embeddings
  88. )
  89. return vectorstore
  90. def _get_similar_examples(self, query: str, k: int = 3):
  91. """获取最相似的示例"""
  92. # 使用FAISS搜索相似示例
  93. docs = self.vectorstore.similarity_search_with_score(query, k=k)
  94. # 格式化返回结果
  95. similar_examples = []
  96. for doc, score in docs:
  97. similar_examples.append({
  98. 'query_type': doc.metadata['query_type'],
  99. 'query': doc.page_content,
  100. 'sql_code': doc.metadata['sql_code'],
  101. 'similarity_score': float(score)
  102. })
  103. return similar_examples
  104. def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
  105. """格式化聊天历史"""
  106. if len(similar_examples) == 0:
  107. return ""
  108. chat_history = "基于以下相似示例:\n\n"
  109. for i, example in enumerate(similar_examples, 1):
  110. chat_history += f"示例 {i}:\n"
  111. chat_history += f"问题: {example['query']}\n"
  112. chat_history += f"SQL: {example['sql_code']}\n"
  113. chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
  114. return chat_history
  115. async def generate_sql_stream(self, query_description: str):
  116. """
  117. 流式生成SQL查询语句
  118. """
  119. try:
  120. # 开始生成SQL
  121. yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
  122. # 获取相似示例
  123. similar_examples = self._get_similar_examples(query_description)
  124. yield json.dumps({
  125. "type": "similar_examples",
  126. "content": similar_examples
  127. }, ensure_ascii=False) + "\n"
  128. # 准备输入数据
  129. chain_input = {
  130. "similar_examples": similar_examples,
  131. "question": query_description
  132. }
  133. print("Chain input:", chain_input)
  134. try:
  135. # 构建完整的提示词
  136. formatted_prompt = self.prompt.format(
  137. chat_history=self._format_chat_history(similar_examples),
  138. question=query_description
  139. )
  140. print("Formatted prompt:", formatted_prompt)
  141. # 流式生成SQL
  142. full_response = ""
  143. async for chunk in self.chain.astream(chain_input):
  144. if chunk:
  145. try:
  146. full_response += chunk
  147. yield json.dumps({
  148. "type": "sql_generation",
  149. "content": chunk
  150. }, ensure_ascii=False) + "\n"
  151. except Exception as chunk_error:
  152. print(f"Error processing chunk: {str(chunk_error)}")
  153. continue
  154. # 使用正则表达式提取SQL代码块
  155. sql_pattern = r"```sql\n(.*?)\n```"
  156. sql_match = re.search(sql_pattern, full_response, re.DOTALL)
  157. if sql_match:
  158. # 提取并清理SQL
  159. sql_content = sql_match.group(1).strip()
  160. # 移除注释
  161. sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE)
  162. # 移除多余空行
  163. sql_content = re.sub(r'\n\s*\n', '\n', sql_content)
  164. # 确保SQL语句完整
  165. if not sql_content.strip().endswith(';'):
  166. sql_content = sql_content.strip() + ';'
  167. yield json.dumps({
  168. "type": "sql_result",
  169. "content": sql_content
  170. }, ensure_ascii=False) + "\n"
  171. else:
  172. # 如果没有找到SQL代码块,返回完整响应
  173. yield json.dumps({
  174. "type": "sql_result",
  175. "content": full_response
  176. }, ensure_ascii=False) + "\n"
  177. yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
  178. except Exception as e:
  179. print(f"Error details: {traceback.format_exc()}")
  180. yield json.dumps({
  181. "type": "error",
  182. "message": f"生成SQL时发生错误: {str(e)}"
  183. }, ensure_ascii=False) + "\n"
  184. except Exception as e:
  185. traceback.print_exc()
  186. yield json.dumps({
  187. "type": "error",
  188. "message": str(e)
  189. }, ensure_ascii=False) + "\n"
  190. async def execute_sql(self, sql: str) -> dict:
  191. """
  192. 执行SQL查询并返回结果
  193. """
  194. print(sql)
  195. try:
  196. # 确保SQL是完整的
  197. if not sql.strip().endswith(';'):
  198. sql = sql.strip() + ';'
  199. result = await self.db.execute_query(sql)
  200. return {
  201. "status": "success",
  202. "data": result
  203. }
  204. except Exception as e:
  205. traceback.print_exc()
  206. return {
  207. "status": "error",
  208. "message": str(e)
  209. }
  210. async def close(self):
  211. """关闭数据库连接"""
  212. await self.db.close()