from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate from langchain.chains import LLMChain from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_core.documents import Document import os from prompt_template import get_prompt from fastapi.responses import StreamingResponse import json import asyncio import numpy as np from config import get_model_config import pandas as pd from typing import List, Dict, Any import traceback class SQLGenerator: def __init__(self, model_type: str = "openai"): # 获取模型配置 model_config = get_model_config(model_type) # 初始化LLM self.llm = ChatOpenAI( model_name=model_config.model_name, temperature=model_config.temperature, api_key=model_config.api_key, base_url=model_config.api_base, max_tokens=model_config.max_tokens, streaming=True ) # 初始化提示词模板 self.prompt = ChatPromptTemplate.from_template(get_prompt()) # 构建链式调用 self.chain = ( { "chat_history": lambda x: self._format_chat_history(x["similar_examples"]), "question": lambda x: x["question"] } | self.prompt | self.llm | StrOutputParser() ) # 初始化本地m3e-base模型 self.model_path = r"E:\项目临时\AI大模型\m3e-base" self.embeddings = HuggingFaceEmbeddings( model_name=self.model_path, model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) # 加载示例数据并创建向量数据库 self.examples = self._load_examples() self.vectorstore = self._build_vectorstore() def _load_examples(self): """加载示例数据""" with open('examples.json', 'r', encoding='utf-8') as f: return json.load(f) def _build_vectorstore(self): """构建FAISS向量数据库""" # 准备Document对象列表 documents = [] for example in self.examples: doc = Document( page_content=example['query'], metadata={ 'query_type': example['query_type'], 'plan': example['plan'], 'sql_code': example['sql_code'] } ) documents.append(doc) # 创建FAISS向量数据库 vectorstore = FAISS.from_documents( documents=documents, embedding=self.embeddings ) return vectorstore def _get_similar_examples(self, query: str, k: int = 3): """获取最相似的示例""" # 使用FAISS搜索相似示例 docs = self.vectorstore.similarity_search_with_score(query, k=k) # 格式化返回结果 similar_examples = [] for doc, score in docs: similar_examples.append({ 'query_type': doc.metadata['query_type'], 'query': doc.page_content, 'plan': doc.metadata['plan'], 'sql_code': doc.metadata['sql_code'], 'similarity_score': float(score) }) return similar_examples def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str: """格式化聊天历史""" if len(similar_examples) == 0: return "" chat_history = "基于以下相似示例:\n\n" for i, example in enumerate(similar_examples, 1): chat_history += f"示例 {i}:\n" chat_history += f"问题: {example['query']}\n" chat_history += f"计划: {example['plan']}\n" chat_history += f"SQL: {example['sql_code']}\n" chat_history += f"相似度: {example['similarity_score']:.2f}\n\n" print(chat_history+" !!!!!") return chat_history async def generate_sql_stream(self, query_description: str): """ 流式生成SQL查询语句 """ try: # 开始生成SQL yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}) + "\n" # 获取相似示例 similar_examples = self._get_similar_examples(query_description) yield json.dumps({ "type": "similar_examples", "content": similar_examples }, ensure_ascii=False) + "\n" # 准备输入数据 chain_input = { "similar_examples": similar_examples, "question": query_description } print("Chain input:", chain_input) try: # 流式生成SQL async for chunk in self.chain.astream(chain_input): if chunk: yield json.dumps({ "type": "sql_generation", "content": chunk }) + "\n" yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n" except Exception as e: print(f"Error during streaming: {str(e)}") print(f"Error type: {type(e)}") print(f"Error details: {traceback.format_exc()}") yield json.dumps({ "type": "error", "message": f"生成SQL时发生错误: {str(e)}" }) + "\n" except Exception as e: traceback.print_exc() yield json.dumps({ "type": "error", "message": str(e) }) + "\n" async def execute_sql(self, sql: str, db_connection) -> dict: """ 执行SQL查询并返回结果 """ try: result = db_connection.execute(sql) columns = result.keys() data = [dict(zip(columns, row)) for row in result.fetchall()] return { "status": "success", "data": data, "columns": columns } except Exception as e: traceback.print_exc() return { "status": "error", "message": str(e) }