Przeglądaj źródła

选址提示词修改

liutao 1 miesiąc temu
rodzic
commit
8791160f7e
2 zmienionych plików z 33 dodań i 42 usunięć
  1. 23 39
      landsite_agent/main.py
  2. 10 3
      landsite_agent/prompt_template.py

+ 23 - 39
landsite_agent/main.py

@@ -23,30 +23,35 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
 class QueryRequest(BaseModel):
     description: str
 
+
 class AnalysisResult(BaseModel):
     sql: str
     data: list
     visualization: dict = None
     similar_examples: List[Dict[str, Any]] = None
 
+
 sql_generator = SQLGenerator()
 
+
 @app.post("/land_analysis/stream")
 async def stream_land_analysis(request: QueryRequest):
     """
     流式返回土地分析结果
     """
+
     async def generate_stream() -> AsyncGenerator[str, None]:
         try:
             similar_examples = None
-            
+
             # 流式生成SQL
             async for chunk in sql_generator.generate_sql_stream(request.description):
                 data = json.loads(chunk)
-                
+
                 if data["type"] == "similar_examples":
                     similar_examples = data["content"]
                     yield chunk
@@ -56,34 +61,22 @@ async def stream_land_analysis(request: QueryRequest):
                     # 获取到完整的SQL后执行
                     sql = data["content"]
                     result = await sql_generator.execute_sql(sql)
-                    
+
                     if result["status"] == "error":
                         yield json.dumps({
                             "type": "error",
                             "message": result["message"]
-                        }) + "\n"
+                        }, ensure_ascii=False) + "\n"
                         return
 
-                    # 生成可视化
-                    df = pd.DataFrame(result["data"])
-                    visualization = None
-                    
-                    if not df.empty:
-                        if len(df.columns) >= 2:
-                            if df.select_dtypes(include=['number']).columns.any():
-                                fig = px.bar(df, x=df.columns[0], y=df.columns[1])
-                                visualization = json.loads(fig.to_json())
-
                     # 返回最终结果
                     yield json.dumps({
                         "type": "result",
                         "data": {
                             "sql": sql,
-                            "data": result["data"],
-                            "visualization": visualization,
-                            "similar_examples": similar_examples
+                            "exec_result": result["data"]
                         }
-                    }) + "\n"
+                    }, ensure_ascii=False) + "\n"
                 else:
                     yield chunk
 
@@ -92,19 +85,20 @@ async def stream_land_analysis(request: QueryRequest):
             yield json.dumps({
                 "type": "error",
                 "message": str(e)
-            }) + "\n"
+            }, ensure_ascii=False) + "\n"
 
     return StreamingResponse(
         generate_stream(),
         media_type="text/event-stream"
     )
 
+
 @app.post("/land_analysis", response_model=AnalysisResult)
 async def generate_and_execute_sql(request: QueryRequest):
     try:
         # 获取相似示例
         similar_examples = sql_generator._get_similar_examples(request.description)
-        
+
         # 构建增强提示词
         enhanced_prompt = f"""
         基于以下相似示例:
@@ -113,37 +107,26 @@ async def generate_and_execute_sql(request: QueryRequest):
         请根据以下描述生成SQL查询:
         {request.description}
         """
-        
+
         # 生成SQL
         sql = await sql_generator.chain.arun(enhanced_prompt)
-        
+
         # 执行SQL
         result = await sql_generator.execute_sql(sql)
-        
+
         if result["status"] == "error":
             raise HTTPException(status_code=400, detail=result["message"])
-        
-        # 生成可视化
-        df = pd.DataFrame(result["data"])
-        visualization = None
-        
-        if not df.empty:
-            if len(df.columns) >= 2:
-                if df.select_dtypes(include=['number']).columns.any():
-                    fig = px.bar(df, x=df.columns[0], y=df.columns[1])
-                    visualization = json.loads(fig.to_json())
-        
+
         return AnalysisResult(
             sql=sql,
-            data=result["data"],
-            visualization=visualization,
-            similar_examples=similar_examples
+            exec_result=result["data"]
         )
-    
+
     except Exception as e:
         traceback.print_exc()
         raise HTTPException(status_code=500, detail=str(e))
 
+
 @app.on_event("shutdown")
 async def shutdown_event():
     """
@@ -151,5 +134,6 @@ async def shutdown_event():
     """
     await sql_generator.close()
 
+
 if __name__ == "__main__":
-    uvicorn.run(app, host="0.0.0.0", port=8001)
+    uvicorn.run(app, host="0.0.0.0", port=8001)

+ 10 - 3
landsite_agent/prompt_template.py

@@ -60,7 +60,12 @@ PROMPT_TEMPLATE = """
    - 不允许使用 SELECT * 或其他字段
    - 示例:SELECT id FROM table WHERE condition;
 
-5. 其他注意事项:
+5. 结果数量限制:
+   - 所有查询必须使用 LIMIT 5 限制返回结果数量
+   - 最多只返回5条数据
+   - 示例:SELECT id FROM table WHERE condition LIMIT 5;
+
+6. 其他注意事项:
    - 确保SQL语句的语法正确性
    - 注意字段名称的准确性
    - 合理使用索引字段(如id、xzqmc等)
@@ -76,7 +81,8 @@ PROMPT_TEMPLATE = """
 2. 只查询id字段
 3. 包含shape is not null条件
 4. 正确使用面积字段和单位
-5. 遵循其他注意事项
+5. 使用LIMIT 5限制返回结果数量
+6. 遵循其他注意事项
 
 请按照以下格式输出,每个部分之间用空行分隔:
 
@@ -89,7 +95,8 @@ PROMPT_TEMPLATE = """
 4. SQL: 生成SQL代码
 ```sql
 SELECT id FROM table
-WHERE condition;
+WHERE condition
+LIMIT 5;
 ```
 
 请确保SQL语句是完整且可执行的,并且SQL代码块是独立的部分。"""