main.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from fastapi import FastAPI, HTTPException
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. import json
  5. from fastapi.responses import StreamingResponse
  6. from typing import AsyncGenerator, List, Dict, Any
  7. import uvicorn
  8. import traceback
  9. from xuanzhi_query import router as xz_router
  10. from sql_generator import SQLGenerator
  11. from config import DEFAULT_MODEL_TYPE
  12. app = FastAPI(title="Land Analysis API")
  13. app.include_router(xz_router)
  14. # 配置CORS
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=["*"],
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. server_host = "0.0.0.0"
  23. server_port = 8521
  24. class QueryRequest(BaseModel):
  25. description: str
  26. class AnalysisResult(BaseModel):
  27. sql: str
  28. data: list
  29. @app.post("/land_analysis/stream")
  30. async def stream_land_analysis(request: QueryRequest):
  31. """
  32. 流式返回土地分析结果
  33. """
  34. sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
  35. async def generate_stream() -> AsyncGenerator[str, None]:
  36. try:
  37. # 流式生成SQL
  38. async for chunk in sql_generator.generate_sql_stream(request.description):
  39. data = json.loads(chunk)
  40. if data["type"] == "sql_generation":
  41. yield "data: " + chunk + "\n"
  42. elif data["type"] == "sql_result":
  43. sql = data["content"]
  44. result = await sql_generator.execute_sql(sql)
  45. if result["status"] == "error":
  46. error_data = json.dumps({
  47. "type": "error",
  48. "content": result["content"]
  49. }, ensure_ascii=False)
  50. yield "data: " + error_data + "\n\n"
  51. return
  52. result_data = json.dumps({
  53. "type": "result",
  54. "data": result["data"]
  55. }, ensure_ascii=False)
  56. yield "data: " + result_data + "\n\n"
  57. else:
  58. yield "data: " + chunk + "\n"
  59. except Exception as e:
  60. traceback.print_exc()
  61. error_data = json.dumps({
  62. "type": "error",
  63. "content": str(e)
  64. }, ensure_ascii=False)
  65. yield "data: " + error_data + "\n"
  66. return StreamingResponse(
  67. generate_stream(),
  68. media_type="text/event-stream",
  69. headers={
  70. "Cache-Control": "no-cache",
  71. "Connection": "keep-alive",
  72. "X-Accel-Buffering": "no"
  73. }
  74. )
  75. if __name__ == "__main__":
  76. uvicorn.run(app, host=server_host, port=server_port)
  77. # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)