|
@@ -14,6 +14,8 @@ import re
|
|
import os
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
+
|
|
|
|
+
|
|
# 加载环境变量
|
|
# 加载环境变量
|
|
load_dotenv('config.env')
|
|
load_dotenv('config.env')
|
|
|
|
|
|
@@ -154,7 +156,6 @@ class SQLGenerator:
|
|
}, ensure_ascii=False) + "\n"
|
|
}, ensure_ascii=False) + "\n"
|
|
|
|
|
|
print(simliar_example_format_dump)
|
|
print(simliar_example_format_dump)
|
|
-
|
|
|
|
# 准备输入数据
|
|
# 准备输入数据
|
|
chain_input = {
|
|
chain_input = {
|
|
"similar_examples": similar_examples,
|
|
"similar_examples": similar_examples,
|
|
@@ -162,7 +163,6 @@ class SQLGenerator:
|
|
}
|
|
}
|
|
|
|
|
|
print("Chain input:", chain_input)
|
|
print("Chain input:", chain_input)
|
|
-
|
|
|
|
try:
|
|
try:
|
|
# 构建完整的提示词
|
|
# 构建完整的提示词
|
|
formatted_prompt = self.prompt.format(
|
|
formatted_prompt = self.prompt.format(
|
|
@@ -170,12 +170,15 @@ class SQLGenerator:
|
|
question=query_description
|
|
question=query_description
|
|
)
|
|
)
|
|
print("Formatted prompt:", formatted_prompt)
|
|
print("Formatted prompt:", formatted_prompt)
|
|
-
|
|
|
|
# 流式生成SQL
|
|
# 流式生成SQL
|
|
full_response = ""
|
|
full_response = ""
|
|
async for chunk in self.chain.astream(chain_input):
|
|
async for chunk in self.chain.astream(chain_input):
|
|
if chunk:
|
|
if chunk:
|
|
try:
|
|
try:
|
|
|
|
+ # Remove think tags if using zjstai model
|
|
|
|
+ if os.getenv('DEFAULT_MODEL_TYPE') == 'zjstai':
|
|
|
|
+ chunk = re.sub(r'<think>', '', chunk, flags=re.DOTALL)
|
|
|
|
+ chunk = re.sub(r'</think>', '', chunk, flags=re.DOTALL)
|
|
full_response += chunk
|
|
full_response += chunk
|
|
yield json.dumps({
|
|
yield json.dumps({
|
|
"type": "sql_generation",
|
|
"type": "sql_generation",
|