Browse Source

智能选址代码重构

liutao 1 month ago
parent
commit
8e979e11d2

+ 97 - 44
landsite_agent/database.py

@@ -1,11 +1,11 @@
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.ext.declarative import declarative_base
-from typing import Any, Dict
+import asyncio
 
+import asyncpg
+from typing import List, Dict, Any
+import json
 
 # 数据库配置
-db_list: Dict[str, Dict[str, Any]] = {
+db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
     "mysql": {
         # MySQL配置留空,等待后续添加
     },
@@ -18,42 +18,95 @@ db_list: Dict[str, Dict[str, Any]] = {
     }
 }
 
-# 创建数据库引擎
-def get_engine(db_type: str = "pg"):
-    """
-    获取数据库引擎
-    :param db_type: 数据库类型,默认为pg
-    :return: SQLAlchemy引擎
-    """
-    if db_type not in db_list:
-        raise ValueError(f"Unsupported database type: {db_type}")
-    
-    db_config = db_list[db_type]
-    
-    if db_type == "pg":
-        # PostgreSQL连接URL
-        db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}"
-        return create_engine(db_url)
-    elif db_type == "mysql":
-        # MySQL连接URL(待实现)
-        raise NotImplementedError("MySQL support is not implemented yet")
-    else:
-        raise ValueError(f"Unsupported database type: {db_type}")
-
-# 创建会话工厂
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=get_engine())
-
-# 创建基类
-Base = declarative_base()
-
-# 获取数据库会话
-def get_db():
-    """
-    获取数据库会话
-    :return: 数据库会话
-    """
-    db = SessionLocal()
-    try:
-        yield db
-    finally:
-        db.close() 
+
+class Database:
+    def __init__(self, db_type: str = "pg"):
+        self.pool = None
+        if db_type not in db_list:
+            raise ValueError(f"Unsupported database type: {db_type}")
+        self.config = db_list[db_type]
+
+    async def connect(self):
+        """创建数据库连接池"""
+        if not self.pool:
+            self.pool = await asyncpg.create_pool(
+                host=self.config["host"],
+                port=self.config["port"],
+                user=self.config["user"],
+                password=self.config["password"],
+                database=self.config["database"],
+                min_size=1,
+                max_size=10
+            )
+
+    async def close(self):
+        """关闭数据库连接池"""
+        if self.pool:
+            await self.pool.close()
+            self.pool = None
+
+    async def execute_query(self, sql: str) -> List[Dict[str, Any]]:
+        """
+        执行SQL查询并返回结果
+        """
+        if not self.pool:
+            await self.connect()
+
+        try:
+            async with self.pool.acquire() as conn:
+                # 执行查询
+                rows = await conn.fetch(sql)
+
+                # 将结果转换为字典列表
+                result = []
+                for row in rows:
+                    # 处理每个字段的值
+                    row_dict = {}
+                    for key, value in row.items():
+                        # 处理特殊类型
+                        if isinstance(value, (dict, list)):
+                            row_dict[key] = json.dumps(value, ensure_ascii=False)
+                        else:
+                            row_dict[key] = value
+                    result.append(row_dict)
+
+                return result
+        except Exception as e:
+            print(f"Database error: {str(e)}")
+            raise
+
+    async def execute_transaction(self, sql_list: List[str]) -> bool:
+        """
+        执行事务
+        """
+        if not self.pool:
+            await self.connect()
+
+        try:
+            async with self.pool.acquire() as conn:
+                async with conn.transaction():
+                    for sql in sql_list:
+                        await conn.execute(sql)
+            return True
+        except Exception as e:
+            print(f"Transaction error: {str(e)}")
+            return False
+
+    async def test_connection(self) -> bool:
+        """
+        测试数据库连接
+        """
+        try:
+            if not self.pool:
+                await self.connect()
+            async with self.pool.acquire() as conn:
+                await conn.execute('SELECT 1')
+            return True
+        except Exception as e:
+            print(f"Connection test failed: {str(e)}")
+            return False
+
+
+if __name__ == "__main__":
+    db = Database()
+    asyncio.run(db.test_connection())

+ 0 - 6
landsite_agent/examples.json

@@ -2,37 +2,31 @@
   {
     "query_type": "land_site_selection",
     "query": "帮我在萧山区推荐几块50亩左右的工业用地,数据表是控制性详细规划",
-    "plan": "Question: 帮我在萧山区推荐几块50亩左右的工业用地,数据表是控制性详细规划 \nThought: 用户问题中想查询城市为'萧山区',面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是控制性详细规划表,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',面积为'50'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
     "sql_code": "select id from sde.kzxxxgh where xzqmc = '萧山区' and ydxz like '%工业%' and abs(ydmj - 50*0.0667) <= 1 and shape is not null order by ydmj nulls last limit 5"
   },
   {
     "query_type": "land_site_selection",
     "query": "帮我在萧山区推荐一宗1公顷左右的学校用地,数据表是控制性详细规划",
-    "plan": "Question: 帮我在萧山区推荐一宗1公顷左右的学校用地,数据表是控制性详细规划\nThought: 用户问题中想查询城市为'萧山区',数量为'1'宗,面积为'1公顷'左右,用地性质'学校'的地块,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',数量为'1'宗,面积为'1公顷'左右,数据表是'控制性详细规划'表,用地性质'学校'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
     "sql_code": "select id from sde.kzxxxgh where xzqmc = '萧山区' and ydxz like '%学校%' and abs(ydmj - 1) <= 1 and shape is not null order by ydmj nulls last limit 1"
   },
   {
     "query_type": "land_site_selection",
     "query": "帮我在萧山区推荐几块50亩左右的工业用地,数据表是公告地块",
-    "plan": "Question: 帮我在萧山区推荐几块50亩左右的工业用地,数据表是公告地块 \nThought: 用户问题中想查询城市为'萧山区',面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是公告地块表,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',面积为'50'亩左右,数据表是'公告地块'表,用地性质'工业'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
     "sql_code": "select id from sde.ecgap_klyzy where xzqmc = '萧山区' and tdyt like '%工业%' and abs(dkmj-5) <= 1 and shape is not null and sfsj=1 order by dkmj nulls last limit 5"
   },
   {
     "query_type": "land_site_selection",
     "query": "请在萧山机场附近选出30-100亩之间的工业用地,数据表是公告地块",
-    "plan": "Question: 请在萧山机场附近选出30-100亩之间的工业用地,数据表是控制性详细规划 \nThought: 用户问题中想查询详细地点为'萧山机场',面积为'30-100'亩左右,用地性质'工业'的地块,数量未限制,数据表是公告地块表,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [\n {\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询位置是'萧山机场',面积为'30-100'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},   {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
     "sql_code": "select t.id from (select id,dkmj,round(st_distance(st_geometryfromtext('POINT (120.42827489304307 30.23751646603668)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.ecgap_klyzy where tdyt like '%工业%' and sfsj=1 and shape is not null and dkmj BETWEEN 30 and 100) as t where t.distance <= 10000  order by t.dkmj nulls last limit 5"
   },
   {
     "query_type": "land_site_selection",
     "query": "帮我在萧山机场附近推荐几块50亩左右的工业用地,数据表是控制性详细规划",
-    "plan": "Question: 帮我在萧山区推荐几块50亩左右的工业用地,数据表是控制性详细规划 \nThought: 用户问题中想查询城市为'萧山区',面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是控制性详细规划,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',面积为'50'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
     "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.42827489304307 30.23751646603668)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
   },
   {
     "query_type": "land_site_selection",
     "query": "帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划",
-    "plan": "Question: 帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划 \nThought: 用户问题中想查询详细地点为'温州南站',具体经纬度坐标为120.58,27.97,面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是控制性详细规划,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [\n {\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询位置是'温州南站',经纬度坐标为120.58,27.97,面积为'50'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},   {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
     "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.58 27.97)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
   }
 ] 

+ 49 - 54
landsite_agent/main.py

@@ -1,8 +1,5 @@
-from fastapi import FastAPI, Depends, HTTPException
+from fastapi import FastAPI, HTTPException
 from fastapi.middleware.cors import CORSMiddleware
-from sqlalchemy.orm import Session
-from database import get_db
-from sql_generator import SQLGenerator
 from pydantic import BaseModel
 import pandas as pd
 import plotly.express as px
@@ -10,8 +7,13 @@ import json
 from fastapi.responses import StreamingResponse
 from typing import AsyncGenerator, List, Dict, Any
 import uvicorn
-app = FastAPI(title="Land Analysis API")
 import traceback
+import asyncio
+
+from sql_generator import SQLGenerator
+
+app = FastAPI(title="Land Analysis API")
+
 # 配置CORS
 app.add_middleware(
     CORSMiddleware,
@@ -33,17 +35,13 @@ class AnalysisResult(BaseModel):
 sql_generator = SQLGenerator()
 
 @app.post("/land_analysis/stream")
-async def stream_land_analysis(
-    request: QueryRequest,
-    db: Session = Depends(get_db)
-):
+async def stream_land_analysis(request: QueryRequest):
     """
     流式返回土地分析结果
     """
     async def generate_stream() -> AsyncGenerator[str, None]:
         try:
             similar_examples = None
-            generated_sql = None
             
             # 流式生成SQL
             async for chunk in sql_generator.generate_sql_stream(request.description):
@@ -53,48 +51,41 @@ async def stream_land_analysis(
                     similar_examples = data["content"]
                     yield chunk
                 elif data["type"] == "sql_generation":
-                    generated_sql = data["content"]
                     yield chunk
-                else:
-                    yield chunk
-
-            if not generated_sql:
-                yield json.dumps({
-                    "type": "error",
-                    "message": "SQL生成失败"
-                }) + "\n"
-                return
+                elif data["type"] == "sql_result":
+                    # 获取到完整的SQL后执行
+                    sql = data["content"]
+                    result = await sql_generator.execute_sql(sql)
+                    
+                    if result["status"] == "error":
+                        yield json.dumps({
+                            "type": "error",
+                            "message": result["message"]
+                        }) + "\n"
+                        return
 
-            # 执行SQL并返回结果
-            result = await sql_generator.execute_sql(generated_sql, db)
-            
-            if result["status"] == "error":
-                yield json.dumps({
-                    "type": "error",
-                    "message": result["message"]
-                }) + "\n"
-                return
-
-            # 生成可视化
-            df = pd.DataFrame(result["data"])
-            visualization = None
-            
-            if not df.empty:
-                if len(df.columns) >= 2:
-                    if df.select_dtypes(include=['number']).columns.any():
-                        fig = px.bar(df, x=df.columns[0], y=df.columns[1])
-                        visualization = json.loads(fig.to_json())
+                    # 生成可视化
+                    df = pd.DataFrame(result["data"])
+                    visualization = None
+                    
+                    if not df.empty:
+                        if len(df.columns) >= 2:
+                            if df.select_dtypes(include=['number']).columns.any():
+                                fig = px.bar(df, x=df.columns[0], y=df.columns[1])
+                                visualization = json.loads(fig.to_json())
 
-            # 返回最终结果
-            yield json.dumps({
-                "type": "result",
-                "data": {
-                    "sql": generated_sql,
-                    "data": result["data"],
-                    "visualization": visualization,
-                    "similar_examples": similar_examples
-                }
-            }) + "\n"
+                    # 返回最终结果
+                    yield json.dumps({
+                        "type": "result",
+                        "data": {
+                            "sql": sql,
+                            "data": result["data"],
+                            "visualization": visualization,
+                            "similar_examples": similar_examples
+                        }
+                    }) + "\n"
+                else:
+                    yield chunk
 
         except Exception as e:
             traceback.print_exc()
@@ -109,10 +100,7 @@ async def stream_land_analysis(
     )
 
 @app.post("/land_analysis", response_model=AnalysisResult)
-async def generate_and_execute_sql(
-    request: QueryRequest,
-    db: Session = Depends(get_db)
-):
+async def generate_and_execute_sql(request: QueryRequest):
     try:
         # 获取相似示例
         similar_examples = sql_generator._get_similar_examples(request.description)
@@ -130,7 +118,7 @@ async def generate_and_execute_sql(
         sql = await sql_generator.chain.arun(enhanced_prompt)
         
         # 执行SQL
-        result = await sql_generator.execute_sql(sql, db)
+        result = await sql_generator.execute_sql(sql)
         
         if result["status"] == "error":
             raise HTTPException(status_code=400, detail=result["message"])
@@ -156,5 +144,12 @@ async def generate_and_execute_sql(
         traceback.print_exc()
         raise HTTPException(status_code=500, detail=str(e))
 
+@app.on_event("shutdown")
+async def shutdown_event():
+    """
+    应用关闭时清理资源
+    """
+    await sql_generator.close()
+
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=8001)

+ 21 - 6
landsite_agent/prompt_template.py

@@ -55,7 +55,12 @@ PROMPT_TEMPLATE = """
    - 公告地块表(sde.ecgap_klyzy):使用 dkmj 字段,单位为亩
    - 注意单位换算:1公顷 = 15亩
 
-4. 其他注意事项:
+4. 查询字段限制:
+   - SELECT 语句中只能查询 id 字段
+   - 不允许使用 SELECT * 或其他字段
+   - 示例:SELECT id FROM table WHERE condition;
+
+5. 其他注意事项:
    - 确保SQL语句的语法正确性
    - 注意字段名称的准确性
    - 合理使用索引字段(如id、xzqmc等)
@@ -68,16 +73,26 @@ PROMPT_TEMPLATE = """
 
 请根据以上字段信息和注意事项,生成符合要求的SQL查询语句。在生成SQL时,请确保:
 1. 只使用SELECT语句
-2. 包含shape is not null条件
-3. 正确使用面积字段和单位
-4. 遵循其他注意事项
+2. 只查询id字段
+3. 包含shape is not null条件
+4. 正确使用面积字段和单位
+5. 遵循其他注意事项
+
+请按照以下格式输出,每个部分之间用空行分隔:
 
-请按照以下格式输出:
 1. Question: 分析用户问题
+
 2. Thought: 思考查询逻辑
+
 3. Plan: 制定查询计划
+
 4. SQL: 生成SQL代码
-"""
+```sql
+SELECT id FROM table
+WHERE condition;
+```
+
+请确保SQL语句是完整且可执行的,并且SQL代码块是独立的部分。"""
 
 def get_prompt():
     """

+ 84 - 30
landsite_agent/sql_generator.py

@@ -1,21 +1,18 @@
 from langchain_openai import ChatOpenAI
 from langchain.prompts import ChatPromptTemplate
-from langchain.chains import LLMChain
-from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
-from langchain_community.vectorstores import FAISS
 from langchain_core.output_parsers import StrOutputParser
 from langchain_core.runnables import RunnablePassthrough
 from langchain_core.documents import Document
+from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain_community.vectorstores import FAISS
 import os
 from prompt_template import get_prompt
-from fastapi.responses import StreamingResponse
 import json
-import asyncio
-import numpy as np
+import traceback
 from config import get_model_config
-import pandas as pd
 from typing import List, Dict, Any
-import traceback
+from database import Database
+import re
 
 class SQLGenerator:
     def __init__(self, model_type: str = "openai"):
@@ -25,10 +22,8 @@ class SQLGenerator:
         # 初始化LLM
         self.llm = ChatOpenAI(
             model_name=model_config.model_name,
-            temperature=model_config.temperature,
             api_key=model_config.api_key,
             base_url=model_config.api_base,
-            max_tokens=model_config.max_tokens,
             streaming=True
         )
 
@@ -58,6 +53,27 @@ class SQLGenerator:
         self.examples = self._load_examples()
         self.vectorstore = self._build_vectorstore()
 
+        # 初始化数据库连接
+        self.db = Database()
+
+    def _extract_sql_from_markdown(self, text: str) -> str:
+        """
+        从markdown格式的文本中提取SQL语句
+        """
+        # 使用正则表达式匹配markdown代码块中的SQL
+        sql_pattern = r"```sql\n(.*?)\n```"
+        match = re.search(sql_pattern, text, re.DOTALL)
+        
+        if match:
+            # 提取SQL语句并清理
+            sql = match.group(1).strip()
+            # 移除可能的注释
+            sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
+            # 移除多余的空行
+            sql = re.sub(r'\n\s*\n', '\n', sql)
+            return sql.strip()
+        return text.strip()
+
     def _load_examples(self):
         """加载示例数据"""
         with open('examples.json', 'r', encoding='utf-8') as f:
@@ -72,7 +88,6 @@ class SQLGenerator:
                 page_content=example['query'],
                 metadata={
                     'query_type': example['query_type'],
-                    'plan': example['plan'],
                     'sql_code': example['sql_code']
                 }
             )
@@ -97,7 +112,6 @@ class SQLGenerator:
             similar_examples.append({
                 'query_type': doc.metadata['query_type'],
                 'query': doc.page_content,
-                'plan': doc.metadata['plan'],
                 'sql_code': doc.metadata['sql_code'],
                 'similarity_score': float(score)
             })
@@ -106,7 +120,6 @@ class SQLGenerator:
 
     def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
         """格式化聊天历史"""
-
         if len(similar_examples) == 0:
             return ""
 
@@ -114,11 +127,9 @@ class SQLGenerator:
         for i, example in enumerate(similar_examples, 1):
             chat_history += f"示例 {i}:\n"
             chat_history += f"问题: {example['query']}\n"
-            chat_history += f"计划: {example['plan']}\n"
             chat_history += f"SQL: {example['sql_code']}\n"
             chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
 
-        print(chat_history+"  !!!!!")
         return chat_history
 
     async def generate_sql_stream(self, query_description: str):
@@ -127,7 +138,7 @@ class SQLGenerator:
         """
         try:
             # 开始生成SQL
-            yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}) + "\n"
+            yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}, ensure_ascii=False) + "\n"
 
             # 获取相似示例
             similar_examples = self._get_similar_examples(query_description)
@@ -145,43 +156,82 @@ class SQLGenerator:
             print("Chain input:", chain_input)
 
             try:
+                # 构建完整的提示词
+                formatted_prompt = self.prompt.format(
+                    chat_history=self._format_chat_history(similar_examples),
+                    question=query_description
+                )
+                print("Formatted prompt:", formatted_prompt)
+
                 # 流式生成SQL
+                full_response = ""
                 async for chunk in self.chain.astream(chain_input):
                     if chunk:
-                        yield json.dumps({
-                            "type": "sql_generation",
-                            "content": chunk
-                        }) + "\n"
+                        try:
+                            full_response += chunk
+                            yield json.dumps({
+                                "type": "sql_generation",
+                                "content": chunk
+                            }, ensure_ascii=False) + "\n"
+                        except Exception as chunk_error:
+                            print(f"Error processing chunk: {str(chunk_error)}")
+                            continue
+
+                # 使用正则表达式提取SQL代码块
+                sql_pattern = r"```sql\n(.*?)\n```"
+                sql_match = re.search(sql_pattern, full_response, re.DOTALL)
+                
+                if sql_match:
+                    # 提取并清理SQL
+                    sql_content = sql_match.group(1).strip()
+                    # 移除注释
+                    sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE)
+                    # 移除多余空行
+                    sql_content = re.sub(r'\n\s*\n', '\n', sql_content)
+                    # 确保SQL语句完整
+                    if not sql_content.strip().endswith(';'):
+                        sql_content = sql_content.strip() + ';'
+                    
+                    yield json.dumps({
+                        "type": "sql_result",
+                        "content": sql_content
+                    }, ensure_ascii=False) + "\n"
+                else:
+                    # 如果没有找到SQL代码块,返回完整响应
+                    yield json.dumps({
+                        "type": "sql_result",
+                        "content": full_response
+                    }, ensure_ascii=False) + "\n"
 
                 yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
             except Exception as e:
-                print(f"Error during streaming: {str(e)}")
-                print(f"Error type: {type(e)}")
                 print(f"Error details: {traceback.format_exc()}")
                 yield json.dumps({
                     "type": "error",
                     "message": f"生成SQL时发生错误: {str(e)}"
-                }) + "\n"
+                }, ensure_ascii=False) + "\n"
 
         except Exception as e:
             traceback.print_exc()
             yield json.dumps({
                 "type": "error",
                 "message": str(e)
-            }) + "\n"
+            }, ensure_ascii=False) + "\n"
 
-    async def execute_sql(self, sql: str, db_connection) -> dict:
+    async def execute_sql(self, sql: str) -> dict:
         """
         执行SQL查询并返回结果
         """
+        print(sql)
         try:
-            result = db_connection.execute(sql)
-            columns = result.keys()
-            data = [dict(zip(columns, row)) for row in result.fetchall()]
+            # 确保SQL是完整的
+            if not sql.strip().endswith(';'):
+                sql = sql.strip() + ';'
+            
+            result = await self.db.execute_query(sql)
             return {
                 "status": "success",
-                "data": data,
-                "columns": columns
+                "data": result
             }
         except Exception as e:
             traceback.print_exc()
@@ -189,3 +239,7 @@ class SQLGenerator:
                 "status": "error",
                 "message": str(e)
             }
+
+    async def close(self):
+        """关闭数据库连接"""
+        await self.db.close()