Bläddra i källkod

选址提示词修改

liutao 1 månad sedan
förälder
incheckning
cb25eb918f

+ 7 - 6
landsite_agent/config.py

@@ -16,12 +16,10 @@ model_list: Dict[str, Dict[str, Any]] = {
         "api_base": "http://ac.zjugis.com:8511/v1",
         "model_name": "qwen2.5-instruct"
     },
-    "azure": {
-        "api_key": "your-azure-api-key",
-        "api_base": "https://your-azure-endpoint.openai.azure.com",
-        "model_name": "gpt-35-turbo",
-        "temperature": 0,
-        "max_tokens": 2000
+    "zjstai": {
+        "api_key": "none",
+        "api_base": "http://172.27.27.20:20331/v1",
+        "model_name": "DeepSeek-R1-Distill-Qwen-32B",
     },
     "local": {
         "api_key": "your-local-api-key",
@@ -32,6 +30,9 @@ model_list: Dict[str, Dict[str, Any]] = {
     }
 }
 
+# 新增全局默认模型类型开关
+DEFAULT_MODEL_TYPE = "openai"
+
 def get_model_config(model_type: str = "openai") -> ModelConfig:
     """
     获取模型配置

+ 8 - 0
landsite_agent/database.py

@@ -18,6 +18,14 @@ db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
     }
 }
 
+# 向量模型配置
+vector_model_config: dict[str, dict[str, str]] = {
+    "m3e-base": {
+        "model_path": r"E:\项目临时\AI大模型\m3e-base",
+        "device": "cpu"
+    },
+    # 可扩展其他向量模型
+}
 
 class Database:
     def __init__(self, db_type: str = "pg"):

+ 5 - 0
landsite_agent/examples.json

@@ -28,5 +28,10 @@
     "query_type": "land_site_selection",
     "query": "帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划",
     "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.58 27.97)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "请在萧山区找出面积最大的商业用地,数据表是控制性详细规划",
+    "sql_code": "SELECT id FROM sde.kzxxxgh WHERE xzqmc = '萧山区' AND ydxz LIKE '%商业%' AND shape IS NOT NUL ORDER BY ydmj DESC LIMIT 1;"
   }
 ] 

+ 14 - 44
landsite_agent/main.py

@@ -1,16 +1,15 @@
 from fastapi import FastAPI, HTTPException
 from fastapi.middleware.cors import CORSMiddleware
 from pydantic import BaseModel
-import pandas as pd
-import plotly.express as px
 import json
 from fastapi.responses import StreamingResponse
 from typing import AsyncGenerator, List, Dict, Any
 import uvicorn
 import traceback
-import asyncio
 from xuanzhi_query import router as xz_router
 from sql_generator import SQLGenerator
+from config import DEFAULT_MODEL_TYPE
+import re
 
 app = FastAPI(title="Land Analysis API")
 app.include_router(xz_router)
@@ -35,15 +34,18 @@ class AnalysisResult(BaseModel):
     similar_examples: List[Dict[str, Any]] = None
 
 
-sql_generator = SQLGenerator()
-
-
 @app.post("/land_analysis/stream")
 async def stream_land_analysis(request: QueryRequest):
     """
     流式返回土地分析结果
     """
 
+    def remove_think_tag(text: str) -> str:
+        # 移除<think>标签及其内容
+        return re.sub(r"<think>[\s\S]*?</think>\\n*", "", text)
+
+    sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
+
     async def generate_stream() -> AsyncGenerator[str, None]:
         try:
             similar_examples = None
@@ -56,9 +58,13 @@ async def stream_land_analysis(request: QueryRequest):
                     similar_examples = data["content"]
                     yield chunk
                 elif data["type"] == "sql_generation":
+                    # zjstai模型时,移除<think>标签内容
+                    # if DEFAULT_MODEL_TYPE == "zjstai":
+                    #     data["content"] = remove_think_tag(data["content"])
+                    #     yield json.dumps(data, ensure_ascii=False) + "\n"
+                    # else:
                     yield chunk
                 elif data["type"] == "sql_result":
-                    # 获取到完整的SQL后执行
                     sql = data["content"]
                     result = await sql_generator.execute_sql(sql)
 
@@ -69,7 +75,6 @@ async def stream_land_analysis(request: QueryRequest):
                         }, ensure_ascii=False) + "\n"
                         return
 
-                    # 返回最终结果
                     yield json.dumps({
                         "type": "result",
                         "data": {
@@ -93,48 +98,13 @@ async def stream_land_analysis(request: QueryRequest):
     )
 
 
-@app.post("/land_analysis", response_model=AnalysisResult)
-async def generate_and_execute_sql(request: QueryRequest):
-    try:
-        # 获取相似示例
-        similar_examples = sql_generator._get_similar_examples(request.description)
-
-        # 构建增强提示词
-        enhanced_prompt = f"""
-        基于以下相似示例:
-        {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
-        
-        请根据以下描述生成SQL查询:
-        {request.description}
-        """
-
-        # 生成SQL
-        sql = await sql_generator.chain.arun(enhanced_prompt)
-
-        # 执行SQL
-        result = await sql_generator.execute_sql(sql)
-
-        if result["status"] == "error":
-            raise HTTPException(status_code=400, detail=result["message"])
-
-        return AnalysisResult(
-            sql=sql,
-            exec_result=result["data"]
-        )
-
-    except Exception as e:
-        traceback.print_exc()
-        raise HTTPException(status_code=500, detail=str(e))
-
-
 @app.on_event("shutdown")
 async def shutdown_event():
     """
     应用关闭时清理资源
     """
-    await sql_generator.close()
+    pass
 
 
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=8521)
-

BIN
landsite_agent/requirements.txt


+ 8 - 9
landsite_agent/sql_generator.py

@@ -1,17 +1,15 @@
 from langchain_openai import ChatOpenAI
 from langchain.prompts import ChatPromptTemplate
 from langchain_core.output_parsers import StrOutputParser
-from langchain_core.runnables import RunnablePassthrough
 from langchain_core.documents import Document
 from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
 from langchain_community.vectorstores import FAISS
-import os
 from prompt_template import get_prompt
 import json
 import traceback
 from config import get_model_config
 from typing import List, Dict, Any
-from database import Database
+from database import Database, vector_model_config
 import re
 
 class SQLGenerator:
@@ -42,10 +40,11 @@ class SQLGenerator:
         )
 
         # 初始化本地m3e-base模型
-        self.model_path = r"E:\项目临时\AI大模型\m3e-base"
+        model_cfg = vector_model_config["m3e-base"]
+        self.model_path = model_cfg["model_path"]
         self.embeddings = HuggingFaceEmbeddings(
             model_name=self.model_path,
-            model_kwargs={'device': 'cpu'},
+            model_kwargs={'device': model_cfg["device"]},
             encode_kwargs={'normalize_embeddings': True}
         )
 
@@ -138,7 +137,7 @@ class SQLGenerator:
         """
         try:
             # 开始生成SQL
-            yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
+            yield json.dumps({"type": "start", "content": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
 
             # 获取相似示例
             similar_examples = self._get_similar_examples(query_description)
@@ -203,19 +202,19 @@ class SQLGenerator:
                         "content": full_response
                     }, ensure_ascii=False) + "\n"
 
-                yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
+                yield json.dumps({"type": "end", "content": "SQL生成完成"}, ensure_ascii=False) + "\n"
             except Exception as e:
                 print(f"Error details: {traceback.format_exc()}")
                 yield json.dumps({
                     "type": "error",
-                    "message": f"生成SQL时发生错误: {str(e)}"
+                    "content": f"生成SQL时发生错误: {str(e)}"
                 }, ensure_ascii=False) + "\n"
 
         except Exception as e:
             traceback.print_exc()
             yield json.dumps({
                 "type": "error",
-                "message": str(e)
+                "content": str(e)
             }, ensure_ascii=False) + "\n"
 
     async def execute_sql(self, sql: str) -> dict: