Browse Source

智能选址agent修改

liutao 2 tháng trước cách đây
mục cha
commit
4dfd23d173

+ 21 - 19
aiAgent_gd/qwen_agent/config/db_config.py

@@ -1,28 +1,30 @@
 # 数据库存储配置
-db_list = {
+from typing import Dict, Any
+
+db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
     "mysql": {
 
     },
-    # "pg": {
-    #     "host": "10.249.168.231",
-    #     "port": "54321",
-    #     "database": "sde",
-    #     "user": "zjugis",
-    #     "password": "zjugis1402!",
-    # }
     "pg": {
-        "host": "172.27.27.16",
-        "port": "yzt",
+        "host": "10.249.168.231",
+        "port": "54321",
+        "database": "sde",
+        "user": "zjugis",
+        "password": "zjugis1402!",
+    },
+    # "pg1": {
+    #     "host": "172.27.27.16",
+    #     "port": "3433",
+    #     "database": "yzt",
+    #     "user": "zjgt_ww_readonly",
+    #     "password": "Zjgt_ww_16",
+    # },
+    "xzpg": {
+        "host": "10.10.9.243",
+        "port": "5432",
         "database": "sde",
-        "user": "zjgt_ww_readonly",
-        "password": "Zjgt_ww_16",
+        "user": "sde",
+        "password": "zjugis1402!",
     }
-    # "pg": {
-    #     "host": "10.10.9.243",
-    #     "port": "5432",
-    #     "database": "sde",
-    #     "user": "sde",
-    #     "password": "zjugis1402!",
-    # }
 }
 

+ 2 - 2
aiAgent_gd/qwen_agent/memory/data/sqls/sql_examples_智能选址.jsonl

@@ -2,11 +2,11 @@
   {
     "query_type": "land_site_selection",
     "query": "帮我在萧山区推荐几块50亩左右的工业用地",
-    "sql_code": "select objectid, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj, st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt from dlgis.gcs330000g2007_kzxxxgh_kgdk_kgy_dsgj where xzqmc = '萧山区' and ydxz like '%工业%' and abs(ydmj - 50*0.0667) <= 1 order by ydmj nulls last limit 10"
+    "sql_code": "select objectid, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj, shape, st_area(shape::geography) as pfmarea,st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt from dlgis.gcs330000g2007_kzxxxgh_kgdk_kgy_dsgj as a where xzqmc = '萧山区' and ydxz like '%工业%' and abs(ydmj - 50*0.0667) <= 1 and NOT EXISTS (select 1 from dlgis.gcs330000k3003_zdzy_gd as b where st_intersects(a.shape, b.shape)) order by ydmj nulls last limit 10"
   },
   {
     "query_type": "land_site_selection",
     "query": "帮我在萧山区推荐一宗1公顷左右的学校用地",
-    "sql_code": "select objectid, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj, st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt from dlgis.gcs330000g2007_kzxxxgh_kgdk_kgy_dsgj where xzqmc = '萧山区' and ydxz like '%学校%' and abs(ydmj - 1) <= 1 order by ydmj nulls last limit 10"
+    "sql_code": "select objectid, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj, shape,st_area(shape::geography) as pfmarea, st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt from dlgis.gcs330000g2007_kzxxxgh_kgdk_kgy_dsgj as a where xzqmc = '萧山区' and ydxz like '%学校%' and abs(ydmj - 1) <= 1 and NOT EXISTS (select 1 from dlgis.gcs330000k3003_zdzy_gd as b where st_intersects(a.shape, b.shape)) order by ydmj nulls last limit 10"
   }
 ]

+ 8 - 5
aiAgent_gd/qwen_agent/sub_agent/sql/land_site_selection_sql_agent.py

@@ -10,7 +10,7 @@ from tabulate import tabulate
 from qwen_agent.memory.SqlMemory import SqlRetriever
 from qwen_agent.messages.context_message import ChatResponseChoice
 from qwen_agent.sub_agent.BaseSubAgent import BaseSubAgent
-from qwen_agent.tools.tools import async_db
+from qwen_agent.tools.tools import async_db,async_xzdb
 
 
 class LandSiteSelectionSqlAgent(BaseSubAgent):
@@ -38,7 +38,7 @@ class LandSiteSelectionSqlAgent(BaseSubAgent):
             `xzqdm` COMMENTS '行政区代码 6位,前2位代表省,前4位代表市,前6位代表区县',
             `dymc` COMMENTS '地块名称',
             `yddm` COMMENTS '用地代码',
-            `ydxz` COMMENTS '用地性质(土地用途)',
+            `ydxz` COMMENTS '用地性质',
             `ydmj` COMMENTS '用地面积 单位:公顷',
             `pfwh` COMMENTS '批复文号',
             `pfsj` COMMENTS '批复时间',
@@ -63,9 +63,12 @@ class LandSiteSelectionSqlAgent(BaseSubAgent):
         注意7: 问题中设计具体的地点时,需要使用round(st_distance(st_geometryfromtext('具体地点的wkt', 4490)::geography,shape::geography)::numeric,0)获取其distance, 如果问题未指定范围则使用 distance <= 5000 来限制在地点5公里内,并对其排序
         注意8: 生成sql时,只对涉及表结构中的字段进行条件设置,不可生成不在表字段列表中的查询条件,不可生成任何不在表字段中的条件,比如周边5公里有什么设施
         注意9: 生成sql时,必须使用 st_astext(st_centroid(shape)) as center_wkt 
-        注意10: 查询语句必须包含 objectid, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt 这几个字段
+        注意10: 查询语句必须包含 objectid, xzqmc, xzqdm, dymc, yddm, ydxz, ydmj, rjlsx, rjlxx, jzmdsx, jzmdxx, jzgdsx, jzgdxx, ldlsx, ldlxx, pfwh, pfsj, shape,st_area(shape::geography) as pfmarea, st_astext(shape) as geom, st_astext(st_centroid(shape)) as center_wkt 这几个字段
         注意11: 只准生成查询 的sql 语句,不可生成任何 修改数据的语句, 比如:update, delete, insert, truncate 等
-        注意12:当用户问题中的土地用途是"工业用地"时,去掉"用地",使用ydxz进行模糊查询,比如ydxz like '%工业%'
+        注意12:当用户问题中的用地性质是"工业用地"时,去掉"用地",使用ydxz进行模糊查询,比如ydxz like '%工业%',工业用地没有二级分类
+        注意13:where语句中必须包含NOT EXISTS(select 1 from dlgis.gcs330000k3003_zdzy_gd as b where st_intersects(a.shape, b.shape)
+        注意14:数据表的schema是dlgis
+        注意15:from语句中给dlgis.gcs330000g2007_kzxxxgh_kgdk_kgy_dsgj设置别名a
         """
         self.retriever = SqlRetriever(query_type='land_site_selection')
 
@@ -166,7 +169,7 @@ class LandSiteSelectionSqlAgent(BaseSubAgent):
             print(f"sql_to_execute:{sql_to_execute}")
             self.sql_code = sql_to_execute
             a = time.time()
-            res_tuples = await async_db.run(sql_to_execute)
+            res_tuples = await async_xzdb.run(sql_to_execute)
             print('SQL Time Cost:', time.time() - a)
             result, success = res_tuples
             print(f"SQL 查询结果: success:{success}, result: {result}")

+ 8 - 5
aiAgent_gd/qwen_agent/tools/tools.py

@@ -211,7 +211,7 @@ class AsyncMySQLSearcher():
 
 
 class AsyncPGSearcher:
-    def __init__(self, db_name='zlzd') -> None:
+    def __init__(self, db_name='pg') -> None:
         # 建立数据库连接
         self.db_name = db_name
         self.db = None
@@ -220,7 +220,8 @@ class AsyncPGSearcher:
         # print(f'db name: {self.db_name}')
 
     async def connect(self):
-        pg_config = db_config.db_list.get("pg")
+        pg_config = db_config.db_list.get(self.db_name)
+        print("pg_config:" %pg_config);
         self.db = await asyncpg.connect(
             host=pg_config.get("host"),  # analysisDB
             user=pg_config.get("user"),
@@ -237,7 +238,7 @@ class AsyncPGSearcher:
         '''
         try:
             print("start to connect db!")
-            pg_config = db_config.db_list.get("pg")
+            pg_config = db_config.db_list.get(self.db_name)
             self.pool = await asyncpg.create_pool(host=pg_config.get("host"),
                                                   user=pg_config.get("user"),
                                                   password=pg_config.get("password"),
@@ -309,8 +310,8 @@ class AsyncPGSearcher:
 # async_db = AsyncMySQLSearcher()
 # asyncio.run(async_db.register())
 
-async_db = AsyncPGSearcher()
-
+async_db = AsyncPGSearcher('pg')
+async_xzdb = AsyncPGSearcher('xzpg')
 
 def modify_sql(sql_query):
     # 使用正则表达式寻找 GROUP BY 和其后面的内容,直到遇到 HAVING, ORDER BY 或者字符串结束
@@ -344,6 +345,7 @@ async def call_plugin(plugin_name: str, plugin_args: str):
     #     plugin_args = plugin_args.replace("json",'')
     import time
     print("plugin_args", plugin_args)
+    print("plugin_name", plugin_name)
     if plugin_args.startswith('```'):
         triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
         if triple_match:
@@ -371,6 +373,7 @@ async def call_plugin(plugin_name: str, plugin_args: str):
             # else:
             #     sql_to_execute = plugin_args
             print(f"sql_to_execute:{sql_to_execute}")
+            print("plugin_name", plugin_name)
             # 
             a = time.time()