main.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from fastapi import FastAPI, HTTPException
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. import json
  5. from fastapi.responses import StreamingResponse
  6. from typing import AsyncGenerator, List, Dict, Any
  7. import uvicorn
  8. import traceback
  9. from xuanzhi_query import router as xz_router
  10. from sql_generator import SQLGenerator
  11. from config import DEFAULT_MODEL_TYPE
  12. import re
  13. app = FastAPI(title="Land Analysis API")
  14. app.include_router(xz_router)
  15. # 配置CORS
  16. app.add_middleware(
  17. CORSMiddleware,
  18. allow_origins=["*"],
  19. allow_credentials=True,
  20. allow_methods=["*"],
  21. allow_headers=["*"],
  22. )
  23. server_host = "0.0.0.0"
  24. server_port = 8521
  25. class QueryRequest(BaseModel):
  26. description: str
  27. class AnalysisResult(BaseModel):
  28. sql: str
  29. data: list
  30. visualization: dict = None
  31. similar_examples: List[Dict[str, Any]] = None
  32. @app.post("/land_analysis/stream")
  33. async def stream_land_analysis(request: QueryRequest):
  34. """
  35. 流式返回土地分析结果
  36. """
  37. def remove_think_tag(text: str) -> str:
  38. # 移除<think>标签及其内容
  39. return re.sub(r"<think>[\s\S]*?</think>\\n*", "", text)
  40. sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
  41. async def generate_stream() -> AsyncGenerator[str, None]:
  42. try:
  43. similar_examples = None
  44. # 流式生成SQL
  45. async for chunk in sql_generator.generate_sql_stream(request.description):
  46. data = json.loads(chunk)
  47. if data["type"] == "similar_examples":
  48. similar_examples = data["content"]
  49. yield chunk
  50. elif data["type"] == "sql_generation":
  51. # zjstai模型时,移除<think>标签内容
  52. # if DEFAULT_MODEL_TYPE == "zjstai":
  53. # data["content"] = remove_think_tag(data["content"])
  54. # yield json.dumps(data, ensure_ascii=False) + "\n"
  55. # else:
  56. yield chunk
  57. elif data["type"] == "sql_result":
  58. sql = data["content"]
  59. result = await sql_generator.execute_sql(sql)
  60. if result["status"] == "error":
  61. yield json.dumps({
  62. "type": "error",
  63. "content": result["content"]
  64. }, ensure_ascii=False) + "\n"
  65. return
  66. yield json.dumps({
  67. "type": "result",
  68. "data": {
  69. "sql": sql,
  70. "exec_result": result["data"]
  71. }
  72. }, ensure_ascii=False) + "\n"
  73. else:
  74. yield chunk
  75. except Exception as e:
  76. traceback.print_exc()
  77. yield json.dumps({
  78. "type": "error",
  79. "content": str(e)
  80. }, ensure_ascii=False) + "\n"
  81. return StreamingResponse(
  82. generate_stream(),
  83. media_type="text/event-stream"
  84. )
  85. @app.on_event("shutdown")
  86. async def shutdown_event():
  87. """
  88. 应用关闭时清理资源
  89. """
  90. pass
  91. if __name__ == "__main__":
  92. uvicorn.run(app, host=server_host, port=server_port)
  93. # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)