main.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from fastapi import FastAPI, Depends, HTTPException
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from sqlalchemy.orm import Session
  4. from database import get_db
  5. from sql_generator import SQLGenerator
  6. from pydantic import BaseModel
  7. import pandas as pd
  8. import plotly.express as px
  9. import json
  10. from fastapi.responses import StreamingResponse
  11. from typing import AsyncGenerator, List, Dict, Any
  12. import uvicorn
  13. app = FastAPI(title="Land Analysis API")
  14. import traceback
  15. # 配置CORS
  16. app.add_middleware(
  17. CORSMiddleware,
  18. allow_origins=["*"],
  19. allow_credentials=True,
  20. allow_methods=["*"],
  21. allow_headers=["*"],
  22. )
  23. class QueryRequest(BaseModel):
  24. description: str
  25. class AnalysisResult(BaseModel):
  26. sql: str
  27. data: list
  28. visualization: dict = None
  29. similar_examples: List[Dict[str, Any]] = None
  30. sql_generator = SQLGenerator()
  31. @app.post("/land_analysis/stream")
  32. async def stream_land_analysis(
  33. request: QueryRequest,
  34. db: Session = Depends(get_db)
  35. ):
  36. """
  37. 流式返回土地分析结果
  38. """
  39. async def generate_stream() -> AsyncGenerator[str, None]:
  40. try:
  41. similar_examples = None
  42. generated_sql = None
  43. # 流式生成SQL
  44. async for chunk in sql_generator.generate_sql_stream(request.description):
  45. data = json.loads(chunk)
  46. if data["type"] == "similar_examples":
  47. similar_examples = data["content"]
  48. yield chunk
  49. elif data["type"] == "sql_generation":
  50. generated_sql = data["content"]
  51. yield chunk
  52. else:
  53. yield chunk
  54. if not generated_sql:
  55. yield json.dumps({
  56. "type": "error",
  57. "message": "SQL生成失败"
  58. }) + "\n"
  59. return
  60. # 执行SQL并返回结果
  61. result = await sql_generator.execute_sql(generated_sql, db)
  62. if result["status"] == "error":
  63. yield json.dumps({
  64. "type": "error",
  65. "message": result["message"]
  66. }) + "\n"
  67. return
  68. # 生成可视化
  69. df = pd.DataFrame(result["data"])
  70. visualization = None
  71. if not df.empty:
  72. if len(df.columns) >= 2:
  73. if df.select_dtypes(include=['number']).columns.any():
  74. fig = px.bar(df, x=df.columns[0], y=df.columns[1])
  75. visualization = json.loads(fig.to_json())
  76. # 返回最终结果
  77. yield json.dumps({
  78. "type": "result",
  79. "data": {
  80. "sql": generated_sql,
  81. "data": result["data"],
  82. "visualization": visualization,
  83. "similar_examples": similar_examples
  84. }
  85. }) + "\n"
  86. except Exception as e:
  87. traceback.print_exc()
  88. yield json.dumps({
  89. "type": "error",
  90. "message": str(e)
  91. }) + "\n"
  92. return StreamingResponse(
  93. generate_stream(),
  94. media_type="text/event-stream"
  95. )
  96. @app.post("/land_analysis", response_model=AnalysisResult)
  97. async def generate_and_execute_sql(
  98. request: QueryRequest,
  99. db: Session = Depends(get_db)
  100. ):
  101. try:
  102. # 获取相似示例
  103. similar_examples = sql_generator._get_similar_examples(request.description)
  104. # 构建增强提示词
  105. enhanced_prompt = f"""
  106. 基于以下相似示例:
  107. {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
  108. 请根据以下描述生成SQL查询:
  109. {request.description}
  110. """
  111. # 生成SQL
  112. sql = await sql_generator.chain.arun(enhanced_prompt)
  113. # 执行SQL
  114. result = await sql_generator.execute_sql(sql, db)
  115. if result["status"] == "error":
  116. raise HTTPException(status_code=400, detail=result["message"])
  117. # 生成可视化
  118. df = pd.DataFrame(result["data"])
  119. visualization = None
  120. if not df.empty:
  121. if len(df.columns) >= 2:
  122. if df.select_dtypes(include=['number']).columns.any():
  123. fig = px.bar(df, x=df.columns[0], y=df.columns[1])
  124. visualization = json.loads(fig.to_json())
  125. return AnalysisResult(
  126. sql=sql,
  127. data=result["data"],
  128. visualization=visualization,
  129. similar_examples=similar_examples
  130. )
  131. except Exception as e:
  132. traceback.print_exc()
  133. raise HTTPException(status_code=500, detail=str(e))
  134. if __name__ == "__main__":
  135. uvicorn.run(app, host="0.0.0.0", port=8001)