sql_generator.py 9.0 KB


  1. from langchain_openai import ChatOpenAI
  2. from langchain.prompts import ChatPromptTemplate
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_core.documents import Document
  5. from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
  6. from langchain_community.vectorstores import FAISS
  7. from prompt_template import get_prompt
  8. import json
  9. import traceback
  10. from config import get_model_config
  11. from typing import List, Dict, Any
  12. from database import Database
  13. import re
  14. import os
  15. from dotenv import load_dotenv
  16. # 加载环境变量
  17. load_dotenv('config.env')
  18. class SQLGenerator:
  19. def __init__(self, model_type: str = None):
  20. # 获取模型配置
  21. model_type = model_type or os.getenv('DEFAULT_MODEL_TYPE', 'openai')
  22. model_config = get_model_config(model_type)
  23. # 初始化LLM
  24. self.llm = ChatOpenAI(
  25. model_name=model_config.model_name,
  26. api_key=model_config.api_key,
  27. base_url=model_config.api_base,
  28. streaming=True
  29. )
  30. # 初始化提示词模板
  31. self.prompt = ChatPromptTemplate.from_template(get_prompt())
  32. # 构建链式调用
  33. self.chain = (
  34. {
  35. "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
  36. "question": lambda x: x["question"]
  37. }
  38. | self.prompt
  39. | self.llm
  40. | StrOutputParser()
  41. )
  42. # 从环境变量获取模型路径
  43. model_path = os.getenv('MODEL_PATH')
  44. if not model_path:
  45. raise ValueError("MODEL_PATH environment variable is not set")
  46. # 初始化本地m3e-base模型
  47. self.embeddings = HuggingFaceEmbeddings(
  48. model_name=model_path,
  49. model_kwargs={'device': "cpu"},
  50. encode_kwargs={'normalize_embeddings': True}
  51. )
  52. # 加载示例数据并创建向量数据库
  53. self.examples = self._load_examples()
  54. self.vectorstore = self._build_vectorstore()
  55. # 初始化数据库连接
  56. self.db = Database()
  57. def _extract_sql_from_markdown(self, text: str) -> str:
  58. """
  59. 从markdown格式的文本中提取SQL语句
  60. """
  61. # 使用正则表达式匹配markdown代码块中的SQL
  62. sql_pattern = r"```sql\n(.*?)\n```"
  63. match = re.search(sql_pattern, text, re.DOTALL)
  64. if match:
  65. # 提取SQL语句并清理
  66. sql = match.group(1).strip()
  67. # 移除可能的注释
  68. sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
  69. # 移除多余的空行
  70. sql = re.sub(r'\n\s*\n', '\n', sql)
  71. return sql.strip()
  72. return text.strip()
  73. def _load_examples(self):
  74. """加载示例数据"""
  75. with open('examples.json', 'r', encoding='utf-8') as f:
  76. return json.load(f)
  77. def _build_vectorstore(self):
  78. """构建FAISS向量数据库"""
  79. # 准备Document对象列表
  80. documents = []
  81. for example in self.examples:
  82. doc = Document(
  83. page_content=example['query'],
  84. metadata={
  85. 'query_type': example['query_type'],
  86. 'sql_code': example['sql_code']
  87. }
  88. )
  89. documents.append(doc)
  90. # 创建FAISS向量数据库
  91. vectorstore = FAISS.from_documents(
  92. documents=documents,
  93. embedding=self.embeddings
  94. )
  95. return vectorstore
  96. def _get_similar_examples(self, query: str, k: int = 3):
  97. """获取最相似的示例"""
  98. # 使用FAISS搜索相似示例
  99. docs = self.vectorstore.similarity_search_with_score(query, k=k)
  100. # 格式化返回结果
  101. similar_examples = []
  102. for doc, score in docs:
  103. similar_examples.append({
  104. 'query_type': doc.metadata['query_type'],
  105. 'query': doc.page_content,
  106. 'sql_code': doc.metadata['sql_code'],
  107. 'similarity_score': float(score)
  108. })
  109. return similar_examples
  110. def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
  111. """格式化聊天历史"""
  112. if len(similar_examples) == 0:
  113. return ""
  114. chat_history = "基于以下相似示例:\n\n"
  115. for i, example in enumerate(similar_examples, 1):
  116. chat_history += f"示例 {i}:\n"
  117. chat_history += f"问题: {example['query']}\n"
  118. chat_history += f"SQL: {example['sql_code']}\n"
  119. chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
  120. return chat_history
  121. async def generate_sql_stream(self, query_description: str):
  122. """
  123. 流式生成SQL查询语句
  124. """
  125. try:
  126. # 获取相似示例
  127. similar_examples = self._get_similar_examples(query_description)
  128. simliar_example_format_dump = json.dumps({
  129. "type": "similar_examples",
  130. "content": similar_examples
  131. }, ensure_ascii=False) + "\n"
  132. print(simliar_example_format_dump)
  133. # 准备输入数据
  134. chain_input = {
  135. "similar_examples": similar_examples,
  136. "question": query_description
  137. }
  138. print("Chain input:", chain_input)
  139. try:
  140. # 构建完整的提示词
  141. formatted_prompt = self.prompt.format(
  142. chat_history=self._format_chat_history(similar_examples),
  143. question=query_description
  144. )
  145. print("Formatted prompt:", formatted_prompt)
  146. # 流式生成SQL
  147. full_response = ""
  148. async for chunk in self.chain.astream(chain_input):
  149. if chunk:
  150. try:
  151. # Remove think tags if using zjstai model
  152. if os.getenv('DEFAULT_MODEL_TYPE') == 'zjstai':
  153. chunk = re.sub(r'<think>', '', chunk, flags=re.DOTALL)
  154. chunk = re.sub(r'</think>', '', chunk, flags=re.DOTALL)
  155. full_response += chunk
  156. yield json.dumps({
  157. "type": "sql_generation",
  158. "content": chunk
  159. }, ensure_ascii=False) + "\n"
  160. except Exception as chunk_error:
  161. print(f"Error processing chunk: {str(chunk_error)}")
  162. continue
  163. # 使用正则表达式提取SQL代码块
  164. sql_pattern = r"```sql\n(.*?)\n```"
  165. sql_match = re.search(sql_pattern, full_response, re.DOTALL)
  166. if sql_match:
  167. # 提取并清理SQL
  168. sql_content = sql_match.group(1).strip()
  169. # 移除注释
  170. sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE)
  171. # 移除多余空行
  172. sql_content = re.sub(r'\n\s*\n', '\n', sql_content)
  173. # 确保SQL语句完整
  174. if not sql_content.strip().endswith(';'):
  175. sql_content = sql_content.strip() + ';'
  176. yield json.dumps({
  177. "type": "sql_result",
  178. "content": sql_content
  179. }, ensure_ascii=False) + "\n"
  180. else:
  181. # 如果没有找到SQL代码块,返回完整响应
  182. yield json.dumps({
  183. "type": "sql_result",
  184. "content": full_response
  185. }, ensure_ascii=False) + "\n"
  186. yield json.dumps({"type": "end", "content": "SQL生成完成"}, ensure_ascii=False) + "\n"
  187. except Exception as e:
  188. traceback.print_exc()
  189. yield json.dumps({
  190. "type": "error",
  191. "content": f"生成SQL时发生错误: {str(e)}"
  192. }, ensure_ascii=False) + "\n"
  193. except Exception as e:
  194. traceback.print_exc()
  195. yield json.dumps({
  196. "type": "error",
  197. "content": str(e)
  198. }, ensure_ascii=False) + "\n"
  199. async def execute_sql(self, sql: str) -> dict:
  200. """
  201. 执行SQL查询并返回结果
  202. """
  203. print(sql)
  204. try:
  205. # 确保SQL是完整的
  206. if not sql.strip().endswith(';'):
  207. sql = sql.strip() + ';'
  208. result = await self.db.execute_query(sql)
  209. return {
  210. "status": "success",
  211. "data": result
  212. }
  213. except Exception as e:
  214. traceback.print_exc()
  215. return {
  216. "status": "error",
  217. "message": str(e)
  218. }
  219. async def close(self):
  220. """关闭数据库连接"""
  221. await self.db.close()