database.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import asyncio
  2. import asyncpg
  3. from typing import List, Dict, Any
  4. import json
  5. # 数据库配置
  6. db_list: dict[str, dict[Any, Any] | dict[str, str]] = {
  7. "mysql": {
  8. # MySQL配置留空,等待后续添加
  9. },
  10. "pg": {
  11. "host": "10.10.9.243",
  12. "port": "5432",
  13. "database": "sde",
  14. "user": "sde",
  15. "password": "sde"
  16. }
  17. }
  18. class Database:
  19. def __init__(self, db_type: str = "pg"):
  20. self.pool = None
  21. if db_type not in db_list:
  22. raise ValueError(f"Unsupported database type: {db_type}")
  23. self.config = db_list[db_type]
  24. async def connect(self):
  25. """创建数据库连接池"""
  26. if not self.pool:
  27. self.pool = await asyncpg.create_pool(
  28. host=self.config["host"],
  29. port=self.config["port"],
  30. user=self.config["user"],
  31. password=self.config["password"],
  32. database=self.config["database"],
  33. min_size=1,
  34. max_size=10
  35. )
  36. async def close(self):
  37. """关闭数据库连接池"""
  38. if self.pool:
  39. await self.pool.close()
  40. self.pool = None
  41. async def execute_query(self, sql: str) -> List[Dict[str, Any]]:
  42. """
  43. 执行SQL查询并返回结果
  44. """
  45. if not self.pool:
  46. await self.connect()
  47. try:
  48. async with self.pool.acquire() as conn:
  49. # 执行查询
  50. rows = await conn.fetch(sql)
  51. # 将结果转换为字典列表
  52. result = []
  53. for row in rows:
  54. # 处理每个字段的值
  55. row_dict = {}
  56. for key, value in row.items():
  57. # 处理特殊类型
  58. if isinstance(value, (dict, list)):
  59. row_dict[key] = json.dumps(value, ensure_ascii=False)
  60. else:
  61. row_dict[key] = value
  62. result.append(row_dict)
  63. return result
  64. except Exception as e:
  65. print(f"Database error: {str(e)}")
  66. raise
  67. async def execute_transaction(self, sql_list: List[str]) -> bool:
  68. """
  69. 执行事务
  70. """
  71. if not self.pool:
  72. await self.connect()
  73. try:
  74. async with self.pool.acquire() as conn:
  75. async with conn.transaction():
  76. for sql in sql_list:
  77. await conn.execute(sql)
  78. return True
  79. except Exception as e:
  80. print(f"Transaction error: {str(e)}")
  81. return False
  82. async def test_connection(self) -> bool:
  83. """
  84. 测试数据库连接
  85. """
  86. try:
  87. if not self.pool:
  88. await self.connect()
  89. async with self.pool.acquire() as conn:
  90. await conn.execute('SELECT 1')
  91. return True
  92. except Exception as e:
  93. print(f"Connection test failed: {str(e)}")
  94. return False
  95. if __name__ == "__main__":
  96. db = Database()
  97. asyncio.run(db.test_connection())