|
@@ -11,11 +11,17 @@ from config import get_model_config
|
|
|
from typing import List, Dict, Any
|
|
|
from database import Database
|
|
|
import re
|
|
|
+import os
|
|
|
+from dotenv import load_dotenv
|
|
|
+
|
|
|
+# 加载环境变量
|
|
|
+load_dotenv('config.env')
|
|
|
|
|
|
|
|
|
class SQLGenerator:
|
|
|
- def __init__(self, model_type: str = "openai"):
|
|
|
+ def __init__(self, model_type: str = None):
|
|
|
# 获取模型配置
|
|
|
+ model_type = model_type or os.getenv('DEFAULT_MODEL_TYPE', 'openai')
|
|
|
model_config = get_model_config(model_type)
|
|
|
|
|
|
# 初始化LLM
|
|
@@ -40,9 +46,12 @@ class SQLGenerator:
|
|
|
| StrOutputParser()
|
|
|
)
|
|
|
|
|
|
+ # 从环境变量获取模型路径
|
|
|
+ model_path = os.getenv('MODEL_PATH')
|
|
|
+ if not model_path:
|
|
|
+ raise ValueError("MODEL_PATH environment variable is not set")
|
|
|
+
|
|
|
# 初始化本地m3e-base模型
|
|
|
- model_path = r"E:\项目临时\AI大模型\m3e-base"
|
|
|
- # model_path= r"/data/m3e-base"
|
|
|
self.embeddings = HuggingFaceEmbeddings(
|
|
|
model_name=model_path,
|
|
|
model_kwargs={'device': "cpu"},
|
|
@@ -137,16 +146,15 @@ class SQLGenerator:
|
|
|
流式生成SQL查询语句
|
|
|
"""
|
|
|
try:
|
|
|
- # 开始生成SQL
|
|
|
- yield json.dumps({"type": "start", "content": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
|
|
|
-
|
|
|
# 获取相似示例
|
|
|
similar_examples = self._get_similar_examples(query_description)
|
|
|
- yield json.dumps({
|
|
|
+ simliar_example_format_dump = json.dumps({
|
|
|
"type": "similar_examples",
|
|
|
"content": similar_examples
|
|
|
}, ensure_ascii=False) + "\n"
|
|
|
|
|
|
+ print(simliar_example_format_dump)
|
|
|
+
|
|
|
# 准备输入数据
|
|
|
chain_input = {
|
|
|
"similar_examples": similar_examples,
|