Browse Source

智能选址代码重构

liutao 1 month ago
parent
commit
afef811612

+ 8 - 0
landsite_agent/.idea/.gitignore

@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 18 - 0
landsite_agent/.idea/inspectionProfiles/Project_Default.xml

@@ -0,0 +1,18 @@
+<component name="InspectionProjectProfileManager">
+  <profile version="1.0">
+    <option name="myName" value="Project Default" />
+    <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
+      <option name="ignoredPackages">
+        <value>
+          <list size="5">
+            <item index="0" class="java.lang.String" itemvalue="pydantic" />
+            <item index="1" class="java.lang.String" itemvalue="pydantic_core" />
+            <item index="2" class="java.lang.String" itemvalue="typing_extensions" />
+            <item index="3" class="java.lang.String" itemvalue="certifi" />
+            <item index="4" class="java.lang.String" itemvalue="numpy" />
+          </list>
+        </value>
+      </option>
+    </inspection_tool>
+  </profile>
+</component>

+ 6 - 0
landsite_agent/.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 8 - 0
landsite_agent/.idea/landsite_agent.iml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="landsite" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 7 - 0
landsite_agent/.idea/misc.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Black">
+    <option name="sdkName" value="landsite" />
+  </component>
+  <component name="ProjectRootManager" version="2" project-jdk-name="landsite" project-jdk-type="Python SDK" />
+</project>

+ 8 - 0
landsite_agent/.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/landsite_agent.iml" filepath="$PROJECT_DIR$/.idea/landsite_agent.iml" />
+    </modules>
+  </component>
+</project>

+ 6 - 0
landsite_agent/.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$/.." vcs="Git" />
+  </component>
+</project>

+ 71 - 0
landsite_agent/README.md

@@ -0,0 +1,71 @@
+# SQL生成器Web应用
+
+这是一个基于FastAPI、OpenAI和LangChain的SQL生成器Web应用,可以根据自然语言描述生成SQL查询语句,执行查询并返回分析结果。
+
+## 功能特点
+
+- 基于自然语言生成SQL查询
+- 自动执行SQL查询
+- 数据可视化展示
+- RESTful API接口
+
+## 环境要求
+
+- Python 3.10+
+- PostgreSQL数据库
+- OpenAI API密钥
+
+## 安装步骤
+
+1. 克隆项目并安装依赖:
+```bash
+pip install -r requirements.txt
+```
+
+2. 配置环境变量:
+创建.env文件并设置以下变量:
+```
+OPENAI_API_KEY=your_openai_api_key
+DATABASE_URL=postgresql://username:password@localhost:5432/your_database
+```
+
+3. 启动应用:
+```bash
+python main.py
+```
+
+## API使用
+
+### 生成SQL并执行查询
+
+POST `/generate-sql`
+
+请求体:
+```json
+{
+    "description": "查询描述"
+}
+```
+
+响应:
+```json
+{
+    "sql": "生成的SQL语句",
+    "data": "查询结果数据",
+    "visualization": "可视化数据(如果有)"
+}
+```
+
+## 示例
+
+```python
+import requests
+
+response = requests.post(
+    "http://localhost:8000/generate-sql",
+    json={
+        "description": "查询所有已上架的公告地块"
+    }
+)
+print(response.json())
+``` 

+ 44 - 0
landsite_agent/config.py

@@ -0,0 +1,44 @@
+from typing import Dict, Any
+from pydantic import BaseModel
+
+class ModelConfig(BaseModel):
+    """AI模型配置类"""
+    api_key: str
+    api_base: str
+    model_name: str
+    temperature: float = 0
+    max_tokens: int = 2000
+
+# AI模型配置
+model_list: Dict[str, Dict[str, Any]] = {
+    "openai": {
+        "api_key": "none",
+        "api_base": "http://ac.zjugis.com:8511/v1",
+        "model_name": "qwen2.5-instruct"
+    },
+    "azure": {
+        "api_key": "your-azure-api-key",
+        "api_base": "https://your-azure-endpoint.openai.azure.com",
+        "model_name": "gpt-35-turbo",
+        "temperature": 0,
+        "max_tokens": 2000
+    },
+    "local": {
+        "api_key": "your-local-api-key",
+        "api_base": "http://localhost:8000/v1",
+        "model_name": "local-model",
+        "temperature": 0,
+        "max_tokens": 2000
+    }
+}
+
+def get_model_config(model_type: str = "openai") -> ModelConfig:
+    """
+    获取模型配置
+    :param model_type: 模型类型,默认为openai
+    :return: 模型配置对象
+    """
+    if model_type not in model_list:
+        raise ValueError(f"Unsupported model type: {model_type}")
+    
+    return ModelConfig(**model_list[model_type]) 

+ 59 - 0
landsite_agent/database.py

@@ -0,0 +1,59 @@
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.ext.declarative import declarative_base
+from typing import Any, Dict
+
+
+# 数据库配置
+db_list: Dict[str, Dict[str, Any]] = {
+    "mysql": {
+        # MySQL配置留空,等待后续添加
+    },
+    "pg": {
+        "host": "10.10.9.243",
+        "port": "5432",
+        "database": "sde",
+        "user": "sde",
+        "password": "sde",
+    }
+}
+
+# 创建数据库引擎
+def get_engine(db_type: str = "pg"):
+    """
+    获取数据库引擎
+    :param db_type: 数据库类型,默认为pg
+    :return: SQLAlchemy引擎
+    """
+    if db_type not in db_list:
+        raise ValueError(f"Unsupported database type: {db_type}")
+    
+    db_config = db_list[db_type]
+    
+    if db_type == "pg":
+        # PostgreSQL连接URL
+        db_url = f"postgresql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['database']}"
+        return create_engine(db_url)
+    elif db_type == "mysql":
+        # MySQL连接URL(待实现)
+        raise NotImplementedError("MySQL support is not implemented yet")
+    else:
+        raise ValueError(f"Unsupported database type: {db_type}")
+
+# 创建会话工厂
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=get_engine())
+
+# 创建基类
+Base = declarative_base()
+
+# 获取数据库会话
+def get_db():
+    """
+    获取数据库会话
+    :return: 数据库会话
+    """
+    db = SessionLocal()
+    try:
+        yield db
+    finally:
+        db.close() 

+ 38 - 0
landsite_agent/examples.json

@@ -0,0 +1,38 @@
+[
+  {
+    "query_type": "land_site_selection",
+    "query": "帮我在萧山区推荐几块50亩左右的工业用地,数据表是控制性详细规划",
+    "plan": "Question: 帮我在萧山区推荐几块50亩左右的工业用地,数据表是控制性详细规划 \nThought: 用户问题中想查询城市为'萧山区',面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是控制性详细规划表,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',面积为'50'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
+    "sql_code": "select id from sde.kzxxxgh where xzqmc = '萧山区' and ydxz like '%工业%' and abs(ydmj - 50*0.0667) <= 1 and shape is not null order by ydmj nulls last limit 5"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "帮我在萧山区推荐一宗1公顷左右的学校用地,数据表是控制性详细规划",
+    "plan": "Question: 帮我在萧山区推荐一宗1公顷左右的学校用地,数据表是控制性详细规划\nThought: 用户问题中想查询城市为'萧山区',数量为'1'宗,面积为'1公顷'左右,用地性质'学校'的地块,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',数量为'1'宗,面积为'1公顷'左右,数据表是'控制性详细规划'表,用地性质'学校'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
+    "sql_code": "select id from sde.kzxxxgh where xzqmc = '萧山区' and ydxz like '%学校%' and abs(ydmj - 1) <= 1 and shape is not null order by ydmj nulls last limit 1"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "帮我在萧山区推荐几块50亩左右的工业用地,数据表是公告地块",
+    "plan": "Question: 帮我在萧山区推荐几块50亩左右的工业用地,数据表是公告地块 \nThought: 用户问题中想查询城市为'萧山区',面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是公告地块表,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',面积为'50'亩左右,数据表是'公告地块'表,用地性质'工业'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
+    "sql_code": "select id from sde.ecgap_klyzy where xzqmc = '萧山区' and tdyt like '%工业%' and abs(dkmj-5) <= 1 and shape is not null and sfsj=1 order by dkmj nulls last limit 5"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "请在萧山机场附近选出30-100亩之间的工业用地,数据表是公告地块",
+    "plan": "Question: 请在萧山机场附近选出30-100亩之间的工业用地,数据表是控制性详细规划 \nThought: 用户问题中想查询详细地点为'萧山机场',面积为'30-100'亩左右,用地性质'工业'的地块,数量未限制,数据表是公告地块表,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [\n {\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询位置是'萧山机场',面积为'30-100'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},   {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
+    "sql_code": "select t.id from (select id,dkmj,round(st_distance(st_geometryfromtext('POINT (120.42827489304307 30.23751646603668)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.ecgap_klyzy where tdyt like '%工业%' and sfsj=1 and shape is not null and dkmj BETWEEN 30 and 100) as t where t.distance <= 10000  order by t.dkmj nulls last limit 5"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "帮我在萧山机场附近推荐几块50亩左右的工业用地,数据表是控制性详细规划",
+    "plan": "Question: 帮我在萧山区推荐几块50亩左右的工业用地,数据表是控制性详细规划 \nThought: 用户问题中想查询城市为'萧山区',面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是控制性详细规划,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [{\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询城市为'萧山区',面积为'50'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},\n    {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
+    "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.42827489304307 30.23751646603668)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
+  },
+  {
+    "query_type": "land_site_selection",
+    "query": "帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划",
+    "plan": "Question: 帮我在温州南站附近推荐几块50亩左右的工业用地,温州南站的坐标为120.58,27.97,数据表是控制性详细规划 \nThought: 用户问题中想查询详细地点为'温州南站',具体经纬度坐标为120.58,27.97,面积为'50'亩左右,用地性质'工业'的地块,数量未限制,数据表是控制性详细规划,所以需要通过[LandSiteSelectionSqlAgent]查询图层信息,最后使用summary的Action来总结并输出。Plan: ```json\n    [\n {\"action_name\": \"LandSiteSelectionSqlAgent\", \"instruction\": \"你需要调用 [LandSiteSelectionSqlAgent],来查询位置是'温州南站',经纬度坐标为120.58,27.97,面积为'50'亩左右,数据表是'控制性详细规划'表,用地性质'工业'的地块\"},   {\"action_name\": \"summary\", \"instruction\": \"你需要根据用户的Question和查询的结果,回答用户问题。\"}]",
+    "sql_code": "select t.id from (select id,ydmj,round(st_distance(st_geometryfromtext('POINT (120.58 27.97)', 4490)::geography,shape::geography)::numeric,0) as distance from sde.kzxxxgh where ydxz like '%工业%'  and shape is not null and abs(ydmj - 50*0.0667) <= 1) as t where t.distance <= 10000  order by t.ydmj nulls last limit 5"
+  }
+] 

+ 160 - 0
landsite_agent/main.py

@@ -0,0 +1,160 @@
+from fastapi import FastAPI, Depends, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from sqlalchemy.orm import Session
+from database import get_db
+from sql_generator import SQLGenerator
+from pydantic import BaseModel
+import pandas as pd
+import plotly.express as px
+import json
+from fastapi.responses import StreamingResponse
+from typing import AsyncGenerator, List, Dict, Any
+import uvicorn
+app = FastAPI(title="Land Analysis API")
+import traceback
+# 配置CORS
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+class QueryRequest(BaseModel):
+    description: str
+
+class AnalysisResult(BaseModel):
+    sql: str
+    data: list
+    visualization: dict = None
+    similar_examples: List[Dict[str, Any]] = None
+
+sql_generator = SQLGenerator()
+
+@app.post("/land_analysis/stream")
+async def stream_land_analysis(
+    request: QueryRequest,
+    db: Session = Depends(get_db)
+):
+    """
+    流式返回土地分析结果
+    """
+    async def generate_stream() -> AsyncGenerator[str, None]:
+        try:
+            similar_examples = None
+            generated_sql = None
+            
+            # 流式生成SQL
+            async for chunk in sql_generator.generate_sql_stream(request.description):
+                data = json.loads(chunk)
+                
+                if data["type"] == "similar_examples":
+                    similar_examples = data["content"]
+                    yield chunk
+                elif data["type"] == "sql_generation":
+                    generated_sql = data["content"]
+                    yield chunk
+                else:
+                    yield chunk
+
+            if not generated_sql:
+                yield json.dumps({
+                    "type": "error",
+                    "message": "SQL生成失败"
+                }) + "\n"
+                return
+
+            # 执行SQL并返回结果
+            result = await sql_generator.execute_sql(generated_sql, db)
+            
+            if result["status"] == "error":
+                yield json.dumps({
+                    "type": "error",
+                    "message": result["message"]
+                }) + "\n"
+                return
+
+            # 生成可视化
+            df = pd.DataFrame(result["data"])
+            visualization = None
+            
+            if not df.empty:
+                if len(df.columns) >= 2:
+                    if df.select_dtypes(include=['number']).columns.any():
+                        fig = px.bar(df, x=df.columns[0], y=df.columns[1])
+                        visualization = json.loads(fig.to_json())
+
+            # 返回最终结果
+            yield json.dumps({
+                "type": "result",
+                "data": {
+                    "sql": generated_sql,
+                    "data": result["data"],
+                    "visualization": visualization,
+                    "similar_examples": similar_examples
+                }
+            }) + "\n"
+
+        except Exception as e:
+            traceback.print_exc()
+            yield json.dumps({
+                "type": "error",
+                "message": str(e)
+            }) + "\n"
+
+    return StreamingResponse(
+        generate_stream(),
+        media_type="text/event-stream"
+    )
+
+@app.post("/land_analysis", response_model=AnalysisResult)
+async def generate_and_execute_sql(
+    request: QueryRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        # 获取相似示例
+        similar_examples = sql_generator._get_similar_examples(request.description)
+        
+        # 构建增强提示词
+        enhanced_prompt = f"""
+        基于以下相似示例:
+        {json.dumps(similar_examples, ensure_ascii=False, indent=2)}
+        
+        请根据以下描述生成SQL查询:
+        {request.description}
+        """
+        
+        # 生成SQL
+        sql = await sql_generator.chain.arun(enhanced_prompt)
+        
+        # 执行SQL
+        result = await sql_generator.execute_sql(sql, db)
+        
+        if result["status"] == "error":
+            raise HTTPException(status_code=400, detail=result["message"])
+        
+        # 生成可视化
+        df = pd.DataFrame(result["data"])
+        visualization = None
+        
+        if not df.empty:
+            if len(df.columns) >= 2:
+                if df.select_dtypes(include=['number']).columns.any():
+                    fig = px.bar(df, x=df.columns[0], y=df.columns[1])
+                    visualization = json.loads(fig.to_json())
+        
+        return AnalysisResult(
+            sql=sql,
+            data=result["data"],
+            visualization=visualization,
+            similar_examples=similar_examples
+        )
+    
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=str(e))
+
+if __name__ == "__main__":
+    uvicorn.run(app, host="0.0.0.0", port=8001)

+ 86 - 0
landsite_agent/prompt_template.py

@@ -0,0 +1,86 @@
+"""
+SQL生成提示词模板
+基于控制性详细规划表(sde.kzxxxgh)和公告地块表(sde.ecgap_klyzy)的字段信息
+"""
+
+PROMPT_TEMPLATE = """
+请根据以下两个表的字段信息生成SQL查询语句:
+
+1. 控制性详细规划表 (sde.kzxxxgh):
+   - id: 主键ID
+   - xzqmc: 所属区县(行政区代码)
+   - xzqdm: 行政区代码(6位,前2位代表省,前4位代表市,前6位代表区县)
+   - dymc: 单元名称
+   - yddm: 用地代码
+   - ydxz: 用地性质
+   - ydmj: 用地面积(单位:公顷)
+   - pfwh: 批复文号
+   - pfsj: 批复时间
+   - rjlsx: 容积率上限
+   - rjlxx: 容积率下限
+   - jzmdsx: 建筑密度上限
+   - jzmdxx: 建筑密度下限
+   - jzgdsx: 建筑高度上限
+   - jzgdxx: 建筑高度下限
+   - ldlxx: 绿地率下限
+   - ldlxx: 绿地率下限
+   - shape: 地块图形wkt
+
+2. 公告地块表 (sde.ecgap_klyzy):
+   - id: 主键ID
+   - xzqmc: 所属区县(行政区代码)
+   - xzqdm: 行政区代码(6位,前2位代表省,前4位代表市,前6位代表区县)
+   - dkmc: 地块名称
+   - dkid: 地块id
+   - address: 土地坐落
+   - dkmj: 土地面积(单位:亩)
+   - tdyt: 土地用途
+   - shape: 地块图形wkt
+   - sfsj: 是否上架(1表示已上架,0表示未上架)
+
+重要注意事项:
+1. 只准生成查询的SQL语句,不可生成任何修改数据的语句,包括但不限于:
+   - UPDATE
+   - DELETE
+   - INSERT
+   - TRUNCATE
+   - DROP
+   - ALTER
+   等修改数据的操作
+
+2. 所有查询必须包含 shape is not null 条件,以过滤掉所有空图形数据
+
+3. 面积字段和单位说明:
+   - 控制性详细规划表(sde.kzxxxgh):使用 ydmj 字段,单位为公顷
+   - 公告地块表(sde.ecgap_klyzy):使用 dkmj 字段,单位为亩
+   - 注意单位换算:1公顷 = 15亩
+
+4. 其他注意事项:
+   - 确保SQL语句的语法正确性
+   - 注意字段名称的准确性
+   - 合理使用索引字段(如id、xzqmc等)
+   - 对于空间查询,注意使用正确的空间函数和坐标系
+
+历史对话和相似示例:
+{chat_history}
+
+用户问题:{question}
+
+请根据以上字段信息和注意事项,生成符合要求的SQL查询语句。在生成SQL时,请确保:
+1. 只使用SELECT语句
+2. 包含shape is not null条件
+3. 正确使用面积字段和单位
+4. 遵循其他注意事项
+
+请按照以下格式输出:
+1. Question: 分析用户问题
+2. Thought: 思考查询逻辑
+3. Plan: 制定查询计划
+4. SQL: 生成SQL代码
+"""
+
+def get_prompt():
+    """
+    获取提示词模板
+    """
+    return PROMPT_TEMPLATE 

BIN
landsite_agent/requirements.txt


+ 191 - 0
landsite_agent/sql_generator.py

@@ -0,0 +1,191 @@
+from langchain_openai import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate
+from langchain.chains import LLMChain
+from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain_community.vectorstores import FAISS
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from langchain_core.documents import Document
+import os
+from prompt_template import get_prompt
+from fastapi.responses import StreamingResponse
+import json
+import asyncio
+import numpy as np
+from config import get_model_config
+import pandas as pd
+from typing import List, Dict, Any
+import traceback
+
+class SQLGenerator:
+    def __init__(self, model_type: str = "openai"):
+        # 获取模型配置
+        model_config = get_model_config(model_type)
+
+        # 初始化LLM
+        self.llm = ChatOpenAI(
+            model_name=model_config.model_name,
+            temperature=model_config.temperature,
+            api_key=model_config.api_key,
+            base_url=model_config.api_base,
+            max_tokens=model_config.max_tokens,
+            streaming=True
+        )
+
+        # 初始化提示词模板
+        self.prompt = ChatPromptTemplate.from_template(get_prompt())
+        
+        # 构建链式调用
+        self.chain = (
+            {
+                "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
+                "question": lambda x: x["question"]
+            }
+            | self.prompt
+            | self.llm
+            | StrOutputParser()
+        )
+
+        # 初始化本地m3e-base模型
+        self.model_path = r"E:\项目临时\AI大模型\m3e-base"
+        self.embeddings = HuggingFaceEmbeddings(
+            model_name=self.model_path,
+            model_kwargs={'device': 'cpu'},
+            encode_kwargs={'normalize_embeddings': True}
+        )
+
+        # 加载示例数据并创建向量数据库
+        self.examples = self._load_examples()
+        self.vectorstore = self._build_vectorstore()
+
+    def _load_examples(self):
+        """加载示例数据"""
+        with open('examples.json', 'r', encoding='utf-8') as f:
+            return json.load(f)
+
+    def _build_vectorstore(self):
+        """构建FAISS向量数据库"""
+        # 准备Document对象列表
+        documents = []
+        for example in self.examples:
+            doc = Document(
+                page_content=example['query'],
+                metadata={
+                    'query_type': example['query_type'],
+                    'plan': example['plan'],
+                    'sql_code': example['sql_code']
+                }
+            )
+            documents.append(doc)
+
+        # 创建FAISS向量数据库
+        vectorstore = FAISS.from_documents(
+            documents=documents,
+            embedding=self.embeddings
+        )
+
+        return vectorstore
+
+    def _get_similar_examples(self, query: str, k: int = 3):
+        """获取最相似的示例"""
+        # 使用FAISS搜索相似示例
+        docs = self.vectorstore.similarity_search_with_score(query, k=k)
+
+        # 格式化返回结果
+        similar_examples = []
+        for doc, score in docs:
+            similar_examples.append({
+                'query_type': doc.metadata['query_type'],
+                'query': doc.page_content,
+                'plan': doc.metadata['plan'],
+                'sql_code': doc.metadata['sql_code'],
+                'similarity_score': float(score)
+            })
+
+        return similar_examples
+
+    def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
+        """格式化聊天历史"""
+
+        if len(similar_examples) == 0:
+            return ""
+
+        chat_history = "基于以下相似示例:\n\n"
+        for i, example in enumerate(similar_examples, 1):
+            chat_history += f"示例 {i}:\n"
+            chat_history += f"问题: {example['query']}\n"
+            chat_history += f"计划: {example['plan']}\n"
+            chat_history += f"SQL: {example['sql_code']}\n"
+            chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
+
+        print(chat_history+"  !!!!!")
+        return chat_history
+
+    async def generate_sql_stream(self, query_description: str):
+        """
+        流式生成SQL查询语句
+        """
+        try:
+            # 开始生成SQL
+            yield json.dumps({"type": "start", "message": "开始生成SQL查询..."}) + "\n"
+
+            # 获取相似示例
+            similar_examples = self._get_similar_examples(query_description)
+            yield json.dumps({
+                "type": "similar_examples",
+                "content": similar_examples
+            }, ensure_ascii=False) + "\n"
+
+            # 准备输入数据
+            chain_input = {
+                "similar_examples": similar_examples,
+                "question": query_description
+            }
+
+            print("Chain input:", chain_input)
+
+            try:
+                # 流式生成SQL
+                async for chunk in self.chain.astream(chain_input):
+                    if chunk:
+                        yield json.dumps({
+                            "type": "sql_generation",
+                            "content": chunk
+                        }) + "\n"
+
+                yield json.dumps({"type": "end", "message": "SQL生成完成"}, ensure_ascii=False) + "\n"
+            except Exception as e:
+                print(f"Error during streaming: {str(e)}")
+                print(f"Error type: {type(e)}")
+                print(f"Error details: {traceback.format_exc()}")
+                yield json.dumps({
+                    "type": "error",
+                    "message": f"生成SQL时发生错误: {str(e)}"
+                }) + "\n"
+
+        except Exception as e:
+            traceback.print_exc()
+            yield json.dumps({
+                "type": "error",
+                "message": str(e)
+            }) + "\n"
+
+    async def execute_sql(self, sql: str, db_connection) -> dict:
+        """
+        执行SQL查询并返回结果
+        """
+        try:
+            result = db_connection.execute(sql)
+            columns = result.keys()
+            data = [dict(zip(columns, row)) for row in result.fetchall()]
+            return {
+                "status": "success",
+                "data": data,
+                "columns": columns
+            }
+        except Exception as e:
+            traceback.print_exc()
+            return {
+                "status": "error",
+                "message": str(e)
+            }