Files
MoFin/mofin_db.py
T
hmo 0924cf3124 refactor: 数据层重构 — 统一 SQLite 访问层 + 多脚本双写
新建 mofin_db.py 共享数据库模块:
- get_conn() 统一连接管理 (WAL + Row factory + 外键)
- init_all_tables() 幂等建表 (12张表: market/sector/stock/kline/fundamentals/sectors/holdings/strategies/watchlist/candidates/score_history/events/evaluations)
- write_market_snapshot() 市场快照双写
- write_klines() K线数据双写 (stocks + daily/weekly/monthly + fundamentals)
- write_price_event() 价格事件双写
- migrate_stock_sectors() 一次性迁移 stock_sector_map.json
- query_*() 通用查询函数 (sector_trend/top_inflow/consecutive_inflow/market_mood/db_stats)

重构现有脚本:
- market_watch.py: 删除内联 DB 代码,改用 mofin_db
- multi_timeframe.py: _save_local_history() 加 SQLite 双写
- price_monitor.py: record_event() 加 SQLite 双写
- mofin_query.py: 改用 mofin_db 查询函数

新增:
- migrate_sectors.py: 一次性迁移脚本

清理:
- get_realtime_prices.py: 死代码 (只读 portfolio.json,不调API)
2026-06-20 16:26:17 +08:00

463 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""mofin_db.py — MoFin 统一数据库访问层
所有脚本通过此模块访问 mofin.db,避免重复建表/连接逻辑。
用法:
from mofin_db import get_conn, write_market_snapshot, write_klines, ...
设计原则:
- 幂等建表(CREATE TABLE IF NOT EXISTS
- WAL 模式 + 外键约束
- 所有写操作返回 (success: bool, detail: str)
- JSON 写入由调用方负责,本模块只写 SQLite
"""
import sqlite3
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
DATA_DIR = Path(__file__).parent / "data"
DB_PATH = DATA_DIR / "mofin.db"
# ═══════════════════════════════════════════════════════════
# 连接管理
# ═══════════════════════════════════════════════════════════
def get_conn() -> sqlite3.Connection:
"""获取数据库连接(WAL 模式,外键约束,Row 工厂)"""
DATA_DIR.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
# ═══════════════════════════════════════════════════════════
# 建表(幂等)
# ═══════════════════════════════════════════════════════════
def init_all_tables(conn: sqlite3.Connection):
"""创建全部表(幂等,已存在则跳过)"""
conn.executescript("""
-- 市场快照
CREATE TABLE IF NOT EXISTS market_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'ths',
up_ratio REAL,
mood TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE INDEX IF NOT EXISTS idx_snapshots_time ON market_snapshots(timestamp);
-- 板块快照
CREATE TABLE IF NOT EXISTS sector_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
snapshot_id INTEGER NOT NULL REFERENCES market_snapshots(id),
name TEXT NOT NULL,
change_pct REAL,
up_count INTEGER,
down_count INTEGER,
net_inflow REAL,
lead_stock TEXT,
lead_stock_change REAL,
volume REAL,
turnover REAL
);
CREATE INDEX IF NOT EXISTS idx_sector_name ON sector_snapshots(name);
CREATE INDEX IF NOT EXISTS idx_sector_snapshot ON sector_snapshots(snapshot_id);
CREATE INDEX IF NOT EXISTS idx_sector_name_time ON sector_snapshots(name, snapshot_id);
-- 个股
CREATE TABLE IF NOT EXISTS stocks (
code TEXT PRIMARY KEY,
name TEXT NOT NULL,
exchange TEXT DEFAULT 'SH',
type TEXT DEFAULT 'A',
updated_at TEXT
);
-- K线(日/周/月)
CREATE TABLE IF NOT EXISTS stock_daily (
code TEXT NOT NULL REFERENCES stocks(code),
date TEXT NOT NULL,
open REAL, close REAL, high REAL, low REAL,
volume REAL, amount REAL,
PRIMARY KEY (code, date)
);
CREATE TABLE IF NOT EXISTS stock_weekly (
code TEXT NOT NULL REFERENCES stocks(code),
date TEXT NOT NULL,
open REAL, close REAL, high REAL, low REAL,
volume REAL,
PRIMARY KEY (code, date)
);
CREATE TABLE IF NOT EXISTS stock_monthly (
code TEXT NOT NULL REFERENCES stocks(code),
date TEXT NOT NULL,
open REAL, close REAL, high REAL, low REAL,
volume REAL,
PRIMARY KEY (code, date)
);
-- 基本面
CREATE TABLE IF NOT EXISTS stock_fundamentals (
code TEXT PRIMARY KEY REFERENCES stocks(code),
pe REAL, pb REAL, eps REAL,
mcap_total REAL, mcap_flow REAL,
updated_at TEXT
);
-- 板块成分映射
CREATE TABLE IF NOT EXISTS stock_sectors (
code TEXT NOT NULL REFERENCES stocks(code),
sector_name TEXT NOT NULL,
source TEXT DEFAULT 'ths',
updated_at TEXT DEFAULT (datetime('now','localtime')),
PRIMARY KEY (code, sector_name)
);
CREATE INDEX IF NOT EXISTS idx_stock_sector ON stock_sectors(sector_name);
-- 持仓
CREATE TABLE IF NOT EXISTS holdings (
code TEXT PRIMARY KEY REFERENCES stocks(code),
name TEXT NOT NULL,
shares INTEGER NOT NULL,
cost REAL,
position_pct REAL,
added_at TEXT,
is_active INTEGER DEFAULT 1,
closed_at TEXT,
close_pnl REAL
);
-- 持仓策略
CREATE TABLE IF NOT EXISTS holding_strategies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES holdings(code),
version INTEGER DEFAULT 1,
stop_loss REAL,
take_profit REAL,
entry_low REAL,
entry_high REAL,
strategy_type TEXT DEFAULT 'holding',
source TEXT,
reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime')),
superseded_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_strategy_code ON holding_strategies(code);
-- 自选股
CREATE TABLE IF NOT EXISTS watchlist_stocks (
code TEXT PRIMARY KEY REFERENCES stocks(code),
name TEXT NOT NULL,
added_at TEXT DEFAULT (datetime('now','localtime')),
is_active INTEGER DEFAULT 1
);
-- 候选池
CREATE TABLE IF NOT EXISTS candidates (
code TEXT PRIMARY KEY REFERENCES stocks(code),
name TEXT NOT NULL,
sector TEXT,
reason TEXT,
entry_range TEXT,
stop_loss REAL,
target REAL,
zhiwei_star REAL,
zhiwei_reviewed INTEGER DEFAULT 0,
zhiwei_reviewed_at TEXT,
promoted INTEGER DEFAULT 0,
promoted_at TEXT,
dropped INTEGER DEFAULT 0,
drop_reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
-- 候选评分历史
CREATE TABLE IF NOT EXISTS candidate_score_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES candidates(code),
score REAL NOT NULL,
source TEXT NOT NULL,
reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
CREATE INDEX IF NOT EXISTS idx_candidate_history ON candidate_score_history(code, created_at);
-- 价格事件
CREATE TABLE IF NOT EXISTS price_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES stocks(code),
name TEXT,
event_type TEXT NOT NULL,
price REAL,
trigger_value TEXT,
event_label TEXT,
created_at TEXT DEFAULT (datetime('now','localtime')),
date TEXT
);
CREATE INDEX IF NOT EXISTS idx_events_code ON price_events(code);
CREATE INDEX IF NOT EXISTS idx_events_date ON price_events(date);
-- 策略评估记录
CREATE TABLE IF NOT EXISTS strategy_evaluations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL REFERENCES stocks(code),
eval_type TEXT NOT NULL,
status TEXT DEFAULT 'pending',
old_stop_loss REAL,
new_stop_loss REAL,
old_tp REAL,
new_tp REAL,
reason TEXT,
created_at TEXT DEFAULT (datetime('now','localtime'))
);
""")
conn.commit()
# ═══════════════════════════════════════════════════════════
# 市场快照写入
# ═══════════════════════════════════════════════════════════
def write_market_snapshot(conn: sqlite3.Connection, market_data: dict) -> tuple[bool, str, Optional[int]]:
"""写入一次市场采集到 market_snapshots + sector_snapshots
Returns: (ok, message, snapshot_id)
"""
try:
cur = conn.execute(
"INSERT INTO market_snapshots (timestamp, source, up_ratio, mood) VALUES (?, ?, ?, ?)",
(market_data["timestamp"], market_data.get("source", "unknown"),
market_data.get("up_ratio", 0), market_data.get("mood", "unknown")),
)
sid = cur.lastrowid
sectors = market_data.get("sectors", [])
rows = [(sid, s.get("name", ""), s.get("change", 0),
s.get("up_count"), s.get("down_count"), s.get("net_inflow"),
s.get("lead_stock"), s.get("lead_stock_change"),
s.get("volume"), s.get("turnover")) for s in sectors]
if rows:
conn.executemany(
"INSERT INTO sector_snapshots (snapshot_id, name, change_pct, up_count, down_count, "
"net_inflow, lead_stock, lead_stock_change, volume, turnover) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", rows)
conn.commit()
return True, f"snapshot_id={sid}, sectors={len(rows)}", sid
except Exception as e:
try:
conn.rollback()
except Exception:
pass
return False, str(e), None
# ═══════════════════════════════════════════════════════════
# K线写入
# ═══════════════════════════════════════════════════════════
def write_klines(conn: sqlite3.Connection, code: str, name: str,
daily: list = None, weekly: list = None, monthly: list = None,
fundamentals: dict = None) -> bool:
"""将个股K线数据双写 SQLite
Args:
code: 股票代码
name: 股票名称
daily/weekly/monthly: [{date, open, close, high, low, volume}, ...]
fundamentals: {pe, pb, eps, mcap_total, mcap_flow}
"""
try:
# 判断交易所
raw = str(code)
if len(raw) == 5 and raw.isdigit():
exchange, stype = "HK", "H"
elif raw.startswith(("6", "5", "9")):
exchange, stype = "SH", "A"
else:
exchange, stype = "SZ", "A"
# stocks 表(INSERT OR REPLACE
conn.execute(
"INSERT OR REPLACE INTO stocks (code, name, exchange, type, updated_at) VALUES (?, ?, ?, ?, ?)",
(code, name, exchange, stype, datetime.now().isoformat()))
# K线数据
for period, table, data in [
("daily", "stock_daily", daily),
("weekly", "stock_weekly", weekly),
("monthly", "stock_monthly", monthly),
]:
if not data:
continue
rows = [(code, d.get("date", ""), d.get("open"), d.get("close"),
d.get("high"), d.get("low"), d.get("volume"),
d.get("amount") if period == "daily" else None) for d in data]
if period == "daily":
conn.executemany(
f"INSERT OR REPLACE INTO {table} (code, date, open, close, high, low, volume, amount) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)", rows)
else:
conn.executemany(
f"INSERT OR REPLACE INTO {table} (code, date, open, close, high, low, volume) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
[(r[0], r[1], r[2], r[3], r[4], r[5], r[6]) for r in rows])
# 基本面
if fundamentals:
conn.execute(
"INSERT OR REPLACE INTO stock_fundamentals (code, pe, pb, eps, mcap_total, mcap_flow, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(code, fundamentals.get("pe"), fundamentals.get("pb"),
fundamentals.get("eps"), fundamentals.get("mcap_total"),
fundamentals.get("mcap_flow"), datetime.now().isoformat()))
conn.commit()
return True
except Exception as e:
try:
conn.rollback()
except Exception:
pass
return False
# ═══════════════════════════════════════════════════════════
# 价格事件写入
# ═══════════════════════════════════════════════════════════
def write_price_event(conn: sqlite3.Connection, code: str, name: str,
event_type: str, price: float, trigger_value: str,
event_label: str = "") -> bool:
"""写入一条价格事件"""
try:
now = datetime.now()
conn.execute(
"INSERT INTO price_events (code, name, event_type, price, trigger_value, event_label, date) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(code, name, event_type, round(price, 2), trigger_value,
event_label, now.strftime("%Y-%m-%d")))
conn.commit()
return True
except Exception:
try:
conn.rollback()
except Exception:
pass
return False
# ═══════════════════════════════════════════════════════════
# 板块成分迁移
# ═══════════════════════════════════════════════════════════
def migrate_stock_sectors(conn: sqlite3.Connection) -> tuple[int, int]:
"""从 stock_sector_map.json 迁移到 stock_sectors 表
Returns: (migrated_stocks, total_mappings)
"""
sector_map_path = DATA_DIR / "stock_sector_map.json"
if not sector_map_path.exists():
return 0, 0
try:
with open(sector_map_path, encoding="utf-8") as f:
data = json.load(f)
except Exception:
return 0, 0
# 过滤元数据字段
mappings = [(code, sectors) for code, sectors in data.items()
if not code.startswith("_") and isinstance(sectors, list)]
total = 0
for code, sectors in mappings:
for sector in sectors:
try:
conn.execute(
"INSERT OR IGNORE INTO stock_sectors (code, sector_name, source) VALUES (?, ?, 'ths')",
(code, sector))
total += 1
except Exception:
pass
conn.commit()
return len(mappings), total
# ═══════════════════════════════════════════════════════════
# 查询辅助
# ═══════════════════════════════════════════════════════════
def query_sector_trend(conn: sqlite3.Connection, name: str, limit: int = 5) -> list[dict]:
"""板块最近N次趋势"""
rows = conn.execute("""
SELECT s.timestamp, ss.change_pct, ss.net_inflow,
ss.up_count, ss.down_count, ss.lead_stock, ss.lead_stock_change
FROM sector_snapshots ss
JOIN market_snapshots s ON ss.snapshot_id = s.id
WHERE ss.name = ? ORDER BY s.timestamp DESC LIMIT ?
""", (name, limit)).fetchall()
return [dict(r) for r in rows]
def query_top_inflow(conn: sqlite3.Connection, limit: int = 5) -> list[dict]:
"""最新一次资金净流入排行"""
rows = conn.execute("""
SELECT ss.name, ss.change_pct, ss.net_inflow, ss.lead_stock, s.timestamp
FROM sector_snapshots ss
JOIN market_snapshots s ON ss.snapshot_id = s.id
WHERE s.id = (SELECT MAX(id) FROM market_snapshots)
AND ss.net_inflow IS NOT NULL
ORDER BY ss.net_inflow DESC LIMIT ?
""", (limit,)).fetchall()
return [dict(r) for r in rows]
def query_consecutive_inflow(conn: sqlite3.Connection, days: int = 3) -> list[dict]:
"""连续N次净流入的板块"""
rows = conn.execute("""
SELECT name, COUNT(*) as times, ROUND(AVG(net_inflow), 2) as avg_inflow,
ROUND(AVG(change_pct), 2) as avg_change
FROM sector_snapshots ss
JOIN market_snapshots s ON ss.snapshot_id = s.id
WHERE s.id > (SELECT MAX(id) - ? FROM market_snapshots)
AND net_inflow > 0
GROUP BY name HAVING COUNT(*) >= ?
ORDER BY avg_inflow DESC
""", (days, days)).fetchall()
return [dict(r) for r in rows]
def query_market_mood(conn: sqlite3.Connection, limit: int = 10) -> list[dict]:
"""市场情绪趋势"""
rows = conn.execute("""
SELECT timestamp, source, up_ratio, mood
FROM market_snapshots ORDER BY timestamp DESC LIMIT ?
""", (limit,)).fetchall()
return [dict(r) for r in rows]
def query_db_stats(conn: sqlite3.Connection) -> dict:
"""数据库概览"""
snap_count = conn.execute("SELECT COUNT(*) FROM market_snapshots").fetchone()[0]
sector_count = conn.execute("SELECT COUNT(*) FROM sector_snapshots").fetchone()[0]
stock_count = conn.execute("SELECT COUNT(*) FROM stocks").fetchone()[0]
kline_count = conn.execute(
"SELECT COUNT(*) FROM stock_daily").fetchone()[0]
event_count = conn.execute("SELECT COUNT(*) FROM price_events").fetchone()[0]
latest = conn.execute(
"SELECT timestamp, source FROM market_snapshots ORDER BY id DESC LIMIT 1").fetchone()
return {
"snapshots": snap_count, "sector_rows": sector_count,
"stocks": stock_count, "daily_klines": kline_count,
"price_events": event_count,
"latest_snapshot": dict(latest) if latest else None,
}