Files
MoFin/scripts/branch_evaluator.py
T

157 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
branch_evaluator.py — 分支自成长引擎
每30分钟评估所有策略树的当前适用性:
1. 读取 decisions.json 中所有 strategy_tree.branches
2. 获取当前宏观情景(detect_scenario
3. 对每只股票获取实时价,评估哪些分支条件命中
4. 命中的分支 → trigger_count+1, last_triggered=now
5. 后续跟进:成功/失败取决于该分支被选中后5日盈亏(由price_monitor回填success_rate
6. 触发≥3次且成功率<30% → 标记 pruning_candidate
7. 写回 decisions.json
设计为 no_agent cron 脚本:非空输出→推送到XMPP,空输出→静默
"""
import json, sys, os, re
from datetime import datetime, date
from mo_data import read_portfolio, read_decisions
from mofin_db import get_conn, write_holding_strategy
# 路径
DECISIONS_PATH = "/home/hmo/web-dashboard/data/decisions.json"
PORTFOLIO_PATH = "/home/hmo/web-dashboard/data/portfolio.json"
# 引入 strategy_tree 模块
sys.path.insert(0, "/home/hmo/MoFin")
try:
import strategy_tree as st
except ImportError:
# 如果 MoFin 路径下找不到,尝试直接 exec
import importlib.util
spec = importlib.util.spec_from_file_location("st", "/home/hmo/MoFin/strategy_tree.py")
st = importlib.util.module_from_spec(spec)
spec.loader.exec_module(st)
def get_live_prices():
"""从 portfolio.json 读取实时价格"""
prices = {}
try:
pf = read_portfolio()
for h in pf.get("holdings", []):
code = str(h.get("code", ""))
prices[code] = h.get("price", 0)
except Exception:
pass
return prices
def evaluate_all():
"""评估所有已触发策略树的分支"""
try:
data = read_decisions()
except Exception as e:
print(f"[错误] 读 decisions.json 失败: {e}", file=sys.stderr)
return
# 当前情景
scenario = st.detect_scenario()
scenario_id = scenario.get("id", "")
scenario_label = scenario.get("label", "未知")
prices = get_live_prices()
decisions = data.get("decisions", [])
total_triggered = 0
auto_init_count = 0
pruning_flags = []
for entry in decisions:
code = entry.get("code", "")
tree = entry.get("strategy_tree")
if not tree:
# 自初始化:无决策树的股票自动生成默认分支
try:
branches = st.init_default_branches(
code=code,
name=entry.get("name", ""),
entry_low=entry.get("entry_low", 0),
entry_high=entry.get("entry_high", 0),
stop_loss=entry.get("stop_loss", 0),
take_profit=entry.get("take_profit", 0),
)
tree = {"branches": branches, "initialized_at": datetime.now().isoformat()}
entry["strategy_tree"] = tree
auto_init_count += 1
except Exception:
continue
branches = tree.get("branches", [])
if not branches:
continue
price = prices.get(code, 0) or entry.get("price", 0)
shares = entry.get("shares", 0)
cost = entry.get("cost", 0)
# 评估所有分支
results = st.evaluate_branches(code, scenario_id, price, shares, cost)
now_ts = datetime.now().isoformat()
updated = False
for result in results:
br_id = result.get("branch_id", "")
# 找到对应分支更新trigger_count
for br in branches:
if br.get("id") == br_id:
if result.get("applicable"):
# 分支命中 → 增加触发计数
br["trigger_count"] = br.get("trigger_count", 0) + 1
br["last_triggered"] = now_ts
total_triggered += 1
updated = True
# 检查是否需要标记剪枝候补
tc = br["trigger_count"]
sr = br.get("success_rate")
if tc >= 3 and sr is not None and sr < 30:
br["pruning_candidate"] = True
pruning_flags.append(f"{code}/{br_id}(触发{tc}次/成功率{sr}%)")
break
if updated:
# 回写 strategy_tree
entry["strategy_tree"] = tree
# 标记评估时间
tree["last_evaluated"] = now_ts
# 写回 — DB 优先
try:
conn = get_conn()
for d in data.get("decisions", []):
write_holding_strategy(conn, d.get("code", ""), d.get("name", ""), d)
conn.close()
except Exception:
pass
# [migrated to DB] — cold backup removed
# with open(DECISIONS_PATH, "w") as f:
# json.dump(data, f, indent=2, ensure_ascii=False)
# 输出摘要(空 = 静默)
lines = []
init_note = f" | 自动初始化{auto_init_count}" if auto_init_count else ""
lines.append(f"【分支评估】情景{scenario_label}({scenario_id}) | 命中{total_triggered}{init_note}")
if pruning_flags:
lines.append(f"需剪枝{len(pruning_flags)}个分支:")
for f in pruning_flags:
lines.append(f"{f}")
else:
lines.append("无需剪枝的分支")
out = "\n".join(lines)
print(out)
return out
if __name__ == "__main__":
evaluate_all()