database.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import asyncio
  2. import asyncpg
  3. from typing import List, Dict, Any
  4. import json
  5. # 数据库配置
  6. db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
  7. "mysql": {
  8. # MySQL配置留空,等待后续添加
  9. },
  10. "pg": {
  11. "host": "10.10.9.243",
  12. "port": "5432",
  13. "database": "sde",
  14. "user": "sde",
  15. "password": "sde",
  16. }
  17. }
  18. # 向量模型配置
  19. vector_model_config: dict[str, dict[str, str]] = {
  20. "m3e-base": {
  21. "model_path": r"E:\项目临时\AI大模型\m3e-base",
  22. "device": "cpu"
  23. },
  24. # 可扩展其他向量模型
  25. }
  26. class Database:
  27. def __init__(self, db_type: str = "pg"):
  28. self.pool = None
  29. if db_type not in db_list:
  30. raise ValueError(f"Unsupported database type: {db_type}")
  31. self.config = db_list[db_type]
  32. async def connect(self):
  33. """创建数据库连接池"""
  34. if not self.pool:
  35. self.pool = await asyncpg.create_pool(
  36. host=self.config["host"],
  37. port=self.config["port"],
  38. user=self.config["user"],
  39. password=self.config["password"],
  40. database=self.config["database"],
  41. min_size=1,
  42. max_size=10
  43. )
  44. async def close(self):
  45. """关闭数据库连接池"""
  46. if self.pool:
  47. await self.pool.close()
  48. self.pool = None
  49. async def execute_query(self, sql: str) -> List[Dict[str, Any]]:
  50. """
  51. 执行SQL查询并返回结果
  52. """
  53. if not self.pool:
  54. await self.connect()
  55. try:
  56. async with self.pool.acquire() as conn:
  57. # 执行查询
  58. rows = await conn.fetch(sql)
  59. # 将结果转换为字典列表
  60. result = []
  61. for row in rows:
  62. # 处理每个字段的值
  63. row_dict = {}
  64. for key, value in row.items():
  65. # 处理特殊类型
  66. if isinstance(value, (dict, list)):
  67. row_dict[key] = json.dumps(value, ensure_ascii=False)
  68. else:
  69. row_dict[key] = value
  70. result.append(row_dict)
  71. return result
  72. except Exception as e:
  73. print(f"Database error: {str(e)}")
  74. raise
  75. async def execute_transaction(self, sql_list: List[str]) -> bool:
  76. """
  77. 执行事务
  78. """
  79. if not self.pool:
  80. await self.connect()
  81. try:
  82. async with self.pool.acquire() as conn:
  83. async with conn.transaction():
  84. for sql in sql_list:
  85. await conn.execute(sql)
  86. return True
  87. except Exception as e:
  88. print(f"Transaction error: {str(e)}")
  89. return False
  90. async def test_connection(self) -> bool:
  91. """
  92. 测试数据库连接
  93. """
  94. try:
  95. if not self.pool:
  96. await self.connect()
  97. async with self.pool.acquire() as conn:
  98. await conn.execute('SELECT 1')
  99. return True
  100. except Exception as e:
  101. print(f"Connection test failed: {str(e)}")
  102. return False
  103. if __name__ == "__main__":
  104. db = Database()
  105. asyncio.run(db.test_connection())