Browse Source

选址提示词修改

liutao 1 month ago
parent
commit
9395a71332

+ 10 - 0
landsite_agent/config.env

@@ -0,0 +1,10 @@
+# AI模型配置
+DEFAULT_MODEL_TYPE=openai
+MODEL_PATH=E:/项目临时/AI大模型/m3e-base
+
+# 数据库配置
+DB_HOST=10.10.9.243
+DB_PORT=5432
+DB_NAME=sde
+DB_USER=sde
+DB_PASSWORD=sde

+ 7 - 4
landsite_agent/config.py

@@ -1,5 +1,10 @@
 from typing import Dict, Any
 from pydantic import BaseModel
+import os
+from dotenv import load_dotenv
+
+# 加载config.env文件
+load_dotenv("config.env")
 
 class ModelConfig(BaseModel):
     """AI模型配置类"""
@@ -24,10 +29,8 @@ model_list: Dict[str, Dict[str, Any]] = {
     }
 }
 
-# 新增全局默认模型类型开关
-DEFAULT_MODEL_TYPE = "openai"
-# # 线上服务器配置
-# DEFAULT_MODEL_TYPE = "zjstai"
+# 从环境变量获取默认模型类型,如果未设置则使用"openai"作为默认值
+DEFAULT_MODEL_TYPE = os.getenv("DEFAULT_MODEL_TYPE", "openai")
 
 def get_model_config(model_type: str = "openai") -> ModelConfig:
     """

+ 13 - 16
landsite_agent/database.py

@@ -1,30 +1,27 @@
 import asyncio
+import os
+from dotenv import load_dotenv
 
 import asyncpg
 from typing import List, Dict, Any
 import json
 
+# 加载config.env文件
+load_dotenv("config.env")
+
 # 数据库配置
-db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
-    "mysql": {
-        # MySQL配置留空,等待后续添加
-    },
-    "pg": {
-        "host": "10.10.9.243",
-        "port": "5432",
-        "database": "sde",
-        "user": "sde",
-        "password": "sde"
-    }
+DB_CONFIG = {
+    "host": os.getenv("DB_HOST"),
+    "port": os.getenv("DB_PORT"),
+    "database": os.getenv("DB_NAME"),
+    "user": os.getenv("DB_USER"),
+    "password": os.getenv("DB_PASSWORD")
 }
 
-
 class Database:
-    def __init__(self, db_type: str = "pg"):
+    def __init__(self):
         self.pool = None
-        if db_type not in db_list:
-            raise ValueError(f"Unsupported database type: {db_type}")
-        self.config = db_list[db_type]
+        self.config = DB_CONFIG
 
     async def connect(self):
         """创建数据库连接池"""

+ 1 - 21
landsite_agent/main.py

@@ -42,10 +42,6 @@ async def stream_land_analysis(request: QueryRequest):
     流式返回土地分析结果
     """
 
-    def remove_think_tag(text: str) -> str:
-        # 移除<think>标签及其内容
-        return re.sub(r"<think>[\s\S]*?</think>\\n*", "", text)
-
     sql_generator = SQLGenerator(model_type=DEFAULT_MODEL_TYPE)
 
     async def generate_stream() -> AsyncGenerator[str, None]:
@@ -56,15 +52,7 @@ async def stream_land_analysis(request: QueryRequest):
             async for chunk in sql_generator.generate_sql_stream(request.description):
                 data = json.loads(chunk)
 
-                if data["type"] == "similar_examples":
-                    similar_examples = data["content"]
-                    yield chunk
-                elif data["type"] == "sql_generation":
-                    # zjstai模型时,移除<think>标签内容
-                    # if DEFAULT_MODEL_TYPE == "zjstai":
-                    #     data["content"] = remove_think_tag(data["content"])
-                    #     yield json.dumps(data, ensure_ascii=False) + "\n"
-                    # else:
+                if data["type"] == "sql_generation":
                     yield chunk
                 elif data["type"] == "sql_result":
                     sql = data["content"]
@@ -100,14 +88,6 @@ async def stream_land_analysis(request: QueryRequest):
     )
 
 
-@app.on_event("shutdown")
-async def shutdown_event():
-    """
-    应用关闭时清理资源
-    """
-    pass
-
-
 if __name__ == "__main__":
     uvicorn.run(app, host=server_host, port=server_port)
     # uvicorn.run("main:app", host=server_host, port=server_port, workers=5)

+ 2 - 1
landsite_agent/requirements.txt

@@ -8,4 +8,5 @@ langchain_openai==0.3.18
 pydantic==2.11.5
 Shapely==2.1.1
 uvicorn==0.34.2
-faiss-cpu==1.11.0
+faiss-cpu==1.11.0
+python-dotenv==1.0.1

+ 15 - 7
landsite_agent/sql_generator.py

@@ -11,11 +11,17 @@ from config import get_model_config
 from typing import List, Dict, Any
 from database import Database
 import re
+import os
+from dotenv import load_dotenv
+
+# 加载环境变量
+load_dotenv('config.env')
 
 
 class SQLGenerator:
-    def __init__(self, model_type: str = "openai"):
+    def __init__(self, model_type: str = None):
         # 获取模型配置
+        model_type = model_type or os.getenv('DEFAULT_MODEL_TYPE', 'openai')
         model_config = get_model_config(model_type)
 
         # 初始化LLM
@@ -40,9 +46,12 @@ class SQLGenerator:
                 | StrOutputParser()
         )
 
+        # 从环境变量获取模型路径
+        model_path = os.getenv('MODEL_PATH')
+        if not model_path:
+            raise ValueError("MODEL_PATH environment variable is not set")
+
         # 初始化本地m3e-base模型
-        model_path = r"E:\项目临时\AI大模型\m3e-base"
-        # model_path= r"/data/m3e-base"
         self.embeddings = HuggingFaceEmbeddings(
             model_name=model_path,
             model_kwargs={'device': "cpu"},
@@ -137,16 +146,15 @@ class SQLGenerator:
         流式生成SQL查询语句
         """
         try:
-            # 开始生成SQL
-            yield json.dumps({"type": "start", "content": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
-
             # 获取相似示例
             similar_examples = self._get_similar_examples(query_description)
-            yield json.dumps({
+            simliar_example_format_dump = json.dumps({
                 "type": "similar_examples",
                 "content": similar_examples
             }, ensure_ascii=False) + "\n"
 
+            print(simliar_example_format_dump)
+
             # 准备输入数据
             chain_input = {
                 "similar_examples": similar_examples,