Files
MoFin/venv/lib/python3.12/site-packages/litellm/integrations/rubrik.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

606 lines
24 KiB
Python

"""Rubrik LiteLLM Plugin for tool blocking and batch logging."""
import asyncio
import os
import random
import time
import urllib.parse
import uuid
from collections import Counter
from typing import TYPE_CHECKING, Any, Literal, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
ModifyResponseException,
)
from litellm.litellm_core_utils.core_helpers import safe_deep_copy
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Function,
GenericGuardrailAPIInputs,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
_ENDPOINT_ANTHROPIC_MESSAGES = "/v1/messages"
_WEBHOOK_PATH_TOOL_BLOCKING = "/v1/after_completion/openai/v1"
_WEBHOOK_PATH_LOGGING_BATCH = "/v1/litellm/batch"
_MAX_QUEUE_SIZE = 10_000
_DROP_WARNING_INTERVAL_SECONDS = 60.0
class _MalformedToolBlockingResponseError(Exception):
"""Raised when the tool blocking service returns a structurally invalid
response (e.g. empty ``choices``).
Distinct from transient network/HTTP errors so callers can surface a
louder, misconfiguration-style log instead of treating it as a routine
fail-open.
"""
class RubrikLogger(CustomGuardrail, CustomBatchLogger):
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
**kwargs,
):
self.flush_lock = asyncio.Lock()
kwargs.setdefault("guardrail_name", "rubrik")
# `initialize_guardrail` always passes these kwargs explicitly, with
# value `None` when the user omits `mode` / `default_on` from the
# guardrail config. Coerce None (omitted) to the desired default
# while preserving any explicit value the caller did set --
# in particular `default_on=False` if the user wants the guardrail
# off by default.
kwargs["event_hook"] = kwargs.get("event_hook") or GuardrailEventHooks.post_call
if kwargs.get("default_on") is None:
kwargs["default_on"] = True
super().__init__(
flush_lock=self.flush_lock,
**kwargs,
)
verbose_logger.debug("initializing rubrik logger")
self.sampling_rate = 1.0
rbrk_sampling_rate = os.getenv("RUBRIK_SAMPLING_RATE")
if rbrk_sampling_rate is not None:
try:
parsed_rate = float(rbrk_sampling_rate.strip())
self.sampling_rate = max(0.0, min(1.0, parsed_rate))
if parsed_rate != self.sampling_rate:
verbose_logger.warning(
f"RUBRIK_SAMPLING_RATE={parsed_rate} clamped to "
f"{self.sampling_rate}"
)
except ValueError:
verbose_logger.warning(
f"Invalid RUBRIK_SAMPLING_RATE: {rbrk_sampling_rate!r}, using 1.0"
)
self.key = api_key or os.getenv("RUBRIK_API_KEY")
if not self.key:
verbose_logger.warning(
"Rubrik: No API key configured. Requests will be unauthenticated."
)
_batch_size = os.getenv("RUBRIK_BATCH_SIZE")
if _batch_size:
try:
self.batch_size = int(_batch_size)
except ValueError:
verbose_logger.warning(
f"Invalid RUBRIK_BATCH_SIZE: {_batch_size!r}, using default"
)
# Cap the in-memory retry queue so a Rubrik webhook outage cannot let
# authenticated traffic accumulate prompt/response payloads until the
# proxy runs out of memory. Once the cap is reached, oldest events are
# dropped to make room for fresh ones (drop-oldest backpressure).
self.max_queue_size = _MAX_QUEUE_SIZE
self._dropped_since_warning = 0
self._last_drop_warning_time = 0.0
_webhook_url = api_base or os.getenv("RUBRIK_WEBHOOK_URL")
if _webhook_url is None:
raise ValueError(
"Rubrik webhook URL not configured. "
"Set RUBRIK_WEBHOOK_URL or pass api_base."
)
_webhook_url = _webhook_url.rstrip("/").removesuffix("/v1")
self.tool_blocking_endpoint = f"{_webhook_url}{_WEBHOOK_PATH_TOOL_BLOCKING}"
self.logging_endpoint = f"{_webhook_url}{_WEBHOOK_PATH_LOGGING_BATCH}"
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.tool_blocking_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback,
params={"timeout": httpx.Timeout(5.0, connect=2.0)},
)
self._headers: dict[str, str] = {"Content-Type": "application/json"}
if self.key:
self._headers["Authorization"] = f"Bearer {self.key}"
# Periodic flush is started lazily on the first log event so that
# low-traffic deployments still get their batches drained even when the
# logger is instantiated outside a running event loop (sync init).
self._flush_task: Optional[asyncio.Task[Any]] = (
self._start_periodic_flush_task()
)
def _start_periodic_flush_task(self) -> Optional[asyncio.Task[Any]]:
"""Start the periodic flush task only when an event loop is already running."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
verbose_logger.debug(
"Rubrik logger init: no running event loop, "
"periodic flush will start on first log event."
)
return None
return loop.create_task(self.periodic_flush())
def _ensure_periodic_flush_task(self) -> None:
# Synchronous helper: in asyncio's cooperative model there is no await
# between the check and assignment, so two callers cannot race here.
if self._flush_task is None or self._flush_task.done():
self._flush_task = self._start_periodic_flush_task()
async def aclose(self):
"""Close the dedicated HTTP clients used by this logger."""
# Cancel the periodic flush task before closing the HTTP clients so
# the loop doesn't wake up and try to POST via a closed client.
if self._flush_task is not None and not self._flush_task.done():
self._flush_task.cancel()
try:
await self._flush_task
except (asyncio.CancelledError, Exception):
pass
self._flush_task = None
await self.tool_blocking_client.close()
await self.async_httpx_client.close()
# -- Guardrail hook --------------------------------------------------------
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> GenericGuardrailAPIInputs:
"""Validate tool calls against the blocking service (fail-open)."""
if input_type != "response":
return inputs
tool_calls = inputs.get("tool_calls")
if not tool_calls:
return inputs
try:
return await self._check_tool_calls(
inputs, tool_calls, request_data, logging_obj
)
except ModifyResponseException:
raise
except _MalformedToolBlockingResponseError as e:
# Distinct from transient errors: the service responded but the
# payload was structurally invalid, which usually indicates a
# misconfigured webhook or a breaking change in its response
# format. Log loudly so operators notice their tool-blocking
# policy is not actually being enforced.
verbose_logger.critical(
"Tool blocking service returned a malformed response: %s. "
"Tool calls are NOT being checked -- verify the webhook "
"configuration. Returning original response unchanged.",
e,
exc_info=True,
)
return inputs
except Exception as e:
verbose_logger.error(
f"Tool blocking hook failed: {e}. "
"Returning original response unchanged.",
exc_info=True,
)
return inputs
async def _check_tool_calls(
self,
inputs: GenericGuardrailAPIInputs,
tool_calls: Any,
request_data: dict,
logging_obj: Optional["LiteLLMLoggingObj"],
) -> GenericGuardrailAPIInputs:
"""Send tool calls to blocking service, raise if any are blocked."""
message_tool_calls = self._normalize_tool_calls(tool_calls)
call_details = (
getattr(logging_obj, "model_call_details", {}) if logging_obj else {}
)
response = request_data.get("response")
request_id = getattr(response, "id", None) if response else None
if logging_obj and not call_details:
verbose_logger.warning(
"Rubrik: logging_obj present but model_call_details is empty "
"-- request context will be missing"
)
response_data = self._build_tool_call_payload(message_tool_calls, request_id)
req_data = self._extract_request_data(call_details)
service_response = await self._post_to_tool_blocking_service(
response_data, req_data
)
blocked_explanation = self._extract_blocked_tools(
service_response, message_tool_calls
)
if blocked_explanation is not None:
model = self._resolve_model(request_data, call_details)
raise ModifyResponseException(
message=blocked_explanation,
model=model,
request_data=request_data,
guardrail_name=self.guardrail_name,
)
return inputs
@staticmethod
def _normalize_tool_calls(tool_calls: Any) -> list[ChatCompletionMessageToolCall]:
"""Convert tool_calls from inputs to ChatCompletionMessageToolCall objects."""
result = []
for tc in tool_calls:
if isinstance(tc, ChatCompletionMessageToolCall):
result.append(tc)
elif isinstance(tc, dict):
func = tc.get("function", {})
result.append(
ChatCompletionMessageToolCall(
id=tc.get("id", ""),
type=tc.get("type", "function"),
function=Function(
name=func.get("name", ""),
arguments=func.get("arguments", ""),
),
)
)
elif hasattr(tc, "id") and hasattr(tc, "function"):
result.append(
ChatCompletionMessageToolCall(
id=tc.id or "",
type=getattr(tc, "type", None) or "function",
function=tc.function,
)
)
else:
raise TypeError(
f"Cannot normalize tool_call of type {type(tc).__name__}"
)
return result
@staticmethod
def _build_tool_call_payload(
tool_calls: list[ChatCompletionMessageToolCall],
request_id: str | None,
) -> dict[str, Any]:
"""Build a full OpenAI ChatCompletion-format dict for the blocking service."""
return {
"id": request_id or f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": "",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
tc.model_dump(exclude_none=True) for tc in tool_calls
],
},
"finish_reason": "tool_calls",
}
],
}
@staticmethod
def _extract_request_data(call_details: dict[str, Any]) -> dict[str, Any]:
"""Extract original request data from model_call_details."""
if not call_details:
return {}
litellm_params = call_details.get("litellm_params", {}) or {}
return {
"messages": call_details.get("messages"),
"model": call_details.get("model"),
"proxy_server_request": RubrikLogger._sanitize_proxy_server_request(
litellm_params.get("proxy_server_request")
),
}
@staticmethod
def _sanitize_proxy_server_request(proxy_server_request: Any) -> Any:
"""Allowlist only routing fields (``url``, ``method``) when forwarding
``proxy_server_request`` to the external Rubrik webhook, dropping
inbound ``headers`` (Authorization, Cookie, x-api-key, ...) and the raw
request ``body`` so proxy credentials are not exfiltrated."""
if not isinstance(proxy_server_request, dict):
return proxy_server_request
return {
key: proxy_server_request[key]
for key in ("url", "method")
if key in proxy_server_request
}
@staticmethod
def _resolve_model(
request_data: dict[str, Any], call_details: dict[str, Any]
) -> str:
"""Get the model name for the ModifyResponseException."""
response = request_data.get("response")
if response and hasattr(response, "model"):
return response.model or "unknown"
return call_details.get("model", "unknown")
# -- Logging hooks ---------------------------------------------------------
async def _prepare_log_payload(
self, kwargs: dict, event_type: str
) -> StandardLoggingPayload | None:
"""Shared logic for success and failure logging."""
if random.random() > self.sampling_rate:
verbose_logger.debug(
f"Skipping Rubrik {event_type} logging "
f"(sampling_rate={self.sampling_rate})"
)
return None
# Deep-copy so mutations don't affect other callbacks sharing this object
standard_logging_payload: StandardLoggingPayload = safe_deep_copy(
kwargs["standard_logging_object"]
)
# For Anthropic /v1/messages requests, LiteLLM creates a separate
# ModelResponse (with a generated chatcmpl-* id) for logging, which
# differs from the original Anthropic msg-* id on the response dict.
# Normalize to litellm_call_id so that the logging and tool-blocking
# endpoints see the same request identifier.
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_request = litellm_params.get("proxy_server_request", {}) or {}
url_path = urllib.parse.urlparse(proxy_request.get("url", "")).path
if url_path.endswith(_ENDPOINT_ANTHROPIC_MESSAGES):
_litellm_call_id = kwargs.get("litellm_call_id")
if _litellm_call_id:
standard_logging_payload["id"] = _litellm_call_id # type: ignore[literal-required]
if "system" in kwargs:
system_prompt_msg_list = kwargs["system"]
try:
if system_prompt_msg_list:
system_scaffold = {
"role": "system",
"content": system_prompt_msg_list,
}
if isinstance(standard_logging_payload["messages"], list):
standard_logging_payload["messages"].insert(0, system_scaffold)
elif isinstance(standard_logging_payload["messages"], (dict, str)):
standard_logging_payload["messages"] = [
system_scaffold,
standard_logging_payload["messages"],
]
except Exception as e:
verbose_logger.warning(
f"Rubrik: failed to prepend system prompt: {e}",
exc_info=True,
)
return standard_logging_payload
async def _enqueue_log_event(self, kwargs: dict, event_type: str):
try:
self._ensure_periodic_flush_task()
payload = await self._prepare_log_payload(kwargs, event_type)
if payload is None:
return
self.log_queue.append(payload)
self._enforce_max_queue_size()
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception as e:
verbose_logger.error(
f"Rubrik {event_type} logging hook failed: {e}. "
"Skipping logging for this event.",
exc_info=True,
)
def _enforce_max_queue_size(self) -> None:
overflow = len(self.log_queue) - self.max_queue_size
if overflow <= 0:
return
del self.log_queue[:overflow]
self._dropped_since_warning += overflow
now = time.time()
if now - self._last_drop_warning_time >= _DROP_WARNING_INTERVAL_SECONDS:
verbose_logger.warning(
"Rubrik: log queue exceeded max_queue_size=%s; dropped %s "
"oldest events since the last warning. The Rubrik webhook may "
"be unhealthy or undersized for current traffic.",
self.max_queue_size,
self._dropped_since_warning,
)
self._dropped_since_warning = 0
self._last_drop_warning_time = now
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
await self._enqueue_log_event(kwargs, "success")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
await self._enqueue_log_event(kwargs, "failure")
# -- Batch logging ---------------------------------------------------------
async def _log_batch_to_rubrik(self, data):
# NOTE: this method intentionally re-raises on failure so the parent
# CustomBatchLogger.flush_queue keeps the unsent events in the queue
# for the next flush attempt instead of silently dropping them.
try:
response = await self.async_httpx_client.post(
url=self.logging_endpoint,
json=data,
headers=self._headers,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
verbose_logger.exception(
f"Rubrik HTTP Error: {e.response.status_code} - {e.response.text}"
)
raise
except Exception:
verbose_logger.exception("Rubrik Layer Error")
raise
async def async_send_batch(self):
"""Handles sending batches of responses to Rubrik.
Note: the canonical flush path is :meth:`flush_queue`, which takes a
single snapshot used for both sending and queue draining. This method
is kept for direct callers / tests; it intentionally does NOT remove
events from the queue.
"""
if not self.log_queue:
return
log_queue_snapshot = list(self.log_queue)
verbose_logger.debug(
"Rubrik: Flushing batch of %s events", len(log_queue_snapshot)
)
await self._log_batch_to_rubrik(
data=log_queue_snapshot,
)
async def flush_queue(self):
"""Snapshot, send, and drain in one consistent step.
Overrides the base implementation so the same snapshot drives both
the HTTP send and the queue truncation. This avoids the subtle
coupling where the base class captures `len(self.log_queue)`
separately from the snapshot taken inside `async_send_batch`,
which could otherwise drift in a future refactor and cause
duplicate deliveries to Rubrik.
"""
if self.flush_lock is None:
return
async with self.flush_lock:
if not self.log_queue:
return
snapshot = list(self.log_queue)
verbose_logger.debug("Rubrik: Flushing batch of %s events", len(snapshot))
try:
await self._log_batch_to_rubrik(data=snapshot)
except Exception:
# Already logged with traceback inside _log_batch_to_rubrik.
# Preserve the in-flight events for retry on the next flush.
return
del self.log_queue[: len(snapshot)]
self.last_flush_time = time.time()
# -- Tool blocking service -------------------------------------------------
async def _post_to_tool_blocking_service(
self,
response_data: dict[str, Any],
request_data: dict[str, Any],
) -> dict[str, Any]:
"""Post a payload to the tool blocking service and return the response.
Args:
response_data: The OpenAI-formatted response payload to send.
request_data: Original LLM request data to include alongside
the response for additional context. Empty dict if unavailable.
Raises:
Exception: If the service is unavailable or returns an error.
"""
envelope = {
"request": request_data,
"response": response_data,
}
verbose_logger.debug(
f"Sending request to tool blocking service: "
f"{self.tool_blocking_endpoint}"
)
http_response = await self.tool_blocking_client.post(
self.tool_blocking_endpoint,
json=envelope,
headers=self._headers,
)
http_response.raise_for_status()
result: dict[str, Any] = http_response.json()
return result
@staticmethod
def _extract_blocked_tools(
service_response: dict[str, Any],
all_tool_calls: list[ChatCompletionMessageToolCall],
) -> Optional[str]:
"""Return the blocking explanation if any tool calls were blocked.
Compares the service response (which contains only allowed tools) against
the full set of tool calls. Returns ``None`` if all tools are allowed, or
the explanation string (prefixed with newlines) otherwise.
Expects service_response in OpenAI chat completion format:
{"choices": [{"message": {"tool_calls": [...], "content": "..."}}]}
"""
choices = service_response.get("choices", [])
if not choices:
raise _MalformedToolBlockingResponseError(
"Tool blocking service returned empty response"
)
message = choices[0].get("message", {})
returned_tool_calls = message.get("tool_calls") or []
blocking_explanation = message.get("content", "")
allowed_id_counts: Counter = Counter(
tc["id"]
for tc in returned_tool_calls
if isinstance(tc, dict) and tc.get("id")
)
required_id_counts: Counter = Counter(tc.id for tc in all_tool_calls if tc.id)
all_allowed = len(returned_tool_calls) >= len(all_tool_calls) and all(
allowed_id_counts.get(tc_id, 0) >= count
for tc_id, count in required_id_counts.items()
)
if all_allowed:
return None
explanation = blocking_explanation or "Tool call blocked by policy."
return f"\n\n{explanation}"