|
@@ -23,30 +23,35 @@ app.add_middleware(
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
+
|
|
|
class QueryRequest(BaseModel):
|
|
|
description: str
|
|
|
|
|
|
+
|
|
|
class AnalysisResult(BaseModel):
|
|
|
sql: str
|
|
|
data: list
|
|
|
visualization: dict = None
|
|
|
similar_examples: List[Dict[str, Any]] = None
|
|
|
|
|
|
+
|
|
|
sql_generator = SQLGenerator()
|
|
|
|
|
|
+
|
|
|
@app.post("/land_analysis/stream")
|
|
|
async def stream_land_analysis(request: QueryRequest):
|
|
|
"""
|
|
|
流式返回土地分析结果
|
|
|
"""
|
|
|
+
|
|
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
|
|
try:
|
|
|
similar_examples = None
|
|
|
-
|
|
|
+
|
|
|
# 流式生成SQL
|
|
|
async for chunk in sql_generator.generate_sql_stream(request.description):
|
|
|
data = json.loads(chunk)
|
|
|
-
|
|
|
+
|
|
|
if data["type"] == "similar_examples":
|
|
|
similar_examples = data["content"]
|
|
|
yield chunk
|
|
@@ -56,34 +61,22 @@ async def stream_land_analysis(request: QueryRequest):
|
|
|
# 获取到完整的SQL后执行
|
|
|
sql = data["content"]
|
|
|
result = await sql_generator.execute_sql(sql)
|
|
|
-
|
|
|
+
|
|
|
if result["status"] == "error":
|
|
|
yield json.dumps({
|
|
|
"type": "error",
|
|
|
"message": result["message"]
|
|
|
- }) + "\n"
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
return
|
|
|
|
|
|
- # 生成可视化
|
|
|
- df = pd.DataFrame(result["data"])
|
|
|
- visualization = None
|
|
|
-
|
|
|
- if not df.empty:
|
|
|
- if len(df.columns) >= 2:
|
|
|
- if df.select_dtypes(include=['number']).columns.any():
|
|
|
- fig = px.bar(df, x=df.columns[0], y=df.columns[1])
|
|
|
- visualization = json.loads(fig.to_json())
|
|
|
-
|
|
|
# 返回最终结果
|
|
|
yield json.dumps({
|
|
|
"type": "result",
|
|
|
"data": {
|
|
|
"sql": sql,
|
|
|
- "data": result["data"],
|
|
|
- "visualization": visualization,
|
|
|
- "similar_examples": similar_examples
|
|
|
+ "exec_result": result["data"]
|
|
|
}
|
|
|
- }) + "\n"
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
else:
|
|
|
yield chunk
|
|
|
|
|
@@ -92,19 +85,20 @@ async def stream_land_analysis(request: QueryRequest):
|
|
|
yield json.dumps({
|
|
|
"type": "error",
|
|
|
"message": str(e)
|
|
|
- }) + "\n"
|
|
|
+ }, ensure_ascii=False) + "\n"
|
|
|
|
|
|
return StreamingResponse(
|
|
|
generate_stream(),
|
|
|
media_type="text/event-stream"
|
|
|
)
|
|
|
|
|
|
+
|
|
|
@app.post("/land_analysis", response_model=AnalysisResult)
|
|
|
async def generate_and_execute_sql(request: QueryRequest):
|
|
|
try:
|
|
|
# 获取相似示例
|
|
|
similar_examples = sql_generator._get_similar_examples(request.description)
|
|
|
-
|
|
|
+
|
|
|
# 构建增强提示词
|
|
|
enhanced_prompt = f"""
|
|
|
基于以下相似示例:
|
|
@@ -113,37 +107,26 @@ async def generate_and_execute_sql(request: QueryRequest):
|
|
|
请根据以下描述生成SQL查询:
|
|
|
{request.description}
|
|
|
"""
|
|
|
-
|
|
|
+
|
|
|
# 生成SQL
|
|
|
sql = await sql_generator.chain.arun(enhanced_prompt)
|
|
|
-
|
|
|
+
|
|
|
# 执行SQL
|
|
|
result = await sql_generator.execute_sql(sql)
|
|
|
-
|
|
|
+
|
|
|
if result["status"] == "error":
|
|
|
raise HTTPException(status_code=400, detail=result["message"])
|
|
|
-
|
|
|
- # 生成可视化
|
|
|
- df = pd.DataFrame(result["data"])
|
|
|
- visualization = None
|
|
|
-
|
|
|
- if not df.empty:
|
|
|
- if len(df.columns) >= 2:
|
|
|
- if df.select_dtypes(include=['number']).columns.any():
|
|
|
- fig = px.bar(df, x=df.columns[0], y=df.columns[1])
|
|
|
- visualization = json.loads(fig.to_json())
|
|
|
-
|
|
|
+
|
|
|
return AnalysisResult(
|
|
|
sql=sql,
|
|
|
- data=result["data"],
|
|
|
- visualization=visualization,
|
|
|
- similar_examples=similar_examples
|
|
|
+ exec_result=result["data"]
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
except Exception as e:
|
|
|
traceback.print_exc()
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
+
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown_event():
|
|
|
"""
|
|
@@ -151,5 +134,6 @@ async def shutdown_event():
|
|
|
"""
|
|
|
await sql_generator.close()
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
- uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=8001)
|