main.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from fastapi import FastAPI, HTTPException
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. import pandas as pd
  5. import plotly.express as px
  6. import json
  7. from fastapi.responses import StreamingResponse
  8. from typing import AsyncGenerator, List, Dict, Any
  9. import uvicorn
  10. import traceback
  11. import asyncio
  12. from sql_generator import SQLGenerator
  13. app = FastAPI(title="Land Analysis API")
  14. # 配置CORS
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=["*"],
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. class QueryRequest(BaseModel):
  23. description: str
  24. class AnalysisResult(BaseModel):
  25. sql: str
  26. data: list
  27. visualization: dict = None
  28. similar_examples: List[Dict[str, Any]] = None
  29. sql_generator = SQLGenerator()
  30. @app.post("/land_analysis/stream")
  31. async def stream_land_analysis(request: QueryRequest):
  32. """
  33. 流式返回土地分析结果
  34. """
  35. async def generate_stream() -> AsyncGenerator[str, None]:
  36. try:
  37. similar_examples = None
  38. # 流式生成SQL
  39. async for chunk in sql_generator.generate_sql_stream(request.description):
  40. data = json.loads(chunk)
  41. if data["type"] == "similar_examples":
  42. similar_examples = data["content"]
  43. yield chunk
  44. elif data["type"] == "sql_generation":
  45. yield chunk
  46. elif data["type"] == "sql_result":
  47. # 获取到完整的SQL后执行
  48. sql = data["content"]
  49. result = await sql_generator.execute_sql(sql)
  50. if result["status"] == "error":
  51. yield json.dumps({
  52. "type": "error",
  53. "message": result["message"]
  54. }) + "\n"
  55. return
  56. # 生成可视化
  57. df = pd.DataFrame(result["data"])
  58. visualization = None
  59. if not df.empty:
  60. if len(df.columns) >= 2:
  61. if df.select_dtypes(include=['number']).columns.any():
  62. fig = px.bar(df, x=df.columns[0], y=df.columns[1])
  63. visualization = json.loads(fig.to_json())
  64. # 返回最终结果
  65. yield json.dumps({
  66. "type": "result",
  67. "data": {
  68. "sql": sql,
  69. "data": result["data"],
  70. "visualization": visualization,
  71. "similar_examples": similar_examples
  72. }
  73. }) + "\n"
  74. else:
  75. yield chunk
  76. except Exception as e:
  77. traceback.print_exc()
  78. yield json.dumps({
  79. "type": "error",
  80. "message": str(e)
  81. }) + "\n"
  82. return StreamingResponse(
  83. generate_stream(),
  84. media_type="text/event-stream"
  85. )
  86. @app.post("/land_analysis", response_model=AnalysisResult)
  87. async def generate_and_execute_sql(request: QueryRequest):
  88. try:
  89. # 获取相似示例
  90. similar_examples = sql_generator._get_similar_examples(request.description)
  91. # 构建增强提示词
  92. enhanced_prompt = f"""
  93. 基于以下相似示例:
  94. {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
  95. 请根据以下描述生成SQL查询:
  96. {request.description}
  97. """
  98. # 生成SQL
  99. sql = await sql_generator.chain.arun(enhanced_prompt)
  100. # 执行SQL
  101. result = await sql_generator.execute_sql(sql)
  102. if result["status"] == "error":
  103. raise HTTPException(status_code=400, detail=result["message"])
  104. # 生成可视化
  105. df = pd.DataFrame(result["data"])
  106. visualization = None
  107. if not df.empty:
  108. if len(df.columns) >= 2:
  109. if df.select_dtypes(include=['number']).columns.any():
  110. fig = px.bar(df, x=df.columns[0], y=df.columns[1])
  111. visualization = json.loads(fig.to_json())
  112. return AnalysisResult(
  113. sql=sql,
  114. data=result["data"],
  115. visualization=visualization,
  116. similar_examples=similar_examples
  117. )
  118. except Exception as e:
  119. traceback.print_exc()
  120. raise HTTPException(status_code=500, detail=str(e))
  121. @app.on_event("shutdown")
  122. async def shutdown_event():
  123. """
  124. 应用关闭时清理资源
  125. """
  126. await sql_generator.close()
  127. if __name__ == "__main__":
  128. uvicorn.run(app, host="0.0.0.0", port=8001)