123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- from langchain_openai import ChatOpenAI
- from langchain.prompts import ChatPromptTemplate
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.documents import Document
- from langchain_huggingface.embeddings.huggingface import HuggingFaceEmbeddings
- from langchain_community.vectorstores import FAISS
- from prompt_template import get_prompt
- import json
- import traceback
- from config import get_model_config
- from typing import List, Dict, Any
- from database import Database
- import re
- import os
- from dotenv import load_dotenv
- # 加载环境变量
- load_dotenv('config.env')
- class SQLGenerator:
- def __init__(self, model_type: str = None):
- # 获取模型配置
- model_type = model_type or os.getenv('DEFAULT_MODEL_TYPE', 'openai')
- model_config = get_model_config(model_type)
- # 初始化LLM
- self.llm = ChatOpenAI(
- model_name=model_config.model_name,
- api_key=model_config.api_key,
- base_url=model_config.api_base,
- streaming=True
- )
- # 初始化提示词模板
- self.prompt = ChatPromptTemplate.from_template(get_prompt())
- # 构建链式调用
- self.chain = (
- {
- "chat_history": lambda x: self._format_chat_history(x["similar_examples"]),
- "question": lambda x: x["question"]
- }
- | self.prompt
- | self.llm
- | StrOutputParser()
- )
- # 从环境变量获取模型路径
- model_path = os.getenv('MODEL_PATH')
- if not model_path:
- raise ValueError("MODEL_PATH environment variable is not set")
- # 初始化本地m3e-base模型
- self.embeddings = HuggingFaceEmbeddings(
- model_name=model_path,
- model_kwargs={'device': "cpu"},
- encode_kwargs={'normalize_embeddings': True}
- )
- # 加载示例数据并创建向量数据库
- self.examples = self._load_examples()
- self.vectorstore = self._build_vectorstore()
- # 初始化数据库连接
- self.db = Database()
- def _extract_sql_from_markdown(self, text: str) -> str:
- """
- 从markdown格式的文本中提取SQL语句
- """
- # 使用正则表达式匹配markdown代码块中的SQL
- sql_pattern = r"```sql\n(.*?)\n```"
- match = re.search(sql_pattern, text, re.DOTALL)
- if match:
- # 提取SQL语句并清理
- sql = match.group(1).strip()
- # 移除可能的注释
- sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
- # 移除多余的空行
- sql = re.sub(r'\n\s*\n', '\n', sql)
- return sql.strip()
- return text.strip()
- def _load_examples(self):
- """加载示例数据"""
- with open('examples.json', 'r', encoding='utf-8') as f:
- return json.load(f)
- def _build_vectorstore(self):
- """构建FAISS向量数据库"""
- # 准备Document对象列表
- documents = []
- for example in self.examples:
- doc = Document(
- page_content=example['query'],
- metadata={
- 'query_type': example['query_type'],
- 'sql_code': example['sql_code']
- }
- )
- documents.append(doc)
- # 创建FAISS向量数据库
- vectorstore = FAISS.from_documents(
- documents=documents,
- embedding=self.embeddings
- )
- return vectorstore
- def _get_similar_examples(self, query: str, k: int = 3):
- """获取最相似的示例"""
- # 使用FAISS搜索相似示例
- docs = self.vectorstore.similarity_search_with_score(query, k=k)
- # 格式化返回结果
- similar_examples = []
- for doc, score in docs:
- similar_examples.append({
- 'query_type': doc.metadata['query_type'],
- 'query': doc.page_content,
- 'sql_code': doc.metadata['sql_code'],
- 'similarity_score': float(score)
- })
- return similar_examples
- def _format_chat_history(self, similar_examples: List[Dict[str, Any]]) -> str:
- """格式化聊天历史"""
- if len(similar_examples) == 0:
- return ""
- chat_history = "基于以下相似示例:\n\n"
- for i, example in enumerate(similar_examples, 1):
- chat_history += f"示例 {i}:\n"
- chat_history += f"问题: {example['query']}\n"
- chat_history += f"SQL: {example['sql_code']}\n"
- chat_history += f"相似度: {example['similarity_score']:.2f}\n\n"
- return chat_history
- async def generate_sql_stream(self, query_description: str):
- """
- 流式生成SQL查询语句
- """
- try:
- # 获取相似示例
- similar_examples = self._get_similar_examples(query_description)
- simliar_example_format_dump = json.dumps({
- "type": "similar_examples",
- "content": similar_examples
- }, ensure_ascii=False) + "\n"
- print(simliar_example_format_dump)
- # 准备输入数据
- chain_input = {
- "similar_examples": similar_examples,
- "question": query_description
- }
- print("Chain input:", chain_input)
- try:
- # 构建完整的提示词
- formatted_prompt = self.prompt.format(
- chat_history=self._format_chat_history(similar_examples),
- question=query_description
- )
- print("Formatted prompt:", formatted_prompt)
- # 流式生成SQL
- full_response = ""
- async for chunk in self.chain.astream(chain_input):
- if chunk:
- try:
- # Remove think tags if using zjstai model
- if os.getenv('DEFAULT_MODEL_TYPE') == 'zjstai':
- chunk = re.sub(r'<think>', '', chunk, flags=re.DOTALL)
- chunk = re.sub(r'</think>', '', chunk, flags=re.DOTALL)
- full_response += chunk
- yield json.dumps({
- "type": "sql_generation",
- "content": chunk
- }, ensure_ascii=False) + "\n"
- except Exception as chunk_error:
- print(f"Error processing chunk: {str(chunk_error)}")
- continue
- # 使用正则表达式提取SQL代码块
- sql_pattern = r"```sql\n(.*?)\n```"
- sql_match = re.search(sql_pattern, full_response, re.DOTALL)
- if sql_match:
- # 提取并清理SQL
- sql_content = sql_match.group(1).strip()
- # 移除注释
- sql_content = re.sub(r'--.*$', '', sql_content, flags=re.MULTILINE)
- # 移除多余空行
- sql_content = re.sub(r'\n\s*\n', '\n', sql_content)
- # 确保SQL语句完整
- if not sql_content.strip().endswith(';'):
- sql_content = sql_content.strip() + ';'
- yield json.dumps({
- "type": "sql_result",
- "content": sql_content
- }, ensure_ascii=False) + "\n"
- else:
- # 如果没有找到SQL代码块,返回完整响应
- yield json.dumps({
- "type": "sql_result",
- "content": full_response
- }, ensure_ascii=False) + "\n"
- yield json.dumps({"type": "end", "content": "SQL生成完成"}, ensure_ascii=False) + "\n"
- except Exception as e:
- traceback.print_exc()
- yield json.dumps({
- "type": "error",
- "content": f"生成SQL时发生错误: {str(e)}"
- }, ensure_ascii=False) + "\n"
- except Exception as e:
- traceback.print_exc()
- yield json.dumps({
- "type": "error",
- "content": str(e)
- }, ensure_ascii=False) + "\n"
- async def execute_sql(self, sql: str) -> dict:
- """
- 执行SQL查询并返回结果
- """
- print(sql)
- try:
- # 确保SQL是完整的
- if not sql.strip().endswith(';'):
- sql = sql.strip() + ';'
- result = await self.db.execute_query(sql)
- return {
- "status": "success",
- "data": result
- }
- except Exception as e:
- traceback.print_exc()
- return {
- "status": "error",
- "message": str(e)
- }
- async def close(self):
- """关闭数据库连接"""
- await self.db.close()
|