fa45d8aa5f
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
Privoxy对node122:18003返回500,直连正常
514 lines
18 KiB
Python
514 lines
18 KiB
Python
"""
|
|
Main compress() function — normalizes input messages, orchestrates BM25/embedding
|
|
scoring, message stubbing, and retrieval tool injection.
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
|
|
|
from litellm.caching.dual_cache import DualCache
|
|
from litellm.compression.message_stubbing import (
|
|
extract_key,
|
|
stub_message,
|
|
truncate_message,
|
|
)
|
|
from litellm.compression.retrieval_tool import build_retrieval_tool
|
|
from litellm.compression.scoring.bm25 import bm25_score_messages
|
|
from litellm.litellm_core_utils.token_counter import token_counter
|
|
from litellm.types.compression import CompressedResult
|
|
from litellm.types.utils import CallTypes
|
|
|
|
# CallTypes that produce Anthropic-shaped messages (structured content blocks).
|
|
# Everything else is treated as OpenAI chat-completions shape.
|
|
_ANTHROPIC_CALL_TYPES = frozenset({CallTypes.anthropic_messages.value})
|
|
# CallTypes that are valid targets for compression. Compression operates on
|
|
# message-shaped inputs, so we only accept call types whose payload is a list
|
|
# of role/content messages.
|
|
_SUPPORTED_CALL_TYPES = frozenset(
|
|
{
|
|
CallTypes.completion.value,
|
|
CallTypes.acompletion.value,
|
|
CallTypes.anthropic_messages.value,
|
|
}
|
|
)
|
|
|
|
|
|
def _normalize_call_type(call_type: Union[CallTypes, str]) -> str:
|
|
"""Return the string value for a ``CallTypes`` enum or a raw string."""
|
|
if isinstance(call_type, CallTypes):
|
|
return call_type.value
|
|
return call_type
|
|
|
|
|
|
def _is_anthropic_call_type(call_type: str) -> bool:
|
|
return call_type in _ANTHROPIC_CALL_TYPES
|
|
|
|
|
|
def _build_retrieval_tools(keys: List[str], call_type: str) -> List[dict]:
|
|
"""
|
|
Build retrieval tool definitions in the target request schema.
|
|
|
|
- Chat-completions call types: keep OpenAI function-tool schema.
|
|
- Anthropic messages call type: remap to Anthropic's custom tool schema.
|
|
"""
|
|
if not keys:
|
|
return []
|
|
|
|
openai_tools = [build_retrieval_tool(keys)]
|
|
if not _is_anthropic_call_type(call_type):
|
|
return openai_tools
|
|
|
|
# Lazy import to avoid introducing provider transformation imports during
|
|
# module import for non-Anthropic call paths.
|
|
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
|
|
|
anthropic_tools, _mcp_servers = AnthropicConfig()._map_tools(openai_tools)
|
|
return cast(List[dict], anthropic_tools)
|
|
|
|
|
|
def _content_to_text(content: Any) -> str:
|
|
"""
|
|
Convert OpenAI/Anthropic message content blocks to plain text.
|
|
|
|
Text extraction policy:
|
|
- Include text-bearing fields only (`text` blocks + string values).
|
|
- For `tool_result`, expand into nested `content` items.
|
|
- Ignore non-textual blocks (images/documents/tool metadata/thinking metadata).
|
|
|
|
Implemented iteratively (stack-based) to avoid unbounded recursion.
|
|
"""
|
|
parts: List[str] = []
|
|
stack: List[Any] = [content]
|
|
while stack:
|
|
item = stack.pop()
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, list):
|
|
# Push list items in reverse order so they are processed left-to-right.
|
|
for element in reversed(item):
|
|
stack.append(element)
|
|
elif isinstance(item, dict):
|
|
item_type = item.get("type")
|
|
if item_type == "text":
|
|
parts.append(str(item.get("text", "")))
|
|
elif item_type == "tool_result":
|
|
stack.append(item.get("content", ""))
|
|
return " ".join(parts)
|
|
|
|
|
|
def _normalize_messages_for_compression(
|
|
messages: List[dict],
|
|
call_type: str,
|
|
) -> Tuple[List[dict], List[dict]]:
|
|
"""
|
|
Normalize each original message to a text-surrogate content for scoring.
|
|
|
|
Returns:
|
|
(normalized_messages, original_messages_copy)
|
|
"""
|
|
if call_type not in _SUPPORTED_CALL_TYPES:
|
|
raise ValueError(
|
|
f"Unsupported call_type={call_type!r} for compression. "
|
|
f"Expected one of: {sorted(_SUPPORTED_CALL_TYPES)}."
|
|
)
|
|
|
|
original_messages: List[Dict[str, Any]] = [dict(m) for m in messages]
|
|
|
|
normalized_messages: List[dict] = []
|
|
for msg in original_messages:
|
|
normalized_messages.append(
|
|
{
|
|
**msg,
|
|
"content": _content_to_text(msg.get("content", "")),
|
|
}
|
|
)
|
|
return normalized_messages, original_messages
|
|
|
|
|
|
def _extract_last_user_message(messages: List[dict]) -> str:
|
|
"""Return the text content of the last user message."""
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
return _content_to_text(msg.get("content", ""))
|
|
return ""
|
|
|
|
|
|
def _extract_tool_use_ids(content: Any) -> List[str]:
|
|
if not isinstance(content, list):
|
|
return []
|
|
tool_use_ids: List[str] = []
|
|
for part in content:
|
|
if not isinstance(part, dict):
|
|
continue
|
|
if part.get("type") != "tool_use":
|
|
continue
|
|
tool_use_id = part.get("id")
|
|
if isinstance(tool_use_id, str) and tool_use_id:
|
|
tool_use_ids.append(tool_use_id)
|
|
return tool_use_ids
|
|
|
|
|
|
def _extract_tool_result_ids(content: Any) -> Set[str]:
|
|
if not isinstance(content, list):
|
|
return set()
|
|
tool_result_ids: Set[str] = set()
|
|
for part in content:
|
|
if not isinstance(part, dict):
|
|
continue
|
|
if part.get("type") != "tool_result":
|
|
continue
|
|
tool_use_id = part.get("tool_use_id")
|
|
if isinstance(tool_use_id, str) and tool_use_id:
|
|
tool_result_ids.add(tool_use_id)
|
|
return tool_result_ids
|
|
|
|
|
|
def _extract_anthropic_tool_exchange_spans(
|
|
messages: List[dict],
|
|
) -> Tuple[List[Set[int]], Optional[str]]:
|
|
"""
|
|
Return atomic 2-message spans for Anthropic tool exchanges.
|
|
|
|
Each assistant message containing `tool_use` must be immediately followed by a
|
|
user message containing matching `tool_result` blocks for all tool_use ids.
|
|
"""
|
|
spans: List[Set[int]] = []
|
|
i = 0
|
|
while i < len(messages):
|
|
current = messages[i]
|
|
if current.get("role") != "assistant":
|
|
i += 1
|
|
continue
|
|
|
|
tool_use_ids = _extract_tool_use_ids(current.get("content"))
|
|
if not tool_use_ids:
|
|
i += 1
|
|
continue
|
|
|
|
if i + 1 >= len(messages):
|
|
return [], "invalid_anthropic_tool_sequence"
|
|
|
|
next_msg = messages[i + 1]
|
|
if next_msg.get("role") != "user":
|
|
return [], "invalid_anthropic_tool_sequence"
|
|
|
|
tool_result_ids = _extract_tool_result_ids(next_msg.get("content"))
|
|
if not tool_result_ids:
|
|
return [], "invalid_anthropic_tool_sequence"
|
|
|
|
for tool_use_id in tool_use_ids:
|
|
if tool_use_id not in tool_result_ids:
|
|
return [], "invalid_anthropic_tool_sequence"
|
|
|
|
spans.append({i, i + 1})
|
|
i += 2
|
|
|
|
return spans, None
|
|
|
|
|
|
def _get_protected_indices(messages: List[dict]) -> List[int]:
|
|
"""
|
|
Return indices of messages that must never be compressed:
|
|
- All system messages
|
|
- The last user message
|
|
- The last assistant message
|
|
"""
|
|
protected: List[int] = []
|
|
|
|
last_user_idx = None
|
|
last_assistant_idx = None
|
|
|
|
for i, msg in enumerate(messages):
|
|
role = msg.get("role", "")
|
|
if role == "system":
|
|
protected.append(i)
|
|
elif role == "user":
|
|
last_user_idx = i
|
|
elif role == "assistant":
|
|
last_assistant_idx = i
|
|
|
|
if last_user_idx is not None:
|
|
protected.append(last_user_idx)
|
|
if last_assistant_idx is not None:
|
|
protected.append(last_assistant_idx)
|
|
|
|
return protected
|
|
|
|
|
|
def _combine_scores(
|
|
bm25_scores: List[float],
|
|
emb_scores: List[float],
|
|
bm25_weight: float = 0.4,
|
|
) -> List[float]:
|
|
"""Weighted average of BM25 and embedding scores, with min-max normalization."""
|
|
|
|
def _normalize(scores: List[float]) -> List[float]:
|
|
min_s = min(scores) if scores else 0.0
|
|
max_s = max(scores) if scores else 0.0
|
|
rng = max_s - min_s
|
|
if rng == 0:
|
|
return [0.0] * len(scores)
|
|
return [(s - min_s) / rng for s in scores]
|
|
|
|
norm_bm25 = _normalize(bm25_scores)
|
|
norm_emb = _normalize(emb_scores)
|
|
emb_weight = 1.0 - bm25_weight
|
|
|
|
return [bm25_weight * b + emb_weight * e for b, e in zip(norm_bm25, norm_emb)]
|
|
|
|
|
|
def _select_kept_indices_for_budget(
|
|
normalized_messages: List[dict],
|
|
original_messages: List[dict],
|
|
combined_scores: List[float],
|
|
compression_target: int,
|
|
model: str,
|
|
initial_kept_indices: Set[int],
|
|
tool_exchange_spans: List[Set[int]],
|
|
) -> Tuple[Set[int], Dict[int, dict]]:
|
|
kept_indices = set(initial_kept_indices)
|
|
current_tokens = 0
|
|
for i in kept_indices:
|
|
current_tokens += token_counter(
|
|
model=model,
|
|
text=cast(str, normalized_messages[i].get("content", "") or ""),
|
|
)
|
|
|
|
# Fill token budget from highest-scoring units.
|
|
# A unit is either:
|
|
# 1) a single message index, or
|
|
# 2) an Anthropic tool-exchange span that must be kept/dropped atomically.
|
|
truncated_overrides: Dict[int, dict] = {} # idx -> truncated message dict
|
|
span_id_by_index: Dict[int, int] = {}
|
|
for span_id, span in enumerate(tool_exchange_spans):
|
|
for idx in span:
|
|
span_id_by_index[idx] = span_id
|
|
|
|
# Build single-message candidate units (non-span messages).
|
|
candidate_units: List[Tuple[float, Tuple[int, ...], bool]] = []
|
|
for idx in range(len(normalized_messages)):
|
|
if idx in span_id_by_index or idx in kept_indices:
|
|
continue
|
|
candidate_units.append((combined_scores[idx], (idx,), True))
|
|
|
|
# Build span candidate units (atomic keep/drop for tool exchanges).
|
|
for span in tool_exchange_spans:
|
|
span_indices = tuple(sorted(span))
|
|
if any(idx in kept_indices for idx in span_indices):
|
|
continue
|
|
span_score = max(combined_scores[idx] for idx in span_indices)
|
|
candidate_units.append((span_score, span_indices, False))
|
|
|
|
# Sort by descending relevance score.
|
|
candidate_units.sort(key=lambda item: item[0], reverse=True)
|
|
|
|
for _score, indices, can_truncate in candidate_units:
|
|
if any(idx in kept_indices for idx in indices):
|
|
continue
|
|
msg_tokens = 0
|
|
for idx in indices:
|
|
msg_tokens += token_counter(
|
|
model=model,
|
|
text=cast(str, normalized_messages[idx].get("content", "") or ""),
|
|
)
|
|
remaining = compression_target - current_tokens
|
|
|
|
if remaining <= 0:
|
|
break # budget exhausted
|
|
|
|
if current_tokens + msg_tokens <= compression_target:
|
|
# Fits entirely
|
|
kept_indices.update(indices)
|
|
current_tokens += msg_tokens
|
|
elif can_truncate and len(indices) == 1 and remaining >= 100:
|
|
# Too large to fit whole single message, but we have budget — truncate it.
|
|
idx = indices[0]
|
|
truncated = truncate_message(original_messages[idx], remaining)
|
|
truncated_tokens = token_counter(
|
|
model=model,
|
|
text=truncated.get("content", "") or "",
|
|
)
|
|
truncated_overrides[idx] = truncated
|
|
kept_indices.add(idx)
|
|
current_tokens += truncated_tokens
|
|
|
|
return kept_indices, truncated_overrides
|
|
|
|
|
|
def _get_dropped_tool_span_indices(
|
|
kept_indices: Set[int], tool_exchange_spans: List[Set[int]]
|
|
) -> Set[int]:
|
|
dropped_tool_span_indices: Set[int] = set()
|
|
for span in tool_exchange_spans:
|
|
if not any(idx in kept_indices for idx in span):
|
|
dropped_tool_span_indices.update(span)
|
|
return dropped_tool_span_indices
|
|
|
|
|
|
def compress(
|
|
messages: List[dict],
|
|
model: str,
|
|
call_type: Union[CallTypes, str] = CallTypes.completion,
|
|
compression_trigger: int = 200_000,
|
|
compression_target: Optional[int] = None,
|
|
embedding_model: Optional[str] = None,
|
|
embedding_model_params: Optional[Dict[str, Any]] = None,
|
|
compression_cache: Optional[DualCache] = None,
|
|
) -> CompressedResult:
|
|
"""
|
|
Compress a list of messages by replacing low-relevance content with stubs.
|
|
|
|
Messages below ``compression_trigger`` tokens pass through unchanged.
|
|
Messages above are scored with BM25 (and optionally embeddings), ranked,
|
|
and the lowest-relevance messages are replaced with stubs. Originals are
|
|
cached and a retrieval tool is injected so the model can recover dropped
|
|
content on demand.
|
|
|
|
Parameters:
|
|
messages: The conversation messages to (potentially) compress.
|
|
model: The LLM model name — used for token counting.
|
|
call_type: The LiteLLM call type whose message schema these messages
|
|
follow. Supported values:
|
|
- ``CallTypes.completion`` / ``CallTypes.acompletion`` — OpenAI
|
|
chat-completions shape (default)
|
|
- ``CallTypes.anthropic_messages`` — Anthropic Messages shape
|
|
(structured content blocks + atomic tool exchanges)
|
|
compression_trigger: Only compress if input exceeds this token count.
|
|
compression_target: Target token count after compression.
|
|
Defaults to ``compression_trigger // 2``.
|
|
embedding_model: If provided, use BM25 + embeddings for scoring.
|
|
If ``None``, BM25 only.
|
|
embedding_model_params: Optional kwargs forwarded to
|
|
``litellm.embedding()`` when ``embedding_model`` is set.
|
|
compression_cache: Passed through to ``litellm.embedding()`` for
|
|
cross-turn caching of embedding vectors.
|
|
|
|
Returns:
|
|
A ``CompressedResult`` dict containing compressed messages, token
|
|
counts, a cache of original content, and the retrieval tool definition.
|
|
"""
|
|
call_type_str = _normalize_call_type(call_type)
|
|
normalized_messages, original_messages = _normalize_messages_for_compression(
|
|
messages=messages,
|
|
call_type=call_type_str,
|
|
)
|
|
|
|
if compression_target is None:
|
|
compression_target = compression_trigger * 7 // 10
|
|
|
|
original_tokens = token_counter(
|
|
model=model,
|
|
messages=cast(List[Any], original_messages),
|
|
)
|
|
|
|
# Pass through if below trigger
|
|
if original_tokens <= compression_trigger:
|
|
return CompressedResult(
|
|
messages=original_messages,
|
|
original_tokens=original_tokens,
|
|
compressed_tokens=original_tokens,
|
|
compression_ratio=0.0,
|
|
cache={},
|
|
tools=[],
|
|
compression_skipped_reason="below_trigger",
|
|
)
|
|
|
|
# Extract query for relevance scoring
|
|
query = _extract_last_user_message(normalized_messages)
|
|
|
|
# Score each message
|
|
bm25_scores = bm25_score_messages(query, normalized_messages)
|
|
|
|
if embedding_model:
|
|
from litellm.compression.scoring.embedding_scorer import (
|
|
embedding_score_messages,
|
|
)
|
|
|
|
emb_scores = embedding_score_messages(
|
|
query,
|
|
normalized_messages,
|
|
model=embedding_model,
|
|
cache=compression_cache,
|
|
embedding_model_params=embedding_model_params,
|
|
)
|
|
combined_scores = _combine_scores(bm25_scores, emb_scores, bm25_weight=0.4)
|
|
else:
|
|
combined_scores = bm25_scores
|
|
|
|
# Protected messages are never compressed
|
|
protected_indices = _get_protected_indices(normalized_messages)
|
|
kept_indices: Set[int] = set(protected_indices)
|
|
|
|
tool_exchange_spans: List[Set[int]] = []
|
|
if _is_anthropic_call_type(call_type_str):
|
|
tool_exchange_spans, tool_sequence_error = (
|
|
_extract_anthropic_tool_exchange_spans(original_messages)
|
|
)
|
|
if tool_sequence_error is not None:
|
|
return CompressedResult(
|
|
messages=original_messages,
|
|
original_tokens=original_tokens,
|
|
compressed_tokens=original_tokens,
|
|
compression_ratio=0.0,
|
|
cache={},
|
|
tools=[],
|
|
compression_skipped_reason=tool_sequence_error,
|
|
)
|
|
|
|
for span in tool_exchange_spans:
|
|
# If any message in the span is protected, keep the whole span.
|
|
if any(idx in kept_indices for idx in span):
|
|
kept_indices.update(span)
|
|
|
|
kept_indices, truncated_overrides = _select_kept_indices_for_budget(
|
|
normalized_messages=normalized_messages,
|
|
original_messages=original_messages,
|
|
combined_scores=combined_scores,
|
|
compression_target=compression_target,
|
|
model=model,
|
|
initial_kept_indices=kept_indices,
|
|
tool_exchange_spans=tool_exchange_spans,
|
|
)
|
|
|
|
# Build compressed messages and cache
|
|
compressed_messages: List[dict] = []
|
|
cache: Dict[str, str] = {}
|
|
used_keys: Set[str] = set()
|
|
dropped_tool_span_indices = _get_dropped_tool_span_indices(
|
|
kept_indices=kept_indices, tool_exchange_spans=tool_exchange_spans
|
|
)
|
|
|
|
for i, msg in enumerate(original_messages):
|
|
if i in dropped_tool_span_indices:
|
|
continue
|
|
if i in kept_indices:
|
|
# Use the truncated version if we made one, otherwise the original
|
|
compressed_messages.append(truncated_overrides.get(i, msg))
|
|
else:
|
|
key = extract_key(
|
|
normalized_messages[i], fallback_index=i, used_keys=used_keys
|
|
)
|
|
content = _content_to_text(msg.get("content", ""))
|
|
cache[key] = content
|
|
compressed_messages.append(stub_message(msg, key))
|
|
|
|
# Build retrieval tool in the target request schema
|
|
tools = _build_retrieval_tools(list(cache.keys()), call_type=call_type_str)
|
|
|
|
compressed_tokens = token_counter(
|
|
model=model,
|
|
messages=cast(List[Any], compressed_messages),
|
|
)
|
|
|
|
return CompressedResult(
|
|
messages=compressed_messages,
|
|
original_tokens=original_tokens,
|
|
compressed_tokens=compressed_tokens,
|
|
compression_ratio=(
|
|
round(1 - (compressed_tokens / original_tokens), 4)
|
|
if original_tokens > 0
|
|
else 0.0
|
|
),
|
|
cache=cache,
|
|
tools=tools,
|
|
)
|