|
@@ -0,0 +1,191 @@
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+from langchain.prompts import ChatPromptTemplate
|
|
|
+from langchain.chains import LLMChain
|
|
|
+from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
+from langchain_community.vectorstores import FAISS
|
|
|
+from langchain_core.output_parsers import StrOutputParser
|
|
|
+from langchain_core.runnables import RunnablePassthrough
|
|
|
+from langchain_core.documents import Document
|
|
|
+import os
|
|
|
+from prompt_template import get_prompt
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
+import json
|
|
|
+import asyncio
|
|
|
+import numpy as np
|
|
|
+from config import get_model_config
|
|
|
+import pandas as pd
|
|
|
+from typing import List, Dict, Any
|
|
|
+import traceback
|
|
|
+
|
|
|
+class SQLGenerator:
|
|
|
+ def __init__(self, model_type: str = "openai"):
|
|
|
+ # 获取模型配置
|
|
|
+ model_config = get_model_config(model_type)
|
|
|
+
|
|
|
+ # 初始化LLM
|
|
|
+ self.llm = ChatOpenAI(
|
|
|
+ model_name=model_config.model_name,
|
|
|
+ temperature=model_config.temperature,
|
|
|
+ api_key=model_config.api_key,
|
|
|
+ base_url=model_config.api_base,
|
|
|
+ max_tokens=model_config.max_tokens,
|
|
|
+ streaming=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # 初始化提示词模板
|
|
|
+ self.prompt = ChatPromptTemplate.from_template(get_prompt())
|
|
|
+
|
|
|
+ # 构建链式调用
|
|
|
+ self.chain = (
|
|
|
+ {
|
|
|
+ "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
|
|
|
+ "question": lambda x: x["question"]
|
|
|
+ }
|
|
|
+ | self.prompt
|
|
|
+ | self.llm
|
|
|
+ | StrOutputParser()
|
|
|
+ )
|
|
|
+
|
|
|
+ # 初始化本地m3e-base模型
|
|
|
+ self.model_path = r"E:\项目临时\AI大模型\m3e-base"
|
|
|
+ self.embeddings = HuggingFaceEmbeddings(
|
|
|
+ model_name=self.model_path,
|
|
|
+ model_kwargs={'device': 'cpu'},
|
|
|
+ encode_kwargs={'normalize_embeddings': True}
|
|
|
+ )
|
|
|
+
|
|
|
+ # 加载示例数据并创建向量数据库
|
|
|
+ self.examples = self._load_examples()
|
|
|
+ self.vectorstore = self._build_vectorstore()
|
|
|
+
|
|
|
+ def _load_examples(self):
|
|
|
+ """加载示例数据"""
|
|
|
+ with open('examples.json', 'r', encoding='utf-8') as f:
|
|
|
+ return json.load(f)
|
|
|
+
|
|
|
+ def _build_vectorstore(self):
|
|
|
+ """构建FAISS向量数据库"""
|
|
|
+ # 准备Document对象列表
|
|
|
+ documents = []
|
|
|
+ for example in self.examples:
|
|
|
+ doc = Document(
|
|
|
+ page_content=example['query'],
|
|
|
+ metadata={
|
|
|
+ 'query_type': example['query_type'],
|
|
|
+ 'plan': example['plan'],
|
|
|
+ 'sql_code': example['sql_code']
|
|
|
+ }
|
|
|
+ )
|
|
|
+ documents.append(doc)
|
|
|
+
|
|
|
+ # 创建FAISS向量数据库
|
|
|
+ vectorstore = FAISS.from_documents(
|
|
|
+ documents=documents,
|
|
|
+ embedding=self.embeddings
|
|
|
+ )
|
|
|
+
|
|
|
+ return vectorstore
|
|
|
+
|
|
|
+ def _get_similar_examples(self, query: str, k: int = 3):
|
|
|
+ """获取最相似的示例"""
|
|
|
+ # 使用FAISS搜索相似示例
|
|
|
+ docs = self.vectorstore.similarity_search_with_score(query, k=k)
|
|
|
+
|
|
|
+ # 格式化返回结果
|
|
|
+ similar_examples = []
|
|
|
+ for doc, score in docs:
|
|
|
+ similar_examples.append({
|
|
|
+ 'query_type': doc.metadata['query_type'],
|
|
|
+ 'query': doc.page_content,
|
|
|
+ 'plan': doc.metadata['plan'],
|
|
|
+ 'sql_code': doc.metadata['sql_code'],
|
|
|
+ 'similarity_score': float(score)
|
|
|
+ })
|
|
|
+
|
|
|
+ return similar_examples
|
|
|
+
|
|
|
+ def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
|
|
|
+ """格式化聊天历史"""
|
|
|
+
|
|
|
+ if len(similar_examples) == 0:
|
|
|
+ return ""
|
|
|
+
|
|
|
+ chat_history = "基于以下相似示例:\n\n"
|
|
|
+ for i, example in enumerate(similar_examples, 1):
|
|
|
+ chat_history += f"示例 {i}:\n"
|
|
|
+ chat_history += f"问题: {example['query']}\n"
|
|
|
+ chat_history += f"计划: {example['plan']}\n"
|
|
|
+ chat_history += f"SQL: {example['sql_code']}\n"
|
|
|
+ chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
|
|
|
+
|
|
|
+ print(chat_history+" !!!!!")
|
|
|
+ return chat_history
|
|
|
+
|
|
|
+ async def generate_sql_stream(self, query_description: str):
|
|
|
+ """
|
|
|
+ 流式生成SQL查询语句
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 开始生成SQL
|
|
|
+ yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}) + "\n"
|
|
|
+
|
|
|
+ # 获取相似示例
|
|
|
+ similar_examples = self._get_similar_examples(query_description)
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "similar_examples",
|
|
|
+ "content": similar_examples
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
+
|
|
|
+ # 准备输入数据
|
|
|
+ chain_input = {
|
|
|
+ "similar_examples": similar_examples,
|
|
|
+ "question": query_description
|
|
|
+ }
|
|
|
+
|
|
|
+ print("Chain input:", chain_input)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 流式生成SQL
|
|
|
+ async for chunk in self.chain.astream(chain_input):
|
|
|
+ if chunk:
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "sql_generation",
|
|
|
+ "content": chunk
|
|
|
+ }) + "\n"
|
|
|
+
|
|
|
+ yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error during streaming: {str(e)}")
|
|
|
+ print(f"Error type: {type(e)}")
|
|
|
+ print(f"Error details: {traceback.format_exc()}")
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "error",
|
|
|
+ "message": f"生成SQL时发生错误: {str(e)}"
|
|
|
+ }) + "\n"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ traceback.print_exc()
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "error",
|
|
|
+ "message": str(e)
|
|
|
+ }) + "\n"
|
|
|
+
|
|
|
+ async def execute_sql(self, sql: str, db_connection) -> dict:
|
|
|
+ """
|
|
|
+ 执行SQL查询并返回结果
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ result = db_connection.execute(sql)
|
|
|
+ columns = result.keys()
|
|
|
+ data = [dict(zip(columns, row)) for row in result.fetchall()]
|
|
|
+ return {
|
|
|
+ "status": "success",
|
|
|
+ "data": data,
|
|
|
+ "columns": columns
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ traceback.print_exc()
|
|
|
+ return {
|
|
|
+ "status": "error",
|
|
|
+ "message": str(e)
|
|
|
+ }
|