123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- import json
- from fastapi.responses import StreamingResponse
- from typing import AsyncGenerator, List, Dict, Any
- import uvicorn
- import traceback
- from xuanzhi_query import router as xz_router
- from sql_generator import SQLGenerator
- from config import DEFAULT_MODEL_TYPE
- app = FastAPI(title="Land Analysis API")
- app.include_router(xz_router)
- # 配置CORS
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- server_host = "0.0.0.0"
- server_port = 8521
- class QueryRequest(BaseModel):
- description: str
- class AnalysisResult(BaseModel):
- sql: str
- data: list
- visualization: dict = None
- similar_examples: List[Dict[str, Any]] = None
- @app.post("/land_analysis/stream")
- async def stream_land_analysis(request: QueryRequest):
- """
- 流式返回土地分析结果
- """
- sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
- async def generate_stream() -> AsyncGenerator[str, None]:
- try:
- # 流式生成SQL
- async for chunk in sql_generator.generate_sql_stream(request.description):
- data = json.loads(chunk)
- if data["type"] == "sql_generation":
- yield chunk
- elif data["type"] == "sql_result":
- sql = data["content"]
- result = await sql_generator.execute_sql(sql)
- if result["status"] == "error":
- yield json.dumps({
- "type": "error",
- "content": result["content"]
- }, ensure_ascii=False) + "\n"
- return
- yield json.dumps({
- "type": "result",
- "data": {
- "sql": sql,
- "exec_result": result["data"]
- }
- }, ensure_ascii=False) + "\n"
- else:
- yield chunk
- except Exception as e:
- traceback.print_exc()
- yield json.dumps({
- "type": "error",
- "content": str(e)
- }, ensure_ascii=False) + "\n"
- return StreamingResponse(
- generate_stream(),
- media_type="text/event-stream"
- )
- if __name__ == "__main__":
- uvicorn.run(app, host=server_host, port=server_port)
- # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)
|