123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- import pandas as pd
- import plotly.express as px
- import json
- from fastapi.responses import StreamingResponse
- from typing import AsyncGenerator, List, Dict, Any
- import uvicorn
- import traceback
- import asyncio
- from sql_generator import SQLGenerator
- app = FastAPI(title="Land Analysis API")
- # 配置CORS
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class QueryRequest(BaseModel):
- description: str
- class AnalysisResult(BaseModel):
- sql: str
- data: list
- visualization: dict = None
- similar_examples: List[Dict[str, Any]] = None
- sql_generator = SQLGenerator()
- @app.post("/land_analysis/stream")
- async def stream_land_analysis(request: QueryRequest):
- """
- 流式返回土地分析结果
- """
- async def generate_stream() -> AsyncGenerator[str, None]:
- try:
- similar_examples = None
-
- # 流式生成SQL
- async for chunk in sql_generator.generate_sql_stream(request.description):
- data = json.loads(chunk)
-
- if data["type"] == "similar_examples":
- similar_examples = data["content"]
- yield chunk
- elif data["type"] == "sql_generation":
- yield chunk
- elif data["type"] == "sql_result":
- # 获取到完整的SQL后执行
- sql = data["content"]
- result = await sql_generator.execute_sql(sql)
-
- if result["status"] == "error":
- yield json.dumps({
- "type": "error",
- "message": result["message"]
- }) + "\n"
- return
- # 生成可视化
- df = pd.DataFrame(result["data"])
- visualization = None
-
- if not df.empty:
- if len(df.columns) >= 2:
- if df.select_dtypes(include=['number']).columns.any():
- fig = px.bar(df, x=df.columns[0], y=df.columns[1])
- visualization = json.loads(fig.to_json())
- # 返回最终结果
- yield json.dumps({
- "type": "result",
- "data": {
- "sql": sql,
- "data": result["data"],
- "visualization": visualization,
- "similar_examples": similar_examples
- }
- }) + "\n"
- else:
- yield chunk
- except Exception as e:
- traceback.print_exc()
- yield json.dumps({
- "type": "error",
- "message": str(e)
- }) + "\n"
- return StreamingResponse(
- generate_stream(),
- media_type="text/event-stream"
- )
- @app.post("/land_analysis", response_model=AnalysisResult)
- async def generate_and_execute_sql(request: QueryRequest):
- try:
- # 获取相似示例
- similar_examples = sql_generator._get_similar_examples(request.description)
-
- # 构建增强提示词
- enhanced_prompt = f"""
- 基于以下相似示例:
- {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
-
- 请根据以下描述生成SQL查询:
- {request.description}
- """
-
- # 生成SQL
- sql = await sql_generator.chain.arun(enhanced_prompt)
-
- # 执行SQL
- result = await sql_generator.execute_sql(sql)
-
- if result["status"] == "error":
- raise HTTPException(status_code=400, detail=result["message"])
-
- # 生成可视化
- df = pd.DataFrame(result["data"])
- visualization = None
-
- if not df.empty:
- if len(df.columns) >= 2:
- if df.select_dtypes(include=['number']).columns.any():
- fig = px.bar(df, x=df.columns[0], y=df.columns[1])
- visualization = json.loads(fig.to_json())
-
- return AnalysisResult(
- sql=sql,
- data=result["data"],
- visualization=visualization,
- similar_examples=similar_examples
- )
-
- except Exception as e:
- traceback.print_exc()
- raise HTTPException(status_code=500, detail=str(e))
- @app.on_event("shutdown")
- async def shutdown_event():
- """
- 应用关闭时清理资源
- """
- await sql_generator.close()
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=8001)
|