diff --git a/README.md b/README.md index f043915..66d6017 100644 --- a/README.md +++ b/README.md @@ -55,17 +55,27 @@ systemctl start xmpp-{bot} ``` AgentsMeeting/ -├── README.md # 本文件 -├── gateway/ # 实际运行的代码(脚本+工具+日志) +├── xmpp_agent_core.py # 统一 Bot 核心(所有 Agent 共享) +│ 用法: python xmpp_agent_core.py --agent xxm|mohe|zhiwei|xiaoguo +│ 功能: PID锁/重连/MUC/dedup/batching/coordinator协议/HTTP桥 +├── xxm_bot.py # 兼容入口 → xmpp_agent_core.py --agent xxm +├── mohe_bot.py # 兼容入口 → xmpp_agent_core.py --agent mohe +├── zhiwei_bot.py # 兼容入口 → xmpp_agent_core.py --agent zhiwei +├── xiaoguo_bot.py # 兼容入口 → xmpp_agent_core.py --agent xiaoguo +├── gateway/ # 运行时组件(脚本+工具+日志) │ ├── README.md # gateway 自身说明 │ ├── scripts/ # 运行时脚本 -│ │ ├── xmpp_bot.py # XMPP Bot (HTTP桥 :5802) +│ │ ├── chat_bridge.py # xxm LLM 桥接(SessionBridge) +│ │ ├── session_router.py # 消息路由 │ │ ├── wechat_agent.py # 微信桥接代理 │ │ ├── api_proxy.py # API 反向代理 (:8787) │ │ ├── xmpp_watchdog.py # 进程看门狗 │ │ ├── health_check_xxm.py# 消息流健康检查 │ │ ├── mohe_watcher.py # 莫荷消息监控 │ │ ├── dashboard.py # 管理门户后端 +│ │ ├── proc_guard.py # PID 锁(防重复启动) +│ │ ├── qq_bot.py # QQ bot 桥接 +│ │ ├── vc_webhook.py # VoceChat webhook │ │ └── templates/ │ │ └── dashboard.html # 管理门户前端 │ ├── logs/ # 运行时日志 @@ -73,37 +83,37 @@ AgentsMeeting/ ├── docs/ │ ├── ARCHITECTURE.md # 架构设计文档 │ ├── AUDIT.md # 稳定性审计报告 +│ ├── CLEANUP_PLAN.md # 代码清理方案 │ ├── OBSERVER-PROTOCOL.md # 群聊观察者模式协议 │ ├── SESSION-BRIDGE-PROTOCOL.md # 跨Session消息桥接协议 │ ├── PRD_v0.2.md # 产品需求文档 │ ├── DEPLOY.md # 部署指南 │ └── OPS.md # 运维手册 ├── config/ -│ ├── .env.example # 环境变量模板 -│ └── profiles/ # 各 Agent 配置文件 -│ ├── xxm/ -│ ├── mohe/ -│ ├── xiaoguo/ -│ └── zhiwei/ -├── src/ # 重构版本(逐步迁移中) -│ ├── shared/ # 共享库 -│ │ ├── config.py # 集中配置管理 -│ │ └── bot_base.py # Bot 基类 -│ ├── bots/ # Bot 实现 -│ ├── channels/ # 通道桥接 -│ │ ├── wechat/ # 微信桥接 -│ │ └── qq/ # QQ 桥接(规划中) -│ └── ops/ # 运维工具 -│ ├── watchdog.py -│ └── health_check.py +│ ├── agents.yaml # Agent 实例注册 +│ └── .env.example # 环境变量模板 +├── configs/ # 各 Agent 配置/SOUL +│ ├── main/ +│ ├── mohe/ +│ ├── xiaoguo/ +│ └── position-analyst/ +├── src/ +│ ├── shared/ +│ │ ├── config.py # 集中配置管理(env-var 方式) +│ │ └── bot_base.py # Bot 基类(功能已合并进 xmpp_agent_core.py) +│ ├── channels/ +│ │ └── qq/ +│ │ └── bridge.py # QQ 通道骨架 +│ └── ops/ +│ └── watch_group.py # 群聊监控 ├── deploy/ │ ├── windows/ │ │ ├── start.ps1 # 一键启动 │ │ ├── stop.ps1 # 一键停止 │ │ └── check.ps1 # 状态检查 │ └── linux/ -│ ├── install.sh # 安装脚本(待 mohe) -│ └── hermes-*.service # systemd 模板(待 mohe) +│ ├── install.sh # 安装脚本 +│ └── hermes-*.service # systemd 模板 └── tests/ # 测试套件 ``` @@ -114,8 +124,8 @@ AgentsMeeting/ | 组件 | 方式 | 频率 | |------|------|------| | 管理门户 | dashboard.py + Web UI (:5803) | 实时轮询 5s | -| xmpp_bot 进程 | watchdog (xmpp_watchdog.py) | 30s | -| xmpp_bot 消息流 | health_check_xxm.py | 5min (scheduled task) | +| xmpp Agent Bot | watchdog (xmpp_watchdog.py) | 30s | +| xmpp Agent 消息流 | health_check_xxm.py | 5min (scheduled task) | | wechat_agent | 内置看门狗 | 120s | | 日志 | `logs/health_check.log` | 人工查看 | @@ -138,7 +148,7 @@ AgentsMeeting/ | ID | 问题 | 平台 | 状态 | |----|------|------|------| -| R01 | MUC join 超时 (conference.yoin.fun DNS) | Linux (ejabberd) | 🔴 | +| R01 | MUC join 超时 (conference.yoin.fun DNS) | Linux (ejabberd) | 🟢 已修复(join_muc + raw presence 双保险) | | R02 | wechat_agent 无系统级自动恢复 | Windows | 🟡 | | R03 | Gateway 进程无 systemd auto-restart | Linux | 🔴 | | R04 | 日志无系统级轮转 | Windows + Linux | 🟡 | @@ -149,7 +159,7 @@ AgentsMeeting/ ## 开发流程 1. **架构设计** → `docs/ARCHITECTURE.md` -2. **代码工程化** → `src/` 按特征优先组织 +2. **核心修改** → `xmpp_agent_core.py`(统一 Bot 核心,所有 Agent 共享) 3. **部署脚本** → `deploy/` 一键启停 4. **测试** → `tests/` 自动化测试 5. **文档** → `docs/` 持续更新 diff --git a/bots/xmpp_bot.py b/bots/xmpp_bot.py deleted file mode 100644 index eb76f7c..0000000 --- a/bots/xmpp_bot.py +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env python3 -"""XMPP Bot mohe@yoin.fun - 稳定重连版""" -import asyncio, logging, ssl, json, urllib.request, os, time, re -from slixmpp import ClientXMPP - -logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') -GATEWAY = "http://localhost:8642/v1/chat/completions" -API_KEY = "hermes123" -_opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) - -class MoheBot(ClientXMPP): - def __init__(self): - super().__init__('mohe@yoin.fun', 'hermes123') - self.add_event_handler('session_bind', self.on_bind) - self.add_event_handler('message', self.on_msg) - self.add_event_handler('disconnected', self.on_disconnect) - self.add_event_handler('connected', self.on_connected) - ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - self.ssl_context = ctx - self.ready = asyncio.Event() - self._call_seq = 0 - self._muc_joined = False - - async def on_connected(self, event): - logging.info("🔗 TCP连接已建立") - - async def on_bind(self, event): - self.send_presence() - self.get_roster() - # 加入内核组(每次重连后重新加入) - self.plugin['xep_0045'].join_muc('coregroup@conference.yoin.fun', 'mohe') - self._muc_joined = True - self.ready.set() - logging.info("✅ 莫荷 XMPP 上线") - - async def on_disconnect(self, event): - self.ready.clear() - self._muc_joined = False - logging.warning("⚠️ XMPP 断线") - - async def on_msg(self, msg): - body = msg['body'] - sender = str(msg['from']) - msg_type = msg['type'] - if not body: - return - if msg_type == 'groupchat': - if 'mohe@yoin.fun' in sender: - return - nickname = sender.split('/')[-1] if '/' in sender else '' - if nickname in ('hmo', 'xxm'): - logging.info(f"📩 群消息 [{sender}]: {body[:100]}") - room = sender.split('/')[0] - ctx_body = f"[核心群 {room}] {nickname} 说: {body}" - await self.call_hermes(ctx_body, room, is_group=True) - return - if msg_type == 'chat' and 'hmo@yoin.fun' in sender: - self._call_seq += 1 - logging.info(f"📩 老爸(#{self._call_seq}): {body}") - await self.call_hermes(body, sender, seq=self._call_seq) - - async def call_hermes(self, content, sender, is_group=False, seq=None): - msg_type = 'groupchat' if is_group else 'chat' - try: - payload = json.dumps({ - "model": "hermes-agent", - "messages": [{"role": "user", "content": content}] - }).encode() - req = urllib.request.Request(GATEWAY, data=payload, method="POST") - req.add_header("Content-Type", "application/json") - req.add_header("Authorization", f"Bearer {API_KEY}") - req.add_header("X-Hermes-Session-Id", "xmpp-mohe-v2") - - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, lambda: _opener.open(req, timeout=600)) - - if seq is not None and seq < self._call_seq: - return - - data = json.loads(result.read()) - reply = data.get("choices", [{}])[0].get("message", {}).get("content", "") - # 处理 __SILENT__ 和 __REPLY__ 标记 - if reply.strip().startswith('__SILENT__'): - logging.info("⏭️ 决定沉默,不发送") - return - reply = re.sub(r'^__REPLY__\s*', '', reply) - finish = data.get("choices", [{}])[0].get("finish_reason", "") - - if reply.strip() and finish != "silent": - if msg_type == 'groupchat': - self.send_message(mto=sender, mbody=reply, mtype='groupchat') - else: - import subprocess as sp - from xml.sax.saxutils import escape - safe = escape(reply) - sp.run([ - "docker", "exec", "ejabberd", "ejabberdctl", "send_stanza", - "mohe@yoin.fun", str(sender), - f"{safe}" - ], capture_output=True, timeout=10) - logging.info(f"✅ 回复: {reply[:80]}") - except Exception as e: - logging.error(f"❌ 错误: {e}") - -async def main(): - retry_delay = 1 # 初始重试间隔(秒) - max_delay = 60 # 最大重试间隔 - while True: - try: - bot = MoheBot() - bot.register_plugin('xep_0030') # Service Discovery - bot.register_plugin('xep_0045') # MUC - bot.register_plugin('xep_0199') # XMPP Ping(保活) - - bot.connect(host='127.0.0.1', port=5222) - await asyncio.wait_for(bot.ready.wait(), timeout=30) - logging.info("莫荷 XMPP 就绪") - retry_delay = 1 # 连接成功后重置重试间隔 - - # 保持运行,断线时自动重连 - while True: - await asyncio.sleep(15) - if not bot.is_connected(): - logging.warning("检测到断线,准备重连...") - break - - except asyncio.TimeoutError: - logging.warning("连接超时,准备重连...") - except Exception as e: - logging.error(f"❌ 主循环错误: {e}") - - # 指数退避重连:1s → 2s → 4s → 8s → ... → 60s max - logging.info(f"⏳ 等待 {retry_delay} 秒后重连...") - await asyncio.sleep(retry_delay) - retry_delay = min(retry_delay * 2, max_delay) - -if __name__ == '__main__': - try: - asyncio.run(main()) - except KeyboardInterrupt: - pass diff --git a/bots/xmpp_xiaoguo_bot.py b/bots/xmpp_xiaoguo_bot.py deleted file mode 100644 index 8ff06de..0000000 --- a/bots/xmpp_xiaoguo_bot.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -"""XMPP Bot xiaoguo@yoin.fun - 跑在 Linux 上""" -import asyncio, logging, ssl, json, urllib.request, subprocess, re -from xml.sax.saxutils import escape - -logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') -GATEWAY = "http://localhost:8645/v1/chat/completions" -API_KEY = "hermes123" -_opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) - -def send(from_jid, to_jid, body): - safe = escape(body) - subprocess.run(["docker","exec","ejabberd","ejabberdctl","send_stanza", - from_jid, to_jid, - f"{safe}" - ], capture_output=True, timeout=10) - -class XiaoGuoBot: - def __init__(self): - import slixmpp - self.xmpp = slixmpp.ClientXMPP('xiaoguo@yoin.fun', 'hermes123') - self.xmpp.add_event_handler('session_bind', self.on_bind) - self.xmpp.add_event_handler('message', self.on_msg) - self.xmpp.add_event_handler('disconnected', self.on_disconnect) - ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - self.xmpp.ssl_context = ctx - self.ready = asyncio.Event() - self._call_seq = 0 - - async def on_bind(self, event): - self.xmpp.send_presence() - self.xmpp.get_roster() - # 加入内核组 - self.xmpp.plugin['xep_0045'].join_muc('coregroup@conference.yoin.fun', 'xiaoguo') - self.ready.set() - logging.info("✅ 小果上线") - - async def on_disconnect(self, event): - self.ready.clear() - logging.warning("⚠️ 小果断线") - - async def on_msg(self, msg): - body = msg['body'] - sender = str(msg['from']) - msg_type = msg['type'] - if not body: - return - # 群聊 - if msg_type == 'groupchat': - if 'xiaoguo@yoin.fun' in sender: - return - nickname = sender.split('/')[-1] if '/' in sender else '' - if nickname in ('hmo', 'xxm'): - logging.info(f"📩 群消息 [{sender}]: {body[:80]}") - room = sender.split('/')[0] - ctx_body = f"[核心群 {room}] {nickname} 说: {body}" - await self.call_hermes(ctx_body, room, is_group=True) - return - # 私聊 - if msg_type == 'chat' and 'hmo@yoin.fun' in sender: - self._call_seq += 1 - logging.info(f"📩 老爸(#{self._call_seq}): {body}") - await self.call_hermes(body, sender) - - async def call_hermes(self, content, sender, is_group=False): - msg_type = 'groupchat' if is_group else 'chat' - try: - payload = json.dumps({ - "model": "hermes-agent", - "messages": [{"role": "user", "content": f"[xiaoguo] {content}"}] - }).encode() - req = urllib.request.Request(GATEWAY, data=payload, method="POST") - req.add_header("Content-Type", "application/json") - req.add_header("Authorization", f"Bearer {API_KEY}") - req.add_header("X-Hermes-Session-Id", "xmpp-xiaoguo") - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, lambda: _opener.open(req, timeout=600)) - data = json.loads(result.read()) - reply = data.get("choices", [{}])[0].get("message", {}).get("content", "") - finish = data.get("choices", [{}])[0].get("finish_reason", "") - # 处理 __SILENT__ 和 __REPLY__ 标记(和莫荷保持一致) - stripped = reply.strip() - if stripped.startswith('__SILENT__'): - logging.info("⏭️ 小果决定沉默,不发送") - return - # 安全网:过滤沉默宣告类文本(防止 LLM 不按规则走) - if re.match(r'^(保持安静|不插嘴|我沉默了|收到|明白|好的|在的?|在呢|来了|沉默|安静)([,,!!。.?\s]|$)', stripped, re.IGNORECASE): - logging.info(f"⏭️ 小果沉默宣告被拦截: {stripped[:60]}") - return - reply = re.sub(r'^__REPLY__\s*', '', reply) - if reply.strip() and finish != "silent": - if is_group: - self.xmpp.send_message(mto=sender, mbody=reply, mtype='groupchat') - else: - send("xiaoguo@yoin.fun", sender, reply) - logging.info(f"✅ 小果回复: {reply[:80]}") - except Exception as e: - logging.error(f"❌ 小果错误: {e}") - -async def main(): - while True: - try: - z = XiaoGuoBot() - z.xmpp.register_plugin('xep_0030') - z.xmpp.register_plugin('xep_0045') - z.xmpp.register_plugin('xep_0199') - z.xmpp.connect(host='127.0.0.1', port=5222) - await asyncio.wait_for(z.ready.wait(), timeout=30) - logging.info("小果就绪") - await asyncio.Event().wait() - except Exception as e: - logging.error(f"小果main错误: {e}") - await asyncio.sleep(3) - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/bots/xmpp_zhiwei_bot.py b/bots/xmpp_zhiwei_bot.py deleted file mode 100644 index 219a7dc..0000000 --- a/bots/xmpp_zhiwei_bot.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -"""XMPP Bot zhiwei@yoin.fun - Hermes API 版(稳定重连版)""" -import asyncio, logging, ssl, json, urllib.request, os, subprocess, time -from xml.sax.saxutils import escape - -logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') -GATEWAY = "http://localhost:8643/v1/chat/completions" -API_KEY = "hermes123" -_opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) - -def send(from_jid, to_jid, body): - safe = escape(body) - subprocess.run(["docker","exec","ejabberd","ejabberdctl","send_stanza", - from_jid, to_jid, - f"{safe}" - ], capture_output=True, timeout=10) - -class ZhiweiBot: - def __init__(self): - import slixmpp - self.xmpp = slixmpp.ClientXMPP('zhiwei@yoin.fun', 'hermes123') - self.xmpp.add_event_handler('session_bind', self.on_bind) - self.xmpp.add_event_handler('message', self.on_msg) - self.xmpp.add_event_handler('disconnected', self.on_disconnect) - self.xmpp.add_event_handler('connected', self.on_connected) - # 启用slixmpp内置自动重连(已禁用—与手动重连冲突) - # self.xmpp.auto_reconnect = True - ctx = ssl.create_default_context(); ctx.check_hostname = False; ctx.verify_mode = ssl.CERT_NONE - self.xmpp.ssl_context = ctx - self.ready = asyncio.Event() - self._call_seq = 0 - - async def on_connected(self, event): - logging.info("🔗 知微TCP连接已建立") - - async def on_bind(self, event): - self.xmpp.send_presence(); self.xmpp.get_roster(); self.ready.set() - logging.info("✅ 知微上线") - - async def on_disconnect(self, event): - self.ready.clear() - logging.warning("⚠️ 知微断线") - # 不要在这里调用 self.xmpp.disconnect(),让 auto_reconnect 处理 - - async def on_msg(self, msg): - body = msg['body']; sender = str(msg['from']) - if not body or msg['type'] != 'chat': return - if 'hmo@yoin.fun' in sender: - self._call_seq += 1 - logging.info(f"📩 老爸(#{self._call_seq}): {body}") - try: - payload = json.dumps({ - "model":"hermes-agent", - "messages":[{"role":"user","content":f"[zhiwei] {body}"}] - }).encode() - req = urllib.request.Request(GATEWAY, data=payload, method="POST") - req.add_header("Content-Type","application/json") - req.add_header("Authorization",f"Bearer {API_KEY}") - req.add_header("X-Hermes-Session-Id","xmpp-zhiwei") - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, lambda: _opener.open(req, timeout=600)) - data = json.loads(result.read()) - reply = data.get("choices",[{}])[0].get("message",{}).get("content","") - finish = data.get("choices",[{}])[0].get("finish_reason","") - if reply.strip() and finish != "silent": - send("zhiwei@yoin.fun", sender, reply) - logging.info(f"✅ 知微回复: {reply[:80]}") - except Exception as e: - logging.error(f"❌ 知微错误: {e}") - -async def main(): - retry_delay = 1 - max_delay = 60 - while True: - try: - z = ZhiweiBot() - z.xmpp.register_plugin('xep_0030'); z.xmpp.register_plugin('xep_0199') - z.xmpp.connect(host='127.0.0.1', port=5222) - await asyncio.wait_for(z.ready.wait(), timeout=30); logging.info("知微就绪") - retry_delay = 1 - - # 保持运行,断线时自动重连 - while True: - await asyncio.sleep(15) - if not z.xmpp.is_connected(): - logging.warning("知微连接丢失,准备重连...") - break - - except Exception as e: - logging.error(f"知微main错误: {e}") - - # 指数退避重连 - logging.info(f"⏳ 知微等待 {retry_delay} 秒后重连...") - await asyncio.sleep(retry_delay) - retry_delay = min(retry_delay * 2, max_delay) - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/docs/CLEANUP_PLAN.md b/docs/CLEANUP_PLAN.md new file mode 100644 index 0000000..c9cbfc6 --- /dev/null +++ b/docs/CLEANUP_PLAN.md @@ -0,0 +1,268 @@ +# AgentsMeeting 代码清理方案 + +> 目标:消除 XMPP Bot 代码碎片化,统一架构。所有 Agent 共享同一份核心代码。 + +--- + +## 1. 现状问题 + +### 1.1 四套独立 Bot 实现 + +| # | 位置 | 行数 | 状态 | 用途 | LLM 调用 | +|---|------|-------|------|------|----------| +| 1 | `gateway/scripts/xmpp_bot.py` | 943 | 活跃 | xxm bot | chat_bridge | +| 2 | `xmpp_agent_core.py `(root) | 348 | 活跃 | mohe/zhiwei/xiaoguo | Hermes API | +| 3 | `bots/xmpp_bot.py` + 姐妹文件 | 359 | 孤儿 | mohe/zhiwei/xiaoguo(旧版) | Hermes API | +| 4 | `xmpp_bot_rest.py` | 72 | 废弃 | ejabberd REST 实验 | Hermes API | + +**问题**:同样的 XMPP 连接、MUC join、消息处理逻辑,在 4 套代码中重复实现。修复了其中一个,其他三个还是坏的。 + +### 1.2 基础设施碎片化 + +bot_base.py(263 行)已经提炼出完善的 XMPP 基类,但没有任何 bot 实际使用它: + +``` +bot_base.py 有: + ✓ PID 锁(proc_guard) + ✓ 自动重连 + MUC join 重试 + ✓ 消息去重(dedup) + ✓ 消息合并(batching) + ✓ silence/shutup 协议 + ✓ send_group / send_private + +没有被任何 bot 使用 ✗ +``` + +### 1.3 重复/孤儿文件 + +| 文件 | 问题 | +|------|------| +| `xmpp_bot.py`(root) | `xmpp_agent_core.py` 的较旧副本 | +| `xmpp_bot_rest.py` | 废弃方案 | +| `bots/` 目录(3 文件) | 被 xmpp_agent_core.py 取代 | +| `hermes_state.py`(root) | 孤儿,import 已断 | +| `scripts/`(6 文件) | 一次性脚本 | +| `src/bots/{mohe,xiaoguo,xxm,zhiwei}/` | 空目录,重构未完成 | +| `src/channels/wechat/` | 空目录 | + +--- + +## 2. 目标架构 + +``` +AgentsMeeting/ +├── xmpp_agent_core.py ← 唯一 Bot 核心(统一架构) +│ 支持 --agent xxm|mohe|zhiwei|xiaoguo +│ 包含:PID锁/重连/MUC/dedup/batching/ +│ coordinator协议/HTTP桥 +│ +├── xxm_bot.py ← 兼容入口(python xxm_bot.py = python xmpp_agent_core.py --agent xxm) +├── mohe_bot.py ← 兼容入口 +├── xiaoguo_bot.py ← 兼容入口 +├── zhiwei_bot.py ← 兼容入口 +│ +└── gateway/ + └── scripts/ + ├── chat_bridge.py ← xxm LLM 桥(不变) + ├── session_router.py ← 消息路由(不变) + └── ...(其他脚本不变) +``` + +**核心原则**: +- `xmpp_agent_core.py` = 唯一核心,所有 Agent 共享 +- 每个 Agent 只有 LLM 调用方式不同(在配置中定义) +- coordinator/GRANT/REVOKE 协议只有一个实现,所有 bot 统一遵守 +- `bot_base.py` 的基础设施合并进 `xmpp_agent_core.py` + +--- + +## 3. 迁移步骤 + +### Phase 0:确认基线(先保证现有 bot 正常工作) + +- [ ] 确认 `gateway/scripts/xmpp_bot.py` 当前运行正常,能收发群消息 +- [ ] 确认 `xmpp_agent_core.py`(mohe)当前运行正常 + +### Phase 1:合并基础设施到 xmpp_agent_core.py + +将 `bot_base.py` 和 `gateway/scripts/xmpp_bot.py` 中的公共功能合并进 `xmpp_agent_core.py`: + +| 功能 | 来源 | 说明 | +|------|------|------| +| PID 锁(proc_guard) | bot_base.py / gateway | 防重复启动 | +| 消息去重(dedup) | bot_base.py / gateway | 防重复处理同一条消息 | +| 消息合并(batching) | bot_base.py / gateway | 3s debounce,附近消息合并一次 LLM 调用 | +| 自动重连 | bot_base.py / gateway | slixmpp auto_reconnect + reconnect_max_delay | +| MUC join 双重保证 | gateway | `join_muc()` + `send_raw(presence)` | +| MAM 启动恢复 | gateway | 启动时拉取最近 50 条历史消息补上下文 | +| HTTP 桥丰富 API | gateway | /health, /presence, /messages, POST /send | +| sub-agent exec | gateway | `##exec:command##` 用于工具调用 | +| delayed reply | gateway | `##delay:N##` 延迟回复 | + +### Phase 2:统一 coordinator 协议 + +gateway/scripts/xmpp_bot.py 已经实现了 coordinator/GRANT/REVOKE 协议。将这套实现整合进 `xmpp_agent_core.py`,确保所有 Agent 使用同一份协议代码。 + +### Phase 3:添加 xxm 作为支持 Agent + +在 `xmpp_agent_core.py` 的 AGENTS 配置中添加 xxm: + +```python +AGENTS = { + "mohe": {..., "gateway": "http://localhost:8642/v1/chat/completions"}, + "zhiwei": {..., "gateway": "http://localhost:8643/v1/chat/completions"}, + "xiaoguo":{..., "gateway": "http://localhost:8645/v1/chat/completions"}, + "xxm": {..., "bridge": "chat_bridge"}, # xxm 走本地 chat_bridge +} +``` + +xxm 的 LLM 调用方式不同(不走 HTTP Hermes API,走本地 `chat_bridge.py`),所以需要在核心中抽象 LLM 调用层: + +```python +def _call_llm(self, content: str) -> str: + if self.cfg.get("bridge") == "chat_bridge": + return _call_chat_bridge(content) # 本地调用 + else: + return _call_hermes_api(content, self.cfg["gateway"]) # HTTP 调用 +``` + +### Phase 4:删除孤儿文件 + +| 文件 | 操作 | +|------|------| +| `xmpp_bot.py`(root) | 移到 trashbox | +| `xmpp_bot_rest.py` | 移到 trashbox | +| `bots/` 整个目录 | 移到 trashbox | +| `hermes_state.py`(root) | 移到 trashbox | +| `scripts/gen_prd.py` | 移到 trashbox | +| `scripts/write_prd.py` | 移到 trashbox | +| `scripts/gen_prd_v02.py` | 移到 trashbox | +| `scripts/write_prd_v02.py` | 移到 trashbox | +| `scripts/build_prd.py` | 移到 trashbox | +| `scripts/test_echo.py` | 移到 trashbox | +| `scripts/gen_b64.py` | 保留(可能是通用工具) | + +### Phase 5:清理空目录 + +| 目录 | 操作 | +|------|------| +| `src/bots/mohe/` | 删除 | +| `src/bots/xiaoguo/` | 删除 | +| `src/bots/xxm/` | 删除 | +| `src/bots/zhiwei/` | 删除 | +| `src/channels/wechat/` | 删除 | +| `gateway/assets/` | 删除 | + +### Phase 6:整理 gateway/temp/ + +保留:PID 文件、活跃缓存(.bridge_context.jsonl、.model_cache.json) +其余 >200 个临时文件:移到 temp/archive/ 或按需清理。 + +### Phase 7:文档更新 + +- [ ] 更新 `docs/ARCHITECTURE.md` —— 反映统一架构 +- [ ] 更新 `README.md` —— 更新项目结构描述 +- [ ] 更新 `config/agents.yaml` —— 保持准确 + +--- + +## 4. 核心文件变更清单 + +### 修改:xmpp_agent_core.py + +从 348 行扩展为 ~600 行,新增: + +``` +新增功能模块: +├── proc_guard PID 锁 +├── 消息去重(_dedup_cache) +├── 消息合并(_batch_* 系统,3s debounce) +├── MAM 启动恢复(_fetch_mam_history) +├── HTTP 桥丰富版(/health, /presence, /messages, POST /send) +├── sub-agent exec(##exec:command##) +├── delayed reply(##delay:N##) +├── 抽象 LLM 调用层(支持 chat_bridge + Hermes API) +└── coordinator/GRANT/REVOKE 协议(从 gateway 版提升) +``` + +新增 AGENTS 配置条目: + +```python +"xxm": { + "jid": "xxm@yoin.fun", + "password": "hermes123", + "nick": "xxm", + "name_cn": "笑笑", + "bridge": "chat_bridge", # 使用本地 chat_bridge + "http_port": 5802, # HTTP 桥端口 + "muc_rooms": [ + "coregroup@conference.yoin.fun", + "jujidina@conference.yoin.fun", + ], + "server": "192.168.1.246", # LAN 直连 + "port": 5222, + "session_id": "ses_xxm_xmpp", + "mention": "@xxm/@笑笑", +} +``` + +### 创建:xxm_bot.py + +``` +#!/usr/bin/env python3 +"""Wrapper for xmpp_agent_core.py --agent xxm""" +import sys, os +sys.argv = [sys.argv[0], '--agent', 'xxm'] +exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py')).read()) +``` + +(与现有的 mohe_bot.py / zhiwei_bot.py / xiaoguo_bot.py 模式一致) + +### 保留不变 + +| 文件 | 说明 | +|------|------| +| `gateway/scripts/chat_bridge.py` | xxm LLM 桥,不迁移 | +| `gateway/scripts/session_router.py` | 消息路由,不迁移 | +| `gateway/scripts/wechat_agent.py` | 微信桥接,独立组件 | +| `gateway/scripts/qq_bot.py` | QQ bot,独立组件 | +| `gateway/scripts/vc_webhook.py` | VoceChat webhook,独立组件 | +| `gateway/scripts/dashboard.py` | 管理门户,独立组件 | +| `gateway/scripts/health_check_xxm.py` | 健康检查,独立组件 | +| `gateway/scripts/xmpp_watchdog.py` | 看门狗,独立组件 | +| `gateway/scripts/mohe_watcher.py` | 莫荷消息监控,独立组件 | +| `gateway/scripts/api_proxy.py` | API 代理,独立组件 | + +--- + +## 5. 测试计划 + +### 5.1 单元测试 + +- [ ] `python tests/test_core.py` —— bot_base 测试保持通过 +- [ ] 新增测试:合并后的 xmpp_agent_core.py 的 LLM 抽象层 + +### 5.2 集成测试 + +- [ ] `python xmpp_agent_core.py --agent xxm` 能启动、连接、加入 MUC +- [ ] `python xmpp_agent_core.py --agent mohe` 能启动、连接、加入 MUC +- [ ] 各 Agent 在 coregroup 中能正确响应 @mention +- [ ] coordinator/GRANT/REVOKE 协议各 Agent 一致 + +### 5.3 部署验证 + +- [ ] `python tests/verify_deploy.py` pass +- [ ] gateway/scripts/xmpp_bot.py → 改为调用 xmpp_agent_core.py --agent xxm +- [ ] Linux 端 update systemd service 指向新路径 + +--- + +## 6. 风险与回滚 + +| 风险 | 缓解 | +|------|------| +| 合并后 xxm bot 不工作 | Phase 0 先备份当前 `gateway/scripts/xmpp_bot.py` | +| 协议行为不一致 | coordinator 协议从 gateway 版提取,与 xmpp_agent_core 现有逻辑逐行对比 | +| 启动命令需要改 | 兼容入口(xxm_bot.py 等)保持 CLI 不变 | + +**回滚方案**:Phase 1-4 每步完成后验证。出问题从 trashbox 恢复删除的文件。 diff --git a/gateway/api/history_api.py b/gateway/api/history_api.py index 55ad551..103d3e6 100644 --- a/gateway/api/history_api.py +++ b/gateway/api/history_api.py @@ -66,9 +66,9 @@ def get_db_handle(): dbs = r.get("data") or [] for db in dbs: dbname = db.get("databaseName", "") - if "MSG" in dbname or "Msg" in dbname: - db_handle_cache = db.get("handle") - return db_handle_cache + if dbname.startswith("MSG") and "Media" not in dbname: + db_handle_cache = db.get("handle") + return db_handle_cache return None @@ -94,24 +94,42 @@ def query_history(wxid, limit=10): return None limit_val = min(int(limit), 200) sql = ( - f"SELECT CreateTime, IsSender, Type, SubType, StrContent, DisplayContent " - f"FROM MSG WHERE StrTalker='{wxid}' AND Type IN (1,49) " + f"SELECT CreateTime, IsSender, Type, SubType, StrContent, DisplayContent, CompressContent, BytesExtra " + f"FROM MSG WHERE StrTalker='{wxid}' AND Type IN ('1','49') " f"ORDER BY CreateTime DESC LIMIT {limit_val}" ) r = wxpost("/api/execSql", {"dbHandle": h, "sql": sql}, timeout=15) data = r.get("data") or [] if not data or len(data) < 2: return None - # Skip header row, reverse to chronological order - rows = data[1:] + # wxhelper returns [{value: [cols]}, {value: [row1]}, ...] + rows = [item.get("value", item) if isinstance(item, dict) else item for item in data] + rows = rows[1:] # skip header rows.reverse() results = [] for row in rows: content = (row[4] or "").strip() if len(row) > 4 else "" if not content and len(row) > 5: content = (row[5] or "").strip() + # Type 49 (article link): extract URL from CompressContent or BytesExtra + if not content and str(row[2]) == "49": + try: + import re + # Try BytesExtra first (row[7]) + for idx in [7, 6]: + if idx < len(row) and row[idx]: + text = str(row[idx]) + urls = re.findall(r'https?://[^\s\x00-\x1f<>\"\']{10,}', text) + if urls: + content = urls[0] + break + except: + pass if not content: - continue + if str(row[2]) == "49": + content = "[文章链接]" + else: + continue results.append({ "CreateTime": row[0], "IsSender": row[1], @@ -181,29 +199,25 @@ def get_recent_chats(limit=20): if not h: return [] sql = ( - f"SELECT StrTalker, MAX(CreateTime) as last_time, COUNT(*) as msg_count " - f"FROM MSG WHERE Type IN (1,49) " - f"GROUP BY StrTalker ORDER BY last_time DESC LIMIT {min(limit, 50)}" + f"SELECT DISTINCT StrTalker FROM MSG WHERE Type IN ('1','49') " + f"LIMIT {min(limit, 50)}" ) r = wxpost("/api/execSql", {"dbHandle": h, "sql": sql}, timeout=15) data = r.get("data") or [] if not data or len(data) < 2: return [] + rows = [item.get("value", item) if isinstance(item, dict) else item for item in data] results = [] - for row in data[1:]: + for row in rows[1:]: wxid = (row[0] or "").strip() - if not wxid or wxid in ("fmessage", "weixin", "wechat", "filehelper"): + if not wxid or wxid in ("fmessage", "weixin", "wechat", "filehelper", "medianote", "floatbottle", "qmessage"): continue - if wxid.startswith("gh_"): - continue - ts = int(row[1]) if row[1] else 0 - count = int(row[2]) if len(row) > 2 and row[2] else 0 results.append({ "wxid": wxid, "nickname": get_nickname(wxid), - "last_message_time": datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%M:%S") if ts else None, - "last_message_ts": ts, - "message_count": count, + "last_message_time": None, + "last_message_ts": 0, + "message_count": 0, }) return results diff --git a/gateway/scripts/wechat_agent.py b/gateway/scripts/wechat_agent.py index 6b72438..5935a36 100644 --- a/gateway/scripts/wechat_agent.py +++ b/gateway/scripts/wechat_agent.py @@ -21,7 +21,8 @@ if not _lock.ok: BOT_WXID = "wxid_5bhmquvkbude22" BLOCK_WXIDS = {"fmessage", "weixin", "wechat"} # ϵͳ?˺?/΢???Ŷӣ----ظ? WX_API = "http://127.0.0.1:19088" -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) LOG_DIR = os.path.join(PROJECT_ROOT, "logs") TEMP_DIR = os.path.join(PROJECT_ROOT, "temp") LOG_FILE = os.path.join(LOG_DIR, "wechat_agent.log") @@ -155,8 +156,8 @@ HERMES_KEY = "hermes123" SENSENOVA_KEY = "sk-aRNj3UwKSLPsDfh15QNTPwbHxahblfaO" SENSENOVA_URL = "https://token.sensenova.cn/v1" -INJECTOR = r"D:\F\NewI\opencode\daily-workspace\projects\wechat-hermes-gateway\tools\Injector_x64.exe" -WXHELPER_DLL = r"D:\F\NewI\opencode\daily-workspace\projects\wechat-hermes-gateway\tools\wxhelper_official_39581.dll" +INJECTOR = os.path.join(SCRIPT_DIR, "..", "tools", "Injector_x64.exe") +WXHELPER_DLL = os.path.join(SCRIPT_DIR, "..", "tools", "wxhelper_official_39581.dll") def log(m): with open(LOG_FILE, "a", encoding="utf-8") as f: @@ -569,8 +570,24 @@ def process_msg(raw_data): ct = d.get("content", "") or d.get("msg", "") or d.get("text", "") msg_type = d.get("type", 1) is_self = d.get("isSelf", 0) or d.get("self", 0) + # DEBUG: capture Type 49 full XML for URL analysis + if msg_type == 49: + try: + with open(LOG_DIR + "/t49_xml.txt", "a", encoding="utf-8") as _f: + _f.write(f"\n=== {time.time()} type=49 from={fu} ===\n{ct[:10000]}\n") + except: pass if "@chatroom" in fu: log(f"GROUP RAW DUMP: keys={list(d.keys())} ct_len={len(ct)} ct[:100]={ct[:100]}") + # DEBUG: capture full raw data for quote analysis + try: + with open(LOG_DIR + "/group_raw.jsonl", "a", encoding="utf-8") as _f: + _f.write(json.dumps({k: str(v)[:2000] for k, v in d.items()}, ensure_ascii=False) + "\n") + except: pass + # DEBUG: capture all raw msgs for field analysis + try: + with open(LOG_DIR + "/all_raw.jsonl", "a", encoding="utf-8") as _f: + _f.write(json.dumps({k: str(v)[:500] for k, v in d.items()}, ensure_ascii=False) + "\n") + except: pass if not fu or not ct or fu == BOT_WXID or fu in BLOCK_WXIDS or fu.startswith("gh_") or is_self: log(f"SKIP: fu={fu} self={is_self}") return @@ -608,6 +625,64 @@ def process_msg(raw_data): else: log(f"-> {fu}: skip (blank image response)") return + # Type 49 (forwarded article) - extract URL and process via article_processor + if msg_type == 49 and ct.strip().startswith(" first, then , then + urls = re.findall(r'(https?://mp\.weixin\.qq\.com[^<]+)', ct) + if not urls: + urls = re.findall(r'(https?://mp\.weixin\.qq\.com[^<]+)', ct) + if not urls: + urls = re.findall(r'(https?://mp\.weixin\.qq\.com[^<]+)', ct) + url = urls[0] if urls else None + # Extract title from XML + titles = re.findall(r'(.*?)', ct) + title = titles[0] if titles else "" + # Extract description + descs = re.findall(r'(.*?)', ct) + desc = descs[0] if descs else "" + + if url: + log(f"ARTICLE URL: {url}") + # Call article_processor on localhost + import urllib.request as ur + req = ur.Request("http://127.0.0.1:5810/process", + data=json.dumps({"url": url}).encode("utf-8"), + headers={"Content-Type": "application/json"}) + with ur.urlopen(req, timeout=180) as resp: + result = json.loads(resp.read().decode("utf-8")) + if result.get("status") == "ok": + content = result.get("content", "")[:3000] + title = result.get("title", "") + images = result.get("images_ocr", 0) + enriched = f"[老莫转发了一篇文章{(chr(10)+'标题: '+title) if title else ''},{images}张图片已OCR]\n\n{content}" + log(f"ARTICLE processed: {len(content)} chars") + reply = call_hermes(fu, enriched) + if reply and reply.strip(): + log(f"-> {fu}: {reply[:50]}") + send_wx(fu, reply.strip()) + return + else: + log(f"ARTICLE process failed: {result.get('error','')[:100]}") + # Fallback: send title + description + fallback = f"[老莫转发了一篇文章]{(chr(10)+'标题: '+title) if title else ''}{(chr(10)+'摘要: '+desc[:200]) if desc else ''}\n(全文抓取失败: {result.get('error','')[:60]})" + reply = call_hermes(fu, fallback) + if reply and reply.strip(): + send_wx(fu, reply.strip()) + return + else: + # No URL found, send title + description + if title: + log(f"ARTICLE: no URL, sending title+desc") + fallback = f"[老莫转发了一篇文章]{(chr(10)+'标题: '+title) if title else ''}{(chr(10)+'摘要: '+desc[:200]) if desc else ''}" + reply = call_hermes(fu, fallback) + if reply and reply.strip(): + send_wx(fu, reply.strip()) + return + except Exception as e: + log(f"ARTICLE handler error: {e}") + # Fall through to text handler # Text - prepend sender wxid+name so Hermes knows who's talking sender_name = get_nickname(fu) chat_type = "Group" if "@chatroom" in fu else "Private" diff --git a/gateway/scripts/xmpp_bot.py b/gateway/scripts/xmpp_bot.py index d23e5b2..203783d 100644 --- a/gateway/scripts/xmpp_bot.py +++ b/gateway/scripts/xmpp_bot.py @@ -26,8 +26,8 @@ if not _lock.ok: # ── Config ── JID = "xxm@yoin.fun" PASSWORD = "hermes123" -SERVER = "xmpp.yoin.fun" -PORT = 3021 +SERVER = "192.168.1.246" +PORT = 5222 ATTACH_SESSION = "ses_xxm_xmpp" MUC_ROOMS = [ "coregroup@conference.yoin.fun", # core group chat @@ -696,23 +696,22 @@ if __name__ == "__main__": bot_nick = JID.split("@")[0] async def _join_silent(): for room_jid in MUC_ROOMS: - for attempt in range(3): - try: - # Use join_muc_wait to ensure room join completes - await self.plugin['xep_0045'].join_muc_wait(room_jid, bot_nick, timeout=60) - log(f"Joined {room_jid} (silent)") - break - except asyncio.TimeoutError: - log(f"MUC join timeout ({attempt+1}/3) for {room_jid}") - if attempt == 2: - log(f"MUC setup failed for {room_jid} after 3 attempts") - await asyncio.sleep(5) - else: - await asyncio.sleep(3) - except Exception as e: - log(f"MUC setup failed for {room_jid}: {e} (type={type(e).__name__})") - await asyncio.sleep(5) - break + nick = bot_nick + try: + # Use join_muc (non-waiting) to register plugin state + self.plugin['xep_0045'].join_muc(room_jid, nick) + # Also send raw presence as backup + presence = ( + f"" + f"" + f"" + f"" + ) + self.send_raw(presence) + log(f"Joined {room_jid} (async)") + except Exception as e: + log(f"MUC join failed for {room_jid}: {type(e).__name__}: {e}") + await asyncio.sleep(2) # After joining, query MAM for recent history await asyncio.sleep(3) # wait for MUC join to propagate await _fetch_mam_history() diff --git a/hermes_state.py b/hermes_state.py deleted file mode 100644 index de8526b..0000000 --- a/hermes_state.py +++ /dev/null @@ -1,4372 +0,0 @@ -#!/usr/bin/env python3 -""" -SQLite State Store for Hermes Agent. - -Provides persistent session storage with FTS5 full-text search, replacing -the per-session JSONL file approach. Stores session metadata, full message -history, and model configuration for CLI and gateway sessions. - -Key design decisions: -- WAL mode for concurrent readers + one writer (gateway multi-platform) -- FTS5 virtual table for fast text search across all session messages -- Compression-triggered session splitting via parent_session_id chains -- Batch runner and RL trajectories are NOT stored here (separate systems) -- Session source tagging ('cli', 'telegram', 'discord', etc.) for filtering -""" - -import json -import logging -import random -import re -import sqlite3 -import threading -import time -from pathlib import Path - -from agent.memory_manager import sanitize_context -from hermes_constants import get_hermes_home -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - -DEFAULT_DB_PATH = get_hermes_home() / "state.db" - -SCHEMA_VERSION = 15 - -# --------------------------------------------------------------------------- -# WAL-compatibility fallback -# --------------------------------------------------------------------------- -# SQLite's WAL mode requires shared-memory (mmap) coordination and fcntl -# byte-range locks that don't reliably work on network filesystems (NFS, -# SMB/CIFS, some FUSE mounts, WSL1). Upstream documents this explicitly: -# https://www.sqlite.org/wal.html#sometimes_queries_return_sqlite_busy_in_wal_mode -# -# On those filesystems ``PRAGMA journal_mode=WAL`` raises -# ``sqlite3.OperationalError: locking protocol`` (SQLITE_PROTOCOL). If we -# propagate that, every feature backed by state.db / kanban.db breaks -# silently — /resume, /title, /history, /branch, kanban dispatcher, etc. -# -# Instead, fall back to ``journal_mode=DELETE`` (the pre-WAL default) which -# works on NFS. Concurrency drops — concurrent readers are blocked during -# a write — but the feature works. -_WAL_INCOMPAT_MARKERS = ( - "locking protocol", # SQLITE_PROTOCOL on NFS/SMB - "not authorized", # Some FUSE mounts block WAL pragma outright -) - -# Last SessionDB() init error, per-process. Surfaced in /resume and -# related slash-command error strings so users know WHY the DB is -# unavailable instead of getting a bare "Session database not available." -# Only SessionDB.__init__ writes to this; kanban_db.connect() failures -# do not update it (by design — kanban failures are reported via their -# own caller's error handling, not via /resume-style slash commands). -_last_init_error: Optional[str] = None -_last_init_error_lock = threading.Lock() - -# Paths for which we've already logged a WAL-fallback WARNING. Without -# this, kanban_db.connect() (called on every kanban operation — see -# hermes_cli/kanban_db.py for ~30 call sites) would re-log the same -# filesystem-incompat warning on every connection, filling errors.log. -_wal_fallback_warned_paths: set[str] = set() -_wal_fallback_warned_lock = threading.Lock() - -_FTS_TRIGGERS = ( - "messages_fts_insert", - "messages_fts_delete", - "messages_fts_update", - "messages_fts_trigram_insert", - "messages_fts_trigram_delete", - "messages_fts_trigram_update", -) - - -def _set_last_init_error(msg: Optional[str]) -> None: - """Record (or clear) the most recent state.db init failure. - - Thread-safe via _last_init_error_lock. Callers pass a message to - record a failure or None to clear. SessionDB.__init__ only calls - this to SET on failure — it deliberately does NOT clear on success, - because in a multi-threaded caller (e.g. gateway / web_server per- - request SessionDB() instantiation), a concurrent successful open - racing past a different thread's failure would erase the cause - string that thread's /resume handler is about to format. Explicit - clears (e.g. test fixtures) are still supported by passing None. - """ - global _last_init_error - with _last_init_error_lock: - _last_init_error = msg - - -def get_last_init_error() -> Optional[str]: - """Return the most recent state.db init failure, if any. - - Slash-command handlers (``/resume``, ``/title``, ``/history``, ``/branch``) - call this to surface the underlying cause in their error messages when - ``_session_db is None``. Returns ``None`` if SessionDB initialized - successfully (or hasn't been attempted). - """ - return _last_init_error - - -def format_session_db_unavailable(prefix: str = "Session database not available") -> str: - """Format a user-facing 'session DB unavailable' message with cause. - - When ``SessionDB()`` init fails, callers set ``_session_db = None`` and - several slash commands (/resume, /title, /history, /branch) previously - responded with a bare ``"Session database not available."`` — no - indication of WHY. This helper includes the captured cause (typically - ``"locking protocol"`` from NFS/SMB) and points users at the known - culprit so they can fix it themselves. - - Example output: - Session database not available: locking protocol (state.db may be - on NFS/SMB — see https://www.sqlite.org/wal.html). - """ - cause = get_last_init_error() - if not cause: - return f"{prefix}." - hint = "" - if any(marker in cause.lower() for marker in _WAL_INCOMPAT_MARKERS): - hint = " (state.db may be on NFS/SMB/FUSE — see https://www.sqlite.org/wal.html)" - return f"{prefix}: {cause}{hint}." - - -def _on_disk_journal_mode(conn: sqlite3.Connection) -> Optional[str]: - """Read the journal mode from the SQLite DB header on disk. - - Returns the mode string (e.g. ``"wal"``, ``"delete"``), or ``None`` - if the value cannot be determined (new DB, or PRAGMA read failed). - """ - try: - row = conn.execute("PRAGMA journal_mode").fetchone() - except sqlite3.OperationalError: - return None - if row is None: - return None - mode = row[0] - if isinstance(mode, bytes): # defensive: sqlite3 occasionally returns bytes - try: - mode = mode.decode("ascii") - except UnicodeDecodeError: - return None - return str(mode).strip().lower() if mode is not None else None - - -def apply_wal_with_fallback( - conn: sqlite3.Connection, - *, - db_label: str = "state.db", -) -> str: - """Set ``journal_mode=WAL`` on ``conn``, falling back to DELETE on failure. - - Returns the journal mode actually set (``"wal"`` or ``"delete"``). - - On WAL-incompatible filesystems (NFS, SMB, some FUSE), SQLite raises - ``OperationalError("locking protocol")`` when setting WAL. We fall - back to DELETE mode — the pre-WAL default, which works on NFS — and - log one WARNING explaining why. - - The WARNING is deduplicated per ``db_label``: repeated connections - to the same underlying DB (e.g. kanban_db.connect() which is called - on every kanban operation) log once per process, not once per call. - Different db_labels log independently, so state.db and kanban.db - each get one warning on the same NFS mount. - - Shared by :class:`SessionDB` and ``hermes_cli.kanban_db.connect`` so - both databases get identical fallback behavior. - - Never downgrades to DELETE if the on-disk DB header reports WAL — see _on_disk_journal_mode. - """ - # Read-only probe — no flock, no checkpoint, no WAL/SHM unlink. - # Skipping the set-pragma prevents WAL-init from unlinking files other connections hold open. - try: - current_mode = conn.execute("PRAGMA journal_mode").fetchone() - if current_mode and current_mode[0] == "wal": - return "wal" - except sqlite3.OperationalError: - pass - - try: - conn.execute("PRAGMA journal_mode=WAL") - return "wal" - except sqlite3.OperationalError as exc: - msg = str(exc).lower() - if not any(marker in msg for marker in _WAL_INCOMPAT_MARKERS): - # Unrelated OperationalError — don't silently swallow. - raise - # Don't downgrade if another process already set WAL on disk. - existing = _on_disk_journal_mode(conn) - if existing == "wal": - raise - _log_wal_fallback_once(db_label, exc) - conn.execute("PRAGMA journal_mode=DELETE") - return "delete" - - -def _log_wal_fallback_once(db_label: str, exc: Exception) -> None: - """Log a single WARNING per (process, db_label) about WAL fallback. - - Without this dedup, NFS users running kanban (which opens a fresh - connection on every operation — see hermes_cli/kanban_db.py) would - fill errors.log with hundreds of identical warnings per hour. - """ - with _wal_fallback_warned_lock: - if db_label in _wal_fallback_warned_paths: - return - _wal_fallback_warned_paths.add(db_label) - logger.warning( - "%s: WAL journal_mode unsupported on this filesystem (%s) — " - "falling back to journal_mode=DELETE (slower rollback-journal " - "mode; reduces concurrency but works on NFS/SMB/FUSE). See " - "https://www.sqlite.org/wal.html for details. This warning " - "fires once per process per database.", - db_label, - exc, - ) - -SCHEMA_SQL = """ -CREATE TABLE IF NOT EXISTS schema_version ( - version INTEGER NOT NULL -); - -CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - source TEXT NOT NULL, - user_id TEXT, - model TEXT, - model_config TEXT, - system_prompt TEXT, - parent_session_id TEXT, - started_at REAL NOT NULL, - ended_at REAL, - end_reason TEXT, - message_count INTEGER DEFAULT 0, - tool_call_count INTEGER DEFAULT 0, - input_tokens INTEGER DEFAULT 0, - output_tokens INTEGER DEFAULT 0, - cache_read_tokens INTEGER DEFAULT 0, - cache_write_tokens INTEGER DEFAULT 0, - reasoning_tokens INTEGER DEFAULT 0, - cwd TEXT, - billing_provider TEXT, - billing_base_url TEXT, - billing_mode TEXT, - estimated_cost_usd REAL, - actual_cost_usd REAL, - cost_status TEXT, - cost_source TEXT, - pricing_version TEXT, - title TEXT, - api_call_count INTEGER DEFAULT 0, - handoff_state TEXT, - handoff_platform TEXT, - handoff_error TEXT, - rewind_count INTEGER NOT NULL DEFAULT 0, - archived INTEGER NOT NULL DEFAULT 0, - FOREIGN KEY (parent_session_id) REFERENCES sessions(id) -); - -CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT NOT NULL REFERENCES sessions(id), - role TEXT NOT NULL, - content TEXT, - tool_call_id TEXT, - tool_calls TEXT, - tool_name TEXT, - timestamp REAL NOT NULL, - token_count INTEGER, - finish_reason TEXT, - reasoning TEXT, - reasoning_content TEXT, - reasoning_details TEXT, - codex_reasoning_items TEXT, - codex_message_items TEXT, - platform_message_id TEXT, - observed INTEGER DEFAULT 0, - active INTEGER NOT NULL DEFAULT 1 -); - -CREATE TABLE IF NOT EXISTS state_meta ( - key TEXT PRIMARY KEY, - value TEXT -); - -CREATE TABLE IF NOT EXISTS compression_locks ( - session_id TEXT PRIMARY KEY, - holder TEXT NOT NULL, - acquired_at REAL NOT NULL, - expires_at REAL NOT NULL -); - -CREATE INDEX IF NOT EXISTS idx_sessions_source ON sessions(source); -CREATE INDEX IF NOT EXISTS idx_sessions_source_id ON sessions(source, id); -CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id); -CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC); -CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp); -CREATE INDEX IF NOT EXISTS idx_compression_locks_expires ON compression_locks(expires_at); -""" - -# Indexes that reference columns added in later schema versions must be -# created AFTER _reconcile_columns() has had a chance to ADD them on -# existing databases. SCHEMA_SQL above is run by sqlite executescript -# which would otherwise fail on legacy DBs ("no such column: active"). -DEFERRED_INDEX_SQL = """ -CREATE INDEX IF NOT EXISTS idx_messages_session_active - ON messages(session_id, active, timestamp); -""" - -FTS_SQL = """ -CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5( - content -); - -CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN - INSERT INTO messages_fts(rowid, content) VALUES ( - new.id, - COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '') - ); -END; - -CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN - DELETE FROM messages_fts WHERE rowid = old.id; -END; - -CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN - DELETE FROM messages_fts WHERE rowid = old.id; - INSERT INTO messages_fts(rowid, content) VALUES ( - new.id, - COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '') - ); -END; -""" - -# Trigram FTS5 table for CJK substring search. The default unicode61 -# tokenizer splits CJK characters into individual tokens, breaking phrase -# matching. The trigram tokenizer creates overlapping 3-byte sequences so -# substring queries work natively for any script (CJK, Thai, etc.). -FTS_TRIGRAM_SQL = """ -CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts_trigram USING fts5( - content, - tokenize='trigram' -); - -CREATE TRIGGER IF NOT EXISTS messages_fts_trigram_insert AFTER INSERT ON messages BEGIN - INSERT INTO messages_fts_trigram(rowid, content) VALUES ( - new.id, - COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '') - ); -END; - -CREATE TRIGGER IF NOT EXISTS messages_fts_trigram_delete AFTER DELETE ON messages BEGIN - DELETE FROM messages_fts_trigram WHERE rowid = old.id; -END; - -CREATE TRIGGER IF NOT EXISTS messages_fts_trigram_update AFTER UPDATE ON messages BEGIN - DELETE FROM messages_fts_trigram WHERE rowid = old.id; - INSERT INTO messages_fts_trigram(rowid, content) VALUES ( - new.id, - COALESCE(new.content, '') || ' ' || COALESCE(new.tool_name, '') || ' ' || COALESCE(new.tool_calls, '') - ); -END; -""" - - -class SessionDB: - """ - SQLite-backed session storage with FTS5 search. - - Thread-safe for the common gateway pattern (multiple reader threads, - single writer via WAL mode). Each method opens its own cursor. - """ - - # ── Write-contention tuning ── - # With multiple hermes processes (gateway + CLI sessions + worktree agents) - # all sharing one state.db, WAL write-lock contention causes visible TUI - # freezes. SQLite's built-in busy handler uses a deterministic sleep - # schedule that causes convoy effects under high concurrency. - # - # Instead, we keep the SQLite timeout short (1s) and handle retries at the - # application level with random jitter, which naturally staggers competing - # writers and avoids the convoy. - _WRITE_MAX_RETRIES = 15 - _WRITE_RETRY_MIN_S = 0.020 # 20ms - _WRITE_RETRY_MAX_S = 0.150 # 150ms - # Attempt a PASSIVE WAL checkpoint every N successful writes. - _CHECKPOINT_EVERY_N_WRITES = 50 - - def __init__(self, db_path: Path = None, read_only: bool = False): - self.db_path = db_path or DEFAULT_DB_PATH - self.read_only = read_only - - self._lock = threading.Lock() - self._write_count = 0 - self._fts_enabled = False - self._fts_unavailable_warned = False - try: - if read_only: - # Read-only attach for cross-profile aggregation: SELECT-only, - # so we skip schema init entirely (no DDL, no FTS probe, no - # column reconcile). Crucially this takes NO write lock, so - # polling another profile's live DB on every sidebar refresh - # never contends with that profile's running backend. The DB - # must already exist + be initialised (callers guard on - # db_path.exists()); a SELECT against an empty file raises and - # the caller degrades per-profile. - self._conn = sqlite3.connect( - f"file:{self.db_path}?mode=ro", - uri=True, - check_same_thread=False, - timeout=1.0, - isolation_level=None, - ) - self._conn.row_factory = sqlite3.Row - return - - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self._conn = sqlite3.connect( - str(self.db_path), - check_same_thread=False, - # Short timeout — application-level retry with random jitter - # handles contention instead of sitting in SQLite's internal - # busy handler for up to 30s. - timeout=1.0, - # auto-starts transactions on DML, which conflicts with our - # explicit BEGIN IMMEDIATE. None = we manage transactions - # ourselves. - isolation_level=None, - ) - self._conn.row_factory = sqlite3.Row - apply_wal_with_fallback(self._conn, db_label="state.db") - self._conn.execute("PRAGMA foreign_keys=ON") - - self._init_schema() - except Exception as exc: - # Capture the cause so /resume and friends can surface WHY the - # session DB is unavailable instead of a bare "Session database - # not available." Callers that catch this exception keep their - # existing ``self._session_db = None`` degradation path. - # - # Note: we deliberately do NOT clear _last_init_error on the - # success path (no else branch). In multi-threaded callers - # (gateway, web_server per-request SessionDB()), a concurrent - # successful open racing past this failure would erase the - # cause that another thread's /resume is about to format. - # Tests that need to reset the state can call - # ``hermes_state._set_last_init_error(None)`` explicitly. - _set_last_init_error(f"{type(exc).__name__}: {exc}") - raise - - # ── Core write helper ── - - @staticmethod - def _is_fts5_unavailable_error(exc: sqlite3.OperationalError) -> bool: - err = str(exc).lower() - return "no such module" in err and "fts5" in err - - def _warn_fts5_unavailable(self, exc: sqlite3.OperationalError) -> None: - self._fts_enabled = False - if self._fts_unavailable_warned: - return - self._fts_unavailable_warned = True - logger.warning( - "SQLite FTS5 unavailable for %s; full-text session search " - "disabled. Run `hermes update` to rebuild the venv with a " - "current Python (managed uv guarantees FTS5). " - "(underlying error: %s)", - self.db_path, - exc, - ) - - def _sqlite_supports_fts5(self, cursor: sqlite3.Cursor) -> bool: - try: - cursor.execute("CREATE VIRTUAL TABLE temp._hermes_fts5_probe USING fts5(x)") - cursor.execute("DROP TABLE temp._hermes_fts5_probe") - return True - except sqlite3.OperationalError as exc: - if not self._is_fts5_unavailable_error(exc): - raise - self._warn_fts5_unavailable(exc) - return False - - @staticmethod - def _drop_fts_triggers(cursor: sqlite3.Cursor) -> None: - for trigger in _FTS_TRIGGERS: - try: - cursor.execute(f"DROP TRIGGER IF EXISTS {trigger}") - except sqlite3.OperationalError: - pass - - @staticmethod - def _fts_trigger_count(cursor: sqlite3.Cursor) -> int: - placeholders = ",".join("?" for _ in _FTS_TRIGGERS) - row = cursor.execute( - f"SELECT COUNT(*) FROM sqlite_master " - f"WHERE type = 'trigger' AND name IN ({placeholders})", - _FTS_TRIGGERS, - ).fetchone() - return int(row[0] if not isinstance(row, sqlite3.Row) else row[0]) - - @staticmethod - def _rebuild_fts_indexes(cursor: sqlite3.Cursor) -> None: - for table_name in ("messages_fts", "messages_fts_trigram"): - cursor.execute(f"DELETE FROM {table_name}") - cursor.execute( - "INSERT INTO messages_fts(rowid, content) " - "SELECT id, " - "COALESCE(content, '') || ' ' || " - "COALESCE(tool_name, '') || ' ' || " - "COALESCE(tool_calls, '') " - "FROM messages" - ) - cursor.execute( - "INSERT INTO messages_fts_trigram(rowid, content) " - "SELECT id, " - "COALESCE(content, '') || ' ' || " - "COALESCE(tool_name, '') || ' ' || " - "COALESCE(tool_calls, '') " - "FROM messages" - ) - - def _fts_table_probe(self, cursor: sqlite3.Cursor, table_name: str) -> Optional[bool]: - try: - cursor.execute(f"SELECT * FROM {table_name} LIMIT 0") - return True - except sqlite3.OperationalError as exc: - if self._is_fts5_unavailable_error(exc): - self._warn_fts5_unavailable(exc) - return None - if "no such table" in str(exc).lower(): - return False - raise - - def _ensure_fts_schema( - self, - cursor: sqlite3.Cursor, - table_name: str, - ddl: str, - ) -> bool: - status = self._fts_table_probe(cursor, table_name) - if status is None: - return False - try: - # Run even when the virtual table exists so any dropped or missing - # triggers are recreated after a previous no-FTS5 runtime disabled - # them to keep message writes working. - cursor.executescript(ddl) - return True - except sqlite3.OperationalError as exc: - if not self._is_fts5_unavailable_error(exc): - raise - self._warn_fts5_unavailable(exc) - return False - - def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T: - """Execute a write transaction with BEGIN IMMEDIATE and jitter retry. - - *fn* receives the connection and should perform INSERT/UPDATE/DELETE - statements. The caller must NOT call ``commit()`` — that's handled - here after *fn* returns. - - BEGIN IMMEDIATE acquires the WAL write lock at transaction start - (not at commit time), so lock contention surfaces immediately. - On ``database is locked``, we release the Python lock, sleep a - random 20-150ms, and retry — breaking the convoy pattern that - SQLite's built-in deterministic backoff creates. - - Returns whatever *fn* returns. - """ - last_err: Optional[Exception] = None - for attempt in range(self._WRITE_MAX_RETRIES): - try: - with self._lock: - self._conn.execute("BEGIN IMMEDIATE") - try: - result = fn(self._conn) - self._conn.commit() - except BaseException: - try: - self._conn.rollback() - except Exception: - pass - raise - # Success — periodic best-effort checkpoint. - self._write_count += 1 - if self._write_count % self._CHECKPOINT_EVERY_N_WRITES == 0: - self._try_wal_checkpoint() - return result - except sqlite3.OperationalError as exc: - err_msg = str(exc).lower() - if "locked" in err_msg or "busy" in err_msg: - last_err = exc - if attempt < self._WRITE_MAX_RETRIES - 1: - jitter = random.uniform( - self._WRITE_RETRY_MIN_S, - self._WRITE_RETRY_MAX_S, - ) - time.sleep(jitter) - continue - # Non-lock error or retries exhausted — propagate. - raise - # Retries exhausted (shouldn't normally reach here). - raise last_err or sqlite3.OperationalError( - "database is locked after max retries" - ) - - def _try_wal_checkpoint(self) -> None: - """Best-effort TRUNCATE WAL checkpoint. Never raises. - - Flushes committed WAL frames back into the main DB file and - truncates the WAL file to zero bytes. Keeps the WAL from - growing unbounded when many processes hold persistent - connections. - - PASSIVE checkpoint was previously used here, but it never - truncates the WAL file — the file stays at its high-water - mark until an explicit TRUNCATE is called (which only - happened inside the infrequent vacuum()). - - TRUNCATE may block writers briefly while checkpointing, but - _try_wal_checkpoint is called off the hot path (every 50 - writes) and already runs under ``self._lock``, so the - additional hold time is negligible. - """ - try: - with self._lock: - result = self._conn.execute( - "PRAGMA wal_checkpoint(TRUNCATE)" - ).fetchone() - if result and result[1] > 0: - logger.debug( - "WAL checkpoint: %d/%d pages checkpointed", - result[2], result[1], - ) - except Exception: - pass # Best effort — never fatal. - - def close(self): - """Close the database connection. - - Attempts a TRUNCATE WAL checkpoint first so that exiting processes - help shrink the WAL file. - """ - with self._lock: - if self._conn: - try: - self._conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") - except Exception: - pass - self._conn.close() - self._conn = None - - @staticmethod - def _parse_schema_columns(schema_sql: str) -> Dict[str, Dict[str, str]]: - """Extract expected columns per table from SCHEMA_SQL. - - Uses an in-memory SQLite database to parse the SQL — SQLite itself - handles all syntax (DEFAULT expressions with commas, inline - REFERENCES, CHECK constraints, etc.) so there are zero regex - edge cases. The in-memory DB is opened, the schema DDL is - executed, and PRAGMA table_info extracts the column metadata. - - Adding a column to SCHEMA_SQL is all that's needed; the - reconciliation loop picks it up automatically. - """ - ref = sqlite3.connect(":memory:") - try: - ref.executescript(schema_sql) - table_columns: Dict[str, Dict[str, str]] = {} - for (tbl,) in ref.execute( - "SELECT name FROM sqlite_master " - "WHERE type='table' AND name NOT LIKE 'sqlite_%'" - ).fetchall(): - cols: Dict[str, str] = {} - for row in ref.execute( - f'PRAGMA table_info("{tbl}")' - ).fetchall(): - # row: (cid, name, type, notnull, dflt_value, pk) - col_name = row[1] - col_type = row[2] or "" - notnull = row[3] - default = row[4] - pk = row[5] - # Reconstruct the type expression for ALTER TABLE ADD COLUMN - parts = [col_type] if col_type else [] - if notnull and not pk: - parts.append("NOT NULL") - if default is not None: - parts.append(f"DEFAULT {default}") - cols[col_name] = " ".join(parts) - table_columns[tbl] = cols - return table_columns - finally: - ref.close() - - def _reconcile_columns(self, cursor: sqlite3.Cursor) -> None: - """Ensure live tables have every column declared in SCHEMA_SQL. - - Follows the Beets/sqlite-utils pattern: the CREATE TABLE definition - in SCHEMA_SQL is the single source of truth for the desired schema. - On every startup this method diffs the live columns (via PRAGMA - table_info) against the declared columns, and ADDs any that are - missing. - - This makes column additions a declarative operation — just add - the column to SCHEMA_SQL and it appears on the next startup. - Version-gated migration blocks are no longer needed for ADD COLUMN. - """ - expected = self._parse_schema_columns(SCHEMA_SQL) - for table_name, declared_cols in expected.items(): - # Get current columns from the live table - try: - rows = cursor.execute( - f'PRAGMA table_info("{table_name}")' - ).fetchall() - except sqlite3.OperationalError: - continue # Table doesn't exist yet (shouldn't happen after executescript) - live_cols = set() - for row in rows: - # PRAGMA table_info returns (cid, name, type, notnull, dflt_value, pk) - name = row[1] if isinstance(row, (tuple, list)) else row["name"] - live_cols.add(name) - - for col_name, col_type in declared_cols.items(): - if col_name not in live_cols: - safe_name = col_name.replace('"', '""') - try: - cursor.execute( - f'ALTER TABLE "{table_name}" ADD COLUMN "{safe_name}" {col_type}' - ) - except sqlite3.OperationalError as exc: - # Expected: "duplicate column name" from a race or - # re-run. Unexpected: "Cannot add a NOT NULL column - # with default value NULL" from a schema mistake. - # Log at DEBUG so it's visible in agent.log. - logger.debug( - "reconcile %s.%s: %s", table_name, col_name, exc, - ) - - def _init_schema(self): - """Create tables and FTS if they don't exist, reconcile columns. - - Schema management follows the declarative reconciliation pattern - (Beets, sqlite-utils): SCHEMA_SQL is the single source of truth. - On existing databases, _reconcile_columns() diffs live columns - against SCHEMA_SQL and ADDs any missing ones. This eliminates - the version-gated migration chain for column additions, making - it impossible for reordered or inserted migrations to skip columns. - - The schema_version table is retained for future data migrations - (transforming existing rows) which cannot be handled declaratively. - """ - cursor = self._conn.cursor() - - cursor.executescript(SCHEMA_SQL) - - # ── Declarative column reconciliation ────────────────────────── - # Diff live tables against SCHEMA_SQL and ADD any missing columns. - # This is idempotent and self-healing: even if a version-gated - # migration was skipped (e.g. due to version renumbering), the - # column gets created here. - self._reconcile_columns(cursor) - - # Indexes that reference reconciler-added columns must be created - # AFTER _reconcile_columns runs — declaring them in SCHEMA_SQL - # makes the initial executescript fail on legacy DBs (the index's - # WHERE clause references a column that doesn't exist yet). - try: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_messages_platform_msg_id " - "ON messages(session_id, platform_message_id) " - "WHERE platform_message_id IS NOT NULL" - ) - except sqlite3.OperationalError as exc: - logger.debug("idx_messages_platform_msg_id create skipped: %s", exc) - - # Deferred indexes that reference the reconciler-added ``active`` - # column (idx_messages_session_active) — same ordering constraint. - cursor.executescript(DEFERRED_INDEX_SQL) - - fts5_available = self._sqlite_supports_fts5(cursor) - fts_migrations_complete = True - if not fts5_available: - # Existing FTS triggers can still fire on messages INSERT/UPDATE - # even though the current sqlite runtime cannot read the virtual - # tables they target. Drop only the triggers so core persistence - # continues; if a future runtime has FTS5, _ensure_fts_schema() - # recreates them. - self._drop_fts_triggers(cursor) - - # ── Schema version bookkeeping ───────────────────────────────── - # Bump to current so future data migrations (if any) can gate on - # version. No version-gated column additions remain. - cursor.execute("SELECT version FROM schema_version LIMIT 1") - row = cursor.fetchone() - if row is None: - cursor.execute( - "INSERT INTO schema_version (version) VALUES (?)", - (SCHEMA_VERSION,), - ) - else: - current_version = row["version"] if isinstance(row, sqlite3.Row) else row[0] - # Data migrations that can't be expressed declaratively (row - # backfills, index changes tied to a specific version step) stay - # in a version-gated chain. Column additions are handled by - # _reconcile_columns() above and no longer need entries here. - if current_version < 10: - # v10: trigram FTS5 table for CJK/substring search. The - # virtual table + triggers are created unconditionally via - # FTS_TRIGRAM_SQL below, but existing rows need a one-time - # backfill into the FTS index. - if fts5_available: - _fts_trigram_exists = self._fts_table_probe( - cursor, "messages_fts_trigram" - ) - if _fts_trigram_exists is False: - if self._ensure_fts_schema( - cursor, "messages_fts_trigram", FTS_TRIGRAM_SQL - ): - cursor.execute( - "INSERT INTO messages_fts_trigram(rowid, content) " - "SELECT id, content FROM messages WHERE content IS NOT NULL" - ) - else: - fts_migrations_complete = False - elif _fts_trigram_exists is None: - fts_migrations_complete = False - else: - fts_migrations_complete = False - if current_version < 11: - # v11: re-index FTS5 tables to cover tool_name + tool_calls and - # switch from external-content to inline mode. Existing DBs have - # old-schema FTS tables and triggers that IF NOT EXISTS won't - # overwrite, so we drop them explicitly and let the post-migration - # existence checks (below) recreate them from FTS_SQL / - # FTS_TRIGRAM_SQL, then backfill every message row. Fixes #16751. - if fts5_available: - self._drop_fts_triggers(cursor) - for _tbl in ("messages_fts", "messages_fts_trigram"): - try: - cursor.execute(f"DROP TABLE IF EXISTS {_tbl}") - except sqlite3.OperationalError as exc: - if not self._is_fts5_unavailable_error(exc): - raise - self._warn_fts5_unavailable(exc) - fts5_available = False - fts_migrations_complete = False - break - - if fts5_available: - # Recreate virtual tables + triggers with the new inline-mode - # schema that indexes content || tool_name || tool_calls. - if ( - self._ensure_fts_schema(cursor, "messages_fts", FTS_SQL) - and self._ensure_fts_schema( - cursor, "messages_fts_trigram", FTS_TRIGRAM_SQL - ) - ): - # Backfill both indexes from every existing messages row. - cursor.execute( - "INSERT INTO messages_fts(rowid, content) " - "SELECT id, " - "COALESCE(content, '') || ' ' || " - "COALESCE(tool_name, '') || ' ' || " - "COALESCE(tool_calls, '') " - "FROM messages" - ) - cursor.execute( - "INSERT INTO messages_fts_trigram(rowid, content) " - "SELECT id, " - "COALESCE(content, '') || ' ' || " - "COALESCE(tool_name, '') || ' ' || " - "COALESCE(tool_calls, '') " - "FROM messages" - ) - else: - fts_migrations_complete = False - else: - fts_migrations_complete = False - if current_version < 12: - # v12: messages.active flag for rewind/undo soft-deletion. - # The declarative reconcile_columns() above adds the - # column itself; this UPDATE is belt-and-suspenders to - # ensure any rows that pre-existed the ADD COLUMN have - # active=1 rather than NULL. - try: - cursor.execute( - "UPDATE messages SET active = 1 WHERE active IS NULL" - ) - except sqlite3.OperationalError: - pass - if current_version < SCHEMA_VERSION and fts_migrations_complete: - cursor.execute( - "UPDATE schema_version SET version = ?", - (SCHEMA_VERSION,), - ) - - # Unique title index — always ensure it exists - try: - cursor.execute( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique " - "ON sessions(title) WHERE title IS NOT NULL" - ) - except sqlite3.OperationalError: - pass # Index already exists - - if fts5_available: - # FTS5 setup. Run the DDL even when the virtual table exists so - # CREATE TRIGGER IF NOT EXISTS repairs trigger-only degradation from - # an earlier no-FTS5 runtime. - triggers_need_repair = self._fts_trigger_count(cursor) < len(_FTS_TRIGGERS) - self._fts_enabled = self._ensure_fts_schema(cursor, "messages_fts", FTS_SQL) - - # Trigram FTS5 for CJK/substring search. This is optional relative - # to the main FTS table; if it cannot be created, CJK search falls - # back to LIKE. - if self._fts_enabled: - trigram_enabled = self._ensure_fts_schema( - cursor, "messages_fts_trigram", FTS_TRIGRAM_SQL - ) - if trigram_enabled and triggers_need_repair: - self._rebuild_fts_indexes(cursor) - - self._conn.commit() - - # ========================================================================= - # Session lifecycle - # ========================================================================= - - def _insert_session_row( - self, - session_id: str, - source: str, - model: str = None, - model_config: Dict[str, Any] = None, - system_prompt: str = None, - user_id: str = None, - parent_session_id: str = None, - cwd: str = None, - ) -> None: - """Shared INSERT OR IGNORE for session rows.""" - def _do(conn): - conn.execute( - """INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config, - system_prompt, parent_session_id, cwd, started_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - source, - user_id, - model, - json.dumps(model_config) if model_config else None, - system_prompt, - parent_session_id, - cwd, - time.time(), - ), - ) - self._execute_write(_do) - - def create_session(self, session_id: str, source: str, **kwargs) -> str: - """Create a new session record. Returns the session_id.""" - self._insert_session_row(session_id, source, **kwargs) - return session_id - def end_session(self, session_id: str, end_reason: str) -> None: - """Mark a session as ended. - - No-ops when the session is already ended. The first end_reason wins: - compression-split sessions must keep their ``end_reason = 'compression'`` - record even if a later stale ``end_session()`` call (e.g. from a - desynced CLI session_id after ``/resume`` or ``/branch``) targets them - with a different reason. Use ``reopen_session()`` first if you - intentionally need to re-end a closed session with a new reason. - """ - def _do(conn): - conn.execute( - "UPDATE sessions SET ended_at = ?, end_reason = ? " - "WHERE id = ? AND ended_at IS NULL", - (time.time(), end_reason, session_id), - ) - self._execute_write(_do) - - def reopen_session(self, session_id: str) -> None: - """Clear ended_at/end_reason so a session can be resumed.""" - def _do(conn): - conn.execute( - "UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?", - (session_id,), - ) - self._execute_write(_do) - - def update_session_cwd(self, session_id: str, cwd: str) -> None: - """Persist the session working directory when a frontend knows it.""" - if not session_id or not cwd: - return - - def _do(conn): - conn.execute("UPDATE sessions SET cwd = ? WHERE id = ?", (cwd, session_id)) - - self._execute_write(_do) - # ────────────────────────────────────────────────────────────────────── - # Compression locks - # ────────────────────────────────────────────────────────────────────── - # Atomic per-session locks that prevent two compression paths from - # racing on the same session_id and producing orphan child sessions. - # - # The race: ``conversation_compression.py`` rotates ``agent.session_id`` - # as a side effect of a successful compression (end old session, create - # new). That mutation is local to the AIAgent instance — but ``state.db`` - # is shared across all instances. Two AIAgents that share the same - # ``session_id`` at the moment they both decide to compress (most - # commonly the parent turn's agent + a background-review fork started - # right after the turn ended) each end the parent and create their own - # NEW session, parented to the same old id. The gateway SessionEntry - # only catches one rotation; the other child silently accumulates - # writes — Damien's "parent → two orphan children" repro shape. - # - # The lock is keyed by ``session_id`` and is held for the duration of - # the compress() call plus the rotation. ``holder`` identifies the - # current owner (pid:tid:nonce) for diagnostics; the lock is recovered - # via ``expires_at`` if the holder process crashed without releasing. - def try_acquire_compression_lock( - self, - session_id: str, - holder: str, - ttl_seconds: float = 300.0, - ) -> bool: - """Try to atomically acquire the compression lock for ``session_id``. - - Returns ``True`` on success (caller now owns the lock and must - release via :meth:`release_compression_lock`). Returns ``False`` - if another holder already owns a non-expired lock — the caller - MUST NOT proceed with compression in that case (its rotation would - race against the holder's, splitting the session lineage). - - Expired locks (``expires_at < now``) are reclaimed transparently: - the stale row is deleted and the new holder acquires it. This - prevents a crashed compressor from permanently blocking the - session. - - Implementation: single-transaction DELETE-expired + INSERT-or-IGNORE, - followed by a SELECT to confirm we got the row. SQLite serialises - writes, so the whole sequence is atomic against other writers. - """ - if not session_id: - return False - now = time.time() - expires_at = now + ttl_seconds - - def _do(conn): - # First: reclaim any expired lock for this session_id. - conn.execute( - "DELETE FROM compression_locks " - "WHERE session_id = ? AND expires_at < ?", - (session_id, now), - ) - # Then: try to insert. INSERT OR IGNORE returns no rowcount - # difference — verify ownership via SELECT. - conn.execute( - "INSERT OR IGNORE INTO compression_locks " - "(session_id, holder, acquired_at, expires_at) " - "VALUES (?, ?, ?, ?)", - (session_id, holder, now, expires_at), - ) - row = conn.execute( - "SELECT holder FROM compression_locks WHERE session_id = ?", - (session_id,), - ).fetchone() - return row is not None and ( - row["holder"] if isinstance(row, sqlite3.Row) else row[0] - ) == holder - - try: - return bool(self._execute_write(_do)) - except sqlite3.Error as exc: - logger.warning( - "try_acquire_compression_lock(%s) failed: %s", - session_id, exc, - ) - # Fail open: returning False makes the caller skip compression, - # which is the safe behaviour when the lock subsystem is broken. - return False - - def release_compression_lock(self, session_id: str, holder: str) -> None: - """Release the compression lock for ``session_id`` iff we own it. - - Idempotent: no-op when the lock has already expired and been - reclaimed by a different holder, or when no lock exists. The - ``holder`` check prevents a late-returning compressor from - clobbering a fresh lock held by someone else. - """ - if not session_id: - return - - def _do(conn): - conn.execute( - "DELETE FROM compression_locks " - "WHERE session_id = ? AND holder = ?", - (session_id, holder), - ) - - try: - self._execute_write(_do) - except sqlite3.Error as exc: - logger.warning( - "release_compression_lock(%s) failed: %s", - session_id, exc, - ) - - def get_compression_lock_holder(self, session_id: str) -> Optional[str]: - """Return the current (non-expired) holder for ``session_id``, or None. - - Diagnostic helper — not used by the locking protocol itself. - """ - if not session_id: - return None - now = time.time() - row = self._conn.execute( - "SELECT holder FROM compression_locks " - "WHERE session_id = ? AND expires_at >= ?", - (session_id, now), - ).fetchone() - if row is None: - return None - return row["holder"] if isinstance(row, sqlite3.Row) else row[0] - - def update_session_meta( - self, - session_id: str, - model_config_json: str, - model: Optional[str] = None, - ) -> None: - """Update model_config and optionally model for an existing session. - - Uses COALESCE so that passing model=None leaves the stored model - column unchanged. Routes through _execute_write for the standard - BEGIN IMMEDIATE + jitter-retry + lock guarantee. - """ - def _do(conn): - conn.execute( - "UPDATE sessions SET model_config = ?, model = COALESCE(?, model) WHERE id = ?", - (model_config_json, model, session_id), - ) - self._execute_write(_do) - - def update_system_prompt(self, session_id: str, system_prompt: str) -> None: - """Store the full assembled system prompt snapshot.""" - def _do(conn): - conn.execute( - "UPDATE sessions SET system_prompt = ? WHERE id = ?", - (system_prompt, session_id), - ) - self._execute_write(_do) - - def update_session_model(self, session_id: str, model: str) -> None: - """Update the model for a session after a mid-session switch. - - Unlike ``update_token_counts`` which uses ``COALESCE(model, ?)`` - (only filling in NULL), this unconditionally sets the model column - so that the dashboard reflects the user's latest /model choice. - """ - def _do(conn): - conn.execute( - "UPDATE sessions SET model = ? WHERE id = ?", - (model, session_id), - ) - self._execute_write(_do) - - def update_token_counts( - self, - session_id: str, - input_tokens: int = 0, - output_tokens: int = 0, - model: str = None, - cache_read_tokens: int = 0, - cache_write_tokens: int = 0, - reasoning_tokens: int = 0, - estimated_cost_usd: Optional[float] = None, - actual_cost_usd: Optional[float] = None, - cost_status: Optional[str] = None, - cost_source: Optional[str] = None, - pricing_version: Optional[str] = None, - billing_provider: Optional[str] = None, - billing_base_url: Optional[str] = None, - billing_mode: Optional[str] = None, - api_call_count: int = 0, - absolute: bool = False, - ) -> None: - """Update token counters and backfill model if not already set. - - When *absolute* is False (default), values are **incremented** — use - this for per-API-call deltas (CLI path). - - When *absolute* is True, values are **set directly** — use this when - the caller already holds cumulative totals (gateway path, where the - cached agent accumulates across messages). - """ - # Ensure the session row exists so the UPDATE doesn't silently affect - # 0 rows. Under concurrent load (cron + kanban + delegate_task) the - # initial create_session() may have failed due to SQLite locking. - # INSERT OR IGNORE is cheap and idempotent. - self._insert_session_row(session_id, "unknown", model=model) - if absolute: - sql = """UPDATE sessions SET - input_tokens = ?, - output_tokens = ?, - cache_read_tokens = ?, - cache_write_tokens = ?, - reasoning_tokens = ?, - estimated_cost_usd = COALESCE(?, 0), - actual_cost_usd = CASE - WHEN ? IS NULL THEN actual_cost_usd - ELSE ? - END, - cost_status = COALESCE(?, cost_status), - cost_source = COALESCE(?, cost_source), - pricing_version = COALESCE(?, pricing_version), - billing_provider = COALESCE(billing_provider, ?), - billing_base_url = COALESCE(billing_base_url, ?), - billing_mode = COALESCE(billing_mode, ?), - model = COALESCE(model, ?), - api_call_count = ? - WHERE id = ?""" - else: - sql = """UPDATE sessions SET - input_tokens = input_tokens + ?, - output_tokens = output_tokens + ?, - cache_read_tokens = cache_read_tokens + ?, - cache_write_tokens = cache_write_tokens + ?, - reasoning_tokens = reasoning_tokens + ?, - estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0), - actual_cost_usd = CASE - WHEN ? IS NULL THEN actual_cost_usd - ELSE COALESCE(actual_cost_usd, 0) + ? - END, - cost_status = COALESCE(?, cost_status), - cost_source = COALESCE(?, cost_source), - pricing_version = COALESCE(?, pricing_version), - billing_provider = COALESCE(billing_provider, ?), - billing_base_url = COALESCE(billing_base_url, ?), - billing_mode = COALESCE(billing_mode, ?), - model = COALESCE(model, ?), - api_call_count = COALESCE(api_call_count, 0) + ? - WHERE id = ?""" - params = ( - input_tokens, - output_tokens, - cache_read_tokens, - cache_write_tokens, - reasoning_tokens, - estimated_cost_usd, - actual_cost_usd, - actual_cost_usd, - cost_status, - cost_source, - pricing_version, - billing_provider, - billing_base_url, - billing_mode, - model, - api_call_count, - session_id, - ) - def _do(conn): - conn.execute(sql, params) - self._execute_write(_do) - - def ensure_session( - self, - session_id: str, - source: str = "unknown", - model: str = None, - **kwargs, - ) -> str: - """Ensure a session row exists (INSERT OR IGNORE). Accepts optional kwargs.""" - self._insert_session_row(session_id, source, model=model, **kwargs) - return session_id - - def prune_empty_ghost_sessions(self, sessions_dir: "Optional[Path]" = None) -> int: - """Remove empty TUI ghost sessions (no messages, no title, >24hr old).""" - cutoff = time.time() - 86400 # Only sessions older than 24 hours - - def _do(conn): - rows = conn.execute(""" - SELECT id FROM sessions - WHERE source = 'tui' - AND title IS NULL - AND ended_at IS NOT NULL - AND started_at < ? - AND NOT EXISTS ( - SELECT 1 FROM messages WHERE messages.session_id = sessions.id - ) - """, (cutoff,)).fetchall() - ids = [r[0] if isinstance(r, (tuple, list)) else r["id"] for r in rows] - if ids: - placeholders = ",".join("?" * len(ids)) - conn.execute( - f"DELETE FROM sessions WHERE id IN ({placeholders})", ids - ) - return ids - - removed_ids = self._execute_write(_do) or [] - # Clean up any on-disk session files (belt-and-suspenders) - if sessions_dir and removed_ids: - for sid in removed_ids: - self._remove_session_files(sessions_dir, sid) - return len(removed_ids) - - def finalize_orphaned_compression_sessions(self) -> int: - """Mark orphaned compression continuation sessions as ended. - - Targets child sessions that were never finalized: parent is ended - with reason='compression', child has messages but no end_reason/ended_at - and api_call_count=0. Non-destructive: preserves all messages and sets - end_reason='orphaned_compression'. Fix for #20001. - """ - cutoff = time.time() - 604800 # 7 days - - def _do(conn): - now = time.time() - result = conn.execute( - """ - UPDATE sessions - SET ended_at = ?, - end_reason = 'orphaned_compression' - WHERE api_call_count = 0 - AND end_reason IS NULL - AND ended_at IS NULL - AND started_at < ? - AND parent_session_id IS NOT NULL - AND EXISTS ( - SELECT 1 FROM sessions p - WHERE p.id = sessions.parent_session_id - AND p.end_reason = 'compression' - AND p.ended_at IS NOT NULL - ) - AND EXISTS ( - SELECT 1 FROM messages m - WHERE m.session_id = sessions.id - ) - """, - (now, cutoff), - ) - return result.rowcount - - return self._execute_write(_do) or 0 - - def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: - """Get a session by ID.""" - with self._lock: - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE id = ?", (session_id,) - ) - row = cursor.fetchone() - return dict(row) if row else None - - def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]: - """Resolve an exact or uniquely prefixed session ID to the full ID. - - Returns the exact ID when it exists. Otherwise treats the input as a - prefix and returns the single matching session ID if the prefix is - unambiguous. Returns None for no matches or ambiguous prefixes. - """ - exact = self.get_session(session_id_or_prefix) - if exact: - return exact["id"] - - escaped = ( - session_id_or_prefix - .replace("\\", "\\\\") - .replace("%", "\\%") - .replace("_", "\\_") - ) - with self._lock: - cursor = self._conn.execute( - "SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2", - (f"{escaped}%",), - ) - matches = [row["id"] for row in cursor.fetchall()] - if len(matches) == 1: - return matches[0] - return None - - # Maximum length for session titles - MAX_TITLE_LENGTH = 100 - - @staticmethod - def sanitize_title(title: Optional[str]) -> Optional[str]: - """Validate and sanitize a session title. - - - Strips leading/trailing whitespace - - Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic - Unicode control chars (zero-width, RTL/LTR overrides, etc.) - - Collapses internal whitespace runs to single spaces - - Normalizes empty/whitespace-only strings to None - - Enforces MAX_TITLE_LENGTH - - Returns the cleaned title string or None. - Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning. - """ - if not title: - return None - - # Remove ASCII control characters (0x00-0x1F, 0x7F) but keep - # whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be - # normalized to spaces by the whitespace collapsing step below - cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title) - - # Remove problematic Unicode control characters: - # - Zero-width chars (U+200B-U+200F, U+FEFF) - # - Directional overrides (U+202A-U+202E, U+2066-U+2069) - # - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB) - cleaned = re.sub( - r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]', - '', cleaned, - ) - - # Collapse internal whitespace runs and strip - cleaned = re.sub(r'\s+', ' ', cleaned).strip() - - if not cleaned: - return None - - if len(cleaned) > SessionDB.MAX_TITLE_LENGTH: - raise ValueError( - f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})" - ) - - return cleaned - - def set_session_title(self, session_id: str, title: str) -> bool: - """Set or update a session's title. - - Returns True if session was found and title was set. - Raises ValueError if title is already in use by another session, - or if the title fails validation (too long, invalid characters). - Empty/whitespace-only strings are normalized to None (clearing the title). - """ - title = self.sanitize_title(title) - def _do(conn): - if title: - # Check uniqueness (allow the same session to keep its own title) - cursor = conn.execute( - "SELECT id FROM sessions WHERE title = ? AND id != ?", - (title, session_id), - ) - conflict = cursor.fetchone() - if conflict: - raise ValueError( - f"Title '{title}' is already in use by session {conflict['id']}" - ) - cursor = conn.execute( - "UPDATE sessions SET title = ? WHERE id = ?", - (title, session_id), - ) - return cursor.rowcount - rowcount = self._execute_write(_do) - return rowcount > 0 - - def get_session_title(self, session_id: str) -> Optional[str]: - """Get the title for a session, or None.""" - with self._lock: - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE id = ?", (session_id,) - ) - row = cursor.fetchone() - return row["title"] if row else None - - def set_session_archived(self, session_id: str, archived: bool) -> bool: - """Archive or unarchive a session. - - Archived sessions are hidden from the default session list but keep all - their messages — this is a soft hide, not a delete. Returns True when a - row was updated. - """ - def _do(conn): - cursor = conn.execute( - "UPDATE sessions SET archived = ? WHERE id = ?", - (1 if archived else 0, session_id), - ) - return cursor.rowcount - rowcount = self._execute_write(_do) - return rowcount > 0 - - def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: - """Look up a session by exact title. Returns session dict or None.""" - with self._lock: - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE title = ?", (title,) - ) - row = cursor.fetchone() - return dict(row) if row else None - - def resolve_session_by_title(self, title: str) -> Optional[str]: - """Resolve a title to a session ID, preferring the latest in a lineage. - - If the exact title exists, returns that session's ID. - If not, searches for "title #N" variants and returns the latest one. - If the exact title exists AND numbered variants exist, returns the - latest numbered variant (the most recent continuation). - """ - # First try exact match - exact = self.get_session_by_title(title) - - # Also search for numbered variants: "title #2", "title #3", etc. - # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches - escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - with self._lock: - cursor = self._conn.execute( - "SELECT id, title, started_at FROM sessions " - "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", - (f"{escaped} #%",), - ) - numbered = cursor.fetchall() - - if numbered: - # Return the most recent numbered variant - return numbered[0]["id"] - elif exact: - return exact["id"] - return None - - def get_next_title_in_lineage(self, base_title: str) -> str: - """Generate the next title in a lineage (e.g., "my session" → "my session #2"). - - Strips any existing " #N" suffix to find the base name, then finds - the highest existing number and increments. - """ - # Strip existing #N suffix to find the true base - match = re.match(r'^(.*?) #(\d+)$', base_title) - if match: - base = match.group(1) - else: - base = base_title - - # Find all existing numbered variants - # Escape SQL LIKE wildcards (%, _) in the base to prevent false matches - escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - with self._lock: - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", - (base, f"{escaped} #%"), - ) - existing = [row["title"] for row in cursor.fetchall()] - - if not existing: - return base # No conflict, use the base name as-is - - # Find the highest number - max_num = 1 # The unnumbered original counts as #1 - for t in existing: - m = re.match(r'^.* #(\d+)$', t) - if m: - max_num = max(max_num, int(m.group(1))) - - return f"{base} #{max_num + 1}" - - def get_compression_tip(self, session_id: str) -> Optional[str]: - """Walk the compression-continuation chain forward and return the tip. - - A compression continuation is a child session where: - 1. The parent's ``end_reason = 'compression'`` - 2. The child was created AFTER the parent was ended (started_at >= ended_at) - - The second condition distinguishes compression continuations from - delegate subagents or branch children, which can also have a - ``parent_session_id`` but were created while the parent was still live. - - Returns the session_id of the latest continuation in the chain, or the - input ``session_id`` if it isn't part of a compression chain (or if the - input itself doesn't exist). - """ - current = session_id - # Bound the walk defensively — compression chains this deep are - # pathological and shouldn't happen in practice. 100 = plenty. - for _ in range(100): - with self._lock: - cursor = self._conn.execute( - "SELECT id FROM sessions " - "WHERE parent_session_id = ? " - " AND started_at >= (" - " SELECT ended_at FROM sessions " - " WHERE id = ? AND end_reason = 'compression'" - " ) " - "ORDER BY started_at DESC LIMIT 1", - (current, current), - ) - row = cursor.fetchone() - if row is None: - return current - current = row["id"] - return current - - def list_sessions_rich( - self, - source: str = None, - exclude_sources: List[str] = None, - limit: int = 20, - offset: int = 0, - include_children: bool = False, - min_message_count: int = 0, - project_compression_tips: bool = True, - order_by_last_active: bool = False, - include_archived: bool = False, - archived_only: bool = False, - id_query: str = None, - ) -> List[Dict[str, Any]]: - """List sessions with preview (first user message) and last active timestamp. - - Returns dicts with keys: id, source, model, title, started_at, ended_at, - message_count, preview (first 60 chars of first user message), - last_active (timestamp of last message). - - Uses a single query with correlated subqueries instead of N+2 queries. - - By default, child sessions (subagent runs, compression continuations) - are excluded. Pass ``include_children=True`` to include them. - - With ``project_compression_tips=True`` (default), sessions that are - roots of compression chains are projected forward to their latest - continuation — one logical conversation = one list entry, showing the - live continuation's id/message_count/title/last_active. This prevents - compressed continuations from being invisible to users while keeping - delegate subagents and branches hidden. Pass ``False`` to return the - raw root rows (useful for admin/debug UIs). - - Pass ``order_by_last_active=True`` to sort by most-recent activity - instead of original conversation start time. For compression chains, - the "most-recent activity" is taken from the live tip (not the root), - so an old conversation that was compressed and continued recently - surfaces in the correct slot. Ordering is computed at SQL level via - a recursive CTE that walks compression-continuation edges, so LIMIT - and OFFSET still apply efficiently. - """ - where_clauses = [] - params = [] - - if not include_children: - # Show root sessions and branch sessions, while still hiding - # sub-agent runs and compression continuations (which also carry a - # parent_session_id but were spawned while the parent was still - # live — i.e., started_at < parent.ended_at). - # - # Branch sessions are identified two ways, OR'd for robustness: - # 1. A stable ``_branched_from`` marker in model_config, written - # by /branch at creation time. This survives the parent being - # reopened and re-ended with a different end_reason (e.g. - # tui_shutdown overwriting 'branched'), which otherwise hides - # the branch — see issue #20856. - # 2. The legacy heuristic (parent ended with 'branched' before the - # child started), covering branch sessions created before the - # marker existed. - where_clauses.append( - "(s.parent_session_id IS NULL" - " OR json_extract(s.model_config, '$._branched_from') IS NOT NULL" - " OR EXISTS (SELECT 1 FROM sessions p" - " WHERE p.id = s.parent_session_id" - " AND p.end_reason = 'branched'" - " AND s.started_at >= p.ended_at))" - ) - - if source: - where_clauses.append("s.source = ?") - params.append(source) - if exclude_sources: - placeholders = ",".join("?" for _ in exclude_sources) - where_clauses.append(f"s.source NOT IN ({placeholders})") - params.extend(exclude_sources) - if min_message_count > 0: - where_clauses.append("s.message_count >= ?") - params.append(min_message_count) - if archived_only: - where_clauses.append("s.archived = 1") - elif not include_archived: - where_clauses.append("s.archived = 0") - - where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - # Optional session-id filter, pushed into SQL so callers (Desktop - # session-id search) don't have to fetch every row and filter in - # Python. ``id_query`` is matched as a case-insensitive substring - # against each surfaced row's id AND every id in its forward - # compression chain — so searching a compression *root* id or a *tip* - # id both resolve to the same projected conversation. Only used in the - # order_by_last_active path (which builds the chain CTE); other callers - # pass id_query=None. - id_needle = (id_query or "").strip().lower() - if order_by_last_active: - # Compute effective_last_active by walking each surfaced session's - # compression-continuation chain forward in SQL and taking the MAX - # timestamp across the chain. This lets us ORDER BY + LIMIT at SQL - # level instead of fetching every row and sorting in Python, while - # still surfacing old compression roots whose live tip is fresh. - # - # The CTE seeds from rows the outer WHERE admits (roots + branch - # children), then recursively joins forward through - # compression-continuation edges using the same criteria as - # get_compression_tip (parent.end_reason='compression' AND - # child.started_at >= parent.ended_at). - outer_where = where_sql - id_params: List[Any] = [] - if id_needle: - # Admit a surfaced row if its own id or any id in its forward - # compression chain matches the needle. LIKE with a leading - # wildcard can't use an index, but the chain membership and - # the small result set keep this bounded — far cheaper than - # fetching every session and scanning in Python. - id_clause = ( - "EXISTS (SELECT 1 FROM chain cq" - " WHERE cq.root_id = s.id" - " AND LOWER(cq.cur_id) LIKE ? ESCAPE '\\')" - ) - like_pattern = ( - "%" - + id_needle.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - + "%" - ) - id_params = [like_pattern] - outer_where = ( - f"{where_sql} AND {id_clause}" if where_sql else f"WHERE {id_clause}" - ) - query = f""" - WITH RECURSIVE chain(root_id, cur_id) AS ( - SELECT s.id, s.id FROM sessions s {where_sql} - UNION ALL - SELECT c.root_id, child.id - FROM chain c - JOIN sessions parent ON parent.id = c.cur_id - JOIN sessions child ON child.parent_session_id = c.cur_id - WHERE parent.end_reason = 'compression' - AND child.started_at >= parent.ended_at - ), - chain_max AS ( - SELECT - root_id, - MAX(COALESCE( - (SELECT MAX(m.timestamp) FROM messages m WHERE m.session_id = cur_id), - (SELECT started_at FROM sessions ss WHERE ss.id = cur_id) - )) AS effective_last_active - FROM chain - GROUP BY root_id - ) - SELECT s.*, - COALESCE( - (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) - FROM messages m - WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL - ORDER BY m.timestamp, m.id LIMIT 1), - '' - ) AS _preview_raw, - COALESCE( - (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), - s.started_at - ) AS last_active, - COALESCE(cm.effective_last_active, s.started_at) AS _effective_last_active - FROM sessions s - LEFT JOIN chain_max cm ON cm.root_id = s.id - {outer_where} - ORDER BY _effective_last_active DESC, s.started_at DESC, s.id DESC - LIMIT ? OFFSET ? - """ - # WHERE params apply twice (CTE seed + outer select); the id filter - # only applies to the outer select. - params = params + params + id_params + [limit, offset] - else: - query = f""" - SELECT s.*, - COALESCE( - (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) - FROM messages m - WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL - ORDER BY m.timestamp, m.id LIMIT 1), - '' - ) AS _preview_raw, - COALESCE( - (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), - s.started_at - ) AS last_active - FROM sessions s - {where_sql} - ORDER BY s.started_at DESC - LIMIT ? OFFSET ? - """ - params.extend([limit, offset]) - with self._lock: - cursor = self._conn.execute(query, params) - rows = cursor.fetchall() - sessions = [] - for row in rows: - s = dict(row) - # Build the preview from the raw substring - raw = s.pop("_preview_raw", "").strip() - if raw: - text = raw[:60] - s["preview"] = text + ("..." if len(raw) > 60 else "") - else: - s["preview"] = "" - # Drop the internal ordering column so callers see a clean dict. - s.pop("_effective_last_active", None) - sessions.append(s) - - # Project compression roots forward to their tips. Each row whose - # end_reason is 'compression' has a continuation child; replace the - # surfaced fields (id, message_count, title, last_active, ended_at, - # end_reason, preview) with the tip's values so the list entry acts - # as the live conversation. Keep the root's started_at to preserve - # chronological ordering by original conversation start. - if project_compression_tips and not include_children: - projected = [] - for s in sessions: - if s.get("end_reason") != "compression": - projected.append(s) - continue - tip_id = self.get_compression_tip(s["id"]) - if tip_id == s["id"]: - projected.append(s) - continue - tip_row = self._get_session_rich_row(tip_id) - if not tip_row: - projected.append(s) - continue - # Preserve the root's started_at for stable sort order, but - # surface the tip's identity and activity data. - merged = dict(s) - for key in ( - "id", "ended_at", "end_reason", "message_count", - "tool_call_count", "title", "last_active", "preview", - "model", "system_prompt", "cwd", - ): - if key in tip_row: - merged[key] = tip_row[key] - merged["_lineage_root_id"] = s["id"] - projected.append(merged) - sessions = projected - - return sessions - - def list_cron_job_runs( - self, - job_id: str, - limit: int = 20, - offset: int = 0, - ) -> List[Dict[str, Any]]: - """List the run sessions produced by a single cron job, newest first. - - Cron runs are flat, independent sessions whose id is - ``cron_{job_id}_{timestamp}`` (see ``cron/scheduler.run_job``). They are - never compression roots and never branch, so this deliberately skips the - ``list_sessions_rich`` recursive compression-chain CTE / leading-wildcard - ``id_query`` path — that path seeds from *every* ``source='cron'`` row in - the DB and only filters to one job's runs after the scan, so it scales - with the whole cron pile (a heavy history makes the desktop run-history - endpoint time out before it eventually populates). - - Instead this binds to one job with a ``[prefix, prefix_hi)`` range over - the id (an index range scan, not a ``%...%`` substring), filters - ``source='cron'``, and orders by ``started_at DESC``. Work scales with - the requested window, not the total cron history. - - Returns the same enriched row shape as ``list_sessions_rich`` (adds - ``preview`` + ``last_active``) so callers can reuse it. - """ - prefix = f"cron_{job_id}_" - # Half-open upper bound for an index range scan: increment the final - # byte of the prefix so the range covers exactly the ids that start - # with ``prefix`` and nothing else. ``prefix`` always ends in '_', but - # compute it generically rather than hardcoding the successor char. - prefix_hi = prefix[:-1] + chr(ord(prefix[-1]) + 1) - - query = """ - SELECT s.*, - COALESCE( - (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) - FROM messages m - WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL - ORDER BY m.timestamp, m.id LIMIT 1), - '' - ) AS _preview_raw, - COALESCE( - (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), - s.started_at - ) AS last_active - FROM sessions s - WHERE s.source = 'cron' AND s.id >= ? AND s.id < ? - ORDER BY s.started_at DESC, s.id DESC - LIMIT ? OFFSET ? - """ - with self._lock: - cursor = self._conn.execute(query, (prefix, prefix_hi, limit, offset)) - rows = cursor.fetchall() - - runs: List[Dict[str, Any]] = [] - for row in rows: - s = dict(row) - raw = s.pop("_preview_raw", "").strip() - if raw: - text = raw[:60] - s["preview"] = text + ("..." if len(raw) > 60 else "") - else: - s["preview"] = "" - runs.append(s) - return runs - - def _get_session_rich_row(self, session_id: str) -> Optional[Dict[str, Any]]: - """Fetch a single session with the same enriched columns as - ``list_sessions_rich`` (preview + last_active). Returns None if the - session doesn't exist. - """ - query = """ - SELECT s.*, - COALESCE( - (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) - FROM messages m - WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL - ORDER BY m.timestamp, m.id LIMIT 1), - '' - ) AS _preview_raw, - COALESCE( - (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), - s.started_at - ) AS last_active - FROM sessions s - WHERE s.id = ? - """ - with self._lock: - cursor = self._conn.execute(query, (session_id,)) - row = cursor.fetchone() - if not row: - return None - s = dict(row) - raw = s.pop("_preview_raw", "").strip() - if raw: - text = raw[:60] - s["preview"] = text + ("..." if len(raw) > 60 else "") - else: - s["preview"] = "" - return s - - # ========================================================================= - # Message storage - # ========================================================================= - - # Sentinel prefix used to distinguish JSON-encoded structured content - # (multimodal messages: lists of parts like text + image_url) from plain - # string content. The NUL byte is not legal in normal text, so this - # cannot collide with real user content. - _CONTENT_JSON_PREFIX = "\x00json:" - - @classmethod - def _encode_content(cls, content: Any) -> Any: - """Serialize structured (list/dict) message content for sqlite. - - sqlite3 can only bind ``str``, ``bytes``, ``int``, ``float``, and ``None`` - to query parameters. Multimodal messages have ``content`` as a list of - parts (``[{"type": "text", ...}, {"type": "image_url", ...}]``), which - raises ``ProgrammingError: Error binding parameter N: type 'list' is - not supported`` when bound directly. - - Returns the value unchanged when it's already a safe scalar, or a - sentinel-prefixed JSON string for lists/dicts. Paired with - :meth:`_decode_content` on read. - """ - if content is None or isinstance(content, (str, bytes, int, float)): - return content - try: - return cls._CONTENT_JSON_PREFIX + json.dumps(content) - except (TypeError, ValueError): - # Last-resort fallback: stringify so persistence never fails. - return str(content) - - @classmethod - def _decode_content(cls, content: Any) -> Any: - """Reverse :meth:`_encode_content`; returns scalars unchanged.""" - if isinstance(content, str) and content.startswith(cls._CONTENT_JSON_PREFIX): - try: - return json.loads(content[len(cls._CONTENT_JSON_PREFIX):]) - except (json.JSONDecodeError, TypeError): - logger.warning( - "Failed to decode JSON-encoded message content; " - "returning raw string" - ) - return content - return content - - def append_message( - self, - session_id: str, - role: str, - content: str = None, - tool_name: str = None, - tool_calls: Any = None, - tool_call_id: str = None, - token_count: int = None, - finish_reason: str = None, - reasoning: str = None, - reasoning_content: str = None, - reasoning_details: Any = None, - codex_reasoning_items: Any = None, - codex_message_items: Any = None, - platform_message_id: str = None, - observed: bool = False, - ) -> int: - """ - Append a message to a session. Returns the message row ID. - - Also increments the session's message_count (and tool_call_count - if role is 'tool' or tool_calls is present). - - ``platform_message_id`` is the external messaging platform's own - message ID (e.g. Telegram update_id, Yuanbao msg_id). It is - independent of the SQLite autoincrement primary key and is used by - platform-specific flows like yuanbao's recall guard to redact a - message by its platform-side identifier. - """ - # Serialize structured fields to JSON before entering the write txn - reasoning_details_json = ( - json.dumps(reasoning_details) - if reasoning_details else None - ) - codex_items_json = ( - json.dumps(codex_reasoning_items) - if codex_reasoning_items else None - ) - codex_message_items_json = ( - json.dumps(codex_message_items) - if codex_message_items else None - ) - tool_calls_json = json.dumps(tool_calls) if tool_calls else None - # Multimodal content (list of parts) must be JSON-encoded: sqlite3 - # cannot bind list/dict parameters directly. - stored_content = self._encode_content(content) - - # Pre-compute tool call count - num_tool_calls = 0 - if tool_calls is not None: - num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 - - def _do(conn): - cursor = conn.execute( - """INSERT INTO messages (session_id, role, content, tool_call_id, - tool_calls, tool_name, timestamp, token_count, finish_reason, - reasoning, reasoning_content, reasoning_details, codex_reasoning_items, - codex_message_items, platform_message_id, observed) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - role, - stored_content, - tool_call_id, - tool_calls_json, - tool_name, - time.time(), - token_count, - finish_reason, - reasoning, - reasoning_content, - reasoning_details_json, - codex_items_json, - codex_message_items_json, - platform_message_id, - 1 if observed else 0, - ), - ) - msg_id = cursor.lastrowid - - # Update counters - if num_tool_calls > 0: - conn.execute( - """UPDATE sessions SET message_count = message_count + 1, - tool_call_count = tool_call_count + ? WHERE id = ?""", - (num_tool_calls, session_id), - ) - else: - conn.execute( - "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", - (session_id,), - ) - return msg_id - - return self._execute_write(_do) - - def replace_messages(self, session_id: str, messages: List[Dict[str, Any]]) -> None: - """Atomically replace every message for a session. - - Used by transcript-rewrite flows such as /retry, /undo, and /compress. - The delete + reinsert sequence must commit as one transaction so a - mid-rewrite failure does not leave SQLite with a partial transcript. - """ - - def _do(conn): - conn.execute( - "DELETE FROM messages WHERE session_id = ?", (session_id,) - ) - conn.execute( - "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", - (session_id,), - ) - - now_ts = time.time() - total_messages = 0 - total_tool_calls = 0 - for msg in messages: - role = msg.get("role", "unknown") - tool_calls = msg.get("tool_calls") - reasoning_details = msg.get("reasoning_details") if role == "assistant" else None - codex_reasoning_items = ( - msg.get("codex_reasoning_items") if role == "assistant" else None - ) - codex_message_items = ( - msg.get("codex_message_items") if role == "assistant" else None - ) - - reasoning_details_json = ( - json.dumps(reasoning_details) if reasoning_details else None - ) - codex_items_json = ( - json.dumps(codex_reasoning_items) if codex_reasoning_items else None - ) - codex_message_items_json = ( - json.dumps(codex_message_items) if codex_message_items else None - ) - tool_calls_json = json.dumps(tool_calls) if tool_calls else None - # Accept either `platform_message_id` (new explicit name) or - # `message_id` (yuanbao's existing convention on message dicts). - platform_msg_id = ( - msg.get("platform_message_id") or msg.get("message_id") - ) - - conn.execute( - """INSERT INTO messages (session_id, role, content, tool_call_id, - tool_calls, tool_name, timestamp, token_count, finish_reason, - reasoning, reasoning_content, reasoning_details, codex_reasoning_items, - codex_message_items, platform_message_id, observed) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - role, - self._encode_content(msg.get("content")), - msg.get("tool_call_id"), - tool_calls_json, - msg.get("tool_name"), - now_ts, - msg.get("token_count"), - msg.get("finish_reason"), - msg.get("reasoning") if role == "assistant" else None, - msg.get("reasoning_content") if role == "assistant" else None, - reasoning_details_json, - codex_items_json, - codex_message_items_json, - platform_msg_id, - 1 if msg.get("observed") else 0, - ), - ) - total_messages += 1 - if tool_calls is not None: - total_tool_calls += ( - len(tool_calls) if isinstance(tool_calls, list) else 1 - ) - now_ts += 1e-6 - - conn.execute( - "UPDATE sessions SET message_count = ?, tool_call_count = ? WHERE id = ?", - (total_messages, total_tool_calls, session_id), - ) - - self._execute_write(_do) - - def get_messages( - self, session_id: str, include_inactive: bool = False - ) -> List[Dict[str, Any]]: - """Load messages for a session in insertion order. - - By default only active messages are returned. Pass - ``include_inactive=True`` to load soft-deleted rows (e.g. for - audit / debug views of rewound history). See - :meth:`rewind_to_message` for the soft-delete mechanic. - - Ordered by AUTOINCREMENT id (true insertion order) rather than - timestamp — see c03acca50 for the WSL2 clock-regression rationale. - """ - active_clause = "" if include_inactive else " AND active = 1" - with self._lock: - cursor = self._conn.execute( - "SELECT * FROM messages WHERE session_id = ?" - f"{active_clause} ORDER BY id", - (session_id,), - ) - rows = cursor.fetchall() - result = [] - for row in rows: - msg = dict(row) - if "content" in msg: - msg["content"] = self._decode_content(msg["content"]) - if msg.get("tool_calls"): - try: - msg["tool_calls"] = json.loads(msg["tool_calls"]) - except (json.JSONDecodeError, TypeError): - logger.warning("Failed to deserialize tool_calls in get_messages, falling back to []") - msg["tool_calls"] = [] - result.append(msg) - return result - - def get_messages_around( - self, - session_id: str, - around_message_id: int, - window: int = 5, - ) -> Dict[str, Any]: - """Load a window of messages anchored on a specific message id. - - Returns a dict with: - - ``window``: up to ``window`` messages before the anchor, the anchor - itself, and up to ``window`` messages after, ordered by id ascending. - - ``messages_before``: count of messages strictly before the anchor - still in the session (== window unless we hit the start). - - ``messages_after``: count of messages strictly after the anchor - still in the session (== window unless we hit the end). - - Used by ``session_search`` for both the discovery shape (anchored on the - FTS5 match) and the scroll shape (anchored on any message id). The - ``messages_before`` / ``messages_after`` counts let the caller detect - session boundaries: when either is less than ``window``, the agent has - reached one end of the session. - - Returns an empty window when ``around_message_id`` is not a real id in - ``session_id`` — callers decide how to surface that. - """ - if window < 0: - window = 0 - with self._lock: - # Confirm the anchor exists in this session. - anchor_exists = self._conn.execute( - "SELECT 1 FROM messages WHERE id = ? AND session_id = ? LIMIT 1", - (around_message_id, session_id), - ).fetchone() - if not anchor_exists: - return {"window": [], "messages_before": 0, "messages_after": 0} - - # Two queries: anchor + before (DESC, take window+1), and after - # (ASC, take window). Final order is id ASC. - before_rows = self._conn.execute( - "SELECT * FROM messages " - "WHERE session_id = ? AND id <= ? " - "ORDER BY id DESC LIMIT ?", - (session_id, around_message_id, window + 1), - ).fetchall() - after_rows = self._conn.execute( - "SELECT * FROM messages " - "WHERE session_id = ? AND id > ? " - "ORDER BY id ASC LIMIT ?", - (session_id, around_message_id, window), - ).fetchall() - - # before_rows is DESC; reverse so it's ASC, then concatenate after_rows. - rows = list(reversed(before_rows)) + list(after_rows) - result = [] - for row in rows: - msg = dict(row) - if "content" in msg: - msg["content"] = self._decode_content(msg["content"]) - if msg.get("tool_calls"): - try: - msg["tool_calls"] = json.loads(msg["tool_calls"]) - except (json.JSONDecodeError, TypeError): - logger.warning( - "Failed to deserialize tool_calls in get_messages_around, falling back to []" - ) - msg["tool_calls"] = [] - result.append(msg) - - # before_rows includes the anchor itself; subtract 1 for the count of - # messages strictly before the anchor in the returned slice. - messages_before = max(0, len(before_rows) - 1) - messages_after = len(after_rows) - return { - "window": result, - "messages_before": messages_before, - "messages_after": messages_after, - } - - def get_anchored_view( - self, - session_id: str, - around_message_id: int, - window: int = 5, - bookend: int = 3, - keep_roles: Optional[Tuple[str, ...]] = ("user", "assistant"), - ) -> Dict[str, Any]: - """Return an anchored window plus session bookends. - - Built on top of ``get_messages_around``. Three slices: - - - ``window``: messages immediately surrounding the anchor. Filtered - to ``keep_roles`` (tool-response noise dropped by default), EXCEPT - the anchor itself is always preserved regardless of role. - - ``bookend_start``: first ``bookend`` user/assistant messages of the - session — but only those whose id is strictly before the window's - first message id. Empty when the window already overlaps the - session head. Empty-content messages (tool-call-only assistant - turns) are skipped so they don't crowd out actual prose openings. - - ``bookend_end``: last ``bookend`` user/assistant messages of the - session, same non-overlap rule at the tail. - - Bookends let an FTS5 hit anywhere in a long session yield the goal - (opening) and the resolution (closing) on a single call — without - loading the whole transcript. - - Returns ``{"window": [], "messages_before": 0, "messages_after": 0, - "bookend_start": [], "bookend_end": []}`` when the anchor isn't in - the session. - - ``keep_roles=None`` disables role filtering (raw window + raw - bookends). - """ - if bookend < 0: - bookend = 0 - - # Reuse the primitive — handles anchor-existence, content decoding, - # tool_calls deserialisation, and boundary counts. - primitive = self.get_messages_around( - session_id, around_message_id, window=window - ) - window_rows = primitive["window"] - if not window_rows: - return { - "window": [], - "messages_before": 0, - "messages_after": 0, - "bookend_start": [], - "bookend_end": [], - } - - # Apply role filter to the window, but never drop the anchor itself. - if keep_roles is not None: - keep_set = set(keep_roles) - filtered_window = [ - m for m in window_rows - if m.get("id") == around_message_id or m.get("role") in keep_set - ] - else: - filtered_window = window_rows - - window_min_id = window_rows[0]["id"] - window_max_id = window_rows[-1]["id"] - - # Fetch bookends only when there's room outside the window. SQL filters - # by id range, role, and non-empty content — tool-call-only assistant - # turns (content='' with tool_calls populated) are excluded so they - # don't crowd out actual prose openings/closings. - bookend_start_rows: List[Any] = [] - bookend_end_rows: List[Any] = [] - if bookend > 0: - with self._lock: - role_clause = "" - role_params: list = [] - if keep_roles is not None: - role_placeholders = ",".join("?" for _ in keep_roles) - role_clause = f" AND role IN ({role_placeholders})" - role_params = list(keep_roles) - - bookend_start_rows = self._conn.execute( - f"SELECT * FROM messages " - f"WHERE session_id = ? AND id < ?{role_clause} " - f"AND length(content) > 0 " - f"ORDER BY id ASC LIMIT ?", - (session_id, window_min_id, *role_params, bookend), - ).fetchall() - - bookend_end_rows = self._conn.execute( - f"SELECT * FROM messages " - f"WHERE session_id = ? AND id > ?{role_clause} " - f"AND length(content) > 0 " - f"ORDER BY id DESC LIMIT ?", - (session_id, window_max_id, *role_params, bookend), - ).fetchall() - # End rows came back DESC for the LIMIT cap; flip to ASC. - bookend_end_rows = list(reversed(bookend_end_rows)) - - def _hydrate(row) -> Dict[str, Any]: - msg = dict(row) - if "content" in msg: - msg["content"] = self._decode_content(msg["content"]) - if msg.get("tool_calls"): - try: - msg["tool_calls"] = json.loads(msg["tool_calls"]) - except (json.JSONDecodeError, TypeError): - logger.warning( - "Failed to deserialize tool_calls in get_anchored_view, falling back to []" - ) - msg["tool_calls"] = [] - return msg - - return { - "window": filtered_window, - "messages_before": primitive["messages_before"], - "messages_after": primitive["messages_after"], - "bookend_start": [_hydrate(r) for r in bookend_start_rows], - "bookend_end": [_hydrate(r) for r in bookend_end_rows], - } - - def resolve_resume_session_id(self, session_id: str) -> str: - """Redirect a resume target to the descendant session that holds the messages. - - Context compression ends the current session and forks a new child session - (linked via ``parent_session_id``). The flush cursor is reset, so the - child is where new messages actually land — the parent ends up with - ``message_count = 0`` rows unless messages had already been flushed to - it before compression. See #15000. - - This helper walks ``parent_session_id`` forward from ``session_id`` and - returns the first descendant in the chain that has at least one message - row. If the original session already has messages, or no descendant - has any, the original ``session_id`` is returned unchanged. - - The chain is always walked via the child whose ``started_at`` is - latest; that matches the single-chain shape that compression creates. - A depth cap (32) guards against accidental loops in malformed data. - """ - if not session_id: - return session_id - - with self._lock: - # If this session already has messages, nothing to redirect. - try: - row = self._conn.execute( - "SELECT 1 FROM messages WHERE session_id = ? LIMIT 1", - (session_id,), - ).fetchone() - except Exception: - return session_id - if row is not None: - return session_id - - # Walk descendants: at each step, pick the most-recently-started - # child session; stop once we find one with messages. - current = session_id - seen = {current} - for _ in range(32): - try: - child_row = self._conn.execute( - "SELECT id FROM sessions " - "WHERE parent_session_id = ? " - "ORDER BY started_at DESC, id DESC LIMIT 1", - (current,), - ).fetchone() - except Exception: - return session_id - if child_row is None: - return session_id - child_id = child_row["id"] if hasattr(child_row, "keys") else child_row[0] - if not child_id or child_id in seen: - return session_id - seen.add(child_id) - try: - msg_row = self._conn.execute( - "SELECT 1 FROM messages WHERE session_id = ? LIMIT 1", - (child_id,), - ).fetchone() - except Exception: - return session_id - if msg_row is not None: - return child_id - current = child_id - return session_id - - def get_messages_as_conversation( - self, - session_id: str, - include_ancestors: bool = False, - include_inactive: bool = False, - ) -> List[Dict[str, Any]]: - """ - Load messages in the OpenAI conversation format (role + content dicts). - Used by the gateway to restore conversation history. - - By default only active messages are returned. Pass - ``include_inactive=True`` to load soft-deleted (rewound) rows - as well. See :meth:`rewind_to_message`. - """ - session_ids = [session_id] - if include_ancestors: - session_ids = self._session_lineage_root_to_tip(session_id) - - active_clause = "" if include_inactive else " AND active = 1" - with self._lock: - placeholders = ",".join("?" for _ in session_ids) - # 只取最近200条,不压缩不丢内容 - rows = self._conn.execute( - "SELECT role, content, tool_call_id, tool_calls, tool_name, " - "finish_reason, reasoning, reasoning_content, reasoning_details, " - "codex_reasoning_items, codex_message_items, platform_message_id, observed " - f"FROM (" - f"SELECT id, role, content, tool_call_id, tool_calls, tool_name, " - f"finish_reason, reasoning, reasoning_content, reasoning_details, " - f"codex_reasoning_items, codex_message_items, platform_message_id, observed " - f"FROM messages WHERE session_id IN ({placeholders})" - f"{active_clause} ORDER BY id DESC LIMIT 200" - f") ORDER BY id ASC", - tuple(session_ids), - ).fetchall() - - messages = [] - for row in rows: - content = self._decode_content(row["content"]) - if row["role"] in {"user", "assistant"} and isinstance(content, str): - content = sanitize_context(content).strip() - msg = {"role": row["role"], "content": content} - if row["tool_call_id"]: - msg["tool_call_id"] = row["tool_call_id"] - if row["tool_name"]: - msg["tool_name"] = row["tool_name"] - if row["tool_calls"]: - try: - msg["tool_calls"] = json.loads(row["tool_calls"]) - except (json.JSONDecodeError, TypeError): - logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []") - msg["tool_calls"] = [] - # Surface the platform-side message id (e.g. yuanbao msg_id, - # telegram update_id) so platform-specific flows like recall - # can match by external identifier instead of having to fall - # back to content-match heuristics. Exposed as ``message_id`` - # for backward compatibility with the JSONL transcript shape. - if row["platform_message_id"]: - msg["message_id"] = row["platform_message_id"] - if row["observed"]: - msg["observed"] = True - # Restore reasoning fields on assistant messages so providers - # that replay reasoning (OpenRouter, OpenAI, Nous) receive - # coherent multi-turn reasoning context. - if row["role"] == "assistant": - if row["finish_reason"]: - msg["finish_reason"] = row["finish_reason"] - if row["reasoning"]: - msg["reasoning"] = row["reasoning"] - if row["reasoning_content"] is not None: - msg["reasoning_content"] = row["reasoning_content"] - if row["reasoning_details"]: - try: - msg["reasoning_details"] = json.loads(row["reasoning_details"]) - except (json.JSONDecodeError, TypeError): - logger.warning("Failed to deserialize reasoning_details, falling back to None") - msg["reasoning_details"] = None - if row["codex_reasoning_items"]: - try: - msg["codex_reasoning_items"] = json.loads(row["codex_reasoning_items"]) - except (json.JSONDecodeError, TypeError): - logger.warning("Failed to deserialize codex_reasoning_items, falling back to None") - msg["codex_reasoning_items"] = None - if row["codex_message_items"]: - try: - msg["codex_message_items"] = json.loads(row["codex_message_items"]) - except (json.JSONDecodeError, TypeError): - logger.warning("Failed to deserialize codex_message_items, falling back to None") - msg["codex_message_items"] = None - if include_ancestors and self._is_duplicate_replayed_user_message(messages, msg): - continue - messages.append(msg) - return messages - - def _session_lineage_root_to_tip(self, session_id: str) -> List[str]: - if not session_id: - return [session_id] - - chain = [] - current = session_id - seen = set() - with self._lock: - for _ in range(100): - if not current or current in seen: - break - seen.add(current) - chain.append(current) - row = self._conn.execute( - "SELECT parent_session_id FROM sessions WHERE id = ?", - (current,), - ).fetchone() - if row is None: - break - current = row["parent_session_id"] if hasattr(row, "keys") else row[0] - return list(reversed(chain)) or [session_id] - - @staticmethod - def _is_duplicate_replayed_user_message(messages: List[Dict[str, Any]], msg: Dict[str, Any]) -> bool: - if msg.get("role") != "user": - return False - content = msg.get("content") - if not isinstance(content, str) or not content: - return False - for prev in reversed(messages): - if prev.get("role") == "user" and prev.get("content") == content: - return True - if prev.get("role") == "assistant" and (prev.get("content") or prev.get("tool_calls")): - return False - return False - - # ========================================================================= - # Rewind (soft-delete) — see /rewind slash command + issue #21910 - # ========================================================================= - - def rewind_to_message( - self, session_id: str, target_message_id: int - ) -> Dict[str, Any]: - """Soft-delete all messages with id >= ``target_message_id`` in *session_id*. - - The target message itself becomes inactive as well so the caller - can pre-fill it as the next user prompt without it appearing - twice in the replayed transcript. Rewound rows are kept on - disk with ``active=0`` for audit / forensic inspection — use - :meth:`get_messages` with ``include_inactive=True`` to see them. - - Returns a dict:: - - { - "rewound_count": int, # number of rows newly flipped to active=0 - "target_message": dict, # full row dict of the target - "new_head_id": int|None # id of the last still-active row, or None - } - - Raises ``ValueError`` if the target message does not exist in - *session_id* or if its role is not ``"user"``. - - Always increments ``sessions.rewind_count`` — even when the - target is already inactive — so the counter accurately reflects - the number of rewind operations performed against the session. - Idempotent on the ``active`` flag: re-rewinding past the same - target is a no-op on row state but still bumps the counter. - """ - - # 1) Validate target up-front (read-only, outside the write txn). - with self._lock: - row = self._conn.execute( - "SELECT * FROM messages WHERE id = ? AND session_id = ?", - (target_message_id, session_id), - ).fetchone() - if row is None: - raise ValueError( - f"message {target_message_id} not found in session {session_id}" - ) - target_row = dict(row) - if target_row.get("role") != "user": - raise ValueError( - f"rewind target must be a 'user' message (got role=" - f"{target_row.get('role')!r}, id={target_message_id})" - ) - - # Decode content for callers (prefill the prompt buffer). - target_row["content"] = self._decode_content(target_row.get("content")) - - rewound: List[int] = [] - - def _do(conn): - cursor = conn.execute( - "SELECT id FROM messages " - "WHERE session_id = ? AND id >= ? AND active = 1", - (session_id, target_message_id), - ) - ids = [r[0] for r in cursor.fetchall()] - if ids: - placeholders = ",".join("?" for _ in ids) - conn.execute( - f"UPDATE messages SET active = 0 WHERE id IN ({placeholders})", - ids, - ) - conn.execute( - "UPDATE sessions SET rewind_count = COALESCE(rewind_count, 0) + 1 " - "WHERE id = ?", - (session_id,), - ) - return ids - - rewound = self._execute_write(_do) - - # 2) Compute new head id (largest still-active row id in session). - with self._lock: - head_row = self._conn.execute( - "SELECT MAX(id) FROM messages WHERE session_id = ? AND active = 1", - (session_id,), - ).fetchone() - new_head_id = head_row[0] if head_row and head_row[0] is not None else None - - return { - "rewound_count": len(rewound), - "target_message": target_row, - "new_head_id": new_head_id, - } - - def restore_rewound(self, session_id: str, since_message_id: int) -> int: - """Mark inactive messages with id >= *since_message_id* active again. - - Returns the number of rows flipped back to ``active=1``. - Intended for undo-of-rewind and test cleanup; not wired to a - slash command in v1. - """ - def _do(conn): - cursor = conn.execute( - "SELECT id FROM messages " - "WHERE session_id = ? AND id >= ? AND active = 0", - (session_id, since_message_id), - ) - ids = [r[0] for r in cursor.fetchall()] - if ids: - placeholders = ",".join("?" for _ in ids) - conn.execute( - f"UPDATE messages SET active = 1 WHERE id IN ({placeholders})", - ids, - ) - return len(ids) - - return self._execute_write(_do) - - def list_recent_user_messages( - self, - session_id: str, - limit: int = 20, - include_inactive: bool = False, - ) -> List[Dict[str, Any]]: - """Return the *limit* most-recent user messages, newest first. - - Each entry is a dict with keys ``id``, ``timestamp``, ``preview``. - ``preview`` is the first 80 characters of the message content - (with line breaks collapsed to spaces). Used by the /rewind - slash command picker. - - By default only active messages are returned. - """ - active_clause = "" if include_inactive else " AND active = 1" - with self._lock: - cursor = self._conn.execute( - "SELECT id, timestamp, content FROM messages " - "WHERE session_id = ? AND role = 'user'" - f"{active_clause} " - "ORDER BY id DESC LIMIT ?", - (session_id, int(limit)), - ) - rows = cursor.fetchall() - - result: List[Dict[str, Any]] = [] - for row in rows: - decoded = self._decode_content(row["content"]) - if isinstance(decoded, list): - # Multimodal — flatten text parts. - text_parts = [ - p.get("text", "") for p in decoded - if isinstance(p, dict) and p.get("type") == "text" - ] - preview = " ".join(t for t in text_parts if t).strip() - if not preview: - preview = "[multimodal content]" - elif isinstance(decoded, str): - preview = decoded - else: - preview = "" - preview = " ".join(preview.split()) # collapse whitespace - if len(preview) > 80: - preview = preview[:77] + "..." - result.append( - { - "id": row["id"], - "timestamp": row["timestamp"], - "preview": preview, - } - ) - return result - - # ========================================================================= - # Search - # ========================================================================= - - @staticmethod - def _sanitize_fts5_query(query: str) -> str: - """Sanitize user input for safe use in FTS5 MATCH queries. - - FTS5 has its own query syntax where characters like ``"``, ``(``, ``)``, - ``+``, ``*``, ``{``, ``}``, the column-filter operator ``:`` and bare - boolean operators (``AND``, ``OR``, ``NOT``) have special meaning. - Passing raw user input directly to MATCH can cause - ``sqlite3.OperationalError``. - - Strategy: - - Preserve properly paired quoted phrases (``"exact phrase"``) - - Strip unmatched FTS5-special characters that would cause errors - - Wrap unquoted hyphenated and dotted terms in quotes so FTS5 - matches them as exact phrases instead of splitting on the - hyphen/dot (e.g. ``chat-send``, ``P2.2``, ``my-app.config.ts``) - """ - # Step 1: Extract balanced double-quoted phrases and protect them - # from further processing via numbered placeholders. - _quoted_parts: list = [] - - def _preserve_quoted(m: re.Match) -> str: - _quoted_parts.append(m.group(0)) - return f"\x00Q{len(_quoted_parts) - 1}\x00" - - sanitized = re.sub(r'"[^"]*"', _preserve_quoted, query) - - # Step 2: Strip remaining (unmatched) FTS5-special characters. ``:`` is - # FTS5's column-filter operator (``col:term``); since the FTS table has a - # single ``content`` column, an unquoted colon query like ``TODO: fix`` - # parses as ``column:term`` and raises "no such column" — swallowed at - # the execute site into zero results. Strip it like the others. - sanitized = re.sub(r'[+{}():\"^]', " ", sanitized) - - # Step 3: Collapse repeated * (e.g. "***") into a single one, - # and remove leading * (prefix-only needs at least one char before *) - sanitized = re.sub(r"\*+", "*", sanitized) - sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized) - - # Step 4: Remove dangling boolean operators at start/end that would - # cause syntax errors (e.g. "hello AND" or "OR world") - sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip()) - sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip()) - - # Step 5: Wrap unquoted dotted and/or hyphenated terms in double - # quotes. FTS5's tokenizer splits on dots and hyphens, turning - # ``chat-send`` into ``chat AND send`` and ``P2.2`` into ``p2 AND 2``. - # Quoting preserves phrase semantics. A single pass avoids the - # double-quoting bug that would occur if dotted, hyphenated and underscored - # patterns were applied sequentially (e.g. ``my-app.config``). - sanitized = re.sub(r"\b(\w+(?:[._-]\w+)+)\b", r'"\1"', sanitized) - - # Step 6: Restore preserved quoted phrases - for i, quoted in enumerate(_quoted_parts): - sanitized = sanitized.replace(f"\x00Q{i}\x00", quoted) - - return sanitized.strip() - - - @staticmethod - def _is_cjk_codepoint(cp: int) -> bool: - return (0x4E00 <= cp <= 0x9FFF or # CJK Unified Ideographs - 0x3400 <= cp <= 0x4DBF or # CJK Extension A - 0x20000 <= cp <= 0x2A6DF or # CJK Extension B - 0x3000 <= cp <= 0x303F or # CJK Symbols - 0x3040 <= cp <= 0x309F or # Hiragana - 0x30A0 <= cp <= 0x30FF or # Katakana - 0xAC00 <= cp <= 0xD7AF) # Hangul Syllables - - @staticmethod - def _contains_cjk(text: str) -> bool: - """Check if text contains CJK (Chinese, Japanese, Korean) characters.""" - for ch in text: - cp = ord(ch) - if (0x4E00 <= cp <= 0x9FFF or # CJK Unified Ideographs - 0x3400 <= cp <= 0x4DBF or # CJK Extension A - 0x20000 <= cp <= 0x2A6DF or # CJK Extension B - 0x3000 <= cp <= 0x303F or # CJK Symbols - 0x3040 <= cp <= 0x309F or # Hiragana - 0x30A0 <= cp <= 0x30FF or # Katakana - 0xAC00 <= cp <= 0xD7AF): # Hangul Syllables - return True - return False - - @classmethod - def _count_cjk(cls, text: str) -> int: - """Count CJK characters in text.""" - return sum(1 for ch in text if cls._is_cjk_codepoint(ord(ch))) - - def search_messages( - self, - query: str, - source_filter: List[str] = None, - exclude_sources: List[str] = None, - role_filter: List[str] = None, - limit: int = 20, - offset: int = 0, - sort: str = None, - include_inactive: bool = False, - ) -> List[Dict[str, Any]]: - """ - Full-text search across session messages using FTS5. - - Supports FTS5 query syntax: - - Simple keywords: "docker deployment" - - Phrases: '"exact phrase"' - - Boolean: "docker OR kubernetes", "python NOT java" - - Prefix: "deploy*" - - Returns matching messages with session metadata, content snippet, - and surrounding context (1 message before and after the match). - - ``sort`` controls temporal ordering: - - ``None`` (default): FTS5 BM25 relevance only. Time-neutral. - - ``"newest"``: order by message timestamp DESC, then by rank. - - ``"oldest"``: order by message timestamp ASC, then by rank. - - The short-CJK LIKE fallback already orders by timestamp DESC and - ignores ``sort``. The trigram CJK path honours ``sort`` like the main - FTS5 path. - - Rewound (``active=0``) rows are excluded by default. Pass - ``include_inactive=True`` to search every row. - """ - if not self._fts_enabled: - return [] - - if not query or not query.strip(): - return [] - - query = self._sanitize_fts5_query(query) - if not query: - return [] - - # Normalise sort. Anything not in the allowed set falls back to None - # (FTS5 rank-only) so callers can pass through user input without - # validation. - if isinstance(sort, str): - sort_norm = sort.strip().lower() - if sort_norm not in ("newest", "oldest"): - sort_norm = None - else: - sort_norm = None - - # ORDER BY shared across the main FTS5 path and trigram CJK path. - # With sort set, timestamp is primary and rank is the tiebreaker. - if sort_norm == "newest": - order_by_sql = "ORDER BY m.timestamp DESC, rank" - elif sort_norm == "oldest": - order_by_sql = "ORDER BY m.timestamp ASC, rank" - else: - order_by_sql = "ORDER BY rank" - - # Build WHERE clauses dynamically - where_clauses = ["messages_fts MATCH ?"] - params: list = [query] - if not include_inactive: - where_clauses.append("m.active = 1") - - if source_filter is not None: - source_placeholders = ",".join("?" for _ in source_filter) - where_clauses.append(f"s.source IN ({source_placeholders})") - params.extend(source_filter) - - if exclude_sources is not None: - exclude_placeholders = ",".join("?" for _ in exclude_sources) - where_clauses.append(f"s.source NOT IN ({exclude_placeholders})") - params.extend(exclude_sources) - - if role_filter: - role_placeholders = ",".join("?" for _ in role_filter) - where_clauses.append(f"m.role IN ({role_placeholders})") - params.extend(role_filter) - - where_sql = " AND ".join(where_clauses) - params.extend([limit, offset]) - - sql = f""" - SELECT - m.id, - m.session_id, - m.role, - snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet, - m.content, - m.timestamp, - m.tool_name, - s.source, - s.model, - s.started_at AS session_started - FROM messages_fts - JOIN messages m ON m.id = messages_fts.rowid - JOIN sessions s ON s.id = m.session_id - WHERE {where_sql} - {order_by_sql} - LIMIT ? OFFSET ? - """ - - # CJK queries bypass the unicode61 FTS5 table. The default tokenizer - # splits CJK characters into individual tokens, so "大别山项目" becomes - # "大 AND 别 AND 山 AND 项 AND 目" — producing false positives and - # missing exact phrase matches. - # - # For queries with 3+ CJK characters, we use the trigram FTS5 table - # (indexed substring matching with ranking and snippets). For shorter - # CJK queries (1-2 chars), trigram can't match (it needs ≥9 UTF-8 - # bytes = 3 CJK chars), so we fall back to LIKE. - is_cjk = self._contains_cjk(query) - if is_cjk: - raw_query = query.strip('"').strip() - cjk_count = self._count_cjk(raw_query) - - # Per-token CJK length check (#20494): trigram needs >=3 CJK chars - # per token. A query like "广西 OR 桂林 OR 漓江" has cjk_count=6 - # (>=3) but each individual token is only 2 chars — trigram returns 0. - # Route to LIKE when any non-operator CJK token is <3 CJK chars. - _tokens_for_check = [ - t for t in raw_query.split() - if t.upper() not in {"AND", "OR", "NOT"} and self._contains_cjk(t) - ] - _any_short_cjk = any( - self._count_cjk(t) < 3 for t in _tokens_for_check - ) - - if cjk_count >= 3 and not _any_short_cjk: - # Trigram FTS5 path — quote each non-operator token to handle - # FTS5 special chars (%, *, etc.) while preserving boolean - # operators (AND, OR, NOT) for multi-term queries. - tokens = raw_query.split() - parts = [] - for tok in tokens: - if tok.upper() in {"AND", "OR", "NOT"}: - parts.append(tok) - else: - parts.append('"' + tok.replace('"', '""') + '"') - trigram_query = " ".join(parts) - tri_where = ["messages_fts_trigram MATCH ?"] - tri_params: list = [trigram_query] - if not include_inactive: - tri_where.append("m.active = 1") - if source_filter is not None: - tri_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})") - tri_params.extend(source_filter) - if exclude_sources is not None: - tri_where.append(f"s.source NOT IN ({','.join('?' for _ in exclude_sources)})") - tri_params.extend(exclude_sources) - if role_filter: - tri_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})") - tri_params.extend(role_filter) - tri_sql = f""" - SELECT - m.id, - m.session_id, - m.role, - snippet(messages_fts_trigram, 0, '>>>', '<<<', '...', 40) AS snippet, - m.content, - m.timestamp, - m.tool_name, - s.source, - s.model, - s.started_at AS session_started - FROM messages_fts_trigram - JOIN messages m ON m.id = messages_fts_trigram.rowid - JOIN sessions s ON s.id = m.session_id - WHERE {' AND '.join(tri_where)} - {order_by_sql} - LIMIT ? OFFSET ? - """ - tri_params.extend([limit, offset]) - with self._lock: - try: - tri_cursor = self._conn.execute(tri_sql, tri_params) - except sqlite3.OperationalError: - matches = [] - else: - matches = [dict(row) for row in tri_cursor.fetchall()] - else: - # Short / mixed CJK query: trigram cannot match tokens with - # <3 CJK chars. Fall back to LIKE substring search. - # For multi-token OR queries (e.g. "广西 OR 桂林 OR 漓江"), - # build one LIKE condition per non-operator token so each term - # is matched independently (#20494). - non_op_tokens = [ - t for t in raw_query.split() - if t.upper() not in {"AND", "OR", "NOT"} - ] or [raw_query] - token_clauses = [] - like_params: list = [] - for tok in non_op_tokens: - esc = tok.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - token_clauses.append( - "(m.content LIKE ? ESCAPE '\\' OR m.tool_name LIKE ? ESCAPE '\\' OR m.tool_calls LIKE ? ESCAPE '\\')" - ) - like_params += [f"%{esc}%", f"%{esc}%", f"%{esc}%"] - like_where = [f"({' OR '.join(token_clauses)})"] - if source_filter is not None: - like_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})") - like_params.extend(source_filter) - if exclude_sources is not None: - like_where.append(f"s.source NOT IN ({','.join('?' for _ in exclude_sources)})") - like_params.extend(exclude_sources) - if role_filter: - like_where.append(f"m.role IN ({','.join('?' for _ in role_filter)})") - like_params.extend(role_filter) - like_sql = f""" - SELECT m.id, m.session_id, m.role, - substr(m.content, - max(1, instr(m.content, ?) - 40), - 120) AS snippet, - m.content, m.timestamp, m.tool_name, - s.source, s.model, s.started_at AS session_started - FROM messages m - JOIN sessions s ON s.id = m.session_id - WHERE {' AND '.join(like_where)} - ORDER BY m.timestamp DESC - LIMIT ? OFFSET ? - """ - like_params.extend([limit, offset]) - # instr() for snippet uses first search token - like_params = [non_op_tokens[0]] + like_params - with self._lock: - like_cursor = self._conn.execute(like_sql, like_params) - matches = [dict(row) for row in like_cursor.fetchall()] - else: - with self._lock: - try: - cursor = self._conn.execute(sql, params) - except sqlite3.OperationalError: - # FTS5 query syntax error despite sanitization — return empty - return [] - else: - matches = [dict(row) for row in cursor.fetchall()] - - # Add surrounding context (1 message before + after each match). - # Done outside the lock so we don't hold it across N sequential queries. - for match in matches: - try: - with self._lock: - ctx_cursor = self._conn.execute( - """WITH target AS ( - SELECT session_id, timestamp, id - FROM messages - WHERE id = ? - ) - SELECT role, content - FROM ( - SELECT m.id, m.timestamp, m.role, m.content - FROM messages m - JOIN target t ON t.session_id = m.session_id - WHERE (m.timestamp < t.timestamp) - OR (m.timestamp = t.timestamp AND m.id < t.id) - ORDER BY m.timestamp DESC, m.id DESC - LIMIT 1 - ) - UNION ALL - SELECT role, content - FROM messages - WHERE id = ? - UNION ALL - SELECT role, content - FROM ( - SELECT m.id, m.timestamp, m.role, m.content - FROM messages m - JOIN target t ON t.session_id = m.session_id - WHERE (m.timestamp > t.timestamp) - OR (m.timestamp = t.timestamp AND m.id > t.id) - ORDER BY m.timestamp ASC, m.id ASC - LIMIT 1 - )""", - (match["id"], match["id"]), - ) - context_msgs = [] - for r in ctx_cursor.fetchall(): - raw = r["content"] - decoded = self._decode_content(raw) - # Multimodal context: render a compact text-only - # summary for search previews. - if isinstance(decoded, list): - text_parts = [ - p.get("text", "") for p in decoded - if isinstance(p, dict) and p.get("type") == "text" - ] - text = " ".join(t for t in text_parts if t).strip() - preview = text or "[multimodal content]" - elif isinstance(decoded, str): - preview = decoded - else: - preview = "" - context_msgs.append( - {"role": r["role"], "content": preview[:200]} - ) - match["context"] = context_msgs - except Exception: - match["context"] = [] - - # Remove full content from result (snippet is enough, saves tokens) - for match in matches: - match.pop("content", None) - - return matches - - def search_sessions_by_id( - self, - query: str, - limit: int = 20, - include_archived: bool = True, - ) -> List[Dict[str, Any]]: - """Search surfaced sessions by exact/prefix/substring session id. - - Desktop search uses this alongside FTS message search so users can paste - a session id from logs, CLI output, or another Hermes surface and jump - straight to that conversation. Matching also checks ``_lineage_root_id`` - for projected compression-chain tips, so an old root id still resolves to - the live continuation row. - """ - needle = (query or "").strip().lower() - if not needle or limit <= 0: - return [] - - # SQL-bounded: list_sessions_rich pushes the id LIKE filter into the - # query (matching the row's own id AND any id in its forward - # compression chain), so we only materialize matching rows instead of - # scanning every session. Fetch a small multiple of `limit` so the - # in-Python exact/prefix/substring ranking below has enough candidates - # to order, then truncate. - candidates = self.list_sessions_rich( - limit=max(limit * 4, limit), - offset=0, - include_archived=include_archived, - order_by_last_active=True, - id_query=needle, - ) - - def score(row: Dict[str, Any]) -> int: - ids = [str(row.get("id") or ""), str(row.get("_lineage_root_id") or "")] - normalized = [value.lower() for value in ids if value] - if any(value == needle for value in normalized): - return 0 - if any(value.startswith(needle) for value in normalized): - return 1 - return 2 - - ranked = sorted( - enumerate(candidates), - key=lambda item: (score(item[1]), item[0]), - ) - return [row for _, row in ranked[:limit]] - - def search_sessions( - self, - source: str = None, - limit: int = 20, - offset: int = 0, - ) -> List[Dict[str, Any]]: - """List sessions, optionally filtered by source. - - Returns rows enriched with a computed ``last_active`` column (latest - message timestamp for the session, falling back to ``started_at``), - ordered by most-recently-used first. - """ - select_with_last_active = ( - "SELECT s.*, COALESCE(m.last_active, s.started_at) AS last_active " - "FROM sessions s " - "LEFT JOIN (" - "SELECT session_id, MAX(timestamp) AS last_active " - "FROM messages GROUP BY session_id" - ") m ON m.session_id = s.id " - ) - with self._lock: - if source: - cursor = self._conn.execute( - f"{select_with_last_active}" - "WHERE s.source = ? " - "ORDER BY last_active DESC, s.started_at DESC, s.id DESC LIMIT ? OFFSET ?", - (source, limit, offset), - ) - else: - cursor = self._conn.execute( - f"{select_with_last_active}" - "ORDER BY last_active DESC, s.started_at DESC, s.id DESC LIMIT ? OFFSET ?", - (limit, offset), - ) - return [dict(row) for row in cursor.fetchall()] - - # ========================================================================= - # Utility - # ========================================================================= - - def session_count( - self, - source: str = None, - min_message_count: int = 0, - include_archived: bool = False, - archived_only: bool = False, - exclude_children: bool = False, - exclude_sources: List[str] = None, - ) -> int: - """Count sessions, optionally filtered by source. - - Pass ``exclude_children=True`` to count only the conversations that - ``list_sessions_rich`` surfaces (root + branch sessions), hiding - sub-agent runs and compression continuations. Use it whenever the count - is paired with a ``list_sessions_rich`` page (e.g. sidebar "load more" - totals) so the total matches the number of listable rows — otherwise the - raw row count is inflated by children and "load more" never settles. - - Pass ``exclude_sources`` to drop whole source classes from the count - (e.g. ``["cron"]`` so the recents "load more" total matches a - cron-excluded ``list_sessions_rich`` page and doesn't keep "load more" - stuck on for buried scheduler sessions). - """ - where_clauses = [] - params = [] - - if exclude_children: - # Mirror list_sessions_rich's child-exclusion clause exactly so the - # count lines up with the rows: roots (no parent) plus branch - # children (parent ended with end_reason='branched'). - where_clauses.append( - "(s.parent_session_id IS NULL" - " OR EXISTS (SELECT 1 FROM sessions p" - " WHERE p.id = s.parent_session_id" - " AND p.end_reason = 'branched'" - " AND s.started_at >= p.ended_at))" - ) - if source: - where_clauses.append("s.source = ?") - params.append(source) - if exclude_sources: - placeholders = ",".join("?" for _ in exclude_sources) - where_clauses.append(f"s.source NOT IN ({placeholders})") - params.extend(exclude_sources) - if min_message_count > 0: - where_clauses.append("s.message_count >= ?") - params.append(min_message_count) - if archived_only: - where_clauses.append("s.archived = 1") - elif not include_archived: - where_clauses.append("s.archived = 0") - - where_sql = f" WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - with self._lock: - cursor = self._conn.execute(f"SELECT COUNT(*) FROM sessions s{where_sql}", params) - return cursor.fetchone()[0] - - def message_count(self, session_id: str = None) -> int: - """Count messages, optionally for a specific session.""" - with self._lock: - if session_id: - cursor = self._conn.execute( - "SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,) - ) - else: - cursor = self._conn.execute("SELECT COUNT(*) FROM messages") - return cursor.fetchone()[0] - - # ========================================================================= - # Export and cleanup - # ========================================================================= - - def export_session(self, session_id: str) -> Optional[Dict[str, Any]]: - """Export a single session with all its messages as a dict.""" - session = self.get_session(session_id) - if not session: - return None - messages = self.get_messages(session_id) - return {**session, "messages": messages} - - def export_all(self, source: str = None) -> List[Dict[str, Any]]: - """ - Export all sessions (with messages) as a list of dicts. - Suitable for writing to a JSONL file for backup/analysis. - """ - sessions = self.search_sessions(source=source, limit=100000) - results = [] - for session in sessions: - messages = self.get_messages(session["id"]) - results.append({**session, "messages": messages}) - return results - - def clear_messages(self, session_id: str) -> None: - """Delete all messages for a session and reset its counters.""" - def _do(conn): - conn.execute( - "DELETE FROM messages WHERE session_id = ?", (session_id,) - ) - conn.execute( - "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", - (session_id,), - ) - self._execute_write(_do) - - @staticmethod - def _remove_session_files(sessions_dir: Optional[Path], session_id: str) -> None: - """Remove on-disk transcript files for a session. - - Cleans up ``{session_id}.json``, ``{session_id}.jsonl``, and any - ``request_dump_{session_id}_*.json`` files left by the gateway. - Silently skips files that don't exist and swallows OSError so a - filesystem hiccup never blocks a DB operation. - """ - if sessions_dir is None: - return - for suffix in (".json", ".jsonl"): - p = sessions_dir / f"{session_id}{suffix}" - try: - p.unlink(missing_ok=True) - except OSError: - pass - # request_dump files use session_id as a prefix component - try: - for p in sessions_dir.glob(f"request_dump_{session_id}_*.json"): - try: - p.unlink(missing_ok=True) - except OSError: - pass - except OSError: - pass - - def delete_session( - self, - session_id: str, - sessions_dir: Optional[Path] = None, - ) -> bool: - """Delete a session and all its messages. - - Child sessions are orphaned (parent_session_id set to NULL) rather - than cascade-deleted, so they remain accessible independently. - When *sessions_dir* is provided, also removes on-disk transcript - files (``.json`` / ``.jsonl`` / ``request_dump_*``) for the deleted - session. Returns True if the session was found and deleted. - """ - def _do(conn): - cursor = conn.execute( - "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) - ) - if cursor.fetchone()[0] == 0: - return False - # Orphan child sessions so FK constraint is satisfied - conn.execute( - "UPDATE sessions SET parent_session_id = NULL " - "WHERE parent_session_id = ?", - (session_id,), - ) - conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) - return True - - deleted = self._execute_write(_do) - if deleted: - self._remove_session_files(sessions_dir, session_id) - return deleted - - def delete_sessions( - self, - session_ids: List[str], - sessions_dir: Optional[Path] = None, - ) -> int: - """Delete every session in *session_ids* in a single transaction. - - Backs the dashboard's bulk-select-then-delete flow on the - sessions page (``POST /api/sessions/bulk-delete``). Mirrors the - single-session :meth:`delete_session` contract per row: - - * Unknown IDs are silently skipped (no 404) — selection state - in the UI can race against another tab's delete, and we'd - rather succeed-on-the-rest than fail-the-whole-batch. - * Children of every deleted ID are orphaned - (``parent_session_id → NULL``), never cascade-deleted, so a - branch / subagent transcript survives an inadvertent parent - delete. - * Messages and the session row both go in one - ``_execute_write`` call so a partial failure can't leave the - DB in a "messages gone but session row still there" state. - * On-disk transcript / ``request_dump_*`` files are cleaned up - outside the DB transaction when *sessions_dir* is provided, - matching :meth:`prune_sessions` and - :meth:`delete_empty_sessions`. - - Returns the count of sessions that actually existed and were - deleted (may be less than ``len(session_ids)`` if some IDs were - already gone). - """ - if not session_ids: - return 0 - # Dedup + drop any non-string entries up-front. Avoids - # double-counting in the WHERE-IN list and protects against - # callers that pass a list with stray ``None`` values. - unique_ids = list({sid for sid in session_ids if isinstance(sid, str) and sid}) - if not unique_ids: - return 0 - - removed_ids: list[str] = [] - - def _do(conn): - placeholders = ",".join("?" * len(unique_ids)) - # First, filter to IDs that actually exist — we want to - # return the real deleted count, not the input length. - cursor = conn.execute( - f"SELECT id FROM sessions WHERE id IN ({placeholders})", - unique_ids, - ) - existing = [row["id"] for row in cursor.fetchall()] - if not existing: - return 0 - - existing_placeholders = ",".join("?" * len(existing)) - # Orphan children whose parent is in the kill list so the - # FK constraint stays satisfied. Pin children whose parent - # is itself in the kill list rather than NULL-ing parents - # of survivors — the IN list on ``parent_session_id`` does - # exactly this. - conn.execute( - f"UPDATE sessions SET parent_session_id = NULL " - f"WHERE parent_session_id IN ({existing_placeholders})", - existing, - ) - conn.execute( - f"DELETE FROM messages WHERE session_id IN ({existing_placeholders})", - existing, - ) - conn.execute( - f"DELETE FROM sessions WHERE id IN ({existing_placeholders})", - existing, - ) - removed_ids.extend(existing) - return len(existing) - - count = self._execute_write(_do) - for sid in removed_ids: - self._remove_session_files(sessions_dir, sid) - return count - - def count_empty_sessions(self) -> int: - """Return the count of empty, non-active, non-archived sessions. - - "Empty" = ``message_count = 0`` AND the session has ended - (``ended_at IS NOT NULL``) AND is not archived. The ``ended_at`` - guard matches the safety contract used by :meth:`prune_sessions`: - only ended sessions are candidates for bulk deletion, so a freshly - spawned session whose first message hasn't landed yet — or one - held open by the live agent — is never sniped out from under - the runtime. - - Backs the ``GET /api/sessions/empty/count`` endpoint that lets the - web dashboard hide its "Delete empty" button when there's nothing - to clean up, and pre-populate the confirm dialog with the actual - count. - """ - with self._lock: - cursor = self._conn.execute( - "SELECT COUNT(*) FROM sessions " - "WHERE message_count = 0 " - "AND ended_at IS NOT NULL " - "AND archived = 0" - ) - return cursor.fetchone()[0] - - def delete_empty_sessions( - self, - sessions_dir: Optional[Path] = None, - ) -> int: - """Delete every empty, ended, non-archived session. - - Mirrors :meth:`prune_sessions`' transactional shape: - - * Selects candidate IDs first (``message_count = 0`` AND - ``ended_at IS NOT NULL`` AND ``archived = 0``) so we never - touch a live session or one the user deliberately archived. - * Orphans any child whose parent is in the kill list — children - of an empty parent are kept and re-parented to ``NULL`` rather - than cascade-deleted, matching ``delete_session`` / - ``prune_sessions`` semantics so branch/subagent transcripts - survive an inadvertent parent cleanup. - * Deletes the rows in a single ``_execute_write`` callback so - the operation is atomic — a partial failure (e.g. SIGKILL - mid-loop) doesn't leave the DB in a "messages-deleted but - session-row-still-there" half-state. - * Cleans up on-disk transcript files (``.json`` / ``.jsonl`` / - ``request_dump_*``) outside the DB transaction when - ``sessions_dir`` is provided. Empty sessions don't typically - have transcript files, but the gateway can leave a stub - ``request_dump_*`` if it crashed before the first reply — - so we still sweep, matching ``prune_sessions``. - - Returns the number of sessions deleted. - """ - removed_ids: list[str] = [] - - def _do(conn): - cursor = conn.execute( - "SELECT id FROM sessions " - "WHERE message_count = 0 " - "AND ended_at IS NOT NULL " - "AND archived = 0" - ) - session_ids = {row["id"] for row in cursor.fetchall()} - - if not session_ids: - return 0 - - placeholders = ",".join("?" * len(session_ids)) - conn.execute( - f"UPDATE sessions SET parent_session_id = NULL " - f"WHERE parent_session_id IN ({placeholders})", - list(session_ids), - ) - - for sid in session_ids: - # DELETE FROM messages is paranoia — by construction - # these rows have ``message_count = 0`` — but if a - # bookkeeping bug ever lets the counter drift below the - # real row count, we still leave a clean FK state. - conn.execute( - "DELETE FROM messages WHERE session_id = ?", (sid,) - ) - conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) - removed_ids.append(sid) - return len(session_ids) - - count = self._execute_write(_do) - for sid in removed_ids: - self._remove_session_files(sessions_dir, sid) - return count - - def prune_sessions( - self, - older_than_days: int = 90, - source: str = None, - sessions_dir: Optional[Path] = None, - ) -> int: - """Delete sessions older than N days. Returns count of deleted sessions. - - Only prunes ended sessions (not active ones). Child sessions outside - the prune window are orphaned (parent_session_id set to NULL) rather - than cascade-deleted. When *sessions_dir* is provided, also removes - on-disk transcript files (``.json`` / ``.jsonl`` / - ``request_dump_*``) for every pruned session, outside the DB - transaction. - """ - cutoff = time.time() - (older_than_days * 86400) - removed_ids: list[str] = [] - - def _do(conn): - if source: - cursor = conn.execute( - """SELECT id FROM sessions - WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""", - (cutoff, source), - ) - else: - cursor = conn.execute( - "SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL", - (cutoff,), - ) - session_ids = {row["id"] for row in cursor.fetchall()} - - if not session_ids: - return 0 - - # Orphan any sessions whose parent is about to be deleted - placeholders = ",".join("?" * len(session_ids)) - conn.execute( - f"UPDATE sessions SET parent_session_id = NULL " - f"WHERE parent_session_id IN ({placeholders})", - list(session_ids), - ) - - for sid in session_ids: - conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) - removed_ids.append(sid) - return len(session_ids) - - count = self._execute_write(_do) - # Clean up on-disk files outside the DB transaction - for sid in removed_ids: - self._remove_session_files(sessions_dir, sid) - return count - - # ── Meta key/value (for scheduler bookkeeping) ── - - def get_meta(self, key: str) -> Optional[str]: - """Read a value from the state_meta key/value store.""" - with self._lock: - row = self._conn.execute( - "SELECT value FROM state_meta WHERE key = ?", (key,) - ).fetchone() - if row is None: - return None - return row["value"] if isinstance(row, sqlite3.Row) else row[0] - - def set_meta(self, key: str, value: str) -> None: - """Write a value to the state_meta key/value store.""" - def _do(conn): - conn.execute( - "INSERT INTO state_meta (key, value) VALUES (?, ?) " - "ON CONFLICT(key) DO UPDATE SET value = excluded.value", - (key, value), - ) - self._execute_write(_do) - - def apply_telegram_topic_migration(self) -> None: - """Create Telegram DM topic-mode tables on explicit /topic opt-in. - - This migration is deliberately not part of automatic SessionDB startup - reconciliation. Operators must be able to upgrade Hermes, keep the old - Telegram bot behavior running, and only mutate topic-mode state when the - user executes /topic to opt into the feature. - - Schema versions: - v1 — initial shape (no ON DELETE CASCADE on session_id FK) - v2 — session_id FK gets ON DELETE CASCADE so session pruning - automatically clears bindings. - """ - def _do(conn): - conn.executescript( - """ - CREATE TABLE IF NOT EXISTS telegram_dm_topic_mode ( - chat_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - enabled INTEGER NOT NULL DEFAULT 1, - activated_at REAL NOT NULL, - updated_at REAL NOT NULL, - has_topics_enabled INTEGER, - allows_users_to_create_topics INTEGER, - capability_checked_at REAL, - intro_message_id TEXT, - pinned_message_id TEXT - ); - - CREATE TABLE IF NOT EXISTS telegram_dm_topic_bindings ( - chat_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - user_id TEXT NOT NULL, - session_key TEXT NOT NULL, - session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, - managed_mode TEXT NOT NULL DEFAULT 'auto', - linked_at REAL NOT NULL, - updated_at REAL NOT NULL, - PRIMARY KEY (chat_id, thread_id) - ); - - CREATE UNIQUE INDEX IF NOT EXISTS idx_telegram_dm_topic_bindings_session - ON telegram_dm_topic_bindings(session_id); - - CREATE INDEX IF NOT EXISTS idx_telegram_dm_topic_bindings_user - ON telegram_dm_topic_bindings(user_id, chat_id); - """ - ) - - # v1 → v2: rebuild telegram_dm_topic_bindings if its session_id FK - # lacks ON DELETE CASCADE. SQLite can't ALTER a foreign key, so we - # rebuild the table. Only runs once per DB (version gate). - current = conn.execute( - "SELECT value FROM state_meta WHERE key = ?", - ("telegram_dm_topic_schema_version",), - ).fetchone() - current_version = int(current[0]) if current and str(current[0]).isdigit() else 0 - if current_version < 2: - fk_rows = conn.execute( - "PRAGMA foreign_key_list('telegram_dm_topic_bindings')" - ).fetchall() - needs_rebuild = any( - row[2] == "sessions" and (row[6] or "") != "CASCADE" - for row in fk_rows - ) - if needs_rebuild: - conn.executescript( - """ - CREATE TABLE telegram_dm_topic_bindings_new ( - chat_id TEXT NOT NULL, - thread_id TEXT NOT NULL, - user_id TEXT NOT NULL, - session_key TEXT NOT NULL, - session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, - managed_mode TEXT NOT NULL DEFAULT 'auto', - linked_at REAL NOT NULL, - updated_at REAL NOT NULL, - PRIMARY KEY (chat_id, thread_id) - ); - INSERT INTO telegram_dm_topic_bindings_new - SELECT chat_id, thread_id, user_id, session_key, - session_id, managed_mode, linked_at, updated_at - FROM telegram_dm_topic_bindings; - DROP TABLE telegram_dm_topic_bindings; - ALTER TABLE telegram_dm_topic_bindings_new - RENAME TO telegram_dm_topic_bindings; - CREATE UNIQUE INDEX idx_telegram_dm_topic_bindings_session - ON telegram_dm_topic_bindings(session_id); - CREATE INDEX idx_telegram_dm_topic_bindings_user - ON telegram_dm_topic_bindings(user_id, chat_id); - """ - ) - - conn.execute( - "INSERT INTO state_meta (key, value) VALUES (?, ?) " - "ON CONFLICT(key) DO UPDATE SET value = excluded.value", - ("telegram_dm_topic_schema_version", "2"), - ) - self._execute_write(_do) - - def enable_telegram_topic_mode( - self, - *, - chat_id: str, - user_id: str, - has_topics_enabled: Optional[bool] = None, - allows_users_to_create_topics: Optional[bool] = None, - ) -> None: - """Enable Telegram DM topic mode for one private chat/user. - - This method intentionally owns the explicit topic migration. Ordinary - SessionDB startup must not create these side tables. - """ - self.apply_telegram_topic_migration() - now = time.time() - - def _to_int(value: Optional[bool]) -> Optional[int]: - if value is None: - return None - return 1 if value else 0 - - def _do(conn): - conn.execute( - """ - INSERT INTO telegram_dm_topic_mode ( - chat_id, user_id, enabled, activated_at, updated_at, - has_topics_enabled, allows_users_to_create_topics, - capability_checked_at - ) VALUES (?, ?, 1, ?, ?, ?, ?, ?) - ON CONFLICT(chat_id) DO UPDATE SET - user_id = excluded.user_id, - enabled = 1, - updated_at = excluded.updated_at, - has_topics_enabled = excluded.has_topics_enabled, - allows_users_to_create_topics = excluded.allows_users_to_create_topics, - capability_checked_at = excluded.capability_checked_at - """, - ( - str(chat_id), - str(user_id), - now, - now, - _to_int(has_topics_enabled), - _to_int(allows_users_to_create_topics), - now, - ), - ) - self._execute_write(_do) - - def disable_telegram_topic_mode( - self, - *, - chat_id: str, - clear_bindings: bool = True, - ) -> None: - """Disable Telegram DM topic mode for one private chat. - - When ``clear_bindings`` is True (default) the (chat_id, thread_id) - bindings for this chat are also cleared so re-enabling later - starts from a clean slate. Set to False if the operator wants to - preserve bindings for a later re-enable. - - Never creates the topic-mode tables from scratch; if they don't - exist there is nothing to disable and the call is a no-op. - """ - def _do(conn): - try: - conn.execute( - "UPDATE telegram_dm_topic_mode SET enabled = 0, updated_at = ? " - "WHERE chat_id = ?", - (time.time(), str(chat_id)), - ) - if clear_bindings: - conn.execute( - "DELETE FROM telegram_dm_topic_bindings WHERE chat_id = ?", - (str(chat_id),), - ) - except sqlite3.OperationalError: - # Tables don't exist yet — nothing to disable. - return - self._execute_write(_do) - - def is_telegram_topic_mode_enabled(self, *, chat_id: str, user_id: str) -> bool: - """Return whether Telegram DM topic mode is enabled for this chat/user.""" - with self._lock: - try: - row = self._conn.execute( - """ - SELECT enabled FROM telegram_dm_topic_mode - WHERE chat_id = ? AND user_id = ? - """, - (str(chat_id), str(user_id)), - ).fetchone() - except sqlite3.OperationalError: - return False - if row is None: - return False - enabled = row["enabled"] if isinstance(row, sqlite3.Row) else row[0] - return bool(enabled) - - def get_telegram_topic_binding( - self, - *, - chat_id: str, - thread_id: str, - ) -> Optional[Dict[str, Any]]: - """Return the session binding for a Telegram DM topic, if present.""" - with self._lock: - try: - row = self._conn.execute( - """ - SELECT * FROM telegram_dm_topic_bindings - WHERE chat_id = ? AND thread_id = ? - """, - (str(chat_id), str(thread_id)), - ).fetchone() - except sqlite3.OperationalError: - return None - return dict(row) if row else None - - def list_telegram_topic_bindings_for_chat( - self, - *, - chat_id: str, - ) -> List[Dict[str, Any]]: - """All Telegram DM topic bindings for one chat, newest first. - - Read-only; returns [] if the bindings table doesn't exist yet - (does not trigger the topic-mode migration). - """ - with self._lock: - try: - rows = self._conn.execute( - "SELECT * FROM telegram_dm_topic_bindings " - "WHERE chat_id = ? ORDER BY updated_at DESC", - (str(chat_id),), - ).fetchall() - except sqlite3.OperationalError: - return [] - return [dict(row) for row in rows] - - def get_telegram_topic_binding_by_session( - self, - *, - session_id: str, - ) -> Optional[Dict[str, Any]]: - """Return the Telegram DM topic binding for a given session_id, if present. - - Uses the UNIQUE INDEX on telegram_dm_topic_bindings(session_id) for an - efficient reverse lookup. Returns None when the session has no binding or - the table does not exist yet. - """ - with self._lock: - try: - row = self._conn.execute( - """ - SELECT * FROM telegram_dm_topic_bindings - WHERE session_id = ? - """, - (str(session_id),), - ).fetchone() - except sqlite3.OperationalError: - return None - return dict(row) if row else None - - def bind_telegram_topic( - self, - *, - chat_id: str, - thread_id: str, - user_id: str, - session_key: str, - session_id: str, - managed_mode: str = "auto", - ) -> None: - """Bind one Telegram DM topic thread to one Hermes session. - - A Hermes session may only be linked to one Telegram topic in MVP. - Rebinding the same topic to the same session is idempotent; trying to - link the same session to a different topic raises ValueError. - """ - self.apply_telegram_topic_migration() - now = time.time() - chat_id = str(chat_id) - thread_id = str(thread_id) - user_id = str(user_id) - session_key = str(session_key) - session_id = str(session_id) - - def _do(conn): - existing_session = conn.execute( - """ - SELECT chat_id, thread_id FROM telegram_dm_topic_bindings - WHERE session_id = ? - """, - (session_id,), - ).fetchone() - if existing_session is not None: - linked_chat = existing_session["chat_id"] if isinstance(existing_session, sqlite3.Row) else existing_session[0] - linked_thread = existing_session["thread_id"] if isinstance(existing_session, sqlite3.Row) else existing_session[1] - if str(linked_chat) != chat_id or str(linked_thread) != thread_id: - raise ValueError("session is already linked to another Telegram topic") - - conn.execute( - """ - INSERT INTO telegram_dm_topic_bindings ( - chat_id, thread_id, user_id, session_key, session_id, - managed_mode, linked_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(chat_id, thread_id) DO UPDATE SET - user_id = excluded.user_id, - session_key = excluded.session_key, - session_id = excluded.session_id, - managed_mode = excluded.managed_mode, - updated_at = excluded.updated_at - """, - ( - chat_id, - thread_id, - user_id, - session_key, - session_id, - managed_mode, - now, - now, - ), - ) - self._execute_write(_do) - - def is_telegram_session_linked_to_topic(self, *, session_id: str) -> bool: - """Return True if a Hermes session is already bound to any Telegram DM topic. - - Read-only: does NOT trigger the telegram-topic migration. If the - topic-mode tables have not been created yet (i.e. nobody has run - ``/topic`` in this profile), the session is by definition unbound - and we return False. - """ - with self._lock: - try: - row = self._conn.execute( - """ - SELECT 1 FROM telegram_dm_topic_bindings - WHERE session_id = ? - LIMIT 1 - """, - (str(session_id),), - ).fetchone() - except sqlite3.OperationalError: - return False - return row is not None - - def list_unlinked_telegram_sessions_for_user( - self, - *, - chat_id: str, - user_id: str, - limit: int = 10, - ) -> List[Dict[str, Any]]: - """List previous Telegram sessions for this user that are not bound to a topic. - - Read-only: does NOT trigger the telegram-topic migration. If the - topic-mode tables are absent, fall back to a simpler query that - just returns this user's Telegram sessions — there can't be any - bindings yet. - """ - with self._lock: - try: - rows = self._conn.execute( - """ - SELECT s.*, - COALESCE( - (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) - FROM messages m - WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL - ORDER BY m.timestamp, m.id LIMIT 1), - '' - ) AS _preview_raw, - COALESCE( - (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), - s.started_at - ) AS last_active - FROM sessions s - WHERE s.source = 'telegram' - AND s.user_id = ? - AND NOT EXISTS ( - SELECT 1 FROM telegram_dm_topic_bindings b - WHERE b.session_id = s.id - ) - ORDER BY last_active DESC, s.started_at DESC - LIMIT ? - """, - (str(user_id), int(limit)), - ).fetchall() - except sqlite3.OperationalError: - # telegram_dm_topic_bindings doesn't exist yet — no bindings - # means every telegram session for this user is "unlinked". - rows = self._conn.execute( - """ - SELECT s.*, - COALESCE( - (SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63) - FROM messages m - WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL - ORDER BY m.timestamp, m.id LIMIT 1), - '' - ) AS _preview_raw, - COALESCE( - (SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id), - s.started_at - ) AS last_active - FROM sessions s - WHERE s.source = 'telegram' - AND s.user_id = ? - ORDER BY last_active DESC, s.started_at DESC - LIMIT ? - """, - (str(user_id), int(limit)), - ).fetchall() - - sessions: List[Dict[str, Any]] = [] - for row in rows: - session = dict(row) - raw = str(session.pop("_preview_raw", "") or "").strip() - session["preview"] = raw[:60] + ("..." if len(raw) > 60 else "") if raw else "" - sessions.append(session) - return sessions - - # ── Space reclamation ── - - # FTS5 virtual tables whose b-tree segments we merge on optimize. The - # trigram table is created lazily / may be disabled, so we probe before - # touching it (see optimize_fts). - _FTS_TABLES = ("messages_fts", "messages_fts_trigram") - - def _fts_table_exists(self, name: str) -> bool: - """True if an FTS5 virtual table is queryable in this DB.""" - try: - self._conn.execute(f"SELECT 1 FROM {name} LIMIT 0") - return True - except sqlite3.OperationalError: - return False - - def optimize_fts(self) -> int: - """Merge fragmented FTS5 b-tree segments into one per index. - - FTS5 indexes grow as a series of incremental segments — one per - ``INSERT`` batch driven by the message triggers. Over tens of - thousands of messages these segments accumulate, which both bloats - the ``*_data`` shadow tables and slows ``MATCH`` queries that must - scan every segment. The special ``'optimize'`` command rewrites each - index as a single merged segment. - - This is purely a maintenance operation — it changes neither search - results nor ``snippet()`` output, only on-disk layout and query - speed. It is complementary to VACUUM: ``optimize`` compacts the FTS - index internally, then VACUUM returns the freed pages to the OS. - - Skips any FTS table that does not exist (e.g. the trigram index when - disabled via ``HERMES_DISABLE_FTS_TRIGRAM`` or not yet created), so - it is safe to call unconditionally. - - Returns the number of FTS indexes that were optimized. - """ - optimized = 0 - with self._lock: - for tbl in self._FTS_TABLES: - if not self._fts_table_exists(tbl): - continue - try: - # The column name in the INSERT must match the table name - # for FTS5 special commands. - self._conn.execute( - f"INSERT INTO {tbl}({tbl}) VALUES('optimize')" - ) - optimized += 1 - except sqlite3.OperationalError as exc: - logger.warning( - "FTS optimize failed for %s: %s", tbl, exc - ) - return optimized - - def vacuum(self) -> int: - """Run VACUUM to reclaim disk space after large deletes. - - SQLite does not shrink the database file when rows are deleted — - freed pages just get reused on the next insert. After a prune that - removed hundreds of sessions, the file stays bloated unless we - explicitly VACUUM. - - VACUUM rewrites the entire DB, so it's expensive (seconds per - 100MB) and cannot run inside a transaction. It also acquires an - exclusive lock, so callers must ensure no other writers are - active. Safe to call at startup before the gateway/CLI starts - serving traffic. - - FTS5 segments are merged first via :meth:`optimize_fts` so the - subsequent VACUUM reclaims the pages freed by the merge. This is a - layout-only optimization — search results are unchanged. - - Returns the number of FTS indexes that were optimized (0 if the - merge step failed or no FTS tables exist). - """ - # Merge FTS5 segments before VACUUM so the freed pages are returned - # to the OS in the same pass. optimize_fts() manages its own lock. - optimized = 0 - try: - optimized = self.optimize_fts() - except Exception as exc: - logger.warning("FTS optimize before VACUUM failed: %s", exc) - # VACUUM cannot be executed inside a transaction. - with self._lock: - # Best-effort WAL checkpoint first, then VACUUM. - try: - self._conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") - except Exception: - pass - self._conn.execute("VACUUM") - return optimized - - def maybe_auto_prune_and_vacuum( - self, - retention_days: int = 90, - min_interval_hours: int = 24, - vacuum: bool = True, - sessions_dir: Optional[Path] = None, - ) -> Dict[str, Any]: - """Idempotent auto-maintenance: prune old sessions + optional VACUUM. - - Records the last run timestamp in state_meta so subsequent calls - within ``min_interval_hours`` no-op. Designed to be called once at - startup from long-lived entrypoints (CLI, gateway, cron scheduler). - - When *sessions_dir* is provided, on-disk transcript files - (``.json`` / ``.jsonl`` / ``request_dump_*``) for pruned sessions - are removed as part of the same sweep (issue #3015). - - Never raises. On any failure, logs a warning and returns a dict - with ``"error"`` set. - - Returns a dict with keys: - - ``"skipped"`` (bool) — true if within min_interval_hours of last run - - ``"pruned"`` (int) — number of sessions deleted - - ``"vacuumed"`` (bool) — true if VACUUM ran - - ``"error"`` (str, optional) — present only on failure - """ - result: Dict[str, Any] = {"skipped": False, "pruned": 0, "vacuumed": False} - try: - # Skip if another process/call did maintenance recently. - last_raw = self.get_meta("last_auto_prune") - now = time.time() - if last_raw: - try: - last_ts = float(last_raw) - if now - last_ts < min_interval_hours * 3600: - result["skipped"] = True - return result - except (TypeError, ValueError): - pass # corrupt meta; treat as no prior run - - pruned = self.prune_sessions( - older_than_days=retention_days, - sessions_dir=sessions_dir, - ) - result["pruned"] = pruned - - # Only VACUUM if we actually freed rows — VACUUM on a tight DB - # is wasted I/O. Threshold keeps small DBs from paying the cost. - if vacuum and pruned > 0: - try: - self.vacuum() - result["vacuumed"] = True - except Exception as exc: - logger.warning("state.db VACUUM failed: %s", exc) - - # Record the attempt even if pruned == 0, so we don't retry - # every startup within the min_interval_hours window. - self.set_meta("last_auto_prune", str(now)) - - if pruned > 0: - logger.info( - "state.db auto-maintenance: pruned %d session(s) older than %d days%s", - pruned, - retention_days, - " + VACUUM" if result["vacuumed"] else "", - ) - except Exception as exc: - # Maintenance must never block startup. Log and return error marker. - logger.warning("state.db auto-maintenance failed: %s", exc) - result["error"] = str(exc) - - return result - - # ── Handoff (cross-platform session transfer) ────────────────────────── - # - # State machine: - # None — no handoff in flight - # "pending" — CLI requested handoff, gateway hasn't picked it up yet - # "running" — gateway is processing (session switch + synthetic turn) - # "completed"— gateway successfully delivered the synthetic turn - # "failed" — gateway hit an error; reason in handoff_error - # - # The CLI writes "pending" then poll-waits for terminal state. The gateway - # watcher transitions pending→running→{completed,failed}. - - def request_handoff(self, session_id: str, platform: str) -> bool: - """Mark a session as pending handoff to the given platform. - - Returns True if the row was found and not already in flight; False if - the session is already in a non-terminal handoff state. - """ - def _do(conn): - cur = conn.execute( - "UPDATE sessions " - "SET handoff_state = 'pending', " - " handoff_platform = ?, " - " handoff_error = NULL " - "WHERE id = ? AND (handoff_state IS NULL " - " OR handoff_state IN ('completed', 'failed'))", - (platform, session_id), - ) - return cur.rowcount > 0 - return self._execute_write(_do) - - def get_handoff_state(self, session_id: str) -> Optional[Dict[str, Any]]: - """Read the current handoff state for a session. - - Returns ``{"state", "platform", "error"}`` or None if the session has - no handoff record. - """ - try: - cur = self._conn.execute( - "SELECT handoff_state, handoff_platform, handoff_error " - "FROM sessions WHERE id = ?", - (session_id,), - ) - row = cur.fetchone() - if not row: - return None - return { - "state": row["handoff_state"], - "platform": row["handoff_platform"], - "error": row["handoff_error"], - } - except Exception: - return None - - def list_pending_handoffs(self) -> List[Dict[str, Any]]: - """Return all sessions in handoff_state='pending', oldest first. - - Used by the gateway's handoff watcher. - """ - try: - cur = self._conn.execute( - "SELECT * FROM sessions " - "WHERE handoff_state = 'pending' " - "ORDER BY started_at ASC" - ) - return [dict(r) for r in cur.fetchall()] - except Exception: - return [] - - def claim_handoff(self, session_id: str) -> bool: - """Atomically transition pending → running. Returns True if claimed.""" - def _do(conn): - cur = conn.execute( - "UPDATE sessions SET handoff_state = 'running' " - "WHERE id = ? AND handoff_state = 'pending'", - (session_id,), - ) - return cur.rowcount > 0 - return self._execute_write(_do) - - def complete_handoff(self, session_id: str) -> None: - """Mark a handoff as completed.""" - def _do(conn): - conn.execute( - "UPDATE sessions SET handoff_state = 'completed', " - "handoff_error = NULL WHERE id = ?", - (session_id,), - ) - self._execute_write(_do) - - def fail_handoff(self, session_id: str, error: str) -> None: - """Mark a handoff as failed and record the reason.""" - def _do(conn): - conn.execute( - "UPDATE sessions SET handoff_state = 'failed', " - "handoff_error = ? WHERE id = ?", - (error[:500], session_id), - ) - self._execute_write(_do) diff --git a/scripts/build_prd.py b/scripts/build_prd.py deleted file mode 100644 index 0c6967e..0000000 --- a/scripts/build_prd.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- -import os -path = r"D:\F\NewI\opencode\daily-workspace\projects\AgentsMeeting\docs\PRD.md" -lines = [] -lines.append("# AgentsMeeting -- PRD v0.1") -lines.append("") -lines.append("> \u7248\u672c: \u521d\u7a3f ^| \u5ba2\u6237: hmo (\u8001\u83ab) ^| PM: mohe (\u83ab\u8377) ^| \u7814\u53d1: xxm (\u5c0f\u5c0f\u83ab)") diff --git a/scripts/gen_prd.py b/scripts/gen_prd.py deleted file mode 100644 index 1c6f83c..0000000 --- a/scripts/gen_prd.py +++ /dev/null @@ -1 +0,0 @@ -import os; open(os.path.join(r"D:\F\NewI\opencode\daily-workspace\projects\AgentsMeeting\docs","PRD.md"),"w",encoding="utf-8").write("# AgentsMeeting - PRD v0.1\n\nOK") \ No newline at end of file diff --git a/scripts/gen_prd_v02.py b/scripts/gen_prd_v02.py deleted file mode 100644 index 2218965..0000000 --- a/scripts/gen_prd_v02.py +++ /dev/null @@ -1,2 +0,0 @@ -import os -print("ok") diff --git a/scripts/test_echo.py b/scripts/test_echo.py deleted file mode 100644 index 4473a45..0000000 --- a/scripts/test_echo.py +++ /dev/null @@ -1 +0,0 @@ -print("test123") diff --git a/scripts/write_prd.py b/scripts/write_prd.py deleted file mode 100644 index 1db0a10..0000000 --- a/scripts/write_prd.py +++ /dev/null @@ -1 +0,0 @@ -print("ok") \ No newline at end of file diff --git a/scripts/write_prd_v02.py b/scripts/write_prd_v02.py deleted file mode 100644 index 37b73e7..0000000 --- a/scripts/write_prd_v02.py +++ /dev/null @@ -1,2 +0,0 @@ -import os,sys -open("D:/F/NewI/opencode/daily-workspace/projects/AgentsMeeting/docs/PRD_v0.2.md","w",encoding="utf-8").write("test ok\n") \ No newline at end of file diff --git a/xmpp_agent_core.py b/xmpp_agent_core.py index 5e54579..772acef 100644 --- a/xmpp_agent_core.py +++ b/xmpp_agent_core.py @@ -1,9 +1,35 @@ #!/usr/bin/env python3 -"""XMPP Bot - 统一版,支持 --agent mohe|zhiwei|xiao 参数""" -import asyncio, logging, ssl, json, urllib.request, os, time, sys, re -from slixmpp import ClientXMPP +""" +XMPP Agent Core — 统一版 +========================= +Single bot core for all agents. Supports --agent xxm|mohe|zhiwei|xiaoguo. + +Usage: + python xmpp_agent_core.py --agent xxm # xxm, uses chat_bridge + python xmpp_agent_core.py --agent mohe # mohe, uses Hermes API + python xmpp_agent_core.py --agent zhiwei # zhiwei, uses Hermes API + python xmpp_agent_core.py --agent xiaoguo # xiaoguo, uses Hermes API + +Shares: PID lock, reconnect, MUC join, dedup, batching, + coordinator protocol (GRANT/REVOKE), HTTP bridge. +Differs only in LLM calling method (chat_bridge vs Hermes API). +""" +import os, sys, time, threading, asyncio, logging, json, re, ssl +import urllib.request, http.server, urllib.parse + +# ── Windows selector loop (slixmpp needs it on Windows) ── +if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +# ── PATH: allow imports from gateway/scripts/ (proc_guard, chat_bridge) ── +_GATEWAY_SCRIPTS = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "gateway", "scripts") +sys.path.insert(0, _GATEWAY_SCRIPTS) + +# ═══════════════════════════════════════════════════════════════ +# AGENTS Configuration +# ═══════════════════════════════════════════════════════════════ -# ── Agent 配置 ────────────────────────────────────────────── AGENTS = { "mohe": { "jid": "mohe@yoin.fun", @@ -13,6 +39,9 @@ AGENTS = { "http_port": 5804, "gateway": "http://localhost:8642/v1/chat/completions", "session_id": "xmpp-mohe-v2", + "server": "127.0.0.1", + "port": 5222, + "muc_rooms": ["coregroup@conference.yoin.fun"], "mention": "@mohe/@莫荷", }, "zhiwei": { @@ -23,6 +52,9 @@ AGENTS = { "http_port": 5805, "gateway": "http://localhost:8643/v1/chat/completions", "session_id": "xmpp-zhiwei", + "server": "127.0.0.1", + "port": 5222, + "muc_rooms": ["coregroup@conference.yoin.fun"], "mention": "@zhiwei/@知微", }, "xiaoguo": { @@ -33,316 +65,830 @@ AGENTS = { "http_port": 5806, "gateway": "http://localhost:8645/v1/chat/completions", "session_id": "xmpp-xiaoguo", + "server": "127.0.0.1", + "port": 5222, + "muc_rooms": ["coregroup@conference.yoin.fun"], "mention": "@xiaoguo/@小果", }, + "xxm": { + "jid": "xxm@yoin.fun", + "password": "hermes123", + "nick": "xxm", + "name_cn": "笑笑", + "http_port": 5802, + "bridge": "chat_bridge", # use local chat_bridge instead of Hermes API + "session_id": "ses_xxm_xmpp", + "server": "192.168.1.246", # LAN direct connect + "port": 5222, + "muc_rooms": [ + "coregroup@conference.yoin.fun", + "jujidina@conference.yoin.fun", + ], + "mention": "@xxm/@笑笑", + }, } -agent = sys.argv[sys.argv.index("--agent") + 1] if "--agent" in sys.argv else "mohe" -cfg = AGENTS.get(agent, AGENTS["mohe"]) +# ── Agent selection ── +_agent_name = "mohe" +if "--agent" in sys.argv: + idx = sys.argv.index("--agent") + if idx + 1 < len(sys.argv): + _agent_name = sys.argv[idx + 1] +cfg = AGENTS.get(_agent_name, AGENTS["mohe"]) + +# ═══════════════════════════════════════════════════════════════ +# PID Lock — prevent duplicate instances +# ═══════════════════════════════════════════════════════════════ + +from proc_guard import guard as _proc_guard +_lock = _proc_guard(f"xmpp_bot_{_agent_name}") +if not _lock.ok: + print(_lock.message, flush=True) + sys.exit(1) + +# ═══════════════════════════════════════════════════════════════ +# Logging +# ═══════════════════════════════════════════════════════════════ + +_LOG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "gateway", "logs") +os.makedirs(_LOG_DIR, exist_ok=True) +_LOG_FILE = os.path.join(_LOG_DIR, f"xmpp_{_agent_name}.log") +_START_TIME = time.time() + + +def log(m: str): + with open(_LOG_FILE, "a", encoding="utf-8") as f: + f.write(f"{time.strftime('%H:%M:%S')} {m}\n") + logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') -GATEWAY = cfg["gateway"] -API_KEY = "hermes123" -AGENT_NICK = cfg["nick"] -AGENT_NAME = cfg["name_cn"] -AGENT_JID = cfg["jid"] -AGENT_MENTION = cfg["mention"] -SESSION_ID = cfg["session_id"] -HTTP_PORT = cfg["http_port"] -_opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) -# ── HTTP 桥(接收本地脚本的主动发送请求) ── -from http.server import HTTPServer, BaseHTTPRequestHandler -import threading, json as json_mod +# ═══════════════════════════════════════════════════════════════ +# LLM Bridge Init — abstracted per agent +# ═══════════════════════════════════════════════════════════════ -_send_queue = [] +_IS_CHAT_BRIDGE = cfg.get("bridge") == "chat_bridge" +_router = None # set only for chat_bridge (xxm) -class SendHandler(BaseHTTPRequestHandler): - def do_POST(self): - length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(length) - try: - data = json_mod.loads(body) - room = data.get('to', 'coregroup@conference.yoin.fun') - text = data.get('body', '') - if text: - _send_queue.append((room, text)) - self.send_response(200) - self.end_headers() - self.wfile.write(b'{"ok":true}') - else: - self.send_response(400) - self.end_headers() - self.wfile.write(b'{"ok":false,"error":"empty body"}') - except Exception as e: - self.send_response(500) - self.end_headers() - self.wfile.write(f'{{"ok":false,"error":"{e}"}}'.encode()) +if _IS_CHAT_BRIDGE: + from chat_bridge import SessionBridge + from session_router import SessionRouter + _bridge = SessionBridge(session_id=cfg["session_id"]) + _router = SessionRouter(bridge=_bridge, default_session=cfg["session_id"]) + log(f"LLM: chat_bridge (session={cfg['session_id']})") +else: + _opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) + log(f"LLM: Hermes API ({cfg['gateway']})") -def _run_http(): - server = HTTPServer(('127.0.0.1', HTTP_PORT), SendHandler) - server.timeout = 1.0 + +def _call_llm(content: str, sender: str, is_group: bool = False) -> str: + """Abstract LLM call. Returns raw response text (or empty string).""" + if _IS_CHAT_BRIDGE: + return _router.route("xmpp", sender, content) or "" + else: + return _call_hermes_api(content) + + +def _call_hermes_api(content: str) -> str: + """POST to Hermes API, return response text or empty string.""" + try: + payload = json.dumps({ + "model": "hermes-agent", + "messages": [{"role": "user", "content": content}] + }).encode() + req = urllib.request.Request(cfg["gateway"], data=payload, method="POST") + req.add_header("Content-Type", "application/json") + req.add_header("Authorization", "Bearer hermes123") + req.add_header("X-Hermes-Session-Id", cfg["session_id"]) + result = _opener.open(req, timeout=600) + data = json.loads(result.read()) + reply = data.get("choices", [{}])[0].get("message", {}).get("content", "") + return reply.strip() + except Exception as e: + log(f"!!! Hermes API error: {e}") + return "" + + +# ═══════════════════════════════════════════════════════════════ +# Message Dedup +# ═══════════════════════════════════════════════════════════════ + +_DEDUP_CACHE: set[str] = set() +_DEDUP_LOCK = threading.Lock() + + +def _is_duplicate(msg_id: str) -> bool: + if not msg_id: + return False + with _DEDUP_LOCK: + if msg_id in _DEDUP_CACHE: + return True + _DEDUP_CACHE.add(msg_id) + if len(_DEDUP_CACHE) > 100: + _DEDUP_CACHE.clear() + return False + + +# ═══════════════════════════════════════════════════════════════ +# Coordinator Protocol (shared across all agents) +# ═══════════════════════════════════════════════════════════════ + +_COORDINATOR: str = "mohe" +_GRANTED: str | None = None +_REVOKED_UNTIL: float = 0.0 +_SHUTUP_PATTERNS = ["闭嘴", "别说话", "安静", "shut", "stfu", "别说了", "停"] + + +def _process_coordinator_signals(nickname: str, body: str) -> bool: + """Parse coordinator/GRANT/REVOKE from incoming messages. + Returns True if message was a control signal (consumed, no further processing).""" + global _COORDINATOR, _GRANTED, _REVOKED_UNTIL + # 1. hmo switches coordinator + if nickname == 'hmo' and 'coordinator=' in body.lower(): + for name in ('mohe', 'zhiwei', 'xxm'): + if f'coordinator={name}' in body.lower(): + _COORDINATOR = name + _GRANTED = None + log(f"Coordinator switched to {name} by hmo") + return True + # 2. GRANT signal (overrides REVOKE) + gm = re.search(r'\[GRANT:(\w+)\]', body) + if gm: + _GRANTED = gm.group(1) + _REVOKED_UNTIL = 0 + log(f"GRANT: {_GRANTED}") + return True + # 3. REVOKE signal (5min auto-restore) + rm = re.search(r'\[REVOKE:(\w+)\]', body) + if rm and rm.group(1) == cfg["nick"]: + _REVOKED_UNTIL = time.time() + 300 + log(f"REVOKEd: {cfg['nick']} silenced for 5min") + return True + return False + + +def _check_shutup(body: str) -> bool: + """hmo says shut up → 5min silence.""" + lower = body.lower().strip() + for pat in _SHUTUP_PATTERNS: + if pat.lower() in lower: + _REVOKED_UNTIL = time.time() + 300 + log(f"(shutup: '{pat}' → 5min silence)") + return True + return False + + +def _process_llm_grant(response: str): + """Parse GRANT signal from LLM's own response.""" + global _GRANTED + gm = re.search(r'\[GRANT:(\w+)\]', response) + if gm: + _GRANTED = gm.group(1) + _REVOKED_UNTIL = 0 + log(f"LLM GRANT: {_GRANTED}") + + +# ═══════════════════════════════════════════════════════════════ +# Response Extraction +# ═══════════════════════════════════════════════════════════════ + +_SILENCE_PATTERNS = [ + "保持沉默", "不应[该]?回复", "没有.*@.*我", "不是对[我我说]", + "跟我无关", "我不用回复", "不该回复", "不参与", + "不是我[应]?该[说回]", +] +_EXEC_RE = re.compile(r"##exec:(.+?)##", re.DOTALL) +_DELAY_RE = re.compile(r"##delay:?(\d+)?##") +_DELAY_DEFAULT = 15 +_EXEC_TIMEOUT = 60 + + +def _strip_toolcall_xml(text: str) -> str: + t = text + t = re.sub(r']*>.*?(|$)', '', t, flags=re.DOTALL) + t = re.sub(r'.*?(|$)', '', t, flags=re.DOTALL) + t = re.sub(r']*>.*?(|$)', '', t, flags=re.DOTALL) + t = re.sub(r'.*?(|$)', '', t, flags=re.DOTALL) + return t.strip() + + +def _extract_response(text: str) -> str | None: + """Strip __SILENT__, reasoning blocks, natural language silence. + Returns actual content to send, or None to stay silent.""" + if not text: + return None + t = text.strip() + if not t: + return None + t = _strip_toolcall_xml(t) + # Natural language silence detection + if not t.startswith("__SILENT__"): + first = t.split("\n", 1)[0] + for pat in _SILENCE_PATTERNS: + if re.search(pat, first): + return None + return t + # Has __SILENT__ prefix + parts = t.split("\n", 1) + if len(parts) < 2: + return None + rest = parts[1].strip() while True: - server.handle_request() + m = re.match(r'^([^)]*)\s*', rest) + if m: + rest = rest[m.end():] + continue + m = re.match(r'^\([^)]*\)\s*', rest) + if m: + rest = rest[m.end():] + continue + break + return rest.strip() or None -threading.Thread(target=_run_http, daemon=True).start() -logging.info(f"🚀 {AGENT_NAME} HTTP 桥启动于 :{HTTP_PORT}") -# ── XMPP Bot 类 ──────────────────────────────────────────────── -class AgentBot(ClientXMPP): +# ═══════════════════════════════════════════════════════════════ +# Message Batching (3s debounce + serialized processing) +# ═══════════════════════════════════════════════════════════════ + +_BATCH_WINDOW = 3.0 +_BATCH_TIMEOUT = 300 +_batch_entries: dict[str, list[str]] = {} +_batch_timers: dict[str, threading.Timer] = {} +_batch_processing: set[str] = set() +_batch_pending: dict[str, list[str]] = {} +_batch_lock = threading.Lock() +_BOT_NICK = cfg["nick"] + + +def _run_command(cmd: str) -> str: + """Execute ##exec:command## shell command.""" + log(f"(exec: {cmd[:120]})") + try: + import subprocess + r = subprocess.run(cmd, shell=True, capture_output=True, + timeout=_EXEC_TIMEOUT, text=True, encoding='utf-8', errors='replace') + out = (r.stdout or "") + (r.stderr or "") + out = out.strip() or f"(no output, exit={r.returncode})" + log(f"(exec done: {len(out)} bytes, exit={r.returncode})") + return out + except subprocess.TimeoutExpired: + log(f"(exec timeout >{_EXEC_TIMEOUT}s)") + return "(命令超时)" + except Exception as e: + log(f"(exec error: {e})") + return f"(命令执行失败: {e})" + + +def _schedule_delayed(delay_sec: int, room: str): + """Schedule ##delay:N## re-invocation.""" + global _xmpp_ref + import subprocess as _sp + from xml.sax.saxutils import escape + + def _fire(): + bot = _xmpp_ref + if not bot: + return + try: + prompt = "时间到,请根据最新的信息汇报结果。" + raw = _call_llm(prompt, room, is_group=True) + reply = _extract_response(raw) + if reply: + safe = escape(reply.strip()) + stanza = f"{safe}" + bot.send_raw(stanza) + log(f"-> [Delay][{room}]: {reply.strip()[:80]}") + except Exception as e: + log(f"!! delay err: {e}") + + t = threading.Timer(delay_sec, _fire) + t.daemon = True + t.start() + log(f"(delay +{delay_sec}s → {room})") + + +def _batch_done(room: str): + """Called when batch LLM finishes. Flush pending if any.""" + with _batch_lock: + _batch_processing.discard(room) + pending = _batch_pending.pop(room, None) + if pending: + _batch_entries[room] = pending + t = threading.Timer(0.1, _fire_batch, args=[room]) + t.daemon = True + t.start() + _batch_timers[room] = t + return + log(f"[Batch][{room}] (idle)") + + +def _fire_batch(room: str): + """Collect batched entries and call LLM.""" + with _batch_lock: + entries = _batch_entries.pop(room, None) + _batch_timers.pop(room, None) + if not entries: + return + _batch_processing.add(room) + combined = "\n".join(entries) + + def _handle(): + timed_out = [False] + + def _timeout(): + timed_out[0] = True + log(f"[Batch][{room}] TIMEOUT ({_BATCH_TIMEOUT}s)") + _batch_done(room) + + timer = threading.Timer(_BATCH_TIMEOUT, _timeout) + timer.daemon = True + timer.start() + try: + raw = _call_llm(combined, room, is_group=True) + if not timed_out[0]: + timer.cancel() + _process_llm_reply(raw, room) + else: + log(f"[Batch][{room}] route returned after timeout, discarded") + except Exception as e: + log(f"!!! BATCH: {e}") + if not timed_out[0]: + timer.cancel() + _batch_done(room) + + threading.Thread(target=_handle, daemon=True).start() + + +def _batch_group_message(room: str, nickname: str, body: str) -> bool: + """Add message to room batch. Returns True if batched, False if @mention (immediate).""" + if f"@{_BOT_NICK}" in body or body.startswith(_BOT_NICK): + return False + formatted = f"[{nickname}]: {body}" + with _batch_lock: + if room in _batch_processing: + _batch_pending.setdefault(room, []).append(formatted) + return True + timer = _batch_timers.pop(room, None) + if timer: + timer.cancel() + _batch_entries.setdefault(room, []).append(formatted) + t = threading.Timer(_BATCH_WINDOW, _fire_batch, args=[room]) + t.daemon = True + t.start() + _batch_timers[room] = t + return True + + +# ═══════════════════════════════════════════════════════════════ +# MAM Recovery — fetch recent history on startup +# ═══════════════════════════════════════════════════════════════ + +_MAM_RECOVERY = True +_MAM_RECOVERY_LOCK = threading.Lock() +_MAM_MARK_DONE = False +_MAM_TIMEOUT = 30 + + +def _set_mam_done(): + global _MAM_RECOVERY + with _MAM_RECOVERY_LOCK: + _MAM_RECOVERY = False + + +def _is_mam_recovery() -> bool: + if time.time() - _START_TIME > _MAM_TIMEOUT: + global _MAM_RECOVERY + with _MAM_RECOVERY_LOCK: + if _MAM_RECOVERY: + _MAM_RECOVERY = False + log("(MAM recovery timed out, force-disabled)") + return _MAM_RECOVERY + with _MAM_RECOVERY_LOCK: + return _MAM_RECOVERY + + +# ═══════════════════════════════════════════════════════════════ +# XML Escape +# ═══════════════════════════════════════════════════════════════ + +def _escape(text: str) -> str: + return (text.replace("&", "&").replace("<", "<") + .replace(">", ">").replace('"', """)) + + +# ═══════════════════════════════════════════════════════════════ +# HTTP Bridge — health, presence, messages, POST send +# ═══════════════════════════════════════════════════════════════ + +_HTTP_PORT = cfg["http_port"] +_MSG_BUF: list[dict] = [] +_MSG_BUF_LOCK = threading.Lock() +_xmpp_ref = None # set after bot creation + + +def _record_group_msg(nickname: str, body: str): + ts = time.strftime("%H:%M:%S") + with _MSG_BUF_LOCK: + _MSG_BUF.append({"ts": ts, "from": nickname, "body": body}) + if len(_MSG_BUF) > 200: + _MSG_BUF[:] = _MSG_BUF[-150:] + + +class _BridgeHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + parsed = urllib.parse.urlparse(self.path) + if parsed.path == "/muc": + try: + muc_info = {"rooms": {}} + bot = _xmpp_ref + if bot is not None and 'xep_0045' in bot.plugin: + muc_plugin = bot.plugin['xep_0045'] + for room_jid in cfg["muc_rooms"]: + room_data = {"jid": room_jid, "participants": []} + try: + if room_jid in muc_plugin.rooms: + room = muc_plugin.rooms[room_jid] + for nick, info in room.get('roster', {}).items(): + room_data["participants"].append({ + "nick": nick, + "jid": str(info.get('jid', '')), + "affiliation": str(info.get('affiliation', '')), + "role": str(info.get('role', '')), + }) + except Exception as e: + room_data["error"] = str(e) + muc_info["rooms"][room_jid] = room_data + self._reply(200, muc_info) + except Exception as e: + self._reply(500, {"ok": False, "error": str(e)}) + return + if parsed.path == "/health": + try: + bot = _xmpp_ref + session_ok = bot.session_started_event.is_set() if (bot and hasattr(bot, 'session_started_event')) else False + socket_ok = bot.is_connected() if (bot and hasattr(bot, 'is_connected')) else False + self._reply(200, { + "ok": True, "xmpp_connected": session_ok or socket_ok, + "agent": _agent_name, "jid": cfg["jid"], + "uptime_sec": int(time.time() - _START_TIME), + "muc_rooms": cfg["muc_rooms"], + }) + except Exception as e: + self._reply(500, {"ok": False, "error": str(e)}) + return + if parsed.path.startswith("/presence"): + jid_to_check = parsed.path[len("/presence/"):].strip() + if not jid_to_check: + self._reply(400, {"ok": False, "error": "missing JID"}) + return + try: + info = {"jid": jid_to_check, "online": False, "resources": []} + bot = _xmpp_ref + if bot and hasattr(bot, 'client_roster'): + roster = bot.client_roster + if jid_to_check in roster: + resources = list(roster[jid_to_check].resources.keys()) + info["online"] = len(resources) > 0 + info["resources"] = resources + self._reply(200, info) + except Exception as e: + self._reply(500, {"ok": False, "error": str(e)}) + return + if parsed.path == "/messages": + try: + qs = urllib.parse.parse_qs(parsed.query) + sender = qs.get("from", [None])[0] + with _MSG_BUF_LOCK: + msgs = list(_MSG_BUF) + if sender: + msgs = [m for m in msgs if m["from"] == sender] + self._reply(200, {"ok": True, "count": len(msgs), "messages": msgs[-50:]}) + except Exception as e: + self._reply(500, {"ok": False, "error": str(e)}) + return + self._reply(404, {"ok": False, "error": "not found"}) + + def do_POST(self): + try: + length = int(self.headers.get('Content-Length', 0)) + body = json.loads(self.rfile.read(length)) + to = body.get('to', cfg["muc_rooms"][0]) + msg = body.get('message', '') or body.get('body', '') + if not msg: + self._reply(400, {"ok": False, "error": "empty message"}) + return + safe = _escape(msg.strip()) + bot = _xmpp_ref + if bot: + stanza = f'{safe}' + bot.send_raw(stanza) + _record_group_msg(cfg["nick"], msg) + log(f"[http] → [{to.split('@')[0]}]: {msg[:80]}") + self._reply(200, {"ok": True}) + except Exception as e: + self._reply(500, {"ok": False, "error": str(e)}) + + def _reply(self, code, data): + body = json.dumps(data, ensure_ascii=False).encode('utf-8') + self.send_response(code) + self.send_header('Content-Type', 'application/json; charset=utf-8') + self.send_header('Content-Length', len(body)) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format, *args): + pass + + +def _start_http_bridge(): + _httpd = http.server.HTTPServer(('0.0.0.0', _HTTP_PORT), _BridgeHandler) + _t = threading.Thread(target=_httpd.serve_forever, daemon=True) + _t.start() + log(f"HTTP bridge ready on :{_HTTP_PORT}") + + +# ═══════════════════════════════════════════════════════════════ +# Reply Processing (shared for all LLM responses) +# ═══════════════════════════════════════════════════════════════ + +def _process_llm_reply(raw_reply: str, room: str): + """Process LLM response: check silence/delay/exec/send.""" + global _xmpp_ref + if not raw_reply: + _batch_done(room) + return + # Parse GRANT signal from LLM response + _process_llm_grant(raw_reply) + # ##delay:N## → schedule later + delay_m = _DELAY_RE.search(raw_reply) + if delay_m: + sec = int(delay_m.group(1)) if delay_m.group(1) else _DELAY_DEFAULT + _schedule_delayed(sec, room) + _batch_done(room) + return + # ##exec:command## → run command, use output as reply + exec_m = _EXEC_RE.search(raw_reply) + if exec_m: + output = _run_command(exec_m.group(1)) + raw_reply = _EXEC_RE.sub(output, raw_reply, count=1) + # Extract actual response + reply_text = _extract_response(raw_reply) + if reply_text: + safe = _escape(reply_text.strip()) + bot = _xmpp_ref + if bot: + stanza = f"{safe}" + bot.send_raw(stanza) + log(f"-> [{room.split('@')[0]}]: {reply_text.strip()[:80]}") + else: + log(f"-> [{room.split('@')[0]}]: (silent)") + _batch_done(room) + + +# ═══════════════════════════════════════════════════════════════ +# Group message handler +# ═══════════════════════════════════════════════════════════════ + +def _handle_group_message(msg): + """Process a groupchat message (runs in thread).""" + global _COORDINATOR, _GRANTED, _REVOKED_UNTIL + if _is_mam_recovery(): + return + msg_id = msg.get("id", "") + if _is_duplicate(msg_id): + return + body = str(msg["body"]).strip() + if not body: + return + full_from = str(msg["from"]) + room = full_from.split("/")[0] + nickname = full_from.split("/")[1] if "/" in full_from else "" + # Self-message skip + if nickname == cfg["nick"]: + log(f"(self) {body[:80]}") + return + _record_group_msg(nickname, body) + # Coordinator signals + if _process_coordinator_signals(nickname, body): + return + # Revoke check + is_revoked = time.time() < _REVOKED_UNTIL + if is_revoked and _GRANTED == cfg["nick"]: + _GRANTED = None + is_revoked = False + log(f"GRANT overrides REVOKE for {cfg['nick']}") + if _check_shutup(body): + return + if is_revoked: + body = f"【只读消息】你被收回发言权。只需了解内容。输出 __SILENT__。\n\n[核心群 {room}] {nickname} 说: {body}" + # Batch or immediate (@mention) + if _batch_group_message(room, nickname, body): + log(f"[{room.split('@')[0]}] {nickname}: {body[:80]} (batched)") + return + log(f"[{room.split('@')[0]}] {nickname}: {body[:80]}") + raw = _call_llm(body, full_from, is_group=True) + _process_llm_reply(raw, room) + + +# ═══════════════════════════════════════════════════════════════ +# Private message handler +# ═══════════════════════════════════════════════════════════════ + +def _handle_private_message(msg): + """Process a private chat message.""" + if msg["type"] == "groupchat": + return + msg_id = msg.get("id", "") + if _is_duplicate(msg_id): + return + body = str(msg["body"]).strip() + sender = str(msg["from"]).split("/")[0] + log(f"<{sender}> {body[:80]}") + if sender == cfg["jid"]: + log("(skipped self)") + return + if time.time() < _REVOKED_UNTIL: + log(f"(silenced) <{sender}> dropped") + return + if _check_shutup(body): + return + raw = _call_llm(body, sender, is_group=False) + if raw: + reply = _extract_response(raw) + if reply: + from xml.sax.saxutils import escape + safe = escape(reply.strip()) + if _IS_CHAT_BRIDGE: + bot = _xmpp_ref + if bot: + stanza = f"{safe}" + bot.send_raw(stanza) + log(f"-> {sender}: {reply.strip()[:80]}") + else: + import subprocess as sp + sp.run(["docker", "exec", "ejabberd", "ejabberdctl", "send_stanza", + cfg["jid"], sender, + f"{safe}" + ], capture_output=True, timeout=10) + log(f"-> {sender}: {reply.strip()[:80]}") + + +# ═══════════════════════════════════════════════════════════════ +# AgentBot Class +# ═══════════════════════════════════════════════════════════════ + +import slixmpp + + +class AgentBot(slixmpp.ClientXMPP): def __init__(self): - super().__init__(AGENT_JID, cfg["password"]) - self.add_event_handler('session_bind', self.on_bind) - self.add_event_handler('message', self.on_msg) - self.add_event_handler('disconnected', self.on_disconnect) - self.add_event_handler('connected', self.on_connected) + super().__init__(cfg["jid"], cfg["password"]) + # Connection settings + self.enable_direct_tls = False + self.enable_starttls = True + self.auto_reconnect = True + self.reconnect_max_delay = 10 + self.whitespace_keepalive = True + self.whitespace_keepalive_interval = 30 + # SSL: accept self-signed certs ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE self.ssl_context = ctx - self.ready = asyncio.Event() - self._call_seq = 0 - self._muc_joined = False - self._recent_sent = [] - self._coordinator = 'mohe' # 默认协调者 - self._granted = None + # Event handlers + self.add_event_handler("session_start", self._on_session_start) + self.add_event_handler("message", self._on_any_message) + self.add_event_handler("groupchat_message", self._on_group_msg) + self.add_event_handler("disconnected", self._on_disconnected) + self.add_event_handler("connected", self._on_connected) + self.add_event_handler("session_end", self._on_session_end) + self.add_event_handler("connection_failed", self._on_conn_failed) + # MUC plugin + self.register_plugin('xep_0045') - async def on_connected(self, event): - logging.info(f"🔗 {AGENT_NAME} TCP连接已建立") + def _on_connected(self, event): + log("connection established") - async def on_bind(self, event): + def _on_session_start(self, event): self.send_presence() self.get_roster() + log(f"{cfg['jid']} online") + # Register MAM plugin lazily try: - self.plugin['xep_0045'].join_muc('coregroup@conference.yoin.fun', AGENT_NICK) - logging.info(f"✅ {AGENT_NAME} 加入群聊 coregroup") - except Exception as e: - logging.error(f"❌ {AGENT_NAME} 加入群聊失败: {e}") - self._muc_joined = True - self.ready.set() - logging.info(f"✅ {AGENT_NAME} XMPP 上线") + self.register_plugin('xep_0313') + except Exception: + log("(MAM: xep_0313 not available)") + # Join MUC rooms + async def _join_all(): + for room_jid in cfg["muc_rooms"]: + try: + self.plugin['xep_0045'].join_muc(room_jid, cfg["nick"]) + presence = ( + f"" + f"" + f"" + f"" + ) + self.send_raw(presence) + log(f"Joined {room_jid}") + except Exception as e: + log(f"MUC join failed {room_jid}: {e}") + await asyncio.sleep(2) + await asyncio.sleep(3) + await self._fetch_mam_history() + asyncio.ensure_future(_join_all()) - async def on_disconnect(self, event): - self.ready.clear() - self._muc_joined = False - logging.warning(f"⚠️ {AGENT_NAME} XMPP 断线") - - async def on_msg(self, msg): - body = msg['body'] - sender = str(msg['from']) - msg_type = msg['type'] - if not body: + async def _fetch_mam_history(self): + """Query MAM for recent MUC messages to rebuild context.""" + if 'xep_0313' not in self.plugin: + log("(MAM: no plugin)") + _set_mam_done() return - if msg_type == 'groupchat': - if AGENT_JID in sender: - return - nickname = sender.split('/')[-1] if '/' in sender else '' - # 自己的消息跳过(通过昵称) - if nickname == AGENT_NICK: - return - - # Coordinator 模式 — 全走 XMPP 消息 - - # 1. hmo 切换 coordinator - if nickname == 'hmo' and 'coordinator=' in body.lower(): - for _name in ['mohe', 'zhiwei', 'xxm']: - if f'coordinator={_name}' in body.lower(): - self._coordinator = _name - self._granted = None - logging.info(f"👑 Coordinator 切换为 {_name}") - break - - # 2. 检测授权信号(优先于收回,GRANT 可以覆盖 REVOKE) - _grant_match = re.search(r'\[GRANT:(\w+)\]', body) - if _grant_match: - self._granted = _grant_match.group(1) - self._revoked_until = 0 # 被授权时解除收回 - logging.info(f"🎤 收到授权:{self._granted}") - - # 3. 检测收回信号 - _revoke_match = re.search(r'\[REVOKE:(\w+)\]', body) - _revoked_until = getattr(self, '_revoked_until', 0) - if _revoke_match and _revoke_match.group(1) == AGENT_NICK: - self._revoked_until = time.time() + 300 # 5分钟自动解除 - logging.info(f"🔇 {AGENT_NAME} 发言权被收回(5分钟后自动恢复)") - if time.time() < _revoked_until: - # 被收回者:只读模式,能看到但不能回复 - _rr = sender.split('/')[0] - _readonly_body = f"【只读消息】你目前被暂时收回发言权。只需了解内容。输出 __SILENT__。\n\n[核心群 {_rr}] {nickname} 说: {body}" - await self.call_hermes(_readonly_body, sender, is_group=True) - return - - # 3. 判断角色 - _coordinator = getattr(self, '_coordinator', 'mohe') - _granted = getattr(self, '_granted', None) - _is_coordinator = (_coordinator == AGENT_NICK) - _is_granted = (_granted == AGENT_NICK) - - if _is_granted: - self._granted = None # 用完即收回 - - room = sender.split('/')[0] - - # 4. 被 REVOKE 的人:只读模式 - if time.time() < getattr(self, '_revoked_until', 0): - _readonly = f"【只读消息】你目前被收回发言权。只需了解内容。输出 __SILENT__。\n\n[核心群 {room}] {nickname} 说: {body}" - await self.call_hermes(_readonly, sender, is_group=True) - return - - # 硬闭嘴闸门:hmo 说闭嘴类的话 → 静默 5 分钟 - _silent_until = getattr(self, '_silent_until', 0) - if time.time() < _silent_until: - return - if nickname == 'hmo': - _sk = ['闭嘴', '别说话', '安静', 'shut', 'stfu', '别说了', '停'] - if any(kw in body.lower() for kw in _sk): - self._silent_until = time.time() + 300 - logging.info(f"🔇 {AGENT_NAME} 收到闭嘴指令,静默 5 分钟") - return - - logging.info(f"📩 群消息 [{sender}]: {body[:100]}") - room = sender.split('/')[0] - ctx_body = ( - "【规则】以下是一条群聊消息。判断是否应该回复。\n" - "只有以下3种情况你才回复:\n" - f"1. hmo直接点名问你({AGENT_MENTION})\n" - "2. 你有其他人没说过的独家信息\n" - "3. 别人说错了关键事实,不纠正会有后果\n" - "如果以上都不符合,你的回复必须只包含 __SILENT__ 这10个字符," - "不要有任何其他内容(不要前缀、不要解释、不要标点、不要空格)。\n\n" - "注意:你是协调者(coordinator)。你的第一职责是管理讨论节奏,不是自己说话。\n" - "- 别人能回答的问题,不抢答。\n" - "- 如果其他 Agent 更合适,用 [GRANT:agent名] 授权他们发言(例如 [GRANT:zhiwei])。\n" - "- 如果有人跑题/刷屏,用 [REVOKE:agent名] 收回发言权(例如 [REVOKE:zhiwei])。\n" - "- [GRANT] 可以覆盖 [REVOKE]。标记会显示在消息中。\n\n" - f"[核心群 {room}] {nickname} 说: {body}" - ) - await self.call_hermes(ctx_body, room, is_group=True) - return - if msg_type == 'chat' and 'hmo@yoin.fun' in sender: - self._call_seq += 1 - logging.info(f"📩 老爸(#{self._call_seq}): {body}") - await self.call_hermes(body, sender, seq=self._call_seq) - - async def call_hermes(self, content, sender, is_group=False, seq=None): - msg_type = 'groupchat' if is_group else 'chat' try: - payload = json.dumps({ - "model": "hermes-agent", - "messages": [{"role": "user", "content": content}] - }).encode() - req = urllib.request.Request(GATEWAY, data=payload, method="POST") - req.add_header("Content-Type", "application/json") - req.add_header("Authorization", f"Bearer {API_KEY}") - req.add_header("X-Hermes-Session-Id", SESSION_ID) - - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, lambda: _opener.open(req, timeout=600)) - - if seq is not None and seq < self._call_seq: - return - - data = json.loads(result.read()) - reply = data.get("choices", [{}])[0].get("message", {}).get("content", "") - reply_stripped = reply.strip() - # 检查 __SILENT__ 标记 - if reply_stripped.startswith('__SILENT__') or reply_stripped.startswith('`__SILENT__`'): - logging.info(f"⏭️ {AGENT_NAME} 决定沉默,不发送") - return - # 如果回复里任意位置出现了 __SILENT__,说明 LLM 没理解协议,整条作废 - if '__SILENT__' in reply: - logging.info(f"⏭️ {AGENT_NAME} 回复中误用 __SILENT__,拦截") - return - # 如果回复里出现了 __REPLY__,也是协议混淆,拦截 - if '__REPLY__' in reply: - logging.info(f"⏭️ {AGENT_NAME} 回复中误用 __REPLY__,拦截") - return - # 额外拦截:LLM 说"我沉默""我不说了"等宣布沉默的话→当 SILENT 处理 - _silent_phrases = ['我沉默', '我不说', '不说了', '不回复', '不插嘴', '我闭嘴', - '闭嘴上', '沉默是', '彻底沉默', '我会沉默', '将保持沉默'] - if any(p in reply for p in _silent_phrases): - logging.info(f"⏭️ {AGENT_NAME} 宣布沉默(命中关键词),拦截") - return - finish = data.get("choices", [{}])[0].get("finish_reason", "") - - if reply.strip() and finish != "silent": - # Coordinator 授权机制:检测回复中的 [GRANT:xxx] 标记 - _grant_match = re.search(r'\[GRANT:(\w+)\]', reply) - if _grant_match: - _grant_name = _grant_match.group(1) - self._granted = _grant_name - logging.info(f"🎤 授权 {_grant_name} 发言(通过 XMPP 发送)") - # 不剥离标记,让其他 bot 从 XMPP 消息中解析 - - if msg_type == 'groupchat': - self.send_message(mto=sender, mbody=reply, mtype='groupchat') - # 防回声:记录已发送的消息正文 - sent_norm = reply.strip()[:100] - self._recent_sent.append(sent_norm) - if len(self._recent_sent) > 10: - self._recent_sent.pop(0) - else: - import subprocess as sp - from xml.sax.saxutils import escape - safe = escape(reply) - sp.run([ - "docker", "exec", "ejabberd", "ejabberdctl", "send_stanza", - AGENT_JID, str(sender), - f"{safe}" - ], capture_output=True, timeout=10) - logging.info(f"✅ {AGENT_NAME} 回复: {reply[:80]}") - except Exception as e: - logging.error(f"❌ {AGENT_NAME} 错误: {e}") - -# ── 主入口 ─────────────────────────────────────────────── -async def main(): - retry_delay = 1 - max_delay = 60 - while True: - try: - bot = AgentBot() - bot.register_plugin('xep_0030') - bot.register_plugin('xep_0045') - bot.register_plugin('xep_0199') - - bot.connect(host='127.0.0.1', port=5222) - await asyncio.wait_for(bot.ready.wait(), timeout=30) - logging.info(f"{AGENT_NAME} XMPP 就绪") - retry_delay = 1 - - async def _drain_queue(): - while True: - await asyncio.sleep(1) - while _send_queue: - room, text = _send_queue.pop(0) + for room_jid in cfg["muc_rooms"]: + log(f"(MAM: querying {room_jid} for last 50 messages...)") + results = await self.plugin['xep_0313'].retrieve( + jid=room_jid, rsm={'max': 50}, + ) + count = 0 + for msg in results['mam']['results']: + forwarded = msg['mam_result']['forwarded'] + body = str(forwarded['stanza']['body'] or '').strip() + if not body: + continue + nick = str(forwarded['stanza']['from']).split('/')[-1] if '/' in str(forwarded['stanza']['from']) else '?' + # Feed into context for chat_bridge (xxm) + if _IS_CHAT_BRIDGE: + role = 'user' if nick != cfg["nick"] else 'assistant' try: - bot.send_message(mto=room, mbody=text, mtype='groupchat') - sent_norm = text.strip()[:100] - bot._recent_sent.append(sent_norm) - if len(bot._recent_sent) > 10: - bot._recent_sent.pop(0) - logging.info(f"📤 主动发送到 {room}: {text[:60]}") - except Exception as e: - logging.error(f"❌ 主动发送失败: {e}") - asyncio.create_task(_drain_queue()) - - while True: - await asyncio.sleep(15) - if not bot.is_connected(): - logging.warning("检测到断线,准备重连...") - break - - except asyncio.TimeoutError: - logging.warning("连接超时,准备重连...") + _bridge._append_to_log(role, f"[{nick}]: {body[:300]}") + except Exception: + pass + count += 1 + log(f"(MAM: loaded {count} msgs from {room_jid})") + _set_mam_done() + log("(MAM recovery complete)") except Exception as e: - logging.error(f"❌ 主循环错误: {e}") + log(f"(MAM error: {e})") + _set_mam_done() - logging.info(f"⏳ 等待 {retry_delay} 秒后重连...") - await asyncio.sleep(retry_delay) - retry_delay = min(retry_delay * 2, max_delay) + def _on_group_msg(self, msg): + threading.Thread(target=_handle_group_message, args=[msg], daemon=True).start() + + def _on_any_message(self, msg): + threading.Thread(target=_handle_private_message, args=[msg], daemon=True).start() + + def _on_session_end(self, event): + log(f"session ended") + + def _on_conn_failed(self, event): + log(f"connection failed: {event}") + + def _on_disconnected(self, event): + log(f"disconnected, reconnecting...") + + +# ═══════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════ + +def main(): + log(f"Starting {cfg['jid']} ({cfg['name_cn']}) — agent={_agent_name}") + if _IS_CHAT_BRIDGE: + log(f" LLM: chat_bridge (session={cfg['session_id']})") + else: + log(f" LLM: Hermes API ({cfg['gateway']})") + log(f" Server: {cfg['server']}:{cfg['port']}") + log(f" Rooms: {cfg['muc_rooms']}") + + bot = AgentBot() + global _xmpp_ref + _xmpp_ref = bot + + _start_http_bridge() + + bot.connect(host=cfg["server"], port=cfg["port"]) + log(f"Connecting {cfg['jid']}@{cfg['server']}:{cfg['port']}") + + loop = asyncio.get_event_loop() + + async def _status_check(): + while True: + await asyncio.sleep(60) + log("(alive)") + + asyncio.ensure_future(_status_check()) -if __name__ == '__main__': try: - asyncio.run(main()) + loop.run_forever() except KeyboardInterrupt: - pass + log("Shutdown by user") + except Exception as e: + log(f"!!! MAIN LOOP CRASH: {e}") + import traceback + log(f"!!! {traceback.format_exc()[:500]}") + raise + + +if __name__ == "__main__": + main() diff --git a/xmpp_bot.py b/xmpp_bot.py deleted file mode 100644 index 7d5f186..0000000 --- a/xmpp_bot.py +++ /dev/null @@ -1,269 +0,0 @@ -#!/usr/bin/env python3 -"""XMPP Bot - 统一版,支持 --agent mohe|zhiwei|xiao 参数""" -import asyncio, logging, ssl, json, urllib.request, os, time, sys -from slixmpp import ClientXMPP - -# ── Agent 配置 ────────────────────────────────────────────── -AGENTS = { - "mohe": { - "jid": "mohe@yoin.fun", - "password": "hermes123", - "nick": "mohe", - "name_cn": "莫荷", - "http_port": 5804, - "gateway": "http://localhost:8642/v1/chat/completions", - "session_id": "xmpp-mohe-v2", - "mention": "@mohe/@莫荷", - }, - "zhiwei": { - "jid": "zhiwei@yoin.fun", - "password": "hermes123", - "nick": "zhiwei", - "name_cn": "知微", - "http_port": 5805, - "gateway": "http://localhost:8643/v1/chat/completions", - "session_id": "xmpp-zhiwei", - "mention": "@zhiwei/@知微", - }, - "xiaoguo": { - "jid": "xiaoguo@yoin.fun", - "password": "hermes123", - "nick": "xiaoguo", - "name_cn": "小果", - "http_port": 5806, - "gateway": "http://localhost:8645/v1/chat/completions", - "session_id": "xmpp-xiaoguo", - "mention": "@xiaoguo/@小果", - }, -} - -agent = sys.argv[sys.argv.index("--agent") + 1] if "--agent" in sys.argv else "mohe" -cfg = AGENTS.get(agent, AGENTS["mohe"]) - -logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') -GATEWAY = cfg["gateway"] -API_KEY = "hermes123" -AGENT_NICK = cfg["nick"] -AGENT_NAME = cfg["name_cn"] -AGENT_JID = cfg["jid"] -AGENT_MENTION = cfg["mention"] -SESSION_ID = cfg["session_id"] -HTTP_PORT = cfg["http_port"] -_opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) - -# ── HTTP 桥(接收本地脚本的主动发送请求) ── -from http.server import HTTPServer, BaseHTTPRequestHandler -import threading, json as json_mod - -_send_queue = [] - -class SendHandler(BaseHTTPRequestHandler): - def do_POST(self): - length = int(self.headers.get('Content-Length', 0)) - body = self.rfile.read(length) - try: - data = json_mod.loads(body) - room = data.get('to', 'coregroup@conference.yoin.fun') - text = data.get('body', '') - if text: - _send_queue.append((room, text)) - self.send_response(200) - self.end_headers() - self.wfile.write(b'{"ok":true}') - else: - self.send_response(400) - self.end_headers() - self.wfile.write(b'{"ok":false,"error":"empty body"}') - except Exception as e: - self.send_response(500) - self.end_headers() - self.wfile.write(f'{{"ok":false,"error":"{e}"}}'.encode()) - -def _run_http(): - server = HTTPServer(('127.0.0.1', HTTP_PORT), SendHandler) - server.timeout = 1.0 - while True: - server.handle_request() - -threading.Thread(target=_run_http, daemon=True).start() -logging.info(f"🚀 {AGENT_NAME} HTTP 桥启动于 :{HTTP_PORT}") - -# ── XMPP Bot 类 ──────────────────────────────────────────────── -class AgentBot(ClientXMPP): - def __init__(self): - super().__init__(AGENT_JID, cfg["password"]) - self.add_event_handler('session_bind', self.on_bind) - self.add_event_handler('message', self.on_msg) - self.add_event_handler('disconnected', self.on_disconnect) - self.add_event_handler('connected', self.on_connected) - ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - self.ssl_context = ctx - self.ready = asyncio.Event() - self._call_seq = 0 - self._muc_joined = False - self._recent_sent = [] - - async def on_connected(self, event): - logging.info(f"🔗 {AGENT_NAME} TCP连接已建立") - - async def on_bind(self, event): - self.send_presence() - self.get_roster() - try: - self.plugin['xep_0045'].join_muc('coregroup@conference.yoin.fun', AGENT_NICK) - logging.info(f"✅ {AGENT_NAME} 加入群聊 coregroup") - except Exception as e: - logging.error(f"❌ {AGENT_NAME} 加入群聊失败: {e}") - self._muc_joined = True - self.ready.set() - logging.info(f"✅ {AGENT_NAME} XMPP 上线") - - async def on_disconnect(self, event): - self.ready.clear() - self._muc_joined = False - logging.warning(f"⚠️ {AGENT_NAME} XMPP 断线") - - async def on_msg(self, msg): - body = msg['body'] - sender = str(msg['from']) - msg_type = msg['type'] - if not body: - return - if msg_type == 'groupchat': - if AGENT_JID in sender: - return - nickname = sender.split('/')[-1] if '/' in sender else '' - # 自己的消息跳过(通过昵称) - if nickname == AGENT_NICK: - return - - # 硬闭嘴闸门:hmo 说闭嘴类的话 → 静默 5 分钟 - _silent_until = getattr(self, '_silent_until', 0) - if time.time() < _silent_until: - return - if nickname == 'hmo': - _sk = ['闭嘴', '别说话', '安静', 'shut', 'stfu', '别说了', '停'] - if any(kw in body.lower() for kw in _sk): - self._silent_until = time.time() + 300 - logging.info(f"🔇 {AGENT_NAME} 收到闭嘴指令,静默 5 分钟") - return - - logging.info(f"📩 群消息 [{sender}]: {body[:100]}") - room = sender.split('/')[0] - ctx_body = ( - "【规则】以下是一条群聊消息。判断是否应该回复。\n" - "只有以下3种情况你才回复:\n" - f"1. hmo直接点名问你({AGENT_MENTION})\n" - "2. 你有其他人没说过的独家信息\n" - "3. 别人说错了关键事实,不纠正会有后果\n" - "如果以上都不符合,你的回复必须只包含 __SILENT__ 这10个字符," - "不要有任何其他内容(不要前缀、不要解释、不要标点、不要空格)。\n\n" - f"[核心群 {room}] {nickname} 说: {body}" - ) - await self.call_hermes(ctx_body, room, is_group=True) - return - if msg_type == 'chat' and 'hmo@yoin.fun' in sender: - self._call_seq += 1 - logging.info(f"📩 老爸(#{self._call_seq}): {body}") - await self.call_hermes(body, sender, seq=self._call_seq) - - async def call_hermes(self, content, sender, is_group=False, seq=None): - msg_type = 'groupchat' if is_group else 'chat' - try: - payload = json.dumps({ - "model": "hermes-agent", - "messages": [{"role": "user", "content": content}] - }).encode() - req = urllib.request.Request(GATEWAY, data=payload, method="POST") - req.add_header("Content-Type", "application/json") - req.add_header("Authorization", f"Bearer {API_KEY}") - req.add_header("X-Hermes-Session-Id", SESSION_ID) - - loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, lambda: _opener.open(req, timeout=600)) - - if seq is not None and seq < self._call_seq: - return - - data = json.loads(result.read()) - reply = data.get("choices", [{}])[0].get("message", {}).get("content", "") - reply_stripped = reply.strip() - if reply_stripped.startswith('__SILENT__') or reply_stripped.startswith('`__SILENT__`'): - logging.info(f"⏭️ {AGENT_NAME} 决定沉默,不发送") - return - finish = data.get("choices", [{}])[0].get("finish_reason", "") - - if reply.strip() and finish != "silent": - if msg_type == 'groupchat': - self.send_message(mto=sender, mbody=reply, mtype='groupchat') - sent_norm = reply.strip()[:100] - self._recent_sent.append(sent_norm) - if len(self._recent_sent) > 10: - self._recent_sent.pop(0) - else: - import subprocess as sp - from xml.sax.saxutils import escape - safe = escape(reply) - sp.run([ - "docker", "exec", "ejabberd", "ejabberdctl", "send_stanza", - AGENT_JID, str(sender), - f"{safe}" - ], capture_output=True, timeout=10) - logging.info(f"✅ {AGENT_NAME} 回复: {reply[:80]}") - except Exception as e: - logging.error(f"❌ {AGENT_NAME} 错误: {e}") - -# ── 主入口 ─────────────────────────────────────────────── -async def main(): - retry_delay = 1 - max_delay = 60 - while True: - try: - bot = AgentBot() - bot.register_plugin('xep_0030') - bot.register_plugin('xep_0045') - bot.register_plugin('xep_0199') - - bot.connect(host='127.0.0.1', port=5222) - await asyncio.wait_for(bot.ready.wait(), timeout=30) - logging.info(f"{AGENT_NAME} XMPP 就绪") - retry_delay = 1 - - async def _drain_queue(): - while True: - await asyncio.sleep(1) - while _send_queue: - room, text = _send_queue.pop(0) - try: - bot.send_message(mto=room, mbody=text, mtype='groupchat') - sent_norm = text.strip()[:100] - bot._recent_sent.append(sent_norm) - if len(bot._recent_sent) > 10: - bot._recent_sent.pop(0) - logging.info(f"📤 主动发送到 {room}: {text[:60]}") - except Exception as e: - logging.error(f"❌ 主动发送失败: {e}") - asyncio.create_task(_drain_queue()) - - while True: - await asyncio.sleep(15) - if not bot.is_connected(): - logging.warning("检测到断线,准备重连...") - break - - except asyncio.TimeoutError: - logging.warning("连接超时,准备重连...") - except Exception as e: - logging.error(f"❌ 主循环错误: {e}") - - logging.info(f"⏳ 等待 {retry_delay} 秒后重连...") - await asyncio.sleep(retry_delay) - retry_delay = min(retry_delay * 2, max_delay) - -if __name__ == '__main__': - try: - asyncio.run(main()) - except KeyboardInterrupt: - pass diff --git a/xmpp_bot_rest.py b/xmpp_bot_rest.py deleted file mode 100644 index b5b0b58..0000000 --- a/xmpp_bot_rest.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -"""XMPP Bot mohe@yoin.fun - 通过 ejabberd REST API 实现""" -import asyncio, logging, ssl, json, urllib.request, os, time -import subprocess, threading - -logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') -GATEWAY = "http://localhost:8642/v1/chat/completions" -API_KEY = "hermes123" -EJB_ADMIN = "admin@localhost" -EJB_PASS = "hermes123" -_opener = urllib.request.build_opener(urllib.request.ProxyHandler({})) - -LAST_SEQ = 0 - -def call_api(content, sender, seq): - """同步调 Hermes API 并回复""" - try: - payload = json.dumps({ - "model": "hermes-agent", - "messages": [{"role": "user", "content": content}] - }).encode() - req = urllib.request.Request(GATEWAY, data=payload, method="POST") - req.add_header("Content-Type", "application/json") - req.add_header("Authorization", f"Bearer {API_KEY}") - req.add_header("X-Hermes-Session-Id", "xmpp-mohe") - - result = _opener.open(req, timeout=600) - data = json.loads(result.read()) - reply = data.get("choices", [{}])[0].get("message", {}).get("content", "") - finish = data.get("choices", [{}])[0].get("finish_reason", "") - - global LAST_SEQ - if seq < LAST_SEQ: - logging.info(f"⏭️ 跳过过期 seq={seq}") - return - - if reply.strip() and finish != "silent": - # 通过 ejabberdctl 发送回复 - subprocess.run([ - "docker", "exec", "ejabberd", "ejabberdctl", "send_stanza", - "mohe@yoin.fun", sender, - f"{reply}" - ], capture_output=True, timeout=30) - logging.info(f"✅ 回复: {reply[:80]}") - except Exception as e: - logging.error(f"❌ 错误: {e}") - -def poll_messages(): - """轮询 ejabberd 离线消息""" - global LAST_SEQ - while True: - try: - # 用 ejabberdctl 获取 mohe 的离线消息 - result = subprocess.run([ - "docker", "exec", "ejabberd", "ejabberdctl", "get_offline_count", "mohe", "yoin.fun" - ], capture_output=True, text=True, timeout=10) - count = int(result.stdout.strip()) - - if count > 0: - # 获取消息内容并处理 - result2 = subprocess.run([ - "docker", "exec", "ejabberd", "ejabberdctl", "get_offline_messages", "mohe", "yoin.fun" - ], capture_output=True, text=True, timeout=10) - # 解析消息并处理(简化处理) - except: - pass - time.sleep(5) - -if __name__ == '__main__': - # 实际上需要通过 XMPP 连接或 BOSH/WS - # 这个方案太复杂,直接换个思路:让 ejabberd → webhook → 处理 → reponse - print("需要更简单的方法") diff --git a/xmpp_mohe_bot.py b/xmpp_mohe_bot.py index 68f3f28..66846c5 100644 --- a/xmpp_mohe_bot.py +++ b/xmpp_mohe_bot.py @@ -2,4 +2,4 @@ """Wrapper for xmpp_agent_core.py --agent mohe""" import sys, os sys.argv = [sys.argv[0], '--agent', 'mohe'] -exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py')).read()) +exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py'), encoding='utf-8').read()) diff --git a/xmpp_xiaoguo_bot.py b/xmpp_xiaoguo_bot.py index fe8db77..c4d58e6 100644 --- a/xmpp_xiaoguo_bot.py +++ b/xmpp_xiaoguo_bot.py @@ -2,4 +2,4 @@ """Wrapper for xmpp_agent_core.py --agent xiaoguo""" import sys, os sys.argv = [sys.argv[0], '--agent', 'xiaoguo'] -exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py')).read()) +exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py'), encoding='utf-8').read()) diff --git a/xmpp_zhiwei_bot.py b/xmpp_zhiwei_bot.py index c537f80..cbaa6d0 100644 --- a/xmpp_zhiwei_bot.py +++ b/xmpp_zhiwei_bot.py @@ -2,4 +2,4 @@ """Wrapper for xmpp_agent_core.py --agent zhiwei""" import sys, os sys.argv = [sys.argv[0], '--agent', 'zhiwei'] -exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py')).read()) +exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py'), encoding='utf-8').read()) diff --git a/xxm_bot.py b/xxm_bot.py new file mode 100644 index 0000000..34a9e65 --- /dev/null +++ b/xxm_bot.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +"""Wrapper for xmpp_agent_core.py --agent xxm""" +import sys, os +sys.argv = [sys.argv[0], '--agent', 'xxm'] +exec(open(os.path.join(os.path.dirname(__file__), 'xmpp_agent_core.py'), encoding='utf-8').read())