Browse Source

选址提示词修改

liutao 1 month ago
parent
commit
cb25eb918f

+ 7 - 6
landsite_agent/config.py

@@ -16,12 +16,10 @@ model_list: Dict[str, Dict[str, Any]] = {
         "api_base": "http://ac.zjugis.com:8511/v1",
         "api_base": "http://ac.zjugis.com:8511/v1",
         "model_name": "qwen2.5-instruct"
         "model_name": "qwen2.5-instruct"
     },
     },
-    "azure": {
-        "api_key": "your-azure-api-key",
-        "api_base": "https://your-azure-endpoint.openai.azure.com",
-        "model_name": "gpt-35-turbo",
-        "temperature": 0,
-        "max_tokens": 2000
+    "zjstai": {
+        "api_key": "none",
+        "api_base": "http://172.27.27.20:20331/v1",
+        "model_name": "DeepSeek-R1-Distill-Qwen-32B",
     },
     },
     "local": {
     "local": {
         "api_key": "your-local-api-key",
         "api_key": "your-local-api-key",
@@ -32,6 +30,9 @@ model_list: Dict[str, Dict[str, Any]] = {
     }
     }
 }
 }
 
 
+# 新增全局默认模型类型开关
+DEFAULT_MODEL_TYPE = "openai"
+
 def get_model_config(model_type: str = "openai") -> ModelConfig:
 def get_model_config(model_type: str = "openai") -> ModelConfig:
     """
     """
     获取模型配置
     获取模型配置

+ 8 - 0
landsite_agent/database.py

@@ -18,6 +18,14 @@ db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
     }
     }
 }
 }
 
 
+# 向量模型配置
+vector_model_config: dict[str, dict[str, str]] = {
+    "m3e-base": {
+        "model_path": r"E:\项目临时\AI大模型\m3e-base",
+        "device": "cpu"
+    },
+    # 可扩展其他向量模型
+}
 
 
 class Database:
 class Database:
     def __init__(self, db_type: str = "pg"):
     def __init__(self, db_type: str = "pg"):

+ 5 - 0
landsite_agent/examples.json

@@ -28,5 +28,10 @@
     "query_type": "land_site_selection",
     "query_type": "land_site_selection",
     "query": "帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划",
     "query": "帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划",
     "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.58 27.97)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
     "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.58 27.97)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "请在萧山区找出面积最大的商业用地,数据表是控制性详细规划",
+    "sql_code": "SELECT id FROM sde.kzxxxgh WHERE xzqmc = '萧山区' AND ydxz LIKE '%商业%' AND shape IS NOT NUL ORDER BY ydmj DESC LIMIT 1;"
   }
   }
 ] 
 ] 

+ 14 - 44
landsite_agent/main.py

@@ -1,16 +1,15 @@
 from fastapi import FastAPI, HTTPException
 from fastapi import FastAPI, HTTPException
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from pydantic import BaseModel
 from pydantic import BaseModel
-import pandas as pd
-import plotly.express as px
 import json
 import json
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 from typing import AsyncGenerator, List, Dict, Any
 from typing import AsyncGenerator, List, Dict, Any
 import uvicorn
 import uvicorn
 import traceback
 import traceback
-import asyncio
 from xuanzhi_query import router as xz_router
 from xuanzhi_query import router as xz_router
 from sql_generator import SQLGenerator
 from sql_generator import SQLGenerator
+from config import DEFAULT_MODEL_TYPE
+import re
 
 
 app = FastAPI(title="Land Analysis API")
 app = FastAPI(title="Land Analysis API")
 app.include_router(xz_router)
 app.include_router(xz_router)
@@ -35,15 +34,18 @@ class AnalysisResult(BaseModel):
     similar_examples: List[Dict[str, Any]] = None
     similar_examples: List[Dict[str, Any]] = None
 
 
 
 
-sql_generator = SQLGenerator()
-
-
 @app.post("/land_analysis/stream")
 @app.post("/land_analysis/stream")
 async def stream_land_analysis(request: QueryRequest):
 async def stream_land_analysis(request: QueryRequest):
     """
     """
     流式返回土地分析结果
     流式返回土地分析结果
     """
     """
 
 
+    def remove_think_tag(text: str) -> str:
+        # 移除<think>标签及其内容
+        return re.sub(r"<think>[\s\S]*?</think>\\n*", "", text)
+
+    sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
+
     async def generate_stream() -> AsyncGenerator[str, None]:
     async def generate_stream() -> AsyncGenerator[str, None]:
         try:
         try:
             similar_examples = None
             similar_examples = None
@@ -56,9 +58,13 @@ async def stream_land_analysis(request: QueryRequest):
                     similar_examples = data["content"]
                     similar_examples = data["content"]
                     yield chunk
                     yield chunk
                 elif data["type"] == "sql_generation":
                 elif data["type"] == "sql_generation":
+                    # zjstai模型时,移除<think>标签内容
+                    # if DEFAULT_MODEL_TYPE == "zjstai":
+                    #     data["content"] = remove_think_tag(data["content"])
+                    #     yield json.dumps(data, ensure_ascii=False) + "\n"
+                    # else:
                     yield chunk
                     yield chunk
                 elif data["type"] == "sql_result":
                 elif data["type"] == "sql_result":
-                    # 获取到完整的SQL后执行
                     sql = data["content"]
                     sql = data["content"]
                     result = await sql_generator.execute_sql(sql)
                     result = await sql_generator.execute_sql(sql)
 
 
@@ -69,7 +75,6 @@ async def stream_land_analysis(request: QueryRequest):
                         }, ensure_ascii=False) + "\n"
                         }, ensure_ascii=False) + "\n"
                         return
                         return
 
 
-                    # 返回最终结果
                     yield json.dumps({
                     yield json.dumps({
                         "type": "result",
                         "type": "result",
                         "data": {
                         "data": {
@@ -93,48 +98,13 @@ async def stream_land_analysis(request: QueryRequest):
     )
     )
 
 
 
 
-@app.post("/land_analysis", response_model=AnalysisResult)
-async def generate_and_execute_sql(request: QueryRequest):
-    try:
-        # 获取相似示例
-        similar_examples = sql_generator._get_similar_examples(request.description)
-
-        # 构建增强提示词
-        enhanced_prompt = f"""
-        基于以下相似示例:
-        {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
-        
-        请根据以下描述生成SQL查询:
-        {request.description}
-        """
-
-        # 生成SQL
-        sql = await sql_generator.chain.arun(enhanced_prompt)
-
-        # 执行SQL
-        result = await sql_generator.execute_sql(sql)
-
-        if result["status"] == "error":
-            raise HTTPException(status_code=400, detail=result["message"])
-
-        return AnalysisResult(
-            sql=sql,
-            exec_result=result["data"]
-        )
-
-    except Exception as e:
-        traceback.print_exc()
-        raise HTTPException(status_code=500, detail=str(e))
-
-
 @app.on_event("shutdown")
 @app.on_event("shutdown")
 async def shutdown_event():
 async def shutdown_event():
     """
     """
     应用关闭时清理资源
     应用关闭时清理资源
     """
     """
-    await sql_generator.close()
+    pass
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=8521)
     uvicorn.run(app, host="0.0.0.0", port=8521)
-

BIN
landsite_agent/requirements.txt


+ 8 - 9
landsite_agent/sql_generator.py

@@ -1,17 +1,15 @@
 from langchain_openai import ChatOpenAI
 from langchain_openai import ChatOpenAI
 from langchain.prompts import ChatPromptTemplate
 from langchain.prompts import ChatPromptTemplate
 from langchain_core.output_parsers import StrOutputParser
 from langchain_core.output_parsers import StrOutputParser
-from langchain_core.runnables import RunnablePassthrough
 from langchain_core.documents import Document
 from langchain_core.documents import Document
 from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
 from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
 from langchain_community.vectorstores import FAISS
 from langchain_community.vectorstores import FAISS
-import os
 from prompt_template import get_prompt
 from prompt_template import get_prompt
 import json
 import json
 import traceback
 import traceback
 from config import get_model_config
 from config import get_model_config
 from typing import List, Dict, Any
 from typing import List, Dict, Any
-from database import Database
+from database import Database, vector_model_config
 import re
 import re
 
 
 class SQLGenerator:
 class SQLGenerator:
@@ -42,10 +40,11 @@ class SQLGenerator:
         )
         )
 
 
         # 初始化本地m3e-base模型
         # 初始化本地m3e-base模型
-        self.model_path = r"E:\项目临时\AI大模型\m3e-base"
+        model_cfg = vector_model_config["m3e-base"]
+        self.model_path = model_cfg["model_path"]
         self.embeddings = HuggingFaceEmbeddings(
         self.embeddings = HuggingFaceEmbeddings(
             model_name=self.model_path,
             model_name=self.model_path,
-            model_kwargs={'device': 'cpu'},
+            model_kwargs={'device': model_cfg["device"]},
             encode_kwargs={'normalize_embeddings': True}
             encode_kwargs={'normalize_embeddings': True}
         )
         )
 
 
@@ -138,7 +137,7 @@ class SQLGenerator:
         """
         """
         try:
         try:
             # 开始生成SQL
             # 开始生成SQL
-            yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
+            yield json.dumps({"type": "start", "content": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
 
 
             # 获取相似示例
             # 获取相似示例
             similar_examples = self._get_similar_examples(query_description)
             similar_examples = self._get_similar_examples(query_description)
@@ -203,19 +202,19 @@ class SQLGenerator:
                         "content": full_response
                         "content": full_response
                     }, ensure_ascii=False) + "\n"
                     }, ensure_ascii=False) + "\n"
 
 
-                yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
+                yield json.dumps({"type": "end", "content": "SQL生成完成"}, ensure_ascii=False) + "\n"
             except Exception as e:
             except Exception as e:
                 print(f"Error details: {traceback.format_exc()}")
                 print(f"Error details: {traceback.format_exc()}")
                 yield json.dumps({
                 yield json.dumps({
                     "type": "error",
                     "type": "error",
-                    "message": f"生成SQL时发生错误: {str(e)}"
+                    "content": f"生成SQL时发生错误: {str(e)}"
                 }, ensure_ascii=False) + "\n"
                 }, ensure_ascii=False) + "\n"
 
 
         except Exception as e:
         except Exception as e:
             traceback.print_exc()
             traceback.print_exc()
             yield json.dumps({
             yield json.dumps({
                 "type": "error",
                 "type": "error",
-                "message": str(e)
+                "content": str(e)
             }, ensure_ascii=False) + "\n"
             }, ensure_ascii=False) + "\n"
 
 
     async def execute_sql(self, sql: str) -> dict:
     async def execute_sql(self, sql: str) -> dict: