|
@@ -9,9 +9,10 @@ import json
|
|
|
import traceback
|
|
|
from config import get_model_config
|
|
|
from typing import List, Dict, Any
|
|
|
-from database import Database, vector_model_config
|
|
|
+from database import Database
|
|
|
import re
|
|
|
|
|
|
+
|
|
|
class SQLGenerator:
|
|
|
def __init__(self, model_type: str = "openai"):
|
|
|
# 获取模型配置
|
|
@@ -27,24 +28,24 @@ class SQLGenerator:
|
|
|
|
|
|
# 初始化提示词模板
|
|
|
self.prompt = ChatPromptTemplate.from_template(get_prompt())
|
|
|
-
|
|
|
+
|
|
|
# 构建链式调用
|
|
|
self.chain = (
|
|
|
- {
|
|
|
- "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
|
|
|
- "question": lambda x: x["question"]
|
|
|
- }
|
|
|
- | self.prompt
|
|
|
- | self.llm
|
|
|
- | StrOutputParser()
|
|
|
+ {
|
|
|
+ "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
|
|
|
+ "question": lambda x: x["question"]
|
|
|
+ }
|
|
|
+ | self.prompt
|
|
|
+ | self.llm
|
|
|
+ | StrOutputParser()
|
|
|
)
|
|
|
|
|
|
# 初始化本地m3e-base模型
|
|
|
- model_cfg = vector_model_config["m3e-base"]
|
|
|
- self.model_path = model_cfg["model_path"]
|
|
|
+ model_path = r"E:\项目临时\AI大模型\m3e-base"
|
|
|
+ # model_path= r"/data/m3e-base"
|
|
|
self.embeddings = HuggingFaceEmbeddings(
|
|
|
- model_name=self.model_path,
|
|
|
- model_kwargs={'device': model_cfg["device"]},
|
|
|
+ model_name=model_path,
|
|
|
+ model_kwargs={'device': "cpu"},
|
|
|
encode_kwargs={'normalize_embeddings': True}
|
|
|
)
|
|
|
|
|
@@ -62,7 +63,7 @@ class SQLGenerator:
|
|
|
# 使用正则表达式匹配markdown代码块中的SQL
|
|
|
sql_pattern = r"```sql\n(.*?)\n```"
|
|
|
match = re.search(sql_pattern, text, re.DOTALL)
|
|
|
-
|
|
|
+
|
|
|
if match:
|
|
|
# 提取SQL语句并清理
|
|
|
sql = match.group(1).strip()
|
|
@@ -179,7 +180,7 @@ class SQLGenerator:
|
|
|
# 使用正则表达式提取SQL代码块
|
|
|
sql_pattern = r"```sql\n(.*?)\n```"
|
|
|
sql_match = re.search(sql_pattern, full_response, re.DOTALL)
|
|
|
-
|
|
|
+
|
|
|
if sql_match:
|
|
|
# 提取并清理SQL
|
|
|
sql_content = sql_match.group(1).strip()
|
|
@@ -190,7 +191,7 @@ class SQLGenerator:
|
|
|
# 确保SQL语句完整
|
|
|
if not sql_content.strip().endswith(';'):
|
|
|
sql_content = sql_content.strip() + ';'
|
|
|
-
|
|
|
+
|
|
|
yield json.dumps({
|
|
|
"type": "sql_result",
|
|
|
"content": sql_content
|
|
@@ -226,7 +227,7 @@ class SQLGenerator:
|
|
|
# 确保SQL是完整的
|
|
|
if not sql.strip().endswith(';'):
|
|
|
sql = sql.strip() + ';'
|
|
|
-
|
|
|
+
|
|
|
result = await self.db.execute_query(sql)
|
|
|
return {
|
|
|
"status": "success",
|