Files
MoFin/strategy_tree.py
T

435 lines
16 KiB
Python

#!/usr/bin/env python3
"""
strategy_tree.py — 情景化多分支策略决策引擎
核心理念:
每只股票不再只有一个买入区+止损,而是有一棵决策树。
每个分支 = {条件, 动作, 优先级, 触发统计}
当前宏观情景决定走哪个分支。
自成长:
→ 每次分支被触发,记录 trigger_count + 后续5日盈亏
→ success_rate < 30% 且触发≥5次 → 自动标记 pruning_candidate
→ 每周 pruning 时剪掉低效分支
数据存在 decisions.json 的 strategy_tree 字段。
"""
import json, os, sys, re
from datetime import datetime, date, timedelta
from mo_data import read_portfolio, read_decisions, read_watchlist
DECISIONS_PATH = "/home/hmo/web-dashboard/data/decisions.json"
PORTFOLIO_PATH = "/home/hmo/web-dashboard/data/portfolio.json"
MACRO_PATH = "/home/hmo/web-dashboard/data/macro_context.json"
MARKET_PATH = "/home/hmo/web-dashboard/data/market.json"
TREND_PATH = "/home/hmo/web-dashboard/data/trend_signals.json"
# ── 情景定义 ──────────────────────────────────────────────────────────────
SCENARIOS = [
{
"id": "sharp_decline",
"label": "急跌防御",
"desc": "大盘放量下跌,多板块共振杀跌",
"rules": {"mood": "bearish", "sector_crash": True},
"portfolio_action": "减仓至80%以下,优先出弱势深套",
},
{
"id": "weak_consolidation",
"label": "弱势震荡",
"desc": "大盘缩量阴跌,结构分化",
"rules": {"mood": "neutral", "breadth": "weak"},
"portfolio_action": "保持仓位90%以内,调结构",
},
{
"id": "sector_rotation",
"label": "板块轮动",
"desc": "大盘窄幅,强势板块切换",
"rules": {"mood": "neutral", "rotation": True},
"portfolio_action": "跟随板块切换,减旧加新",
},
{
"id": "bullish_recovery",
"label": "反弹上行",
"desc": "大盘放量上涨,情绪回暖",
"rules": {"mood": "bullish"},
"portfolio_action": "加仓至95%,追随趋势",
},
]
# ── 情景判定 ──────────────────────────────────────────────────────────────
def detect_scenario():
"""从宏观+市场数据判断当前情景
返回:
{"id": str, "label": str, "confidence": float, "portfolio_action": str}
"""
scenario_id = "weak_consolidation" # 默认
confidence = 0.5
try:
# 优先 DB
import sqlite3
from pathlib import Path
db = sqlite3.connect(str(Path(__file__).parent / "data" / "mofin.db"))
mrow = db.execute(
"SELECT indices, structure, sector_mood FROM macro_context_log "
"WHERE has_valid_data=1 ORDER BY created_at DESC LIMIT 1"
).fetchone()
db.close()
if mrow:
structure = json.loads(mrow[1]) if mrow[1] else {}
overall = structure.get("overall", "").lower()
mood = (mrow[2] or "").lower() if len(mrow) > 2 else ""
else:
raise ValueError("no db data")
except Exception:
try:
macro = json.load(open(MACRO_PATH))
market = json.load(open(MARKET_PATH))
mood = market.get("mood", "").lower()
structure = macro.get("structure", {})
overall = structure.get("overall", "").lower()
except Exception:
return {"id": "weak_consolidation", "label": "默认-弱势震荡", "confidence": 0.3, "portfolio_action": "观望"}
trend_desc = structure.get("description", "").lower()
# Check for sharp decline
if "bearish" in mood or "bearish" in overall:
if "crash" in trend_desc or "跌幅" in trend_desc or "恐慌" in trend_desc:
scenario_id = "sharp_decline"
confidence = 0.7
elif "弱势" in trend_desc or "疲弱" in trend_desc:
scenario_id = "weak_consolidation"
confidence = 0.6
else:
scenario_id = "weak_consolidation"
confidence = 0.5
elif "bullish" in mood or "bullish" in overall:
scenario_id = "bullish_recovery"
confidence = 0.6
elif "neutral" in mood:
# Check for rotation signals
try:
trend = json.load(open(TREND_PATH))
if trend.get("rotation_detected"):
scenario_id = "sector_rotation"
confidence = 0.5
except Exception:
pass
scenario_id = "weak_consolidation"
confidence = 0.4
sc = next((s for s in SCENARIOS if s["id"] == scenario_id), SCENARIOS[0])
return {
"id": scenario_id,
"label": sc["label"],
"desc": sc["desc"],
"confidence": round(confidence, 2),
"portfolio_action": sc["portfolio_action"],
}
# ── 分支评估 ──────────────────────────────────────────────────────────────
def evaluate_branches(code, scenario_id, price, shares, cost):
"""评估某只股票在当前情景下的所有分支
从 decisions.json 读取 strategy_tree.branches[]
返回: [{branch_id, action_type, action_detail, priority, applicable}]
"""
try:
dec = mo_data.read_decisions()
except Exception:
return []
entry = None
for e in dec.get("decisions", []):
if e.get("code") == code:
entry = e
break
if not entry:
return []
branches = entry.get("strategy_tree", {}).get("branches", [])
if not branches:
return []
results = []
for br in sorted(branches, key=lambda b: b.get("priority", 999)):
applicable = _check_branch_condition(br, scenario_id, price, shares, cost)
results.append({
"branch_id": br.get("id"),
"action_type": br.get("action", {}).get("type", "hold"),
"action_detail": br.get("action", {}),
"priority": br.get("priority", 999),
"rationale": br.get("rationale", ""),
"applicable": applicable,
})
return results
def _check_branch_condition(branch, scenario_id, price, shares, cost):
"""检查分支条件是否满足"""
cond = branch.get("condition", {})
required_scenario = cond.get("scenario", "")
if required_scenario and required_scenario != scenario_id:
return False
# Price conditions
price_cond = cond.get("price", "")
if price_cond:
ops = re.findall(r'([<>=!]+)\s*([\d.]+)', price_cond)
for op, val_str in ops:
val = float(val_str)
op = op.strip()
if op == "<" and not (price < val):
return False
if op == ">" and not (price > val):
return False
if op == "<=" and not (price <= val):
return False
if op == ">=" and not (price >= val):
return False
if op == "==" and not (abs(price - val) < 0.01):
return False
# Price lower bound (separate field)
price_lower = cond.get("price_lower", "")
if price_lower:
ops = re.findall(r'([<>=!]+)\s*([\d.]+)', price_lower)
for op, val_str in ops:
val = float(val_str)
op = op.strip()
if op == "<" and not (price < val):
return False
if op == ">" and not (price > val):
return False
if op == "<=" and not (price <= val):
return False
if op == ">=" and not (price >= val):
return False
if op == "==" and not (abs(price - val) < 0.01):
return False
# Trend condition
trend = cond.get("trend", "")
if trend and trend == "uptrend":
pass # TODO: check multi_timeframe
# Loss condition
loss_pct = cond.get("loss_pct", "")
if loss_pct and cost > 0:
actual_loss = (price - cost) / cost * 100
if "<" in str(loss_pct):
limit = float(str(loss_pct).replace("<", "").replace("%", ""))
if not (actual_loss < limit):
return False
return True
# ── 分支触发记录 ──────────────────────────────────────────────────────────
def record_branch_trigger(code, branch_id):
"""记录分支被触发了一次,用于自成长统计"""
try:
dec = mo_data.read_decisions()
for e in dec.get("decisions", []):
if e.get("code") == code:
st = e.setdefault("strategy_tree", {})
for br in st.get("branches", []):
if br.get("id") == branch_id:
br["trigger_count"] = br.get("trigger_count", 0) + 1
br["last_triggered"] = datetime.now().isoformat()
break
break
json.dump(dec, open(DECISIONS_PATH, "w"), ensure_ascii=False, indent=2)
except Exception:
pass
# ── 分支剪枝(自成长核心)─────────────────────────────────────────────────
def prune_low_performance_branches(min_triggers=5, min_success_rate=0.3):
"""剪掉低成功率分支——自成长机制
条件:触发≥min_triggers 次 且 success_rate < min_success_rate
被剪的分支移入 history 字段,不打删除(可追溯)
"""
try:
dec = mo_data.read_decisions()
except Exception:
return []
pruned = []
for e in dec.get("decisions", []):
st = e.setdefault("strategy_tree", {})
branches = st.get("branches", [])
kept = []
for br in branches:
tc = br.get("trigger_count", 0)
sr = br.get("success_rate")
if sr is not None and tc >= min_triggers and sr < min_success_rate:
# 移入 history
history = st.setdefault("pruned_branches", [])
br["pruned_at"] = datetime.now().isoformat()
br["prune_reason"] = f"低成功率: {sr:.0%} (触发{tc}次)"
history.append(br)
pruned.append(f'{e.get("code")}:{br.get("id")} ({sr:.0%} < {min_success_rate:.0%})')
else:
kept.append(br)
st["branches"] = kept
if pruned:
json.dump(dec, open(DECISIONS_PATH, "w"), ensure_ascii=False, indent=2)
return pruned
# ── 初始化策略树(为一只票创建默认分支)─────────────────────────────────────
def init_default_branches(code, name, entry_low, entry_high, stop_loss, take_profit):
"""为 stock 创建默认多分支策略——由 per_stock_reassess 调用"""
base_price = (entry_low + entry_high) / 2 if entry_low and entry_high else 0
branches = []
# 分支0:止损(始终有效)
if stop_loss:
branches.append({
"id": f"{code}_stop_loss",
"condition": {"price": f"<{stop_loss}"},
"action": {"type": "sell", "amount": "all", "reason": "止损"},
"priority": 0,
"rationale": "止损保护本金",
"trigger_count": 0,
"success_rate": None,
"last_triggered": None,
})
# 分支1:回调买入(弱势情景适用)
if entry_low:
branches.append({
"id": f"{code}_buy_dip",
"condition": {"scenario": "weak_consolidation", "price": f"<={entry_high}", "price_lower": f">={entry_low}"},
"action": {"type": "buy", "amount": "normal", "limit": entry_low, "reason": "回调支撑买入"},
"priority": 1,
"rationale": "价格回调到支撑区,弱势市场低吸",
"trigger_count": 0,
"success_rate": None,
"last_triggered": None,
})
# 分支2:突破追涨(强势情景适用)
if take_profit:
branches.append({
"id": f"{code}_breakout_chase",
"condition": {"scenario": "bullish_recovery", "price": f">={take_profit}"},
"action": {"type": "buy", "amount": "normal", "limit": "market", "reason": "突破确认追涨"},
"priority": 2,
"rationale": "价格突破阻力,确认上升趋势后买入",
"trigger_count": 0,
"success_rate": None,
"last_triggered": None,
})
# 分支3:减仓(急跌情景适用)
branches.append({
"id": f"{code}_trim",
"condition": {"scenario": "sharp_decline", "loss_pct": "<-15%"},
"action": {"type": "sell", "amount": "half", "reason": "急跌降风险"},
"priority": 3,
"rationale": "急跌市场,深套股减半仓减少敞口",
"trigger_count": 0,
"success_rate": None,
"last_triggered": None,
})
# 分支4:止盈(浮盈较大)
if take_profit and entry_low:
branches.append({
"id": f"{code}_take_profit",
"condition": {"price": f">={take_profit}"},
"action": {"type": "sell", "amount": "half", "reason": "止盈锁利"},
"priority": 4,
"rationale": "达到目标价,减半仓锁定利润",
"trigger_count": 0,
"success_rate": None,
"last_triggered": None,
})
# 分支5:持有(默认)
branches.append({
"id": f"{code}_hold",
"condition": {},
"action": {"type": "hold", "reason": "无明确信号,继续持有"},
"priority": 99,
"rationale": "没有分支匹配时的默认动作",
"trigger_count": 0,
"success_rate": None,
"last_triggered": None,
})
return branches
# ── 组合约束检查 ──────────────────────────────────────────────────────────
def check_portfolio_constraint(action_type, amount, cash_remain=None):
"""组合约束检查:现金够不够?仓位上限?"""
try:
pf = mo_data.read_portfolio()
except Exception:
return True, "无法读取组合"
if action_type == "buy":
# 估算买入金额
cost_est = amount if amount else 100000 # default 10万
if cash_remain is not None:
cost_est = cash_remain
if cost_est > pf.get("cash", 0):
return False, f"现金不足: 需要~{cost_est:.0f},可用{pf['cash']:.0f}"
return True, "OK"
# ── CLI 入口 ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="多分支策略决策引擎")
parser.add_argument("--detect", action="store_true", help="检测当前情景")
parser.add_argument("--evaluate", type=str, help="评估指定股票的分支")
parser.add_argument("--prune", action="store_true", help="剪枝低效分支")
args = parser.parse_args()
if args.detect:
sc = detect_scenario()
print(f"情景: {sc['id']} ({sc['label']})")
print(f"置信度: {sc['confidence']}")
print(f"组合动作: {sc['portfolio_action']}")
if args.evaluate:
code = args.evaluate
sc = detect_scenario()
print(f"当前情景: {sc['id']} ({sc['label']})")
print(f"评估 {code}:")
results = evaluate_branches(code, sc["id"], 0, 0, 0)
for r in results:
status = "" if r["applicable"] else " "
print(f" {status} [{r['priority']}] {r['branch_id']}{r['action_type']}: {r['rationale']}")
if args.prune:
pruned = prune_low_performance_branches()
if pruned:
print(f"已剪枝: {len(pruned)}")
for p in pruned:
print(f" - {p}")
else:
print("无需要剪枝的分支")