Bladeren bron

选址提示词修改

liutao 1 maand geleden
bovenliggende
commit
f9ab347902
5 gewijzigde bestanden met toevoegingen van 27 en 34 verwijderingen
  1. 3 7
      landsite_agent/config.py
  2. 1 9
      landsite_agent/database.py
  3. 4 1
      landsite_agent/main.py
  4. 1 0
      landsite_agent/requirements.txt
  5. 18 17
      landsite_agent/sql_generator.py

+ 3 - 7
landsite_agent/config.py

@@ -9,6 +9,7 @@ class ModelConfig(BaseModel):
     temperature: float = 0
     max_tokens: int = 2000
 
+
 # AI模型配置
 model_list: Dict[str, Dict[str, Any]] = {
     "openai": {
@@ -20,18 +21,13 @@ model_list: Dict[str, Dict[str, Any]] = {
         "api_key": "none",
         "api_base": "http://172.27.27.20:20331/v1",
         "model_name": "DeepSeek-R1-Distill-Qwen-32B",
-    },
-    "local": {
-        "api_key": "your-local-api-key",
-        "api_base": "http://localhost:8000/v1",
-        "model_name": "local-model",
-        "temperature": 0,
-        "max_tokens": 2000
     }
 }
 
 # 新增全局默认模型类型开关
 DEFAULT_MODEL_TYPE = "openai"
+# # 线上服务器配置
+# DEFAULT_MODEL_TYPE = "zjstai"
 
 def get_model_config(model_type: str = "openai") -> ModelConfig:
     """

+ 1 - 9
landsite_agent/database.py

@@ -14,18 +14,10 @@ db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
         "port": "5432",
         "database": "sde",
         "user": "sde",
-        "password": "sde",
+        "password": "sde"
     }
 }
 
-# 向量模型配置
-vector_model_config: dict[str, dict[str, str]] = {
-    "m3e-base": {
-        "model_path": r"E:\项目临时\AI大模型\m3e-base",
-        "device": "cpu"
-    },
-    # 可扩展其他向量模型
-}
 
 class Database:
     def __init__(self, db_type: str = "pg"):

+ 4 - 1
landsite_agent/main.py

@@ -21,6 +21,8 @@ app.add_middleware(
     allow_methods=["*"],
     allow_headers=["*"],
 )
+server_host = "0.0.0.0"
+server_port = 8521
 
 
 class QueryRequest(BaseModel):
@@ -107,4 +109,5 @@ async def shutdown_event():
 
 
 if __name__ == "__main__":
-    uvicorn.run(app, host="0.0.0.0", port=8521)
+    uvicorn.run(app, host=server_host, port=server_port)
+    # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)

+ 1 - 0
landsite_agent/requirements.txt

@@ -8,3 +8,4 @@ langchain_openai==0.3.18
 pydantic==2.11.5
 Shapely==2.1.1
 uvicorn==0.34.2
+faiss-cpu==1.11.0

+ 18 - 17
landsite_agent/sql_generator.py

@@ -9,9 +9,10 @@ import json
 import traceback
 from config import get_model_config
 from typing import List, Dict, Any
-from database import Database, vector_model_config
+from database import Database
 import re
 
+
 class SQLGenerator:
     def __init__(self, model_type: str = "openai"):
         # 获取模型配置
@@ -27,24 +28,24 @@ class SQLGenerator:
 
         # 初始化提示词模板
         self.prompt = ChatPromptTemplate.from_template(get_prompt())
-        
+
         # 构建链式调用
         self.chain = (
-            {
-                "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
-                "question": lambda x: x["question"]
-            }
-            | self.prompt
-            | self.llm
-            | StrOutputParser()
+                {
+                    "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
+                    "question": lambda x: x["question"]
+                }
+                | self.prompt
+                | self.llm
+                | StrOutputParser()
         )
 
         # 初始化本地m3e-base模型
-        model_cfg = vector_model_config["m3e-base"]
-        self.model_path = model_cfg["model_path"]
+        model_path = r"E:\项目临时\AI大模型\m3e-base"
+        # model_path= r"/data/m3e-base"
         self.embeddings = HuggingFaceEmbeddings(
-            model_name=self.model_path,
-            model_kwargs={'device': model_cfg["device"]},
+            model_name=model_path,
+            model_kwargs={'device': "cpu"},
             encode_kwargs={'normalize_embeddings': True}
         )
 
@@ -62,7 +63,7 @@ class SQLGenerator:
         # 使用正则表达式匹配markdown代码块中的SQL
         sql_pattern = r"```sql\n(.*?)\n```"
         match = re.search(sql_pattern, text, re.DOTALL)
-        
+
         if match:
             # 提取SQL语句并清理
             sql = match.group(1).strip()
@@ -179,7 +180,7 @@ class SQLGenerator:
                 # 使用正则表达式提取SQL代码块
                 sql_pattern = r"```sql\n(.*?)\n```"
                 sql_match = re.search(sql_pattern, full_response, re.DOTALL)
-                
+
                 if sql_match:
                     # 提取并清理SQL
                     sql_content = sql_match.group(1).strip()
@@ -190,7 +191,7 @@ class SQLGenerator:
                     # 确保SQL语句完整
                     if not sql_content.strip().endswith(';'):
                         sql_content = sql_content.strip() + ';'
-                    
+
                     yield json.dumps({
                         "type": "sql_result",
                         "content": sql_content
@@ -226,7 +227,7 @@ class SQLGenerator:
             # 确保SQL是完整的
             if not sql.strip().endswith(';'):
                 sql = sql.strip() + ';'
-            
+
             result = await self.db.execute_query(sql)
             return {
                 "status": "success",