refactor: 数据层重构 — 统一 SQLite 访问层 + 多脚本双写

新建 mofin_db.py 共享数据库模块:
- get_conn() 统一连接管理 (WAL + Row factory + 外键)
- init_all_tables() 幂等建表 (12张表: market/sector/stock/kline/fundamentals/sectors/holdings/strategies/watchlist/candidates/score_history/events/evaluations)
- write_market_snapshot() 市场快照双写
- write_klines() K线数据双写 (stocks + daily/weekly/monthly + fundamentals)
- write_price_event() 价格事件双写
- migrate_stock_sectors() 一次性迁移 stock_sector_map.json
- query_*() 通用查询函数 (sector_trend/top_inflow/consecutive_inflow/market_mood/db_stats)

重构现有脚本:
- market_watch.py: 删除内联 DB 代码,改用 mofin_db
- multi_timeframe.py: _save_local_history() 加 SQLite 双写
- price_monitor.py: record_event() 加 SQLite 双写
- mofin_query.py: 改用 mofin_db 查询函数

新增:
- migrate_sectors.py: 一次性迁移脚本

清理:
- get_realtime_prices.py: 死代码 (只读 portfolio.json,不调API)
This commit is contained in:
hmo
2026-06-20 16:25:36 +08:00
parent 8926b11090
commit 0924cf3124
7 changed files with 568 additions and 308 deletions
+462
View File
@@ -0,0 +1,462 @@
#!/usr/bin/env python3
"""mofin_db.py — MoFin 统一数据库访问层
所有脚本通过此模块访问 mofin.db,避免重复建表/连接逻辑。
用法:
from mofin_db import get_conn, write_market_snapshot, write_klines, ...
设计原则:
- 幂等建表(CREATE TABLE IF NOT EXISTS
- WAL 模式 + 外键约束
- 所有写操作返回 (success: bool, detail: str)
- JSON 写入由调用方负责,本模块只写 SQLite
"""
import sqlite3
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
DATA_DIR = Path(__file__).parent / "data"
DB_PATH = DATA_DIR / "mofin.db"
# ═══════════════════════════════════════════════════════════
# 连接管理
# ═══════════════════════════════════════════════════════════
def get_conn() -> sqlite3.Connection:
"""获取数据库连接(WAL 模式,外键约束,Row 工厂)"""
DATA_DIR.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
# ═══════════════════════════════════════════════════════════
# 建表(幂等)
# ═══════════════════════════════════════════════════════════
def init_all_tables(conn: sqlite3.Connection):
"""创建全部表(幂等,已存在则跳过)"""
conn.executescript("""
-- 市场快照
CREATE TABLE IF NOT EXISTS market_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'ths',
up_ratio REAL,
mood TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE INDEX IF NOT EXISTS idx_snapshots_time ON market_snapshots(timestamp);
-- 板块快照
CREATE TABLE IF NOT EXISTS sector_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
snapshot_id INTEGER NOT NULL REFERENCES market_snapshots(id),
name TEXT NOT NULL,
change_pct REAL,
up_count INTEGER,
down_count INTEGER,
net_inflow REAL,
lead_stock TEXT,
lead_stock_change REAL,
volume REAL,
turnover REAL
);
CREATE INDEX IF NOT EXISTS idx_sector_name ON sector_snapshots(name);
CREATE INDEX IF NOT EXISTS idx_sector_snapshot ON sector_snapshots(snapshot_id);
CREATE INDEX IF NOT EXISTS idx_sector_name_time ON sector_snapshots(name, snapshot_id);
-- 个股
CREATE TABLE IF NOT EXISTS stocks (
code TEXT PRIMARY KEY,
name TEXT NOT NULL,
exchange TEXT DEFAULT 'SH',
type TEXT DEFAULT 'A',
updated_at TEXT
);
-- K线(日/周/月)
CREATE TABLE IF NOT EXISTS stock_daily (
code TEXT NOT NULL REFERENCES stocks(code),
date TEXT NOT NULL,
open REAL, close REAL, high REAL, low REAL,
volume REAL, amount REAL,
PRIMARY KEY (code, date)
);
CREATE TABLE IF NOT EXISTS stock_weekly (
code TEXT NOT NULL REFERENCES stocks(code),
date TEXT NOT NULL,
open REAL, close REAL, high REAL, low REAL,
volume REAL,
PRIMARY KEY (code, date)
);
CREATE TABLE IF NOT EXISTS stock_monthly (
code TEXT NOT NULL REFERENCES stocks(code),
date TEXT NOT NULL,
open REAL, close REAL, high REAL, low REAL,
volume REAL,
PRIMARY KEY (code, date)
);
-- 基本面
CREATE TABLE IF NOT EXISTS stock_fundamentals (
code TEXT PRIMARY KEY REFERENCES stocks(code),
pe REAL, pb REAL, eps REAL,
mcap_total REAL, mcap_flow REAL,
updated_at TEXT
);
-- 板块成分映射
CREATE TABLE IF NOT EXISTS stock_sectors (
code TEXT NOT NULL REFERENCES stocks(code),
sector_name TEXT NOT NULL,
source TEXT DEFAULT 'ths',
updated_at TEXT DEFAULT (datetime('now','localtime')),
PRIMARY KEY (code, sector_name)
);
CREATE INDEX IF NOT EXISTS idx_stock_sector ON stock_sectors(sector_name);
-- 持仓
CREATE TABLE IF NOT EXISTS holdings (
code TEXT PRIMARY KEY REFERENCES stocks(code),
name TEXT NOT NULL,
shares INTEGER NOT NULL,
cost REAL,
position_pct REAL,
added_at TEXT,
is_active INTEGER DEFAULT 1,
closed_at TEXT,
close_pnl REAL
);
-- 持仓策略
CREATE TABLE IF NOT EXISTS holding_strategies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES holdings(code),
version INTEGER DEFAULT 1,
stop_loss REAL,
take_profit REAL,
entry_low REAL,
entry_high REAL,
strategy_type TEXT DEFAULT 'holding',
source TEXT,
reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime')),
superseded_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_strategy_code ON holding_strategies(code);
-- 自选股
CREATE TABLE IF NOT EXISTS watchlist_stocks (
code TEXT PRIMARY KEY REFERENCES stocks(code),
name TEXT NOT NULL,
added_at TEXT DEFAULT (datetime('now','localtime')),
is_active INTEGER DEFAULT 1
);
-- 候选池
CREATE TABLE IF NOT EXISTS candidates (
code TEXT PRIMARY KEY REFERENCES stocks(code),
name TEXT NOT NULL,
sector TEXT,
reason TEXT,
entry_range TEXT,
stop_loss REAL,
target REAL,
zhiwei_star REAL,
zhiwei_reviewed INTEGER DEFAULT 0,
zhiwei_reviewed_at TEXT,
promoted INTEGER DEFAULT 0,
promoted_at TEXT,
dropped INTEGER DEFAULT 0,
drop_reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
-- 候选评分历史
CREATE TABLE IF NOT EXISTS candidate_score_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES candidates(code),
score REAL NOT NULL,
source TEXT NOT NULL,
reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE INDEX IF NOT EXISTS idx_candidate_history ON candidate_score_history(code, created_at);
-- 价格事件
CREATE TABLE IF NOT EXISTS price_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES stocks(code),
name TEXT,
event_type TEXT NOT NULL,
price REAL,
trigger_value TEXT,
event_label TEXT,
created_at TEXT DEFAULT (datetime('now','localtime')),
date TEXT
);
CREATE INDEX IF NOT EXISTS idx_events_code ON price_events(code);
CREATE INDEX IF NOT EXISTS idx_events_date ON price_events(date);
-- 策略评估记录
CREATE TABLE IF NOT EXISTS strategy_evaluations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES stocks(code),
eval_type TEXT NOT NULL,
status TEXT DEFAULT 'pending',
old_stop_loss REAL,
new_stop_loss REAL,
old_tp REAL,
new_tp REAL,
reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
""")
conn.commit()
# ═══════════════════════════════════════════════════════════
# 市场快照写入
# ═══════════════════════════════════════════════════════════
def write_market_snapshot(conn: sqlite3.Connection, market_data: dict) -> tuple[bool, str, Optional[int]]:
"""写入一次市场采集到 market_snapshots + sector_snapshots
Returns: (ok, message, snapshot_id)
"""
try:
cur = conn.execute(
"INSERT INTO market_snapshots (timestamp, source, up_ratio, mood) VALUES (?, ?, ?, ?)",
(market_data["timestamp"], market_data.get("source", "unknown"),
market_data.get("up_ratio", 0), market_data.get("mood", "unknown")),
)
sid = cur.lastrowid
sectors = market_data.get("sectors", [])
rows = [(sid, s.get("name", ""), s.get("change", 0),
s.get("up_count"), s.get("down_count"), s.get("net_inflow"),
s.get("lead_stock"), s.get("lead_stock_change"),
s.get("volume"), s.get("turnover")) for s in sectors]
if rows:
conn.executemany(
"INSERT INTO sector_snapshots (snapshot_id, name, change_pct, up_count, down_count, "
"net_inflow, lead_stock, lead_stock_change, volume, turnover) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", rows)
conn.commit()
return True, f"snapshot_id={sid}, sectors={len(rows)}", sid
except Exception as e:
try:
conn.rollback()
except Exception:
pass
return False, str(e), None
# ═══════════════════════════════════════════════════════════
# K线写入
# ═══════════════════════════════════════════════════════════
def write_klines(conn: sqlite3.Connection, code: str, name: str,
daily: list = None, weekly: list = None, monthly: list = None,
fundamentals: dict = None) -> bool:
"""将个股K线数据双写 SQLite
Args:
code: 股票代码
name: 股票名称
daily/weekly/monthly: [{date, open, close, high, low, volume}, ...]
fundamentals: {pe, pb, eps, mcap_total, mcap_flow}
"""
try:
# 判断交易所
raw = str(code)
if len(raw) == 5 and raw.isdigit():
exchange, stype = "HK", "H"
elif raw.startswith(("6", "5", "9")):
exchange, stype = "SH", "A"
else:
exchange, stype = "SZ", "A"
# stocks 表(INSERT OR REPLACE
conn.execute(
"INSERT OR REPLACE INTO stocks (code, name, exchange, type, updated_at) VALUES (?, ?, ?, ?, ?)",
(code, name, exchange, stype, datetime.now().isoformat()))
# K线数据
for period, table, data in [
("daily", "stock_daily", daily),
("weekly", "stock_weekly", weekly),
("monthly", "stock_monthly", monthly),
]:
if not data:
continue
rows = [(code, d.get("date", ""), d.get("open"), d.get("close"),
d.get("high"), d.get("low"), d.get("volume"),
d.get("amount") if period == "daily" else None) for d in data]
if period == "daily":
conn.executemany(
f"INSERT OR REPLACE INTO {table} (code, date, open, close, high, low, volume, amount) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)", rows)
else:
conn.executemany(
f"INSERT OR REPLACE INTO {table} (code, date, open, close, high, low, volume) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
[(r[0], r[1], r[2], r[3], r[4], r[5], r[6]) for r in rows])
# 基本面
if fundamentals:
conn.execute(
"INSERT OR REPLACE INTO stock_fundamentals (code, pe, pb, eps, mcap_total, mcap_flow, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(code, fundamentals.get("pe"), fundamentals.get("pb"),
fundamentals.get("eps"), fundamentals.get("mcap_total"),
fundamentals.get("mcap_flow"), datetime.now().isoformat()))
conn.commit()
return True
except Exception as e:
try:
conn.rollback()
except Exception:
pass
return False
# ═══════════════════════════════════════════════════════════
# 价格事件写入
# ═══════════════════════════════════════════════════════════
def write_price_event(conn: sqlite3.Connection, code: str, name: str,
event_type: str, price: float, trigger_value: str,
event_label: str = "") -> bool:
"""写入一条价格事件"""
try:
now = datetime.now()
conn.execute(
"INSERT INTO price_events (code, name, event_type, price, trigger_value, event_label, date) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(code, name, event_type, round(price, 2), trigger_value,
event_label, now.strftime("%Y-%m-%d")))
conn.commit()
return True
except Exception:
try:
conn.rollback()
except Exception:
pass
return False
# ═══════════════════════════════════════════════════════════
# 板块成分迁移
# ═══════════════════════════════════════════════════════════
def migrate_stock_sectors(conn: sqlite3.Connection) -> tuple[int, int]:
"""从 stock_sector_map.json 迁移到 stock_sectors 表
Returns: (migrated_stocks, total_mappings)
"""
sector_map_path = DATA_DIR / "stock_sector_map.json"
if not sector_map_path.exists():
return 0, 0
try:
with open(sector_map_path, encoding="utf-8") as f:
data = json.load(f)
except Exception:
return 0, 0
# 过滤元数据字段
mappings = [(code, sectors) for code, sectors in data.items()
if not code.startswith("_") and isinstance(sectors, list)]
total = 0
for code, sectors in mappings:
for sector in sectors:
try:
conn.execute(
"INSERT OR IGNORE INTO stock_sectors (code, sector_name, source) VALUES (?, ?, 'ths')",
(code, sector))
total += 1
except Exception:
pass
conn.commit()
return len(mappings), total
# ═══════════════════════════════════════════════════════════
# 查询辅助
# ═══════════════════════════════════════════════════════════
def query_sector_trend(conn: sqlite3.Connection, name: str, limit: int = 5) -> list[dict]:
"""板块最近N次趋势"""
rows = conn.execute("""
SELECT s.timestamp, ss.change_pct, ss.net_inflow,
ss.up_count, ss.down_count, ss.lead_stock, ss.lead_stock_change
FROM sector_snapshots ss
JOIN market_snapshots s ON ss.snapshot_id = s.id
WHERE ss.name = ? ORDER BY s.timestamp DESC LIMIT ?
""", (name, limit)).fetchall()
return [dict(r) for r in rows]
def query_top_inflow(conn: sqlite3.Connection, limit: int = 5) -> list[dict]:
"""最新一次资金净流入排行"""
rows = conn.execute("""
SELECT ss.name, ss.change_pct, ss.net_inflow, ss.lead_stock, s.timestamp
FROM sector_snapshots ss
JOIN market_snapshots s ON ss.snapshot_id = s.id
WHERE s.id = (SELECT MAX(id) FROM market_snapshots)
AND ss.net_inflow IS NOT NULL
ORDER BY ss.net_inflow DESC LIMIT ?
""", (limit,)).fetchall()
return [dict(r) for r in rows]
def query_consecutive_inflow(conn: sqlite3.Connection, days: int = 3) -> list[dict]:
"""连续N次净流入的板块"""
rows = conn.execute("""
SELECT name, COUNT(*) as times, ROUND(AVG(net_inflow), 2) as avg_inflow,
ROUND(AVG(change_pct), 2) as avg_change
FROM sector_snapshots ss
JOIN market_snapshots s ON ss.snapshot_id = s.id
WHERE s.id > (SELECT MAX(id) - ? FROM market_snapshots)
AND net_inflow > 0
GROUP BY name HAVING COUNT(*) >= ?
ORDER BY avg_inflow DESC
""", (days, days)).fetchall()
return [dict(r) for r in rows]
def query_market_mood(conn: sqlite3.Connection, limit: int = 10) -> list[dict]:
"""市场情绪趋势"""
rows = conn.execute("""
SELECT timestamp, source, up_ratio, mood
FROM market_snapshots ORDER BY timestamp DESC LIMIT ?
""", (limit,)).fetchall()
return [dict(r) for r in rows]
def query_db_stats(conn: sqlite3.Connection) -> dict:
"""数据库概览"""
snap_count = conn.execute("SELECT COUNT(*) FROM market_snapshots").fetchone()[0]
sector_count = conn.execute("SELECT COUNT(*) FROM sector_snapshots").fetchone()[0]
stock_count = conn.execute("SELECT COUNT(*) FROM stocks").fetchone()[0]
kline_count = conn.execute(
"SELECT COUNT(*) FROM stock_daily").fetchone()[0]
event_count = conn.execute("SELECT COUNT(*) FROM price_events").fetchone()[0]
latest = conn.execute(
"SELECT timestamp, source FROM market_snapshots ORDER BY id DESC LIMIT 1").fetchone()
return {
"snapshots": snap_count, "sector_rows": sector_count,
"stocks": stock_count, "daily_klines": kline_count,
"price_events": event_count,
"latest_snapshot": dict(latest) if latest else None,
}