from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.documents import Document from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from prompt_template import get_prompt import json import traceback from config import get_model_config from typing import List, Dict, Any from database import Database import re import os from dotenv import load_dotenv # 加载环境变量 load_dotenv('config.env') class SQLGenerator: def __init__(self, model_type: str = None): # 获取模型配置 model_type = model_type or os.getenv('DEFAULT_MODEL_TYPE', 'openai') model_config = get_model_config(model_type) # 初始化LLM self.llm = ChatOpenAI( model_name=model_config.model_name, api_key=model_config.api_key, base_url=model_config.api_base, streaming=True ) # 初始化提示词模板 self.prompt = ChatPromptTemplate.from_template(get_prompt()) # 构建链式调用 self.chain = ( { "chat_history": lambda x: self._format_chat_history(x["similar_examples"]), "question": lambda x: x["question"] } | self.prompt | self.llm | StrOutputParser() ) # 从环境变量获取模型路径 model_path = os.getenv('MODEL_PATH') if not model_path: raise ValueError("MODEL_PATH environment variable is not set") # 初始化本地m3e-base模型 self.embeddings = HuggingFaceEmbeddings( model_name=model_path, model_kwargs={'device': "cpu"}, encode_kwargs={'normalize_embeddings': True} ) # 加载示例数据并创建向量数据库 self.examples = self._load_examples() self.vectorstore = self._build_vectorstore() # 初始化数据库连接 self.db = Database() def _extract_sql_from_markdown(self, text: str) -> str: """ 从markdown格式的文本中提取SQL语句 """ # 使用正则表达式匹配markdown代码块中的SQL sql_pattern = r"```sql\n(.*?)\n```" match = re.search(sql_pattern, text, re.DOTALL) if match: # 提取SQL语句并清理 sql = match.group(1).strip() # 移除可能的注释 sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE) # 移除多余的空行 sql = re.sub(r'\n\s*\n', '\n', sql) return sql.strip() return text.strip() def _load_examples(self): """加载示例数据""" with open('examples.json', 'r', encoding='utf-8') as f: return json.load(f) def _build_vectorstore(self): """构建FAISS向量数据库""" # 准备Document对象列表 documents = [] for example in self.examples: doc = Document( page_content=example['query'], metadata={ 'query_type': example['query_type'], 'sql_code': example['sql_code'] } ) documents.append(doc) # 创建FAISS向量数据库 vectorstore = FAISS.from_documents( documents=documents, embedding=self.embeddings ) return vectorstore def _get_similar_examples(self, query: str, k: int = 3): """获取最相似的示例""" # 使用FAISS搜索相似示例 docs = self.vectorstore.similarity_search_with_score(query, k=k) # 格式化返回结果 similar_examples = [] for doc, score in docs: similar_examples.append({ 'query_type': doc.metadata['query_type'], 'query': doc.page_content, 'sql_code': doc.metadata['sql_code'], 'similarity_score': float(score) }) return similar_examples def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str: """格式化聊天历史""" if len(similar_examples) == 0: return "" chat_history = "基于以下相似示例:\n\n" for i, example in enumerate(similar_examples, 1): chat_history += f"示例 {i}:\n" chat_history += f"问题: {example['query']}\n" chat_history += f"SQL: {example['sql_code']}\n" chat_history += f"相似度: {example['similarity_score']:.2f}\n\n" return chat_history async def generate_sql_stream(self, query_description: str): """ 流式生成SQL查询语句 """ try: # 获取相似示例 similar_examples = self._get_similar_examples(query_description) simliar_example_format_dump = json.dumps({ "type": "similar_examples", "content": similar_examples }, ensure_ascii=False) + "\n" print(simliar_example_format_dump) # 准备输入数据 chain_input = { "similar_examples": similar_examples, "question": query_description } print("Chain input:", chain_input) try: # 构建完整的提示词 formatted_prompt = self.prompt.format( chat_history=self._format_chat_history(similar_examples), question=query_description ) print("Formatted prompt:", formatted_prompt) # 流式生成SQL full_response = "" async for chunk in self.chain.astream(chain_input): if chunk: try: full_response += chunk yield json.dumps({ "type": "sql_generation", "content": chunk }, ensure_ascii=False) + "\n" except Exception as chunk_error: print(f"Error processing chunk: {str(chunk_error)}") continue # 使用正则表达式提取SQL代码块 sql_pattern = r"```sql\n(.*?)\n```" sql_match = re.search(sql_pattern, full_response, re.DOTALL) if sql_match: # 提取并清理SQL sql_content = sql_match.group(1).strip() # 移除注释 sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE) # 移除多余空行 sql_content = re.sub(r'\n\s*\n', '\n', sql_content) # 确保SQL语句完整 if not sql_content.strip().endswith(';'): sql_content = sql_content.strip() + ';' yield json.dumps({ "type": "sql_result", "content": sql_content }, ensure_ascii=False) + "\n" else: # 如果没有找到SQL代码块,返回完整响应 yield json.dumps({ "type": "sql_result", "content": full_response }, ensure_ascii=False) + "\n" yield json.dumps({"type": "end", "content": "SQL生成完成"}, ensure_ascii=False) + "\n" except Exception as e: print(f"Error details: {traceback.format_exc()}") yield json.dumps({ "type": "error", "content": f"生成SQL时发生错误: {str(e)}" }, ensure_ascii=False) + "\n" except Exception as e: traceback.print_exc() yield json.dumps({ "type": "error", "content": str(e) }, ensure_ascii=False) + "\n" async def execute_sql(self, sql: str) -> dict: """ 执行SQL查询并返回结果 """ print(sql) try: # 确保SQL是完整的 if not sql.strip().endswith(';'): sql = sql.strip() + ';' result = await self.db.execute_query(sql) return { "status": "success", "data": result } except Exception as e: traceback.print_exc() return { "status": "error", "message": str(e) } async def close(self): """关闭数据库连接""" await self.db.close()