feat: search_sessions 跨session搜索工具 + 提示词强化

This commit is contained in:
hmo
2026-06-14 12:00:38 +08:00
parent 94c511479d
commit 28523beccf
+83
View File
@@ -193,6 +193,27 @@ _TOOLS = [
} }
} }
} }
},
{
"type": "function",
"function": {
"name": "search_sessions",
"description": "跨所有 session 搜索指定关键词,自动定位最相关的 session。不知道应该查哪个 session 时用这个。返回匹配到的 session 标题、匹配条数和消息摘要片段。",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索关键词,如项目名、文件名、话题等"
},
"max_sessions": {
"type": "integer",
"description": "最多搜几个 session,默认 5"
}
},
"required": ["query"]
}
}
} }
] ]
_MAX_TOOL_LOOPS = 30 # 超限后走 clean final force,不再泄漏 XML _MAX_TOOL_LOOPS = 30 # 超限后走 clean final force,不再泄漏 XML
@@ -212,6 +233,51 @@ def _run_tool_command(cmd: str) -> str:
return f"(执行失败: {e})" return f"(执行失败: {e})"
def _search_all_sessions(query: str, max_sessions: int = 5) -> str:
"""Search ALL recent sessions for a query. Returns formatted results."""
import sqlite3, json
db = _SERVE_DB
if not os.path.exists(db):
return f"(session 数据库不存在: {db})"
try:
conn = sqlite3.connect(db)
conn.row_factory = sqlite3.Row
# Get recent sessions
sessions = conn.execute(
"SELECT id, title, time_updated FROM session ORDER BY time_updated DESC LIMIT ?",
(max_sessions * 3,),
).fetchall()
results = []
for s in sessions:
sid = s["id"]
title = s["title"] or "(无标题)"
# Search messages in this session
msgs = conn.execute(
f"""SELECT m.id, m.data FROM message m
WHERE m.session_id=? ORDER BY m.time_created DESC LIMIT 50""",
(sid,),
).fetchall()
matches = []
for m in msgs:
try:
d = json.loads(m["data"])
content = d.get("content", "")
if query.lower() in content.lower():
matches.append(content[:200])
except (json.JSONDecodeError, ValueError):
continue
if len(matches) >= 3:
break
if matches:
results.append(f"[{title}]({sid[:16]}...): {len(matches)}条匹配\n" + "\n".join(f" · {m}" for m in matches))
conn.close()
if not results:
return f"(搜索 \"{query}\" 未在任何 session 中找到匹配)"
return "搜索到相关 session\n\n" + "\n\n".join(results)
except Exception as e:
return f"(搜索出错: {e})"
# ── Serve session DB path ── # ── Serve session DB path ──
_SERVE_DB = os.path.join( _SERVE_DB = os.path.join(
os.environ.get("USERPROFILE", "C:\\Users\\hmo"), os.environ.get("USERPROFILE", "C:\\Users\\hmo"),
@@ -469,6 +535,17 @@ class SessionBridge:
ctx = extract_session_context(sid, limit=limit) ctx = extract_session_context(sid, limit=limit)
output = ctx if ctx else f"(session {sid}: no messages)" output = ctx if ctx else f"(session {sid}: no messages)"
_logger.info(" tool: session_search → %s (%d chars)", sid[:32], len(output)) _logger.info(" tool: session_search → %s (%d chars)", sid[:32], len(output))
elif fn_name == "search_sessions":
try:
fn_args = json.loads(fn_args_str)
query = fn_args.get("query", "")
max_sessions = min(int(fn_args.get("max_sessions", 5)), 20)
except (json.JSONDecodeError, ValueError, TypeError):
query = ""
max_sessions = 5
output = _search_all_sessions(query, max_sessions)
_logger.info(" tool: search_sessions query=%s (%d chars)", query[:40], len(output))
else: else:
output = f"(unknown tool: {fn_name})" output = f"(unknown tool: {fn_name})"
@@ -552,6 +629,12 @@ class SessionBridge:
"否则被视为空话。只说不做比不说更糟糕。\n" "否则被视为空话。只说不做比不说更糟糕。\n"
"不确定该怎么做时,先 run_command 查一下再决定。\n" "不确定该怎么做时,先 run_command 查一下再决定。\n"
"\n" "\n"
"=== 回答问题前先搜 session ===\n"
"当被问到项目状态、代码位置、近期工作、其他人说过什么等事实性问题时,\n"
"**必须先使用 search_sessions 或 session_search 找到证据再回答**。\n"
"不要凭你当前 session 里的记忆回答——你的 session 可能不是最相关的。\n"
"正确流程:search_sessions(关键词) → 找到相关 session → session_search(具体 session) → 再回答\n"
"\n"
"=== 写文件的正确方式 ===\n" "=== 写文件的正确方式 ===\n"
"用 Python 一次性写完所有内容,不要分多次调用。\n" "用 Python 一次性写完所有内容,不要分多次调用。\n"
"错误示例(会覆盖,每调用一次就清空一次):python -c \"open('file', 'w').write('一行')\"\n" "错误示例(会覆盖,每调用一次就清空一次):python -c \"open('file', 'w').write('一行')\"\n"