|
@@ -1,21 +1,18 @@
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
from langchain.prompts import ChatPromptTemplate
|
|
|
-from langchain.chains import LLMChain
|
|
|
-from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
-from langchain_community.vectorstores import FAISS
|
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
|
from langchain_core.documents import Document
|
|
|
+from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
+from langchain_community.vectorstores import FAISS
|
|
|
import os
|
|
|
from prompt_template import get_prompt
|
|
|
-from fastapi.responses import StreamingResponse
|
|
|
import json
|
|
|
-import asyncio
|
|
|
-import numpy as np
|
|
|
+import traceback
|
|
|
from config import get_model_config
|
|
|
-import pandas as pd
|
|
|
from typing import List, Dict, Any
|
|
|
-import traceback
|
|
|
+from database import Database
|
|
|
+import re
|
|
|
|
|
|
class SQLGenerator:
|
|
|
def __init__(self, model_type: str = "openai"):
|
|
@@ -25,10 +22,8 @@ class SQLGenerator:
|
|
|
# 初始化LLM
|
|
|
self.llm = ChatOpenAI(
|
|
|
model_name=model_config.model_name,
|
|
|
- temperature=model_config.temperature,
|
|
|
api_key=model_config.api_key,
|
|
|
base_url=model_config.api_base,
|
|
|
- max_tokens=model_config.max_tokens,
|
|
|
streaming=True
|
|
|
)
|
|
|
|
|
@@ -58,6 +53,27 @@ class SQLGenerator:
|
|
|
self.examples = self._load_examples()
|
|
|
self.vectorstore = self._build_vectorstore()
|
|
|
|
|
|
+ # 初始化数据库连接
|
|
|
+ self.db = Database()
|
|
|
+
|
|
|
+ def _extract_sql_from_markdown(self, text: str) -> str:
|
|
|
+ """
|
|
|
+ 从markdown格式的文本中提取SQL语句
|
|
|
+ """
|
|
|
+ # 使用正则表达式匹配markdown代码块中的SQL
|
|
|
+ sql_pattern = r"```sql\n(.*?)\n```"
|
|
|
+ match = re.search(sql_pattern, text, re.DOTALL)
|
|
|
+
|
|
|
+ if match:
|
|
|
+ # 提取SQL语句并清理
|
|
|
+ sql = match.group(1).strip()
|
|
|
+ # 移除可能的注释
|
|
|
+ sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
|
|
|
+ # 移除多余的空行
|
|
|
+ sql = re.sub(r'\n\s*\n', '\n', sql)
|
|
|
+ return sql.strip()
|
|
|
+ return text.strip()
|
|
|
+
|
|
|
def _load_examples(self):
|
|
|
"""加载示例数据"""
|
|
|
with open('examples.json', 'r', encoding='utf-8') as f:
|
|
@@ -72,7 +88,6 @@ class SQLGenerator:
|
|
|
page_content=example['query'],
|
|
|
metadata={
|
|
|
'query_type': example['query_type'],
|
|
|
- 'plan': example['plan'],
|
|
|
'sql_code': example['sql_code']
|
|
|
}
|
|
|
)
|
|
@@ -97,7 +112,6 @@ class SQLGenerator:
|
|
|
similar_examples.append({
|
|
|
'query_type': doc.metadata['query_type'],
|
|
|
'query': doc.page_content,
|
|
|
- 'plan': doc.metadata['plan'],
|
|
|
'sql_code': doc.metadata['sql_code'],
|
|
|
'similarity_score': float(score)
|
|
|
})
|
|
@@ -106,7 +120,6 @@ class SQLGenerator:
|
|
|
|
|
|
def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
|
|
|
"""格式化聊天历史"""
|
|
|
-
|
|
|
if len(similar_examples) == 0:
|
|
|
return ""
|
|
|
|
|
@@ -114,11 +127,9 @@ class SQLGenerator:
|
|
|
for i, example in enumerate(similar_examples, 1):
|
|
|
chat_history += f"示例 {i}:\n"
|
|
|
chat_history += f"问题: {example['query']}\n"
|
|
|
- chat_history += f"计划: {example['plan']}\n"
|
|
|
chat_history += f"SQL: {example['sql_code']}\n"
|
|
|
chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
|
|
|
|
|
|
- print(chat_history+" !!!!!")
|
|
|
return chat_history
|
|
|
|
|
|
async def generate_sql_stream(self, query_description: str):
|
|
@@ -127,7 +138,7 @@ class SQLGenerator:
|
|
|
"""
|
|
|
try:
|
|
|
# 开始生成SQL
|
|
|
- yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}) + "\n"
|
|
|
+ yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
|
|
|
|
|
|
# 获取相似示例
|
|
|
similar_examples = self._get_similar_examples(query_description)
|
|
@@ -145,43 +156,82 @@ class SQLGenerator:
|
|
|
print("Chain input:", chain_input)
|
|
|
|
|
|
try:
|
|
|
+ # 构建完整的提示词
|
|
|
+ formatted_prompt = self.prompt.format(
|
|
|
+ chat_history=self._format_chat_history(similar_examples),
|
|
|
+ question=query_description
|
|
|
+ )
|
|
|
+ print("Formatted prompt:", formatted_prompt)
|
|
|
+
|
|
|
# 流式生成SQL
|
|
|
+ full_response = ""
|
|
|
async for chunk in self.chain.astream(chain_input):
|
|
|
if chunk:
|
|
|
- yield json.dumps({
|
|
|
- "type": "sql_generation",
|
|
|
- "content": chunk
|
|
|
- }) + "\n"
|
|
|
+ try:
|
|
|
+ full_response += chunk
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "sql_generation",
|
|
|
+ "content": chunk
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
+ except Exception as chunk_error:
|
|
|
+ print(f"Error processing chunk: {str(chunk_error)}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 使用正则表达式提取SQL代码块
|
|
|
+ sql_pattern = r"```sql\n(.*?)\n```"
|
|
|
+ sql_match = re.search(sql_pattern, full_response, re.DOTALL)
|
|
|
+
|
|
|
+ if sql_match:
|
|
|
+ # 提取并清理SQL
|
|
|
+ sql_content = sql_match.group(1).strip()
|
|
|
+ # 移除注释
|
|
|
+ sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE)
|
|
|
+ # 移除多余空行
|
|
|
+ sql_content = re.sub(r'\n\s*\n', '\n', sql_content)
|
|
|
+ # 确保SQL语句完整
|
|
|
+ if not sql_content.strip().endswith(';'):
|
|
|
+ sql_content = sql_content.strip() + ';'
|
|
|
+
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "sql_result",
|
|
|
+ "content": sql_content
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
+ else:
|
|
|
+ # 如果没有找到SQL代码块,返回完整响应
|
|
|
+ yield json.dumps({
|
|
|
+ "type": "sql_result",
|
|
|
+ "content": full_response
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
|
|
|
yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
|
|
|
except Exception as e:
|
|
|
- print(f"Error during streaming: {str(e)}")
|
|
|
- print(f"Error type: {type(e)}")
|
|
|
print(f"Error details: {traceback.format_exc()}")
|
|
|
yield json.dumps({
|
|
|
"type": "error",
|
|
|
"message": f"生成SQL时发生错误: {str(e)}"
|
|
|
- }) + "\n"
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
|
|
|
except Exception as e:
|
|
|
traceback.print_exc()
|
|
|
yield json.dumps({
|
|
|
"type": "error",
|
|
|
"message": str(e)
|
|
|
- }) + "\n"
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
|
|
|
- async def execute_sql(self, sql: str, db_connection) -> dict:
|
|
|
+ async def execute_sql(self, sql: str) -> dict:
|
|
|
"""
|
|
|
执行SQL查询并返回结果
|
|
|
"""
|
|
|
+ print(sql)
|
|
|
try:
|
|
|
- result = db_connection.execute(sql)
|
|
|
- columns = result.keys()
|
|
|
- data = [dict(zip(columns, row)) for row in result.fetchall()]
|
|
|
+ # 确保SQL是完整的
|
|
|
+ if not sql.strip().endswith(';'):
|
|
|
+ sql = sql.strip() + ';'
|
|
|
+
|
|
|
+ result = await self.db.execute_query(sql)
|
|
|
return {
|
|
|
"status": "success",
|
|
|
- "data": data,
|
|
|
- "columns": columns
|
|
|
+ "data": result
|
|
|
}
|
|
|
except Exception as e:
|
|
|
traceback.print_exc()
|
|
@@ -189,3 +239,7 @@ class SQLGenerator:
|
|
|
"status": "error",
|
|
|
"message": str(e)
|
|
|
}
|
|
|
+
|
|
|
+ async def close(self):
|
|
|
+ """关闭数据库连接"""
|
|
|
+ await self.db.close()
|