125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
||
"""数据库连接器 - 支持直连和SSH隧道两种模式"""
|
||
|
||
import json
|
||
import sys
|
||
from pathlib import Path
|
||
from contextlib import contextmanager
|
||
|
||
try:
|
||
import pymysql
|
||
except ImportError as e:
|
||
print(f"缺少依赖: {e}")
|
||
print("请运行: pip install pymysql")
|
||
sys.exit(1)
|
||
|
||
|
||
def load_config():
|
||
"""加载配置文件"""
|
||
config_path = Path(__file__).parent.parent / "config" / "settings.json"
|
||
if not config_path.exists():
|
||
print(f"配置文件不存在: {config_path}")
|
||
print("请复制 settings.json.example 为 settings.json 并填写配置")
|
||
sys.exit(1)
|
||
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
|
||
|
||
@contextmanager
|
||
def get_db_connection():
|
||
"""获取数据库连接(自动判断直连或SSH隧道)"""
|
||
config = load_config()
|
||
ssh_config = config.get("ssh")
|
||
db_config = config["database"]
|
||
|
||
use_ssh = ssh_config and ssh_config.get("host")
|
||
|
||
if use_ssh:
|
||
try:
|
||
from sshtunnel import SSHTunnelForwarder
|
||
import paramiko
|
||
except ImportError:
|
||
print("SSH隧道需要额外依赖: pip install paramiko sshtunnel")
|
||
sys.exit(1)
|
||
|
||
tunnel = SSHTunnelForwarder(
|
||
(ssh_config["host"], ssh_config["port"]),
|
||
ssh_username=ssh_config["username"],
|
||
ssh_password=ssh_config.get("password"),
|
||
ssh_pkey=ssh_config.get("key_file"),
|
||
remote_bind_address=(db_config["host"], db_config["port"]),
|
||
local_bind_address=("127.0.0.1",),
|
||
)
|
||
|
||
try:
|
||
tunnel.start()
|
||
|
||
connection = pymysql.connect(
|
||
host="127.0.0.1",
|
||
port=tunnel.local_bind_port,
|
||
user=db_config["username"],
|
||
password=db_config["password"],
|
||
database=db_config["database"],
|
||
charset="utf8mb4",
|
||
cursorclass=pymysql.cursors.DictCursor,
|
||
connect_timeout=config["query"]["timeout"],
|
||
)
|
||
|
||
try:
|
||
yield connection
|
||
finally:
|
||
connection.close()
|
||
finally:
|
||
tunnel.stop()
|
||
else:
|
||
connection = pymysql.connect(
|
||
host=db_config["host"],
|
||
port=db_config["port"],
|
||
user=db_config["username"],
|
||
password=db_config["password"],
|
||
database=db_config["database"],
|
||
charset="utf8mb4",
|
||
cursorclass=pymysql.cursors.DictCursor,
|
||
connect_timeout=config["query"]["timeout"],
|
||
)
|
||
|
||
try:
|
||
yield connection
|
||
finally:
|
||
connection.close()
|
||
|
||
|
||
def test_connection():
|
||
"""测试数据库连接"""
|
||
config = load_config()
|
||
use_ssh = config.get("ssh") and config["ssh"].get("host")
|
||
|
||
if use_ssh:
|
||
print("正在建立SSH隧道...")
|
||
else:
|
||
print("正在直连数据库...")
|
||
|
||
try:
|
||
with get_db_connection() as conn:
|
||
print("数据库连接成功!")
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("SELECT 1 as test")
|
||
result = cursor.fetchone()
|
||
print(f"测试查询结果: {result}")
|
||
|
||
cursor.execute("SHOW TABLES")
|
||
tables = cursor.fetchall()
|
||
print(f"\n数据库中共有 {len(tables)} 张表:")
|
||
for t in tables:
|
||
table_name = list(t.values())[0]
|
||
print(f" - {table_name}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"连接失败: {e}")
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_connection()
|