#!/usr/bin/env python3 """mofin_db.py — MoFin 统一数据库访问层 所有脚本通过此模块访问 mofin.db,避免重复建表/连接逻辑。 用法: from mofin_db import get_conn, write_market_snapshot, write_klines, ... 设计原则: - 幂等建表(CREATE TABLE IF NOT EXISTS) - WAL 模式 + 外键约束 - 所有写操作返回 (success: bool, detail: str) - JSON 写入由调用方负责,本模块只写 SQLite """ import sqlite3 import json from datetime import datetime from pathlib import Path from typing import Optional DATA_DIR = Path(__file__).parent / "data" DB_PATH = DATA_DIR / "mofin.db" # ═══════════════════════════════════════════════════════════ # 连接管理 # ═══════════════════════════════════════════════════════════ def get_conn() -> sqlite3.Connection: """获取数据库连接(WAL 模式,外键约束,Row 工厂)""" DATA_DIR.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") return conn # ═══════════════════════════════════════════════════════════ # 建表(幂等) # ═══════════════════════════════════════════════════════════ def init_all_tables(conn: sqlite3.Connection): """创建全部表(幂等,已存在则跳过)""" conn.executescript(""" -- 市场快照 CREATE TABLE IF NOT EXISTS market_snapshots ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL, source TEXT NOT NULL DEFAULT 'ths', up_ratio REAL, mood TEXT, created_at TEXT DEFAULT (datetime('now','localtime')) ); CREATE INDEX IF NOT EXISTS idx_snapshots_time ON market_snapshots(timestamp); -- 板块快照 CREATE TABLE IF NOT EXISTS sector_snapshots ( id INTEGER PRIMARY KEY AUTOINCREMENT, snapshot_id INTEGER NOT NULL REFERENCES market_snapshots(id), name TEXT NOT NULL, change_pct REAL, up_count INTEGER, down_count INTEGER, net_inflow REAL, lead_stock TEXT, lead_stock_change REAL, volume REAL, turnover REAL ); CREATE INDEX IF NOT EXISTS idx_sector_name ON sector_snapshots(name); CREATE INDEX IF NOT EXISTS idx_sector_snapshot ON sector_snapshots(snapshot_id); CREATE INDEX IF NOT EXISTS idx_sector_name_time ON sector_snapshots(name, snapshot_id); -- 个股 CREATE TABLE IF NOT EXISTS stocks ( code TEXT PRIMARY KEY, name TEXT NOT NULL, exchange TEXT DEFAULT 'SH', type TEXT DEFAULT 'A', updated_at TEXT ); -- K线(日/周/月) CREATE TABLE IF NOT EXISTS stock_daily ( code TEXT NOT NULL REFERENCES stocks(code), date TEXT NOT NULL, open REAL, close REAL, high REAL, low REAL, volume REAL, amount REAL, PRIMARY KEY (code, date) ); CREATE TABLE IF NOT EXISTS stock_weekly ( code TEXT NOT NULL REFERENCES stocks(code), date TEXT NOT NULL, open REAL, close REAL, high REAL, low REAL, volume REAL, PRIMARY KEY (code, date) ); CREATE TABLE IF NOT EXISTS stock_monthly ( code TEXT NOT NULL REFERENCES stocks(code), date TEXT NOT NULL, open REAL, close REAL, high REAL, low REAL, volume REAL, PRIMARY KEY (code, date) ); -- 基本面 CREATE TABLE IF NOT EXISTS stock_fundamentals ( code TEXT PRIMARY KEY REFERENCES stocks(code), pe REAL, pb REAL, eps REAL, mcap_total REAL, mcap_flow REAL, updated_at TEXT ); -- 板块成分映射 CREATE TABLE IF NOT EXISTS stock_sectors ( code TEXT NOT NULL REFERENCES stocks(code), sector_name TEXT NOT NULL, source TEXT DEFAULT 'ths', updated_at TEXT DEFAULT (datetime('now','localtime')), PRIMARY KEY (code, sector_name) ); CREATE INDEX IF NOT EXISTS idx_stock_sector ON stock_sectors(sector_name); -- 持仓 CREATE TABLE IF NOT EXISTS holdings ( code TEXT PRIMARY KEY REFERENCES stocks(code), name TEXT NOT NULL, shares INTEGER NOT NULL, cost REAL, position_pct REAL, added_at TEXT, is_active INTEGER DEFAULT 1, closed_at TEXT, close_pnl REAL ); -- 持仓策略 CREATE TABLE IF NOT EXISTS holding_strategies ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT NOT NULL REFERENCES holdings(code), version INTEGER DEFAULT 1, stop_loss REAL, take_profit REAL, entry_low REAL, entry_high REAL, strategy_type TEXT DEFAULT 'holding', source TEXT, reason TEXT, created_at TEXT DEFAULT (datetime('now','localtime')), superseded_at TEXT ); CREATE INDEX IF NOT EXISTS idx_strategy_code ON holding_strategies(code); -- 自选股 CREATE TABLE IF NOT EXISTS watchlist_stocks ( code TEXT PRIMARY KEY REFERENCES stocks(code), name TEXT NOT NULL, added_at TEXT DEFAULT (datetime('now','localtime')), is_active INTEGER DEFAULT 1 ); -- 候选池 CREATE TABLE IF NOT EXISTS candidates ( code TEXT PRIMARY KEY REFERENCES stocks(code), name TEXT NOT NULL, sector TEXT, reason TEXT, entry_range TEXT, stop_loss REAL, target REAL, zhiwei_star REAL, zhiwei_reviewed INTEGER DEFAULT 0, zhiwei_reviewed_at TEXT, promoted INTEGER DEFAULT 0, promoted_at TEXT, dropped INTEGER DEFAULT 0, drop_reason TEXT, created_at TEXT DEFAULT (datetime('now','localtime')) ); -- 候选评分历史 CREATE TABLE IF NOT EXISTS candidate_score_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT NOT NULL REFERENCES candidates(code), score REAL NOT NULL, source TEXT NOT NULL, reason TEXT, created_at TEXT DEFAULT (datetime('now','localtime')) ); CREATE INDEX IF NOT EXISTS idx_candidate_history ON candidate_score_history(code, created_at); -- 价格事件 CREATE TABLE IF NOT EXISTS price_events ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT NOT NULL REFERENCES stocks(code), name TEXT, event_type TEXT NOT NULL, price REAL, trigger_value TEXT, event_label TEXT, created_at TEXT DEFAULT (datetime('now','localtime')), date TEXT ); CREATE INDEX IF NOT EXISTS idx_events_code ON price_events(code); CREATE INDEX IF NOT EXISTS idx_events_date ON price_events(date); -- 策略评估记录 CREATE TABLE IF NOT EXISTS strategy_evaluations ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT NOT NULL REFERENCES stocks(code), eval_type TEXT NOT NULL, status TEXT DEFAULT 'pending', old_stop_loss REAL, new_stop_loss REAL, old_tp REAL, new_tp REAL, reason TEXT, created_at TEXT DEFAULT (datetime('now','localtime')) ); """) conn.commit() # ═══════════════════════════════════════════════════════════ # 市场快照写入 # ═══════════════════════════════════════════════════════════ def write_market_snapshot(conn: sqlite3.Connection, market_data: dict) -> tuple[bool, str, Optional[int]]: """写入一次市场采集到 market_snapshots + sector_snapshots Returns: (ok, message, snapshot_id) """ try: cur = conn.execute( "INSERT INTO market_snapshots (timestamp, source, up_ratio, mood) VALUES (?, ?, ?, ?)", (market_data["timestamp"], market_data.get("source", "unknown"), market_data.get("up_ratio", 0), market_data.get("mood", "unknown")), ) sid = cur.lastrowid sectors = market_data.get("sectors", []) rows = [(sid, s.get("name", ""), s.get("change", 0), s.get("up_count"), s.get("down_count"), s.get("net_inflow"), s.get("lead_stock"), s.get("lead_stock_change"), s.get("volume"), s.get("turnover")) for s in sectors] if rows: conn.executemany( "INSERT INTO sector_snapshots (snapshot_id, name, change_pct, up_count, down_count, " "net_inflow, lead_stock, lead_stock_change, volume, turnover) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", rows) conn.commit() return True, f"snapshot_id={sid}, sectors={len(rows)}", sid except Exception as e: try: conn.rollback() except Exception: pass return False, str(e), None # ═══════════════════════════════════════════════════════════ # K线写入 # ═══════════════════════════════════════════════════════════ def write_klines(conn: sqlite3.Connection, code: str, name: str, daily: list = None, weekly: list = None, monthly: list = None, fundamentals: dict = None) -> bool: """将个股K线数据双写 SQLite Args: code: 股票代码 name: 股票名称 daily/weekly/monthly: [{date, open, close, high, low, volume}, ...] fundamentals: {pe, pb, eps, mcap_total, mcap_flow} """ try: # 判断交易所 raw = str(code) if len(raw) == 5 and raw.isdigit(): exchange, stype = "HK", "H" elif raw.startswith(("6", "5", "9")): exchange, stype = "SH", "A" else: exchange, stype = "SZ", "A" # stocks 表(INSERT OR REPLACE) conn.execute( "INSERT OR REPLACE INTO stocks (code, name, exchange, type, updated_at) VALUES (?, ?, ?, ?, ?)", (code, name, exchange, stype, datetime.now().isoformat())) # K线数据 for period, table, data in [ ("daily", "stock_daily", daily), ("weekly", "stock_weekly", weekly), ("monthly", "stock_monthly", monthly), ]: if not data: continue rows = [(code, d.get("date", ""), d.get("open"), d.get("close"), d.get("high"), d.get("low"), d.get("volume"), d.get("amount") if period == "daily" else None) for d in data] if period == "daily": conn.executemany( f"INSERT OR REPLACE INTO {table} (code, date, open, close, high, low, volume, amount) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", rows) else: conn.executemany( f"INSERT OR REPLACE INTO {table} (code, date, open, close, high, low, volume) " "VALUES (?, ?, ?, ?, ?, ?, ?)", [(r[0], r[1], r[2], r[3], r[4], r[5], r[6]) for r in rows]) # 基本面 if fundamentals: conn.execute( "INSERT OR REPLACE INTO stock_fundamentals (code, pe, pb, eps, mcap_total, mcap_flow, updated_at) " "VALUES (?, ?, ?, ?, ?, ?, ?)", (code, fundamentals.get("pe"), fundamentals.get("pb"), fundamentals.get("eps"), fundamentals.get("mcap_total"), fundamentals.get("mcap_flow"), datetime.now().isoformat())) conn.commit() return True except Exception as e: try: conn.rollback() except Exception: pass return False # ═══════════════════════════════════════════════════════════ # 价格事件写入 # ═══════════════════════════════════════════════════════════ def write_price_event(conn: sqlite3.Connection, code: str, name: str, event_type: str, price: float, trigger_value: str, event_label: str = "") -> bool: """写入一条价格事件""" try: now = datetime.now() conn.execute( "INSERT INTO price_events (code, name, event_type, price, trigger_value, event_label, date) " "VALUES (?, ?, ?, ?, ?, ?, ?)", (code, name, event_type, round(price, 2), trigger_value, event_label, now.strftime("%Y-%m-%d"))) conn.commit() return True except Exception: try: conn.rollback() except Exception: pass return False # ═══════════════════════════════════════════════════════════ # 板块成分迁移 # ═══════════════════════════════════════════════════════════ def migrate_stock_sectors(conn: sqlite3.Connection) -> tuple[int, int]: """从 stock_sector_map.json 迁移到 stock_sectors 表 Returns: (migrated_stocks, total_mappings) """ sector_map_path = DATA_DIR / "stock_sector_map.json" if not sector_map_path.exists(): return 0, 0 try: with open(sector_map_path, encoding="utf-8") as f: data = json.load(f) except Exception: return 0, 0 # 过滤元数据字段 mappings = [(code, sectors) for code, sectors in data.items() if not code.startswith("_") and isinstance(sectors, list)] total = 0 for code, sectors in mappings: for sector in sectors: try: conn.execute( "INSERT OR IGNORE INTO stock_sectors (code, sector_name, source) VALUES (?, ?, 'ths')", (code, sector)) total += 1 except Exception: pass conn.commit() return len(mappings), total # ═══════════════════════════════════════════════════════════ # 查询辅助 # ═══════════════════════════════════════════════════════════ def query_sector_trend(conn: sqlite3.Connection, name: str, limit: int = 5) -> list[dict]: """板块最近N次趋势""" rows = conn.execute(""" SELECT s.timestamp, ss.change_pct, ss.net_inflow, ss.up_count, ss.down_count, ss.lead_stock, ss.lead_stock_change FROM sector_snapshots ss JOIN market_snapshots s ON ss.snapshot_id = s.id WHERE ss.name = ? ORDER BY s.timestamp DESC LIMIT ? """, (name, limit)).fetchall() return [dict(r) for r in rows] def query_top_inflow(conn: sqlite3.Connection, limit: int = 5) -> list[dict]: """最新一次资金净流入排行""" rows = conn.execute(""" SELECT ss.name, ss.change_pct, ss.net_inflow, ss.lead_stock, s.timestamp FROM sector_snapshots ss JOIN market_snapshots s ON ss.snapshot_id = s.id WHERE s.id = (SELECT MAX(id) FROM market_snapshots) AND ss.net_inflow IS NOT NULL ORDER BY ss.net_inflow DESC LIMIT ? """, (limit,)).fetchall() return [dict(r) for r in rows] def query_consecutive_inflow(conn: sqlite3.Connection, days: int = 3) -> list[dict]: """连续N次净流入的板块""" rows = conn.execute(""" SELECT name, COUNT(*) as times, ROUND(AVG(net_inflow), 2) as avg_inflow, ROUND(AVG(change_pct), 2) as avg_change FROM sector_snapshots ss JOIN market_snapshots s ON ss.snapshot_id = s.id WHERE s.id > (SELECT MAX(id) - ? FROM market_snapshots) AND net_inflow > 0 GROUP BY name HAVING COUNT(*) >= ? ORDER BY avg_inflow DESC """, (days, days)).fetchall() return [dict(r) for r in rows] def query_market_mood(conn: sqlite3.Connection, limit: int = 10) -> list[dict]: """市场情绪趋势""" rows = conn.execute(""" SELECT timestamp, source, up_ratio, mood FROM market_snapshots ORDER BY timestamp DESC LIMIT ? """, (limit,)).fetchall() return [dict(r) for r in rows] def query_db_stats(conn: sqlite3.Connection) -> dict: """数据库概览""" snap_count = conn.execute("SELECT COUNT(*) FROM market_snapshots").fetchone()[0] sector_count = conn.execute("SELECT COUNT(*) FROM sector_snapshots").fetchone()[0] stock_count = conn.execute("SELECT COUNT(*) FROM stocks").fetchone()[0] kline_count = conn.execute( "SELECT COUNT(*) FROM stock_daily").fetchone()[0] event_count = conn.execute("SELECT COUNT(*) FROM price_events").fetchone()[0] latest = conn.execute( "SELECT timestamp, source FROM market_snapshots ORDER BY id DESC LIMIT 1").fetchone() return { "snapshots": snap_count, "sector_rows": sector_count, "stocks": stock_count, "daily_klines": kline_count, "price_events": event_count, "latest_snapshot": dict(latest) if latest else None, }