Files
MoFin/venv/lib/python3.12/site-packages/litellm/integrations/compression_interception/handler.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

405 lines
15 KiB
Python

"""
Compression Interception Handler
CustomLogger that compresses inbound Anthropic Messages requests and fulfills
litellm_content_retrieve tool calls server-side via the typed agentic loop plan.
"""
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from litellm._logging import verbose_logger
from litellm.compression import compress
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.integrations.compression_interception import (
CompressionInterceptionConfig,
)
from litellm.types.integrations.custom_logger import (
AgenticLoopPlan,
AgenticLoopRequestPatch,
)
from litellm.types.utils import CallTypes
LITELLM_CONTENT_RETRIEVE_TOOL_NAME = "litellm_content_retrieve"
_CACHE_TTL_SECONDS = 15 * 60
class CompressionInterceptionLogger(CustomLogger):
"""
CustomLogger that implements transparent prompt compression + retrieval loops.
Flow:
1. Compress inbound /v1/messages requests in pre-call hook.
2. Inject litellm_content_retrieve tool and persist compressed cache by call_id.
3. Detect retrieval tool_use blocks in first model response.
4. Build typed rerun plan with tool_result blocks from the compressed cache.
"""
def __init__(
self,
enabled: bool = True,
compression_trigger: int = 200_000,
compression_target: Optional[int] = None,
embedding_model: Optional[str] = None,
embedding_model_params: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.enabled = enabled
self.compression_trigger = compression_trigger
self.compression_target = compression_target
self.embedding_model = embedding_model
self.embedding_model_params = embedding_model_params
self._compression_cache_by_call_id: Dict[str, Tuple[Dict[str, str], float]] = {}
@classmethod
def from_config_yaml(
cls, config: CompressionInterceptionConfig
) -> "CompressionInterceptionLogger":
return cls(
enabled=bool(config.get("enabled", True)),
compression_trigger=int(config.get("compression_trigger", 200_000)),
compression_target=config.get("compression_target"),
embedding_model=config.get("embedding_model"),
embedding_model_params=config.get("embedding_model_params"),
)
@staticmethod
def initialize_from_proxy_config(
litellm_settings: Dict[str, Any],
callback_specific_params: Dict[str, Any],
) -> "CompressionInterceptionLogger":
compression_params: CompressionInterceptionConfig = {}
if "compression_interception_params" in litellm_settings:
compression_params = litellm_settings["compression_interception_params"]
elif "compression_interception" in callback_specific_params and isinstance(
callback_specific_params["compression_interception"], dict
):
compression_params = cast(
CompressionInterceptionConfig,
callback_specific_params["compression_interception"],
)
return CompressionInterceptionLogger.from_config_yaml(compression_params)
async def async_pre_call_deployment_hook(
self, kwargs: Dict[str, Any], call_type: Optional[CallTypes]
) -> Optional[dict]:
if not self.enabled:
return None
if call_type is not None and call_type != CallTypes.anthropic_messages:
return None
if int(kwargs.get("_agentic_loop_depth", 0) or 0) > 0:
return None
messages = kwargs.get("messages")
model = kwargs.get("model")
if not isinstance(messages, list) or not isinstance(model, str):
return None
if self._has_retrieval_tool(kwargs.get("tools")):
return None
self._prune_expired_cache()
compressed = compress( # type: ignore
messages=messages,
model=model,
call_type=CallTypes.anthropic_messages,
compression_trigger=self.compression_trigger,
compression_target=self.compression_target,
embedding_model=self.embedding_model,
embedding_model_params=self.embedding_model_params,
)
cache = cast(Dict[str, str], compressed.get("cache", {}))
skip_reason = cast(Optional[str], compressed.get("compression_skipped_reason"))
compressed_tools = cast(List[Dict[str, Any]], compressed.get("tools", []))
# Only mutate kwargs when compression actually produced a result.
# If compression was a no-op (below trigger, invalid tool sequence, etc.),
# leave ``messages`` and ``tools`` untouched — injecting an empty
# ``tools: []`` onto a request that originally had no tools breaks
# Anthropic Messages requests.
if cache:
kwargs["messages"] = compressed["messages"]
if compressed_tools:
kwargs["tools"] = self._merge_tools(
existing_tools=cast(
Optional[List[Dict[str, Any]]], kwargs.get("tools")
),
compressed_tools=compressed_tools,
)
call_id = cast(Optional[str], kwargs.get("litellm_call_id"))
if not call_id:
call_id = str(uuid.uuid4())
kwargs["litellm_call_id"] = call_id
self._compression_cache_by_call_id[call_id] = (cache, time.time())
verbose_logger.debug(
"CompressionInterception: compressed request [call_id=%s original=%d compressed=%d cached_keys=%d]",
call_id,
compressed.get("original_tokens"),
compressed.get("compressed_tokens"),
len(cache),
)
elif skip_reason is not None:
verbose_logger.debug(
"CompressionInterception: compression skipped [reason=%s original=%d compressed=%d]",
skip_reason,
compressed.get("original_tokens"),
compressed.get("compressed_tokens"),
)
return kwargs
async def async_should_run_agentic_loop(
self,
response: Any,
model: str,
messages: List[Dict],
tools: Optional[List[Dict]],
stream: bool,
custom_llm_provider: str,
kwargs: Dict,
) -> Tuple[bool, Dict]:
if not self.enabled:
return False, {}
if not self._has_retrieval_tool(tools):
return False, {}
tool_calls, thinking_blocks = self._extract_retrieval_tool_calls(
response=response
)
if not tool_calls:
return False, {}
return True, {
"tool_calls": tool_calls,
"thinking_blocks": thinking_blocks,
"tool_type": "compression_retrieval",
}
async def async_build_agentic_loop_plan(
self,
tools: Dict,
model: str,
messages: List[Dict],
response: Any,
anthropic_messages_provider_config: Any,
anthropic_messages_optional_request_params: Dict,
logging_obj: Any,
stream: bool,
kwargs: Dict,
) -> AgenticLoopPlan:
self._prune_expired_cache()
tool_calls = cast(List[Dict[str, Any]], tools.get("tool_calls", []))
thinking_blocks = cast(List[Dict[str, Any]], tools.get("thinking_blocks", []))
call_id = self._resolve_call_id(logging_obj=logging_obj, kwargs=kwargs)
cache = self._get_cache(call_id=call_id)
retrieval_results = [
self._resolve_retrieval_content(tc, cache) for tc in tool_calls
]
assistant_message = {
"role": "assistant",
"content": thinking_blocks
+ [
{
"type": "tool_use",
"id": tc.get("id"),
"name": tc.get("name", LITELLM_CONTENT_RETRIEVE_TOOL_NAME),
"input": tc.get("input", {}),
}
for tc in tool_calls
],
}
user_message = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tool_calls[i].get("id"),
"content": retrieval_results[i],
}
for i in range(len(tool_calls))
],
}
follow_up_messages = messages + [assistant_message, user_message]
max_tokens = cast(
Optional[int],
anthropic_messages_optional_request_params.get("max_tokens")
or kwargs.get("max_tokens"),
)
optional_params_without_max_tokens = {
k: v
for k, v in anthropic_messages_optional_request_params.items()
if k != "max_tokens"
}
full_model_name = model
if logging_obj is not None:
agentic_params = logging_obj.model_call_details.get(
"agentic_loop_params", {}
)
full_model_name = cast(str, agentic_params.get("model", model))
request_patch = AgenticLoopRequestPatch(
model=full_model_name,
messages=follow_up_messages,
max_tokens=max_tokens,
optional_params=optional_params_without_max_tokens,
kwargs=self._prepare_followup_kwargs(kwargs=kwargs),
)
return AgenticLoopPlan(
run_agentic_loop=True,
request_patch=request_patch,
metadata={"tool_type": "compression_retrieval", "call_id": call_id or ""},
)
def _prune_expired_cache(self) -> None:
now = time.time()
self._compression_cache_by_call_id = {
call_id: (cache, created_at)
for call_id, (
cache,
created_at,
) in self._compression_cache_by_call_id.items()
if now - created_at <= _CACHE_TTL_SECONDS
}
def _get_cache(self, call_id: Optional[str]) -> Dict[str, str]:
if not call_id:
return {}
cache_entry = self._compression_cache_by_call_id.get(call_id)
if cache_entry is None:
return {}
return cache_entry[0]
def _resolve_call_id(
self, logging_obj: Any, kwargs: Dict[str, Any]
) -> Optional[str]:
if logging_obj is not None:
logging_call_id = getattr(logging_obj, "litellm_call_id", None)
if isinstance(logging_call_id, str) and logging_call_id:
return logging_call_id
kwargs_call_id = kwargs.get("litellm_call_id")
return cast(
Optional[str], kwargs_call_id if isinstance(kwargs_call_id, str) else None
)
def _resolve_retrieval_content(
self, tool_call: Dict[str, Any], cache: Dict[str, str]
) -> str:
raw_input = tool_call.get("input", {})
key = ""
if isinstance(raw_input, dict):
key = str(raw_input.get("key", "") or "")
if not key:
return "No retrieval key provided."
if key in cache:
return cache[key]
return f"[compressed content key '{key}' not found]"
def _extract_retrieval_tool_calls(
self, response: Any
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
if isinstance(response, dict):
content = response.get("content", [])
else:
content = getattr(response, "content", []) or []
if not isinstance(content, list):
return [], []
tool_calls: List[Dict[str, Any]] = []
thinking_blocks: List[Dict[str, Any]] = []
for block in content:
if isinstance(block, dict):
block_type = block.get("type")
block_name = block.get("name")
if block_type in ("thinking", "redacted_thinking"):
thinking_blocks.append(block)
if (
block_type == "tool_use"
and block_name == LITELLM_CONTENT_RETRIEVE_TOOL_NAME
):
tool_calls.append(
{
"id": block.get("id"),
"type": "tool_use",
"name": block_name,
"input": block.get("input", {}),
}
)
else:
block_type = getattr(block, "type", None)
block_name = getattr(block, "name", None)
if block_type == "thinking":
thinking_blocks.append(
{
"type": "thinking",
"thinking": getattr(block, "thinking", ""),
"signature": getattr(block, "signature", ""),
}
)
elif block_type == "redacted_thinking":
thinking_blocks.append(
{
"type": "redacted_thinking",
"data": getattr(block, "data", ""),
}
)
if (
block_type == "tool_use"
and block_name == LITELLM_CONTENT_RETRIEVE_TOOL_NAME
):
tool_calls.append(
{
"id": getattr(block, "id", None),
"type": "tool_use",
"name": block_name,
"input": getattr(block, "input", {}) or {},
}
)
return tool_calls, thinking_blocks
def _prepare_followup_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
internal_keys = {"litellm_logging_obj"}
return {
k: v
for k, v in kwargs.items()
if not k.startswith("_compression_interception") and k not in internal_keys
}
def _has_retrieval_tool(self, tools: Any) -> bool:
if not isinstance(tools, list):
return False
for tool in tools:
if not isinstance(tool, dict):
continue
function = tool.get("function")
if tool.get("type") == "function" and isinstance(function, dict):
if function.get("name") == LITELLM_CONTENT_RETRIEVE_TOOL_NAME:
return True
if (
tool.get("type") == "custom"
and tool.get("name") == LITELLM_CONTENT_RETRIEVE_TOOL_NAME
):
return True
return False
def _merge_tools(
self,
existing_tools: Optional[List[Dict[str, Any]]],
compressed_tools: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
merged = list(existing_tools or [])
if self._has_retrieval_tool(merged):
return merged
merged.extend(compressed_tools)
return merged