from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import json from fastapi.responses import StreamingResponse from typing import AsyncGenerator, List, Dict, Any import uvicorn import traceback from xuanzhi_query import router as xz_router from sql_generator import SQLGenerator from config import DEFAULT_MODEL_TYPE import re app = FastAPI(title="Land Analysis API") app.include_router(xz_router) # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) server_host = "0.0.0.0" server_port = 8521 class QueryRequest(BaseModel): description: str class AnalysisResult(BaseModel): sql: str data: list visualization: dict = None similar_examples: List[Dict[str, Any]] = None @app.post("/land_analysis/stream") async def stream_land_analysis(request: QueryRequest): """ 流式返回土地分析结果 """ def remove_think_tag(text: str) -> str: # 移除标签及其内容 return re.sub(r"[\s\S]*?\\n*", "", text) sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE) async def generate_stream() -> AsyncGenerator[str, None]: try: similar_examples = None # 流式生成SQL async for chunk in sql_generator.generate_sql_stream(request.description): data = json.loads(chunk) if data["type"] == "similar_examples": similar_examples = data["content"] yield chunk elif data["type"] == "sql_generation": # zjstai模型时,移除标签内容 # if DEFAULT_MODEL_TYPE == "zjstai": # data["content"] = remove_think_tag(data["content"]) # yield json.dumps(data, ensure_ascii=False) + "\n" # else: yield chunk elif data["type"] == "sql_result": sql = data["content"] result = await sql_generator.execute_sql(sql) if result["status"] == "error": yield json.dumps({ "type": "error", "content": result["content"] }, ensure_ascii=False) + "\n" return yield json.dumps({ "type": "result", "data": { "sql": sql, "exec_result": result["data"] } }, ensure_ascii=False) + "\n" else: yield chunk except Exception as e: traceback.print_exc() yield json.dumps({ "type": "error", "content": str(e) }, ensure_ascii=False) + "\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) @app.on_event("shutdown") async def shutdown_event(): """ 应用关闭时清理资源 """ pass if __name__ == "__main__": uvicorn.run(app, host=server_host, port=server_port) # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)