123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- import json
- from fastapi.responses import StreamingResponse
- from typing import AsyncGenerator, List, Dict, Any
- import uvicorn
- import traceback
- from xuanzhi_query import router as xz_router
- from sql_generator import SQLGenerator
- from config import DEFAULT_MODEL_TYPE
- import re
- app = FastAPI(title="Land Analysis API")
- app.include_router(xz_router)
- # 配置CORS
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- server_host = "0.0.0.0"
- server_port = 8521
- class QueryRequest(BaseModel):
- description: str
- class AnalysisResult(BaseModel):
- sql: str
- data: list
- visualization: dict = None
- similar_examples: List[Dict[str, Any]] = None
- @app.post("/land_analysis/stream")
- async def stream_land_analysis(request: QueryRequest):
- """
- 流式返回土地分析结果
- """
- def remove_think_tag(text: str) -> str:
- # 移除<think>标签及其内容
- return re.sub(r"<think>[\s\S]*?</think>\\n*", "", text)
- sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
- async def generate_stream() -> AsyncGenerator[str, None]:
- try:
- similar_examples = None
- # 流式生成SQL
- async for chunk in sql_generator.generate_sql_stream(request.description):
- data = json.loads(chunk)
- if data["type"] == "similar_examples":
- similar_examples = data["content"]
- yield chunk
- elif data["type"] == "sql_generation":
- # zjstai模型时,移除<think>标签内容
- # if DEFAULT_MODEL_TYPE == "zjstai":
- # data["content"] = remove_think_tag(data["content"])
- # yield json.dumps(data, ensure_ascii=False) + "\n"
- # else:
- yield chunk
- elif data["type"] == "sql_result":
- sql = data["content"]
- result = await sql_generator.execute_sql(sql)
- if result["status"] == "error":
- yield json.dumps({
- "type": "error",
- "content": result["content"]
- }, ensure_ascii=False) + "\n"
- return
- yield json.dumps({
- "type": "result",
- "data": {
- "sql": sql,
- "exec_result": result["data"]
- }
- }, ensure_ascii=False) + "\n"
- else:
- yield chunk
- except Exception as e:
- traceback.print_exc()
- yield json.dumps({
- "type": "error",
- "content": str(e)
- }, ensure_ascii=False) + "\n"
- return StreamingResponse(
- generate_stream(),
- media_type="text/event-stream"
- )
- @app.on_event("shutdown")
- async def shutdown_event():
- """
- 应用关闭时清理资源
- """
- pass
- if __name__ == "__main__":
- uvicorn.run(app, host=server_host, port=server_port)
- # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)
|