from fastapi import FastAPI, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.orm import Session from database import get_db from sql_generator import SQLGenerator from pydantic import BaseModel import pandas as pd import plotly.express as px import json from fastapi.responses import StreamingResponse from typing import AsyncGenerator, List, Dict, Any import uvicorn app = FastAPI(title="Land Analysis API") import traceback # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class QueryRequest(BaseModel): description: str class AnalysisResult(BaseModel): sql: str data: list visualization: dict = None similar_examples: List[Dict[str, Any]] = None sql_generator = SQLGenerator() @app.post("/land_analysis/stream") async def stream_land_analysis( request: QueryRequest, db: Session = Depends(get_db) ): """ 流式返回土地分析结果 """ async def generate_stream() -> AsyncGenerator[str, None]: try: similar_examples = None generated_sql = None # 流式生成SQL async for chunk in sql_generator.generate_sql_stream(request.description): data = json.loads(chunk) if data["type"] == "similar_examples": similar_examples = data["content"] yield chunk elif data["type"] == "sql_generation": generated_sql = data["content"] yield chunk else: yield chunk if not generated_sql: yield json.dumps({ "type": "error", "message": "SQL生成失败" }) + "\n" return # 执行SQL并返回结果 result = await sql_generator.execute_sql(generated_sql, db) if result["status"] == "error": yield json.dumps({ "type": "error", "message": result["message"] }) + "\n" return # 生成可视化 df = pd.DataFrame(result["data"]) visualization = None if not df.empty: if len(df.columns) >= 2: if df.select_dtypes(include=['number']).columns.any(): fig = px.bar(df, x=df.columns[0], y=df.columns[1]) visualization = json.loads(fig.to_json()) # 返回最终结果 yield json.dumps({ "type": "result", "data": { "sql": generated_sql, "data": result["data"], "visualization": visualization, "similar_examples": similar_examples } }) + "\n" except Exception as e: traceback.print_exc() yield json.dumps({ "type": "error", "message": str(e) }) + "\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) @app.post("/land_analysis", response_model=AnalysisResult) async def generate_and_execute_sql( request: QueryRequest, db: Session = Depends(get_db) ): try: # 获取相似示例 similar_examples = sql_generator._get_similar_examples(request.description) # 构建增强提示词 enhanced_prompt = f""" 基于以下相似示例: {json.dumps(similar_examples, ensure_ascii=False, indent=2)} 请根据以下描述生成SQL查询: {request.description} """ # 生成SQL sql = await sql_generator.chain.arun(enhanced_prompt) # 执行SQL result = await sql_generator.execute_sql(sql, db) if result["status"] == "error": raise HTTPException(status_code=400, detail=result["message"]) # 生成可视化 df = pd.DataFrame(result["data"]) visualization = None if not df.empty: if len(df.columns) >= 2: if df.select_dtypes(include=['number']).columns.any(): fig = px.bar(df, x=df.columns[0], y=df.columns[1]) visualization = json.loads(fig.to_json()) return AnalysisResult( sql=sql, data=result["data"], visualization=visualization, similar_examples=similar_examples ) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8001)