|
@@ -1,16 +1,15 @@
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel
|
|
|
-import pandas as pd
|
|
|
-import plotly.express as px
|
|
|
import json
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from typing import AsyncGenerator, List, Dict, Any
|
|
|
import uvicorn
|
|
|
import traceback
|
|
|
-import asyncio
|
|
|
from xuanzhi_query import router as xz_router
|
|
|
from sql_generator import SQLGenerator
|
|
|
+from config import DEFAULT_MODEL_TYPE
|
|
|
+import re
|
|
|
|
|
|
app = FastAPI(title="Land Analysis API")
|
|
|
app.include_router(xz_router)
|
|
@@ -35,15 +34,18 @@ class AnalysisResult(BaseModel):
|
|
|
similar_examples: List[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
-sql_generator = SQLGenerator()
|
|
|
-
|
|
|
-
|
|
|
@app.post("/land_analysis/stream")
|
|
|
async def stream_land_analysis(request: QueryRequest):
|
|
|
"""
|
|
|
流式返回土地分析结果
|
|
|
"""
|
|
|
|
|
|
+ def remove_think_tag(text: str) -> str:
|
|
|
+ # 移除<think>标签及其内容
|
|
|
+ return re.sub(r"<think>[\s\S]*?</think>\\n*", "", text)
|
|
|
+
|
|
|
+ sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
|
|
|
+
|
|
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
|
|
try:
|
|
|
similar_examples = None
|
|
@@ -56,9 +58,13 @@ async def stream_land_analysis(request: QueryRequest):
|
|
|
similar_examples = data["content"]
|
|
|
yield chunk
|
|
|
elif data["type"] == "sql_generation":
|
|
|
+ # zjstai模型时,移除<think>标签内容
|
|
|
+ # if DEFAULT_MODEL_TYPE == "zjstai":
|
|
|
+ # data["content"] = remove_think_tag(data["content"])
|
|
|
+ # yield json.dumps(data, ensure_ascii=False) + "\n"
|
|
|
+ # else:
|
|
|
yield chunk
|
|
|
elif data["type"] == "sql_result":
|
|
|
- # 获取到完整的SQL后执行
|
|
|
sql = data["content"]
|
|
|
result = await sql_generator.execute_sql(sql)
|
|
|
|
|
@@ -69,7 +75,6 @@ async def stream_land_analysis(request: QueryRequest):
|
|
|
}, ensure_ascii=False) + "\n"
|
|
|
return
|
|
|
|
|
|
- # 返回最终结果
|
|
|
yield json.dumps({
|
|
|
"type": "result",
|
|
|
"data": {
|
|
@@ -93,48 +98,13 @@ async def stream_land_analysis(request: QueryRequest):
|
|
|
)
|
|
|
|
|
|
|
|
|
-@app.post("/land_analysis", response_model=AnalysisResult)
|
|
|
-async def generate_and_execute_sql(request: QueryRequest):
|
|
|
- try:
|
|
|
- # 获取相似示例
|
|
|
- similar_examples = sql_generator._get_similar_examples(request.description)
|
|
|
-
|
|
|
- # 构建增强提示词
|
|
|
- enhanced_prompt = f"""
|
|
|
- 基于以下相似示例:
|
|
|
- {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
|
|
|
-
|
|
|
- 请根据以下描述生成SQL查询:
|
|
|
- {request.description}
|
|
|
- """
|
|
|
-
|
|
|
- # 生成SQL
|
|
|
- sql = await sql_generator.chain.arun(enhanced_prompt)
|
|
|
-
|
|
|
- # 执行SQL
|
|
|
- result = await sql_generator.execute_sql(sql)
|
|
|
-
|
|
|
- if result["status"] == "error":
|
|
|
- raise HTTPException(status_code=400, detail=result["message"])
|
|
|
-
|
|
|
- return AnalysisResult(
|
|
|
- sql=sql,
|
|
|
- exec_result=result["data"]
|
|
|
- )
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- traceback.print_exc()
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
-
|
|
|
-
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown_event():
|
|
|
"""
|
|
|
应用关闭时清理资源
|
|
|
"""
|
|
|
- await sql_generator.close()
|
|
|
+ pass
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8521)
|
|
|
-
|