123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- import pandas as pd
- import plotly.express as px
- import json
- from fastapi.responses import StreamingResponse
- from typing import AsyncGenerator, List, Dict, Any
- import uvicorn
- import traceback
- import asyncio
- from xuanzhi_query import router as xz_router
- from sql_generator import SQLGenerator
- app = FastAPI(title="Land Analysis API")
- app.include_router(xz_router)
- # 配置CORS
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class QueryRequest(BaseModel):
- description: str
- class AnalysisResult(BaseModel):
- sql: str
- data: list
- visualization: dict = None
- similar_examples: List[Dict[str, Any]] = None
- sql_generator = SQLGenerator()
- @app.post("/land_analysis/stream")
- async def stream_land_analysis(request: QueryRequest):
- """
- 流式返回土地分析结果
- """
- async def generate_stream() -> AsyncGenerator[str, None]:
- try:
- similar_examples = None
- # 流式生成SQL
- async for chunk in sql_generator.generate_sql_stream(request.description):
- data = json.loads(chunk)
- if data["type"] == "similar_examples":
- similar_examples = data["content"]
- yield chunk
- elif data["type"] == "sql_generation":
- yield chunk
- elif data["type"] == "sql_result":
- # 获取到完整的SQL后执行
- sql = data["content"]
- result = await sql_generator.execute_sql(sql)
- if result["status"] == "error":
- yield json.dumps({
- "type": "error",
- "content": result["content"]
- }, ensure_ascii=False) + "\n"
- return
- # 返回最终结果
- yield json.dumps({
- "type": "result",
- "data": {
- "sql": sql,
- "exec_result": result["data"]
- }
- }, ensure_ascii=False) + "\n"
- else:
- yield chunk
- except Exception as e:
- traceback.print_exc()
- yield json.dumps({
- "type": "error",
- "content": str(e)
- }, ensure_ascii=False) + "\n"
- return StreamingResponse(
- generate_stream(),
- media_type="text/event-stream"
- )
- @app.post("/land_analysis", response_model=AnalysisResult)
- async def generate_and_execute_sql(request: QueryRequest):
- try:
- # 获取相似示例
- similar_examples = sql_generator._get_similar_examples(request.description)
- # 构建增强提示词
- enhanced_prompt = f"""
- 基于以下相似示例:
- {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
-
- 请根据以下描述生成SQL查询:
- {request.description}
- """
- # 生成SQL
- sql = await sql_generator.chain.arun(enhanced_prompt)
- # 执行SQL
- result = await sql_generator.execute_sql(sql)
- if result["status"] == "error":
- raise HTTPException(status_code=400, detail=result["message"])
- return AnalysisResult(
- sql=sql,
- exec_result=result["data"]
- )
- except Exception as e:
- traceback.print_exc()
- raise HTTPException(status_code=500, detail=str(e))
- @app.on_event("shutdown")
- async def shutdown_event():
- """
- 应用关闭时清理资源
- """
- await sql_generator.close()
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=8521)
|