Browse Source

选址提示词修改

liutao 1 tháng trước cách đây
mục cha
commit
a525853942

+ 0 - 3
landsite_agent/main.py

@@ -9,7 +9,6 @@ import traceback
 from xuanzhi_query import router as xz_router
 from sql_generator import SQLGenerator
 from config import DEFAULT_MODEL_TYPE
-import re
 
 app = FastAPI(title="Land Analysis API")
 app.include_router(xz_router)
@@ -46,8 +45,6 @@ async def stream_land_analysis(request: QueryRequest):
 
     async def generate_stream() -> AsyncGenerator[str, None]:
         try:
-            similar_examples = None
-
             # 流式生成SQL
             async for chunk in sql_generator.generate_sql_stream(request.description):
                 data = json.loads(chunk)

+ 4 - 4
landsite_agent/prompt_template.py

@@ -87,13 +87,13 @@ PROMPT_TEMPLATE = """
 
 请按照以下格式输出,每个部分之间用空行分隔:
 
-1. Question: 分析用户问题
+#### Question: 分析用户问题
 
-2. Thought: 思考查询逻辑
+#### Thought: 思考查询逻辑
 
-3. Plan: 制定查询计划
+#### Plan: 制定查询计划
 
-4. SQL: 生成SQL代码
+#### SQL: 生成SQL代码
 ```sql
 SELECT id FROM table
 WHERE condition

+ 6 - 3
landsite_agent/sql_generator.py

@@ -14,6 +14,8 @@ import re
 import os
 from dotenv import load_dotenv
 
+
+
 # 加载环境变量
 load_dotenv('config.env')
 
@@ -154,7 +156,6 @@ class SQLGenerator:
             }, ensure_ascii=False) + "\n"
 
             print(simliar_example_format_dump)
-
             # 准备输入数据
             chain_input = {
                 "similar_examples": similar_examples,
@@ -162,7 +163,6 @@ class SQLGenerator:
             }
 
             print("Chain input:", chain_input)
-
             try:
                 # 构建完整的提示词
                 formatted_prompt = self.prompt.format(
@@ -170,12 +170,15 @@ class SQLGenerator:
                     question=query_description
                 )
                 print("Formatted prompt:", formatted_prompt)
-
                 # 流式生成SQL
                 full_response = ""
                 async for chunk in self.chain.astream(chain_input):
                     if chunk:
                         try:
+                            # Remove think tags if using zjstai model
+                            if os.getenv('DEFAULT_MODEL_TYPE') == 'zjstai':
+                                chunk = re.sub(r'<think>', '', chunk, flags=re.DOTALL)
+                                chunk = re.sub(r'</think>', '', chunk, flags=re.DOTALL)
                             full_response += chunk
                             yield json.dumps({
                                 "type": "sql_generation",