database.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import asyncio
  2. import os
  3. from dotenv import load_dotenv
  4. import asyncpg
  5. from typing import List, Dict, Any
  6. import json
  7. # 加载config.env文件
  8. load_dotenv("config.env")
  9. # 数据库配置
  10. DB_CONFIG = {
  11. "host": os.getenv("DB_HOST"),
  12. "port": os.getenv("DB_PORT"),
  13. "database": os.getenv("DB_NAME"),
  14. "user": os.getenv("DB_USER"),
  15. "password": os.getenv("DB_PASSWORD")
  16. }
  17. class Database:
  18. def __init__(self):
  19. self.pool = None
  20. self.config = DB_CONFIG
  21. async def connect(self):
  22. """创建数据库连接池"""
  23. if not self.pool:
  24. self.pool = await asyncpg.create_pool(
  25. host=self.config["host"],
  26. port=self.config["port"],
  27. user=self.config["user"],
  28. password=self.config["password"],
  29. database=self.config["database"],
  30. min_size=1,
  31. max_size=10
  32. )
  33. async def close(self):
  34. """关闭数据库连接池"""
  35. if self.pool:
  36. await self.pool.close()
  37. self.pool = None
  38. async def execute_query(self, sql: str) -> List[Dict[str, Any]]:
  39. """
  40. 执行SQL查询并返回结果
  41. """
  42. if not self.pool:
  43. await self.connect()
  44. try:
  45. async with self.pool.acquire() as conn:
  46. # 执行查询
  47. rows = await conn.fetch(sql)
  48. # 将结果转换为字典列表
  49. result = []
  50. for row in rows:
  51. # 处理每个字段的值
  52. row_dict = {}
  53. for key, value in row.items():
  54. # 处理特殊类型
  55. if isinstance(value, (dict, list)):
  56. row_dict[key] = json.dumps(value, ensure_ascii=False)
  57. else:
  58. row_dict[key] = value
  59. result.append(row_dict)
  60. return result
  61. except Exception as e:
  62. print(f"Database error: {str(e)}")
  63. raise
  64. async def execute_transaction(self, sql_list: List[str]) -> bool:
  65. """
  66. 执行事务
  67. """
  68. if not self.pool:
  69. await self.connect()
  70. try:
  71. async with self.pool.acquire() as conn:
  72. async with conn.transaction():
  73. for sql in sql_list:
  74. await conn.execute(sql)
  75. return True
  76. except Exception as e:
  77. print(f"Transaction error: {str(e)}")
  78. return False
  79. async def test_connection(self) -> bool:
  80. """
  81. 测试数据库连接
  82. """
  83. try:
  84. if not self.pool:
  85. await self.connect()
  86. async with self.pool.acquire() as conn:
  87. await conn.execute('SELECT 1')
  88. return True
  89. except Exception as e:
  90. print(f"Connection test failed: {str(e)}")
  91. return False
  92. if __name__ == "__main__":
  93. db = Database()
  94. asyncio.run(db.test_connection())