Files
MoFin/venv/lib/python3.12/site-packages/litellm/responses/streaming_iterator.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

2305 lines
92 KiB
Python

from __future__ import annotations
import asyncio
import json
import time
import traceback
import uuid
from datetime import datetime
from functools import lru_cache
from typing import Any, Dict, List, Literal, Optional
import httpx
from openai._streaming import SSEDecoder
import litellm
from litellm.constants import (
LITELLM_MAX_STREAMING_DURATION_SECONDS,
STREAM_SSE_DONE_STRING,
)
from litellm.litellm_core_utils.asyncify import run_async_function
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.response_metadata import (
update_response_metadata,
)
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import ResponsesAPIStreamEvents
from litellm.types.utils import CallTypes
from litellm.utils import async_post_call_success_deployment_hook
@lru_cache(maxsize=1)
def _get_openai_response_types():
from litellm.types.llms import openai as openai_types
return openai_types
def _log_background_task_failure(task: "asyncio.Task[Any]", *, task_name: str) -> None:
if task.cancelled():
return
exception = task.exception()
if exception is not None:
verbose_logger.error("%s failed: %s", task_name, exception)
class BaseResponsesAPIStreamingIterator:
"""
Base class for streaming iterators that process responses from the Responses API.
This class contains shared logic for both synchronous and asynchronous iterators.
"""
def __init__(
self,
response: httpx.Response,
model: str,
responses_api_provider_config: Optional[BaseResponsesAPIConfig],
logging_obj: LiteLLMLoggingObj,
litellm_metadata: Optional[Dict[str, Any]] = None,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict[str, Any]] = None,
call_type: Optional[str] = None,
):
self.response = response
self.model = model
self.logging_obj = logging_obj
self.finished = False
self.responses_api_provider_config = responses_api_provider_config
self.completed_response: Optional[Any] = None
self.start_time = getattr(logging_obj, "start_time", datetime.now())
self._failure_handled = False # Track if failure handler has been called
self._completed_response_cached = False
self._completed_response_logged = False
self._completed_response_cache_hit: Optional[bool] = None
self._persist_completed_response_before_logging = True
self._stream_created_time: float = time.time()
# track request context for hooks
self.litellm_metadata = litellm_metadata
self.custom_llm_provider = custom_llm_provider
self.request_data: Dict[str, Any] = request_data or {}
self.call_type: Optional[str] = call_type
# set hidden params for response headers (e.g., x-litellm-model-id)
# This matches the stream wrapper in litellm/litellm_core_utils/streaming_handler.py
_api_base = get_api_base(
model=model or "",
optional_params=self.logging_obj.model_call_details.get(
"litellm_params", {}
),
)
_model_info: Dict = (
litellm_metadata.get("model_info", {}) if litellm_metadata else {}
)
self._hidden_params = {
"model_id": _model_info.get("id", None),
"api_base": _api_base,
"custom_llm_provider": custom_llm_provider,
}
self._hidden_params["additional_headers"] = process_response_headers(
self.response.headers or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE
def _check_max_streaming_duration(self) -> None:
"""Raise litellm.Timeout if the stream has exceeded LITELLM_MAX_STREAMING_DURATION_SECONDS."""
if LITELLM_MAX_STREAMING_DURATION_SECONDS is None:
return
elapsed = time.time() - self._stream_created_time
if elapsed > LITELLM_MAX_STREAMING_DURATION_SECONDS:
raise litellm.Timeout(
message=f"Stream exceeded max streaming duration of {LITELLM_MAX_STREAMING_DURATION_SECONDS}s (elapsed {elapsed:.1f}s)",
model=self.model or "",
llm_provider=self.custom_llm_provider or "",
)
def _process_chunk(self, chunk) -> Optional[Any]:
"""Process a single chunk of data from the stream"""
if not chunk:
return None
# NOTE: ``SSEDecoder`` already strips the SSE ``data:`` field prefix, so
# the value passed in here is the raw field content. Do not re-run
# ``_strip_sse_data_from_chunk`` on it — doing so would incorrectly mangle
# payloads whose actual JSON value happens to start with ``data:``.
# Handle "[DONE]" marker
if chunk == STREAM_SSE_DONE_STRING:
self.finished = True
return None
try:
# Parse the JSON chunk
parsed_chunk = json.loads(chunk)
# Format as ResponsesAPIStreamingResponse
if isinstance(parsed_chunk, dict):
if self.responses_api_provider_config is None:
raise ValueError(
"responses_api_provider_config is required to process live streaming chunks"
)
openai_responses_api_chunk = (
self.responses_api_provider_config.transform_streaming_response(
model=self.model,
parsed_chunk=parsed_chunk,
logging_obj=self.logging_obj,
)
)
# Only when the SSE JSON carries a response body (delta events do not).
# Using getattr(..., "response") alone is unsafe with Mocks: they synthesize a
# truthy child Mock for any attribute, which breaks tests and is wrong on stream.
if "response" in parsed_chunk:
response_object = getattr(
openai_responses_api_chunk, "response", None
)
if response_object is not None:
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
responses_api_response=response_object,
litellm_metadata=self.litellm_metadata,
custom_llm_provider=self.custom_llm_provider,
)
setattr(openai_responses_api_chunk, "response", response)
# Encode container_id on streaming events so proxy/UI follow-ups route correctly
_event_type = getattr(openai_responses_api_chunk, "type", None)
_stream_model_id = (
self.litellm_metadata.get("model_info", {}).get("id")
if self.litellm_metadata
else None
)
if _event_type in (
ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED,
ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
):
_item = getattr(openai_responses_api_chunk, "item", None)
if _item is not None:
ResponsesAPIRequestUtils._encode_container_id_on_output_item(
item=_item,
custom_llm_provider=self.custom_llm_provider,
model_id=_stream_model_id,
)
elif (
_event_type == ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED
):
_annotation = getattr(
openai_responses_api_chunk, "annotation", None
)
if _annotation is not None:
ResponsesAPIRequestUtils._encode_container_id_on_output_item(
item=_annotation,
custom_llm_provider=self.custom_llm_provider,
model_id=_stream_model_id,
)
elif _event_type == ResponsesAPIStreamEvents.CONTENT_PART_DONE:
_part = getattr(openai_responses_api_chunk, "part", None)
if _part is not None:
if isinstance(_part, dict):
ResponsesAPIRequestUtils._encode_container_ids_in_annotations(
_part.get("annotations"),
self.custom_llm_provider,
_stream_model_id,
)
else:
ResponsesAPIRequestUtils._encode_container_ids_in_annotations(
getattr(_part, "annotations", None),
self.custom_llm_provider,
_stream_model_id,
)
# Wrap encrypted_content in streaming events (output_item.added, output_item.done)
if self.litellm_metadata and self.litellm_metadata.get(
"encrypted_content_affinity_enabled"
):
openai_types = _get_openai_response_types()
event_type = getattr(openai_responses_api_chunk, "type", None)
if event_type in (
openai_types.ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED,
openai_types.ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
):
item = getattr(openai_responses_api_chunk, "item", None)
if item:
encrypted_content = getattr(item, "encrypted_content", None)
if encrypted_content and isinstance(encrypted_content, str):
model_id = (
self.litellm_metadata.get("model_info", {}).get(
"id"
)
if self.litellm_metadata
else None
)
if model_id:
wrapped_content = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
encrypted_content, model_id
)
setattr(item, "encrypted_content", wrapped_content)
# Store the completed response (also for incomplete/failed so logging still fires)
_chunk_type = getattr(openai_responses_api_chunk, "type", None)
openai_types = _get_openai_response_types()
if openai_responses_api_chunk and _chunk_type in (
openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
openai_types.ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE,
openai_types.ResponsesAPIStreamEvents.RESPONSE_FAILED,
):
self.completed_response = openai_responses_api_chunk
# Add cost to usage object if include_cost_in_streaming_usage is True
if (
litellm.include_cost_in_streaming_usage
and self.logging_obj is not None
):
response_obj: Optional[Any] = getattr(
openai_responses_api_chunk, "response", None
)
if response_obj:
usage_obj: Optional[Any] = getattr(
response_obj, "usage", None
)
if usage_obj is not None:
try:
cost: Optional[float] = (
self.logging_obj._response_cost_calculator(
result=response_obj
)
)
if cost is not None:
setattr(usage_obj, "cost", cost)
except Exception:
# Best-effort usage cost annotation should not break stream replay.
pass
if (
_chunk_type
== openai_types.ResponsesAPIStreamEvents.RESPONSE_FAILED
):
self._handle_logging_failed_response()
else:
self._handle_logging_completed_response()
return openai_responses_api_chunk
return None
except json.JSONDecodeError:
# If we can't parse the chunk, continue
return None
except Exception as e:
# Trigger failure hooks before re-raising
# This ensures failures are logged even when _process_chunk is called directly
self._handle_failure(e)
raise
def _log_completed_response(self, *, is_async: bool) -> None:
if self._completed_response_logged:
return
self._completed_response_logged = True
if self._persist_completed_response_before_logging:
self._persist_completed_response_to_cache(is_async=is_async)
# Create a copy for logging to avoid modifying the response object that will be returned to the user
# The logging handlers may transform usage from Responses API format (input_tokens/output_tokens)
# to chat completion format (prompt_tokens/completion_tokens) for internal logging
# Use model_dump + model_validate instead of deepcopy to avoid pickle errors with
# Pydantic ValidatorIterator when response contains tool_choice with allowed_tools (fixes #17192)
logging_response = self.completed_response
if self.completed_response is not None and hasattr(
self.completed_response, "model_dump"
):
try:
logging_response = type(self.completed_response).model_validate(
self.completed_response.model_dump()
)
except Exception:
# Fallback to original if serialization fails
pass
end_time = datetime.now()
if is_async:
asyncio.create_task(
self.logging_obj.async_success_handler(
result=logging_response,
start_time=self.start_time,
end_time=end_time,
cache_hit=self._completed_response_cache_hit,
)
)
else:
run_async_function(
async_function=self.logging_obj.async_success_handler,
result=logging_response,
start_time=self.start_time,
end_time=end_time,
cache_hit=self._completed_response_cache_hit,
)
executor.submit(
self.logging_obj.success_handler,
result=logging_response,
cache_hit=self._completed_response_cache_hit,
start_time=self.start_time,
end_time=end_time,
)
self._run_post_success_hooks(end_time=end_time)
def _handle_logging_completed_response(self):
"""Base implementation - should be overridden by subclasses"""
pass
def _handle_logging_failed_response(self):
"""
Handle logging for RESPONSE_FAILED events by routing to failure handlers.
Unlike _handle_logging_completed_response (which calls success handlers),
this constructs an exception from the response error and routes to
async_failure_handler / failure_handler so logging integrations correctly
record the call as failed.
"""
response_obj = (
getattr(self.completed_response, "response", None)
if self.completed_response
else None
)
error_info = getattr(response_obj, "error", None) if response_obj else None
error_message = "Response failed"
if isinstance(error_info, dict):
error_message = error_info.get("message", str(error_info))
exception = litellm.APIError(
status_code=500,
message=error_message,
llm_provider=self.custom_llm_provider or "",
model=self.model or "",
)
self._handle_failure(exception)
def _get_completed_response_object(self) -> Optional[Any]:
openai_types = _get_openai_response_types()
completed_response = self.completed_response
if isinstance(completed_response, openai_types.ResponsesAPIResponse):
return completed_response
response_obj = getattr(completed_response, "response", None)
if isinstance(response_obj, openai_types.ResponsesAPIResponse):
return response_obj
return None
def _persist_completed_response_to_cache(self, *, is_async: bool) -> None:
if self._completed_response_cached:
return
completed_response = self.completed_response
openai_types = _get_openai_response_types()
if (
getattr(completed_response, "type", None)
!= openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED
):
return
response_obj = self._get_completed_response_object()
if response_obj is None:
return
caching_handler = getattr(self.logging_obj, "_llm_caching_handler", None)
if caching_handler is None:
return
request_kwargs = getattr(caching_handler, "request_kwargs", None)
if (
not isinstance(request_kwargs, dict)
or request_kwargs.get("stream") is not True
):
return
request_kwargs = request_kwargs.copy()
preset_cache_key = getattr(caching_handler, "preset_cache_key", None)
request_cache_key = request_kwargs.pop("cache_key", None)
if preset_cache_key is None:
preset_cache_key = request_cache_key
if request_kwargs.get("metadata") is None:
request_kwargs.pop("metadata", None)
request_kwargs.pop("custom_llm_provider", None)
if preset_cache_key is not None:
request_kwargs["cache_key"] = preset_cache_key
if not caching_handler._should_store_result_in_cache(
original_function=caching_handler.original_function,
kwargs=request_kwargs,
):
return
if litellm.cache is None:
return
cached_response = response_obj.model_dump_json()
if is_async:
cache_write_task = asyncio.create_task(
litellm.cache.async_add_cache(
cached_response,
dynamic_cache_object=getattr(caching_handler, "dual_cache", None),
**request_kwargs,
)
)
cache_write_task.add_done_callback(
lambda task: _log_background_task_failure(
task,
task_name="Responses stream cache write",
)
)
else:
litellm.cache.add_cache(
cached_response,
dynamic_cache_object=getattr(caching_handler, "dual_cache", None),
**request_kwargs,
)
self._completed_response_cached = True
async def _call_post_streaming_deployment_hook(self, chunk):
"""
Allow callbacks to modify streaming chunks before returning (parity with chat).
"""
try:
# Align with chat pipeline: use logging_obj model_call_details + call_type
typed_call_type: Optional[CallTypes] = None
if self.call_type is not None:
try:
typed_call_type = CallTypes(self.call_type)
except ValueError:
typed_call_type = None
if typed_call_type is None:
try:
typed_call_type = CallTypes(
getattr(self.logging_obj, "call_type", None)
)
except Exception:
typed_call_type = None
request_data = self.request_data or getattr(
self.logging_obj, "model_call_details", {}
)
callbacks = getattr(litellm, "callbacks", None) or []
hooks_ran = False
for callback in callbacks:
if hasattr(callback, "async_post_call_streaming_deployment_hook"):
hooks_ran = True
result = await callback.async_post_call_streaming_deployment_hook(
request_data=request_data,
response_chunk=chunk,
call_type=typed_call_type,
)
if result is not None:
chunk = result
if hooks_ran:
setattr(chunk, "_post_streaming_hooks_ran", True)
return chunk
except Exception:
return chunk
async def call_post_streaming_hooks_for_testing(self, chunk):
"""
Helper to invoke streaming deployment hooks explicitly (used in tests).
"""
return await self._call_post_streaming_deployment_hook(chunk)
def _run_post_success_hooks(self, end_time: datetime):
"""
Run post-call deployment hooks and update metadata similar to chat pipeline.
"""
if self.completed_response is None:
return
request_payload: Dict[str, Any] = {}
if isinstance(self.request_data, dict):
request_payload.update(self.request_data)
try:
if hasattr(self.logging_obj, "model_call_details"):
request_payload.update(self.logging_obj.model_call_details)
except Exception:
pass
if "litellm_params" not in request_payload:
try:
request_payload["litellm_params"] = getattr(
self.logging_obj, "model_call_details", {}
).get("litellm_params", {})
except Exception:
request_payload["litellm_params"] = {}
try:
update_response_metadata(
result=self.completed_response,
logging_obj=self.logging_obj,
model=self.model,
kwargs=request_payload,
start_time=self.start_time,
end_time=end_time,
)
except Exception:
# Non-blocking
pass
try:
typed_call_type: Optional[CallTypes] = None
if self.call_type is not None:
try:
typed_call_type = CallTypes(self.call_type)
except ValueError:
typed_call_type = None
except Exception:
typed_call_type = None
if typed_call_type is None:
try:
typed_call_type = CallTypes.responses
except Exception:
typed_call_type = None
try:
# Call synchronously; async hook will be executed via asyncio.run in a new loop
run_async_function(
async_function=async_post_call_success_deployment_hook,
request_data=request_payload,
response=self.completed_response,
call_type=typed_call_type,
)
except Exception:
pass
def _handle_failure(self, exception: Exception):
"""
Trigger failure handlers before bubbling the exception.
Only calls handlers once even if called multiple times.
"""
# Prevent double-calling failure handlers
if self._failure_handled:
return
self._failure_handled = True
traceback_exception = traceback.format_exc()
try:
run_async_function(
async_function=self.logging_obj.async_failure_handler,
exception=exception,
traceback_exception=traceback_exception,
start_time=self.start_time,
end_time=datetime.now(),
)
except Exception:
pass
try:
executor.submit(
self.logging_obj.failure_handler,
exception,
traceback_exception,
self.start_time,
datetime.now(),
)
except Exception:
pass
async def call_post_streaming_hooks_for_testing(iterator, chunk):
"""
Module-level helper for tests to ensure hooks can be invoked even if the iterator is wrapped.
"""
hook_fn = getattr(iterator, "_call_post_streaming_deployment_hook", None)
if hook_fn is None:
return chunk
return await hook_fn(chunk)
class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
"""
Async iterator for processing streaming responses from the Responses API.
"""
def __init__(
self,
response: httpx.Response,
model: str,
responses_api_provider_config: BaseResponsesAPIConfig,
logging_obj: LiteLLMLoggingObj,
litellm_metadata: Optional[Dict[str, Any]] = None,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict[str, Any]] = None,
call_type: Optional[str] = None,
):
super().__init__(
response,
model,
responses_api_provider_config,
logging_obj,
litellm_metadata,
custom_llm_provider,
request_data,
call_type,
)
self.stream_iterator = SSEDecoder().aiter_bytes(response.aiter_bytes())
def __aiter__(self):
return self
async def __anext__(self) -> Any:
try:
self._check_max_streaming_duration()
while True:
# Get the next chunk from the stream
try:
sse = await self.stream_iterator.__anext__()
except StopAsyncIteration:
self.finished = True
raise StopAsyncIteration
self._check_max_streaming_duration()
result = self._process_chunk(sse.data)
if self.finished:
raise StopAsyncIteration
elif result is not None:
# Await hook directly instead of run_async_function
# (which spawns a thread + event loop per call)
result = await self._call_post_streaming_deployment_hook(
chunk=result,
)
return result
# If result is None, continue the loop to get the next chunk
except StopAsyncIteration:
# Normal end of stream - don't log as failure
raise
except httpx.HTTPError as e:
# Handle HTTP errors
self.finished = True
self._handle_failure(e)
raise e
except Exception as e:
self.finished = True
self._handle_failure(e)
raise e
def _handle_logging_completed_response(self):
"""Handle logging for completed responses in async context"""
self._log_completed_response(is_async=True)
class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
"""
Synchronous iterator for processing streaming responses from the Responses API.
"""
def __init__(
self,
response: httpx.Response,
model: str,
responses_api_provider_config: BaseResponsesAPIConfig,
logging_obj: LiteLLMLoggingObj,
litellm_metadata: Optional[Dict[str, Any]] = None,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict[str, Any]] = None,
call_type: Optional[str] = None,
):
super().__init__(
response,
model,
responses_api_provider_config,
logging_obj,
litellm_metadata,
custom_llm_provider,
request_data,
call_type,
)
self.stream_iterator = SSEDecoder().iter_bytes(response.iter_bytes())
def __iter__(self):
return self
def __next__(self):
try:
self._check_max_streaming_duration()
while True:
# Get the next chunk from the stream
try:
sse = next(self.stream_iterator)
except StopIteration:
self.finished = True
raise StopIteration
self._check_max_streaming_duration()
result = self._process_chunk(sse.data)
if self.finished:
raise StopIteration
elif result is not None:
# Sync path: use run_async_function for the hook
result = run_async_function(
async_function=self._call_post_streaming_deployment_hook,
chunk=result,
)
return result
# If result is None, continue the loop to get the next chunk
except StopIteration:
# Normal end of stream - don't log as failure
raise
except httpx.HTTPError as e:
# Handle HTTP errors
self.finished = True
self._handle_failure(e)
raise e
except Exception as e:
self.finished = True
self._handle_failure(e)
raise e
def _handle_logging_completed_response(self):
"""Handle logging for completed responses in sync context"""
self._log_completed_response(is_async=False)
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
"""
Mock iterator—fake a stream by slicing the full response text into
5 char deltas, then emit a completed event.
Models like o1-pro don't support streaming, so we fake it.
"""
CHUNK_SIZE = 5
def __init__(
self,
response: httpx.Response,
model: str,
responses_api_provider_config: BaseResponsesAPIConfig,
logging_obj: LiteLLMLoggingObj,
litellm_metadata: Optional[Dict[str, Any]] = None,
custom_llm_provider: Optional[str] = None,
request_data: Optional[Dict[str, Any]] = None,
call_type: Optional[str] = None,
):
transformed = responses_api_provider_config.transform_response_api_response(
model=model,
raw_response=response,
logging_obj=logging_obj,
)
super().__init__(
response=httpx.Response(200),
model=model,
responses_api_provider_config=None,
logging_obj=logging_obj,
litellm_metadata=litellm_metadata,
custom_llm_provider=custom_llm_provider,
request_data=request_data,
call_type=call_type,
)
self._set_events_from_response(transformed=transformed, logging_obj=logging_obj)
def _set_events_from_response(
self,
transformed: Any,
logging_obj: LiteLLMLoggingObj,
) -> None:
self._events = _build_synthetic_response_events(
transformed=transformed,
logging_obj=logging_obj,
chunk_size=self.CHUNK_SIZE,
)
self._idx = 0
self.completed_response = self._events[-1]
def __aiter__(self):
return self
async def __anext__(self) -> Any:
if self._idx >= len(self._events):
raise StopAsyncIteration
evt = self._events[self._idx]
self._idx += 1
openai_types = _get_openai_response_types()
if (
getattr(evt, "type", None)
== openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED
):
self.completed_response = evt
self._log_completed_response(is_async=True)
return evt
def __iter__(self):
return self
def __next__(self) -> Any:
if self._idx >= len(self._events):
raise StopIteration
evt = self._events[self._idx]
self._idx += 1
openai_types = _get_openai_response_types()
if (
getattr(evt, "type", None)
== openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED
):
self.completed_response = evt
self._log_completed_response(is_async=False)
return evt
class CachedResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
def __init__(
self,
response: Any,
logging_obj: LiteLLMLoggingObj,
request_data: Optional[Dict[str, Any]] = None,
call_type: Optional[str] = None,
):
BaseResponsesAPIStreamingIterator.__init__(
self,
response=httpx.Response(200),
model=getattr(response, "model", ""),
responses_api_provider_config=None,
logging_obj=logging_obj,
litellm_metadata=None,
custom_llm_provider="cached_response",
request_data=request_data,
call_type=call_type,
)
self._completed_response_cache_hit = True
self._persist_completed_response_before_logging = False
self._events: List[Any] = []
self._idx = 0
self._set_events_from_response(transformed=response, logging_obj=logging_obj)
def _set_events_from_response(
self,
transformed: Any,
logging_obj: LiteLLMLoggingObj,
) -> None:
self._events = _build_synthetic_response_events(
transformed=transformed,
logging_obj=logging_obj,
chunk_size=MockResponsesAPIStreamingIterator.CHUNK_SIZE,
)
self._idx = 0
self.completed_response = self._events[-1]
def __aiter__(self):
return self
async def __anext__(self) -> Any:
if self._idx >= len(self._events):
raise StopAsyncIteration
evt = self._events[self._idx]
self._idx += 1
openai_types = _get_openai_response_types()
if (
getattr(evt, "type", None)
== openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED
):
self.completed_response = evt
self._log_completed_response(is_async=True)
return evt
def __iter__(self):
return self
def __next__(self) -> Any:
if self._idx >= len(self._events):
raise StopIteration
evt = self._events[self._idx]
self._idx += 1
openai_types = _get_openai_response_types()
if (
getattr(evt, "type", None)
== openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED
):
self.completed_response = evt
self._log_completed_response(is_async=False)
return evt
def _dump_response_object(obj: Any) -> Dict[str, Any]:
if hasattr(obj, "model_dump"):
return obj.model_dump()
if isinstance(obj, dict):
return obj
return {}
def _build_response_status_event(
event_type: Literal[
"response.created",
"response.in_progress",
],
transformed: Any,
) -> Any:
openai_types = _get_openai_response_types()
in_progress_response = transformed.model_copy(
deep=True,
update={"status": "in_progress", "output": []},
)
if event_type == openai_types.ResponsesAPIStreamEvents.RESPONSE_CREATED:
return openai_types.ResponseCreatedEvent(
type=event_type, response=in_progress_response
)
return openai_types.ResponseInProgressEvent(
type=event_type, response=in_progress_response
)
def _build_content_part_done_event(
*,
item_id: str,
output_index: int,
content_index: int,
part_payload: Dict[str, Any],
) -> Optional[Any]:
openai_types = _get_openai_response_types()
part_type = part_payload.get("type")
part: Any
if part_type == "output_text":
annotations = [
openai_types.BaseLiteLLMOpenAIResponseObject(**annotation)
for annotation in part_payload.get("annotations", []) or []
]
part = openai_types.ContentPartDonePartOutputText(
type="output_text",
text=str(part_payload.get("text") or ""),
annotations=annotations,
logprobs=part_payload.get("logprobs"),
)
elif part_type == "refusal":
part = openai_types.ContentPartDonePartRefusal(
type="refusal",
refusal=str(part_payload.get("refusal") or ""),
)
elif part_type == "reasoning_text":
part = openai_types.ContentPartDonePartReasoningText(
type="reasoning_text",
reasoning=str(part_payload.get("reasoning") or ""),
)
else:
return None
return openai_types.ContentPartDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.CONTENT_PART_DONE,
item_id=item_id,
output_index=output_index,
content_index=content_index,
part=part,
)
def _add_text_like_part_events(
*,
events: List[Any],
item_id: str,
output_index: int,
content_index: int,
part_payload: Dict[str, Any],
chunk_size: int,
) -> None:
openai_types = _get_openai_response_types()
part_type = part_payload.get("type")
if part_type == "output_text":
text = str(part_payload.get("text") or "")
for i in range(0, len(text), chunk_size):
events.append(
openai_types.OutputTextDeltaEvent(
type=openai_types.ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA,
item_id=item_id,
output_index=output_index,
content_index=content_index,
delta=text[i : i + chunk_size],
)
)
for annotation_index, annotation in enumerate(
part_payload.get("annotations", []) or []
):
events.append(
openai_types.OutputTextAnnotationAddedEvent(
type=openai_types.ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED,
item_id=item_id,
output_index=output_index,
content_index=content_index,
annotation_index=annotation_index,
annotation=annotation,
)
)
events.append(
openai_types.OutputTextDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE,
item_id=item_id,
output_index=output_index,
content_index=content_index,
text=text,
)
)
elif part_type == "refusal":
refusal = str(part_payload.get("refusal") or "")
for i in range(0, len(refusal), chunk_size):
events.append(
openai_types.RefusalDeltaEvent(
type=openai_types.ResponsesAPIStreamEvents.REFUSAL_DELTA,
item_id=item_id,
output_index=output_index,
content_index=content_index,
delta=refusal[i : i + chunk_size],
)
)
events.append(
openai_types.RefusalDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.REFUSAL_DONE,
item_id=item_id,
output_index=output_index,
content_index=content_index,
refusal=refusal,
)
)
def _build_synthetic_response_events(
*,
transformed: Any,
logging_obj: LiteLLMLoggingObj,
chunk_size: int,
) -> List[Any]:
openai_types = _get_openai_response_types()
if litellm.include_cost_in_streaming_usage and logging_obj is not None:
usage_obj: Optional[Any] = getattr(transformed, "usage", None)
if usage_obj is not None:
try:
cost: Optional[float] = logging_obj._response_cost_calculator(
result=transformed
)
if cost is not None:
setattr(usage_obj, "cost", cost)
except Exception:
pass
events: List[Any] = [
_build_response_status_event(
openai_types.ResponsesAPIStreamEvents.RESPONSE_CREATED, transformed
),
_build_response_status_event(
openai_types.ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS, transformed
),
]
sequence_number = 0
for output_index, output_item in enumerate(
getattr(transformed, "output", []) or []
):
output_item_payload = _dump_response_object(output_item)
item_id = str(output_item_payload.get("id") or transformed.id)
item_type = output_item_payload.get("type")
events.append(
openai_types.OutputItemAddedEvent(
type=openai_types.ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED,
output_index=output_index,
item=openai_types.BaseLiteLLMOpenAIResponseObject(
**output_item_payload
),
)
)
if item_type == "message":
for content_index, part in enumerate(
output_item_payload.get("content", []) or []
):
part_payload = _dump_response_object(part)
events.append(
openai_types.ContentPartAddedEvent(
type=openai_types.ResponsesAPIStreamEvents.CONTENT_PART_ADDED,
item_id=item_id,
output_index=output_index,
content_index=content_index,
part=openai_types.BaseLiteLLMOpenAIResponseObject(
**part_payload
),
)
)
_add_text_like_part_events(
events=events,
item_id=item_id,
output_index=output_index,
content_index=content_index,
part_payload=part_payload,
chunk_size=chunk_size,
)
done_event = _build_content_part_done_event(
item_id=item_id,
output_index=output_index,
content_index=content_index,
part_payload=part_payload,
)
if done_event is not None:
events.append(done_event)
elif item_type == "function_call":
arguments = str(output_item_payload.get("arguments") or "")
for i in range(0, len(arguments), chunk_size):
events.append(
openai_types.FunctionCallArgumentsDeltaEvent(
type=openai_types.ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA,
item_id=item_id,
output_index=output_index,
delta=arguments[i : i + chunk_size],
)
)
events.append(
openai_types.FunctionCallArgumentsDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE,
item_id=item_id,
output_index=output_index,
arguments=arguments,
)
)
elif item_type == "reasoning":
for summary_index, summary in enumerate(
output_item_payload.get("summary", []) or []
):
summary_payload = _dump_response_object(summary)
summary_text = str(summary_payload.get("text") or "")
for i in range(0, len(summary_text), chunk_size):
events.append(
openai_types.ReasoningSummaryTextDeltaEvent(
type=openai_types.ResponsesAPIStreamEvents.REASONING_SUMMARY_TEXT_DELTA,
item_id=item_id,
output_index=output_index,
summary_index=summary_index,
delta=summary_text[i : i + chunk_size],
)
)
sequence_number += 1
events.append(
openai_types.ReasoningSummaryTextDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.REASONING_SUMMARY_TEXT_DONE,
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
summary_index=summary_index,
text=summary_text,
)
)
sequence_number += 1
events.append(
openai_types.ReasoningSummaryPartDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.REASONING_SUMMARY_PART_DONE,
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
summary_index=summary_index,
part=openai_types.BaseLiteLLMOpenAIResponseObject(
**summary_payload
),
)
)
sequence_number += 1
events.append(
openai_types.OutputItemDoneEvent(
type=openai_types.ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
output_index=output_index,
sequence_number=sequence_number,
item=openai_types.BaseLiteLLMOpenAIResponseObject(
**output_item_payload
),
)
)
events.append(
openai_types.ResponseCompletedEvent(
type=openai_types.ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed,
)
)
return events
# ---------------------------------------------------------------------------
# WebSocket mode streaming (bidirectional forwarding)
# ---------------------------------------------------------------------------
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.thread_pool_executor import executor as _ws_executor
RESPONSES_WS_LOGGED_EVENT_TYPES = [
"response.created",
"response.completed",
"response.failed",
"response.incomplete",
"error",
]
RESPONSES_WS_MASKABLE_TEXT_BLOCK_TYPES = frozenset(
{"input_text", "output_text", "text"}
)
class ResponsesWebSocketStreaming:
"""
Manages bidirectional WebSocket forwarding for the Responses API
WebSocket mode (wss://.../v1/responses).
Unlike the Realtime API, the Responses API WebSocket mode:
- Uses response.create as the client-to-server event
- Streams back the same events as the HTTP streaming Responses API
- Supports previous_response_id for incremental continuation
- Supports generate: false for warmup
- One response at a time per connection (sequential, no multiplexing)
"""
def __init__(
self,
websocket: Any,
backend_ws: Any,
logging_obj: LiteLLMLoggingObj,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[Dict] = None,
first_message: Optional[str] = None,
guardrail_callbacks: Optional[List[Any]] = None,
output_guardrail_callbacks: Optional[List[Any]] = None,
authorized_model: Optional[str] = None,
):
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.user_api_key_dict = user_api_key_dict
self.request_data: Dict = request_data or {}
self.messages: list[Dict] = []
self.input_messages: list[Dict[str, str]] = []
self.first_message = first_message
self.guardrail_callbacks: List[Any] = guardrail_callbacks or []
self.output_guardrail_callbacks: List[Any] = output_guardrail_callbacks or []
# Model name authorized at connection time; enforced on every
# response.create frame to prevent deployment-substitution attacks.
self.authorized_model: Optional[str] = authorized_model
def _should_store_event(self, event_obj: dict) -> bool:
return event_obj.get("type") in RESPONSES_WS_LOGGED_EVENT_TYPES
def _store_event(self, event: Any) -> None:
if isinstance(event, bytes):
event = event.decode("utf-8")
if isinstance(event, str):
try:
event_obj = json.loads(event)
except (json.JSONDecodeError, TypeError):
return
else:
event_obj = event
if self._should_store_event(event_obj):
self.messages.append(event_obj)
def _collect_input_from_client_event(self, message: Any) -> None:
"""Extract user input content from response.create for logging."""
try:
if isinstance(message, str):
msg_obj = json.loads(message)
elif isinstance(message, dict):
msg_obj = message
else:
return
if msg_obj.get("type") != "response.create":
return
input_items = msg_obj.get("input", [])
if isinstance(input_items, str):
self.input_messages.append({"role": "user", "content": input_items})
return
if isinstance(input_items, list):
for item in input_items:
if not isinstance(item, dict):
continue
if item.get("type") == "message" and item.get("role") == "user":
content = item.get("content", [])
if isinstance(content, str):
self.input_messages.append(
{"role": "user", "content": content}
)
elif isinstance(content, list):
for c in content:
if (
isinstance(c, dict)
and c.get("type") == "input_text"
):
text = c.get("text", "")
if text:
self.input_messages.append(
{"role": "user", "content": text}
)
except (json.JSONDecodeError, AttributeError, TypeError):
pass
def _store_input(self, message: Any) -> None:
self._collect_input_from_client_event(message)
if self.logging_obj:
self.logging_obj.pre_call(input=message, api_key="")
async def _log_messages(self) -> None:
if not self.logging_obj:
return
if self.input_messages:
self.logging_obj.model_call_details["messages"] = self.input_messages
if self.messages:
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
_ws_executor.submit(self.logging_obj.success_handler, self.messages)
async def backend_to_client(self) -> None:
"""Forward events from backend WebSocket to the client."""
import websockets
try:
while True:
try:
raw_response = await self.backend_ws.recv(decode=False) # type: ignore[union-attr]
except TypeError:
raw_response = await self.backend_ws.recv() # type: ignore[union-attr, assignment]
if isinstance(raw_response, bytes):
response_str = raw_response.decode("utf-8")
else:
response_str = raw_response
# When apply_to_output masking is active, suppress delta events
# and the text-bearing "done" events. Per-fragment Presidio
# cannot reliably catch PII spanning multiple delta chunks (e.g.
# "alice@" + "example.com"), and the done events carry the full
# output text that response.completed already delivers in
# fully-masked form; forwarding them would leak unmasked PII
# before response.completed arrives. The client receives only the
# masked response.completed.
if self.output_guardrail_callbacks:
try:
_evt_type = json.loads(response_str).get("type")
except (json.JSONDecodeError, TypeError):
_evt_type = None
if (
_evt_type in self._DELTA_EVENT_TYPES
or _evt_type in self._OUTPUT_DONE_EVENT_TYPES
):
continue
unmasked_str = self._unmask_response_event(response_str)
output_masked_str = await self._mask_response_completed(unmasked_str)
# Log the output-masked form so PII redacted by apply_to_output
# guardrails does not appear in success logs.
self._store_event(output_masked_str)
await self.websocket.send_text(output_masked_str)
except websockets.exceptions.ConnectionClosed as e: # type: ignore
verbose_logger.debug("Responses WS backend connection closed: %s", e)
except Exception as e:
verbose_logger.exception("Error in responses WS backend_to_client: %s", e)
finally:
await self._log_messages()
def _enforce_authorized_model(self, msg_obj: dict) -> bool:
"""
Overwrite any ``model`` field in a ``response.create`` frame with the
connection-authorized model to prevent deployment-substitution attacks.
Handles both shapes:
flat: ``{"type": "response.create", "model": "...", ...}``
nested: ``{"type": "response.create", "response": {"model": "...", ...}}``
Returns True if the object was modified.
"""
if not self.authorized_model:
return False
modified = False
nested = msg_obj.get("response")
if isinstance(nested, dict):
if nested.get("model") != self.authorized_model:
nested["model"] = self.authorized_model
modified = True
if "model" in msg_obj and msg_obj["model"] != self.authorized_model:
msg_obj["model"] = self.authorized_model
modified = True
elif msg_obj.get("model") != self.authorized_model:
msg_obj["model"] = self.authorized_model
modified = True
return modified
async def _mask_response_create(self, message: str) -> str:
"""
Enforce the authorized model and apply Presidio PII masking to a
``response.create`` message before it is forwarded to the upstream
provider.
- Overwrites any ``model`` field with the connection-authorized model
to prevent deployment-substitution attacks (always applied).
- Walks the ``input`` and ``instructions`` fields, calls ``check_pii``
on every text block, and stores the resulting ``pii_tokens`` map in
``self.request_data["metadata"]`` for later unmasking.
Non-``response.create`` messages are returned unchanged.
"""
try:
msg_obj = json.loads(message)
except (json.JSONDecodeError, TypeError):
return message
if msg_obj.get("type") != "response.create":
return message
# Always enforce the authorized model, even when PII masking is off.
model_modified = self._enforce_authorized_model(msg_obj)
if not self.guardrail_callbacks:
return json.dumps(msg_obj) if model_modified else message
if "metadata" not in self.request_data:
self.request_data["metadata"] = {}
modified = model_modified
for cb in self.guardrail_callbacks:
presidio_config = cb.get_presidio_settings_from_request_data(
self.request_data
)
# response.create carries client text in two shapes:
# flat: {"type": "response.create", "input": ..., "instructions": ...}
# nested: {"type": "response.create", "response": {"input": ..., "instructions": ...}}
# Mask "input" and "instructions" in both shapes so PII is never
# forwarded unmasked regardless of where the client places it.
nested_response = (
msg_obj.get("response")
if isinstance(msg_obj.get("response"), dict)
else None
)
text_containers: list[tuple[dict, str]] = []
for container in (msg_obj, nested_response):
if container is None:
continue
if "input" in container:
text_containers.append((container, "input"))
if isinstance(container.get("instructions"), str):
text_containers.append((container, "instructions"))
for container, key in text_containers:
field_value = container[key]
if isinstance(field_value, str):
container[key] = await cb.check_pii(
text=field_value,
output_parse_pii=True,
presidio_config=presidio_config,
request_data=self.request_data,
)
modified = True
elif isinstance(field_value, list):
for item in field_value:
if not isinstance(item, dict):
continue
for item_field in ("content", "output"):
value = item.get(item_field)
if isinstance(value, str):
item[item_field] = await cb.check_pii(
text=value,
output_parse_pii=True,
presidio_config=presidio_config,
request_data=self.request_data,
)
modified = True
elif isinstance(value, list):
for block in value:
if (
isinstance(block, dict)
and block.get("type")
in RESPONSES_WS_MASKABLE_TEXT_BLOCK_TYPES
and isinstance(block.get("text"), str)
):
block["text"] = await cb.check_pii(
text=block["text"],
output_parse_pii=True,
presidio_config=presidio_config,
request_data=self.request_data,
)
modified = True
return json.dumps(msg_obj) if modified else message
# Delta event types whose ``delta`` field may contain PII tokens.
_DELTA_EVENT_TYPES = frozenset(
{
"response.output_text.delta",
"response.reasoning_summary_text.delta",
"response.refusal.delta",
"response.function_call_arguments.delta",
}
)
# Terminal events that carry the full output text or tool-call arguments
# already delivered by ``response.completed``. Suppressed when output masking
# is active so the unmasked copy never reaches the client before the masked
# completed event.
_OUTPUT_DONE_EVENT_TYPES = frozenset(
{
"response.output_text.done",
"response.content_part.done",
"response.output_item.done",
"response.function_call_arguments.done",
"response.reasoning_summary_text.done",
"response.reasoning_summary_part.done",
}
)
def _unmask_response_event(self, response_str: str) -> str:
"""
Apply Presidio PII unmasking to backend events before forwarding to
the client.
Handles two shapes:
- ``response.completed``: walks ``response.output[*].content[*].text``
- streaming delta events (``response.output_text.delta``, etc.):
replaces tokens in the ``delta`` field
Uses the ``pii_tokens`` map stored during ``_mask_response_create`` to
replace every token (e.g. ``<EMAIL_ADDRESS_1>``) with the original
value. Events with no stored tokens are returned unchanged.
"""
if not self.guardrail_callbacks:
return response_str
pii_tokens: Dict[str, str] = (self.request_data.get("metadata") or {}).get(
"pii_tokens", {}
)
if not pii_tokens:
return response_str
try:
evt_obj = json.loads(response_str)
except (json.JSONDecodeError, TypeError):
return response_str
cb = self.guardrail_callbacks[0]
event_type = evt_obj.get("type")
if event_type == "response.completed":
modified = False
response_obj = evt_obj.get("response") or {}
if not isinstance(response_obj, dict):
return response_str
for output_item in response_obj.get("output") or []:
if not isinstance(output_item, dict):
continue
content = output_item.get("content") or []
if not isinstance(content, list):
continue
for content_block in content:
if not isinstance(content_block, dict):
continue
text = content_block.get("text")
if isinstance(text, str):
unmasked = cb._unmask_pii_text(text, pii_tokens)
if unmasked != text:
content_block["text"] = unmasked
modified = True
return json.dumps(evt_obj) if modified else response_str
if event_type in self._DELTA_EVENT_TYPES:
delta = evt_obj.get("delta")
if isinstance(delta, str):
unmasked = cb._unmask_pii_text(delta, pii_tokens)
if unmasked != delta:
evt_obj["delta"] = unmasked
return json.dumps(evt_obj)
return response_str
async def _mask_response_completed(self, response_str: str) -> str:
"""
Apply Presidio output masking (apply_to_output=True) to the
``response.completed`` event before it is forwarded to the client.
Walks ``response.output[*].content[*].text`` and masks every text block,
as well as ``response.output[*].arguments`` on function-call items and
``response.output[*].summary[*].text`` on reasoning items. Delta and
``*.done`` events are suppressed upstream in ``backend_to_client`` when
output masking is active, so only the authoritative full-output view
reaches this method; events of other types are returned unchanged.
"""
if not self.output_guardrail_callbacks:
return response_str
try:
evt_obj = json.loads(response_str)
except (json.JSONDecodeError, TypeError):
return response_str
if evt_obj.get("type") != "response.completed":
return response_str
modified = False
for cb in self.output_guardrail_callbacks:
presidio_config = cb.get_presidio_settings_from_request_data(
self.request_data
)
response_obj = evt_obj.get("response") or {}
if not isinstance(response_obj, dict):
continue
for output_item in response_obj.get("output") or []:
if not isinstance(output_item, dict):
continue
arguments = output_item.get("arguments")
if isinstance(arguments, str):
masked_args = await cb.check_pii(
text=arguments,
output_parse_pii=False,
presidio_config=presidio_config,
request_data=self.request_data,
)
if masked_args != arguments:
output_item["arguments"] = masked_args
modified = True
summary = output_item.get("summary") or []
if isinstance(summary, list):
for summary_block in summary:
if not isinstance(summary_block, dict):
continue
summary_text = summary_block.get("text")
if isinstance(summary_text, str):
masked_summary = await cb.check_pii(
text=summary_text,
output_parse_pii=False,
presidio_config=presidio_config,
request_data=self.request_data,
)
if masked_summary != summary_text:
summary_block["text"] = masked_summary
modified = True
content = output_item.get("content") or []
if not isinstance(content, list):
continue
for content_block in content:
if not isinstance(content_block, dict):
continue
text = content_block.get("text")
if isinstance(text, str):
masked = await cb.check_pii(
text=text,
output_parse_pii=False,
presidio_config=presidio_config,
request_data=self.request_data,
)
if masked != text:
content_block["text"] = masked
modified = True
return json.dumps(evt_obj) if modified else response_str
async def client_to_backend(self) -> None:
"""Forward response.create events from client to backend."""
try:
if self.first_message is not None:
masked_first = await self._mask_response_create(self.first_message)
self._store_input(masked_first)
self._store_event(masked_first)
await self.backend_ws.send(masked_first) # type: ignore[union-attr]
while True:
message = await self.websocket.receive_text()
masked = await self._mask_response_create(message)
self._store_input(masked)
self._store_event(masked)
await self.backend_ws.send(masked) # type: ignore[union-attr]
except Exception as e:
verbose_logger.debug("Responses WS client_to_backend ended: %s", e)
async def bidirectional_forward(self) -> None:
"""Run both forwarding directions concurrently."""
forward_task = asyncio.create_task(self.backend_to_client())
try:
await self.client_to_backend()
except Exception:
pass
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
try:
await self.backend_ws.close()
except Exception:
pass
# ---------------------------------------------------------------------------
# Managed WebSocket mode (HTTP-backed, provider-agnostic)
# ---------------------------------------------------------------------------
_RESPONSE_CREATE_PARAMS: frozenset = (
_get_openai_response_types().ResponsesAPIRequestParams.__required_keys__
| _get_openai_response_types().ResponsesAPIRequestParams.__optional_keys__
)
_MANAGED_WS_SKIP_KWARGS: frozenset = frozenset(
{
"litellm_logging_obj",
"litellm_call_id",
"aresponses",
"_aresponses_websocket",
"user_api_key_dict",
}
)
_WARMUP_RESPONSE_ID_PREFIX = "resp_warmup_"
class ManagedResponsesWebSocketHandler:
"""
Handles Responses API WebSocket mode for providers that do not expose a
native ``wss://`` responses endpoint.
Instead of proxying to a provider WebSocket, this handler:
- Listens for ``response.create`` events from the client
- Makes HTTP streaming calls via ``litellm.aresponses(stream=True)``
- Serialises and forwards every streaming event back over the WebSocket
- Supports ``previous_response_id`` for multi-turn conversations via
in-memory session tracking (avoids async DB-write timing issues)
- Supports sequential requests over a single persistent connection
This makes every provider that LiteLLM can reach over HTTP available on
the WebSocket transport without any provider-specific changes.
"""
def __init__(
self,
websocket: Any,
model: str,
logging_obj: "LiteLLMLoggingObj",
user_api_key_dict: Optional[Any] = None,
litellm_metadata: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
custom_llm_provider: Optional[str] = None,
first_message: Optional[str] = None,
**kwargs: Any,
) -> None:
self.websocket = websocket
self.model = model
self.logging_obj = logging_obj
self.user_api_key_dict = user_api_key_dict
self.litellm_metadata: Dict[str, Any] = litellm_metadata or {}
self.model_group: Optional[str] = self.litellm_metadata.get(
"model_group"
) or self.litellm_metadata.get("deployment_model_name")
self.api_key = api_key
self.api_base = api_base
self.timeout = timeout
self.custom_llm_provider = custom_llm_provider
self._connection_provider = self._resolve_provider(model) or custom_llm_provider
self.first_message = first_message
# Carry through safe pass-through kwargs (e.g. extra_headers)
self.extra_kwargs: Dict[str, Any] = {
k: v for k, v in kwargs.items() if k not in _MANAGED_WS_SKIP_KWARGS
}
# In-memory session history: response_id → full accumulated message list.
# Keyed by the DECODED (pre-encoding) response ID from response.completed.
# This avoids the async DB-write race condition where spend logs haven't
# been committed yet when the next response.create arrives.
self._session_history: Dict[str, List[Dict[str, Any]]] = {}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _serialize_chunk(chunk: Any) -> Optional[str]:
"""Serialize a streaming chunk to a JSON string for WebSocket transmission."""
try:
if hasattr(chunk, "model_dump_json"):
return chunk.model_dump_json(exclude_none=True)
if hasattr(chunk, "model_dump"):
return json.dumps(chunk.model_dump(exclude_none=True), default=str)
if isinstance(chunk, dict):
return json.dumps(chunk, default=str)
return json.dumps(str(chunk))
except Exception as exc:
verbose_logger.debug(
"ManagedResponsesWS: failed to serialize chunk: %s", exc
)
return None
async def _send_error(self, message: str, error_type: str = "server_error") -> None:
try:
await self.websocket.send_text(
json.dumps(
{"type": "error", "error": {"type": error_type, "message": message}}
)
)
except Exception:
pass
def _get_history_messages(self, previous_response_id: str) -> List[Dict[str, Any]]:
"""
Return accumulated message history for *previous_response_id*.
The key is the *decoded* response ID (the raw provider response ID before
LiteLLM base64-encodes it into the ``resp_...`` format).
"""
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(
previous_response_id
)
raw_id = decoded.get("response_id", previous_response_id)
return list(self._session_history.get(raw_id, []))
def _store_history(self, response_id: str, messages: List[Dict[str, Any]]) -> None:
"""
Store the complete accumulated message history for *response_id*.
Replaces any prior value — callers are responsible for passing the full
history (prior turns + current input + new output).
"""
self._session_history[response_id] = messages
@staticmethod
def _extract_response_id(completed_event: Dict[str, Any]) -> Optional[str]:
"""
Pull the raw (decoded) response ID out of a ``response.completed`` event.
Returns *None* if the event doesn't contain a usable ID.
"""
resp_obj = completed_event.get("response", {})
encoded_id: Optional[str] = (
resp_obj.get("id") if isinstance(resp_obj, dict) else None
)
if not encoded_id:
return None
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(encoded_id)
return decoded.get("response_id", encoded_id)
@staticmethod
def _extract_output_messages(
completed_event: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""
Convert the output items in a ``response.completed`` event into
Responses API message dicts suitable for the next turn's ``input``.
"""
resp_obj = completed_event.get("response", {})
if not isinstance(resp_obj, dict):
return []
messages: List[Dict[str, Any]] = []
for item in resp_obj.get("output", []) or []:
if not isinstance(item, dict):
continue
item_type = item.get("type")
role = item.get("role", "assistant")
if item_type == "message":
content_parts = item.get("content") or []
text_parts = [
p.get("text", "")
for p in content_parts
if isinstance(p, dict) and p.get("type") in ("output_text", "text")
]
text = "".join(text_parts)
if text:
messages.append(
{
"type": "message",
"role": role,
"content": [{"type": "output_text", "text": text}],
}
)
elif item_type == "function_call":
messages.append(item)
return messages
@staticmethod
def _input_to_messages(input_val: Any) -> List[Dict[str, Any]]:
"""
Normalise the ``input`` field of a ``response.create`` event to a list
of Responses API message dicts.
"""
if isinstance(input_val, str):
return [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": input_val}],
}
]
if isinstance(input_val, list):
return [item for item in input_val if isinstance(item, dict)]
return []
# ------------------------------------------------------------------
# _process_response_create sub-methods
# ------------------------------------------------------------------
async def _parse_message(self, raw_message: str) -> Optional[Dict[str, Any]]:
"""Parse raw WS text; return the message dict or None (JSON error / ignored type)."""
try:
msg_obj = json.loads(raw_message)
except json.JSONDecodeError:
await self._send_error(
"Invalid JSON in response.create event", "invalid_request_error"
)
return None
if msg_obj.get("type") != "response.create":
# Silently ignore non-response.create messages (e.g. warmup pings)
return None
return msg_obj
@staticmethod
def _is_warmup_frame(msg_obj: Dict[str, Any]) -> bool:
"""Return True for a response.create whose generate flag is false."""
nested = msg_obj.get("response")
source = nested if isinstance(nested, dict) and nested else msg_obj
return source.get("generate") is False
@staticmethod
def _is_warmup_response_id(response_id: Optional[str]) -> bool:
"""Return True for synthetic warmup IDs that only exist on this connection."""
if not response_id:
return False
decoded = ResponsesAPIRequestUtils._decode_responses_api_response_id(
response_id
)
raw_id = decoded.get("response_id", response_id)
return str(raw_id).startswith(_WARMUP_RESPONSE_ID_PREFIX)
@staticmethod
def _warmup_source_params(msg_obj: Dict[str, Any]) -> Dict[str, Any]:
nested = msg_obj.get("response")
if isinstance(nested, dict) and nested:
return nested
return {k: v for k, v in msg_obj.items() if k != "type"}
def _build_warmup_response(self, msg_obj: Dict[str, Any]) -> Dict[str, Any]:
"""Build a minimal completed Responses API object for a warmup ack."""
source = self._warmup_source_params(msg_obj)
wire_model = source.get("model") or self.model_group or self.model
return {
"id": f"{_WARMUP_RESPONSE_ID_PREFIX}{uuid.uuid4().hex}",
"object": "response",
"created_at": int(time.time()),
"status": "completed",
"model": wire_model,
"output": [],
"usage": {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
},
}
async def _send_warmup_ack(self, msg_obj: Dict[str, Any]) -> None:
"""
Acknowledge a generate=false prewarm without calling the provider.
Codex blocks on the warmup turn until it receives response.created and
response.completed over the WebSocket. Managed HTTP providers cannot
honor an empty-input warmup, so we synthesize the completion locally.
"""
response = self._build_warmup_response(msg_obj)
for event_type, status in (
("response.created", "in_progress"),
("response.completed", "completed"),
):
event = {
"type": event_type,
"response": {**response, "status": status},
}
serialized = self._serialize_chunk(event)
if serialized is None:
continue
await self.websocket.send_text(serialized)
@staticmethod
def _build_base_call_kwargs(msg_obj: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract Responses API params from the event, handling both wire formats:
Nested: {"type": "response.create", "response": {"input": [...], ...}}
Flat: {"type": "response.create", "input": [...], "model": "...", ...}
"""
nested = msg_obj.get("response")
response_params: Dict[str, Any] = (
nested
if isinstance(nested, dict) and nested
else {k: v for k, v in msg_obj.items() if k != "type"}
)
return {
param: response_params[param]
for param in _RESPONSE_CREATE_PARAMS
if param in response_params and response_params[param] is not None
}
def _apply_history(
self,
call_kwargs: Dict[str, Any],
previous_response_id: Optional[str],
current_messages: List[Dict[str, Any]],
prior_history: List[Dict[str, Any]],
) -> None:
"""Prepend in-memory turn history, or fall back to DB-based reconstruction."""
if not previous_response_id:
return
if self._is_warmup_response_id(previous_response_id):
verbose_logger.debug(
"ManagedResponsesWS: ignoring synthetic warmup previous_response_id=%s",
previous_response_id,
)
return
if prior_history:
call_kwargs["input"] = prior_history + current_messages
verbose_logger.debug(
"ManagedResponsesWS: prepended %d history messages for previous_response_id=%s",
len(prior_history),
previous_response_id,
)
else:
verbose_logger.debug(
"ManagedResponsesWS: no in-memory history for previous_response_id=%s; "
"falling back to DB-based session reconstruction",
previous_response_id,
)
# Fall back to DB-based session reconstruction (may work for
# cross-connection multi-turn when spend logs are committed)
call_kwargs["previous_response_id"] = previous_response_id
@staticmethod
def _resolve_provider(model: Optional[str]) -> Optional[str]:
"""Resolve the LLM provider for a model string, or None if unresolvable."""
if not model:
return None
try:
from litellm import get_llm_provider
_, provider, _, _ = get_llm_provider(model=model)
return provider
except Exception:
return None
def _same_provider(self, model: Optional[str]) -> bool:
"""Return True if model uses the same LLM provider as the connection model."""
if model is None or model == self.model:
return True
event_provider = self._resolve_provider(model)
if event_provider is None:
return False
return event_provider == self._connection_provider
def _inject_credentials(
self, call_kwargs: Dict[str, Any], model: Optional[str] = None
) -> None:
"""Inject connection-level credentials and metadata into call_kwargs."""
if self.api_key is not None:
call_kwargs["api_key"] = self.api_key
if self.api_base is not None:
call_kwargs["api_base"] = self.api_base
if self.timeout is not None:
call_kwargs["timeout"] = self.timeout
# Only force connection-level custom_llm_provider when the per-event model
# uses the same provider as the connection model. If the provider differs
# (e.g., connection is vertex_ai but event says openai/gpt-4), let litellm
# re-resolve from the model string. Same-provider model variants (e.g.,
# vertex_ai/gemini-2.0 -> vertex_ai/gemini-1.5) still inherit the provider.
if self.custom_llm_provider is not None and self._same_provider(model):
call_kwargs["custom_llm_provider"] = self.custom_llm_provider
if self.litellm_metadata:
call_kwargs["litellm_metadata"] = dict(self.litellm_metadata)
@staticmethod
def _update_proxy_request(call_kwargs: Dict[str, Any], model: str) -> None:
"""Update proxy_server_request body so spend logs record the full request."""
proxy_server_request = (call_kwargs.get("litellm_metadata") or {}).get(
"proxy_server_request"
) or {}
if not isinstance(proxy_server_request, dict):
return
body = dict(proxy_server_request.get("body") or {})
body["input"] = call_kwargs.get("input")
body["store"] = call_kwargs.get("store")
body["model"] = model
for k in ("tools", "tool_choice", "instructions", "metadata"):
if k in call_kwargs and call_kwargs[k] is not None:
body[k] = call_kwargs[k]
proxy_server_request = {**proxy_server_request, "body": body}
if "litellm_metadata" not in call_kwargs:
call_kwargs["litellm_metadata"] = {}
call_kwargs["litellm_metadata"]["proxy_server_request"] = proxy_server_request
call_kwargs.setdefault("litellm_params", {})
call_kwargs["litellm_params"]["proxy_server_request"] = proxy_server_request
async def _stream_and_forward(
self, model: str, call_kwargs: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Stream ``litellm.aresponses`` and forward every chunk over the WebSocket.
Captures the ``response.completed`` event type from the chunk object
directly (before serialization) to avoid a redundant JSON round-trip on
every chunk. Returns the completed event dict, or ``None``.
"""
completed_event: Optional[Dict[str, Any]] = None
stream_response = await litellm.aresponses(model=model, **call_kwargs)
async for chunk in stream_response: # type: ignore[union-attr]
if chunk is None:
continue
# Read type from the object before serializing to avoid double JSON parse
chunk_type = getattr(chunk, "type", None) or (
chunk.get("type") if isinstance(chunk, dict) else None
)
serialized = self._serialize_chunk(chunk)
if serialized is None:
continue
if chunk_type == "response.completed" and completed_event is None:
try:
completed_event = json.loads(serialized)
except Exception:
pass
try:
await self.websocket.send_text(serialized)
except Exception as send_exc:
verbose_logger.debug(
"ManagedResponsesWS: error sending chunk to client: %s", send_exc
)
return completed_event # Client disconnected
return completed_event
def _save_turn_history(
self,
completed_event: Optional[Dict[str, Any]],
prior_history: List[Dict[str, Any]],
current_messages: List[Dict[str, Any]],
) -> None:
"""Store this turn in in-memory history for future previous_response_id lookups."""
if completed_event is None:
return
new_response_id = self._extract_response_id(completed_event)
if not new_response_id:
return
output_msgs = self._extract_output_messages(completed_event)
all_messages = prior_history + current_messages + output_msgs
self._store_history(new_response_id, all_messages)
verbose_logger.debug(
"ManagedResponsesWS: stored %d messages for response_id=%s",
len(all_messages),
new_response_id,
)
# ------------------------------------------------------------------
# Core request handler
# ------------------------------------------------------------------
async def _process_response_create(self, raw_message: str) -> None:
"""
Parse one ``response.create`` event, call ``litellm.aresponses(stream=True)``,
and forward every streaming event to the client.
Multi-turn support via in-memory session history
------------------------------------------------
When ``previous_response_id`` is present in the event:
1. Look up the accumulated message history in ``self._session_history``
(keyed by the decoded provider response ID).
2. Prepend those messages to the current ``input`` so the model has full
conversation context.
3. After the stream completes, extract the new response ID and output
messages from ``response.completed`` and store them in
``self._session_history`` for the next turn.
This in-memory approach avoids the async DB-write race condition that
occurs when spend logs haven't been committed by the time the second
``response.create`` arrives over the same WebSocket connection.
"""
msg_obj = await self._parse_message(raw_message)
if msg_obj is None:
return
# generate=false is a prompt-cache warmup hint (sent by codex prewarm).
# Native provider sockets handle it server-side, but there is no HTTP
# equivalent and the frame carries empty input. Managed providers must
# synthesize a completion so clients like Codex can proceed.
if self._is_warmup_frame(msg_obj):
try:
await self._send_warmup_ack(msg_obj)
except Exception as exc:
verbose_logger.debug(
"ManagedResponsesWS: error sending warmup ack: %s", exc
)
return
call_kwargs = self._build_base_call_kwargs(msg_obj)
call_kwargs["stream"] = True
# A frame that repeats the connection's public alias (model_group) must
# reuse the router-resolved self.model; passing the alias raw to
# litellm.aresponses fails in get_llm_provider. A genuinely different
# provider-prefixed per-frame model is still honored.
requested_model = call_kwargs.pop("model", None)
if requested_model is None or requested_model == self.model_group:
model = self.model
else:
model = requested_model
previous_response_id: Optional[str] = call_kwargs.pop(
"previous_response_id", None
)
current_messages = self._input_to_messages(call_kwargs.get("input"))
# Fetch history once; reused in both _apply_history and _save_turn_history
prior_history = (
self._get_history_messages(previous_response_id)
if previous_response_id
else []
)
self._apply_history(
call_kwargs, previous_response_id, current_messages, prior_history
)
self._inject_credentials(call_kwargs, model=model)
self._update_proxy_request(
call_kwargs, requested_model or self.model_group or model
)
call_kwargs.update(self.extra_kwargs)
try:
completed_event = await self._stream_and_forward(model, call_kwargs)
except Exception as exc:
verbose_logger.exception(
"ManagedResponsesWS: error processing response.create: %s", exc
)
await self._send_error(str(exc))
return
self._save_turn_history(completed_event, prior_history, current_messages)
# ------------------------------------------------------------------
# Main entry point
# ------------------------------------------------------------------
async def run(self) -> None:
"""
Main loop: accept ``response.create`` events sequentially and handle
each one before waiting for the next message.
"""
try:
if self.first_message is not None:
await self._process_response_create(self.first_message)
while True:
try:
message = await self.websocket.receive_text()
except Exception as exc:
verbose_logger.debug(
"ManagedResponsesWS: client disconnected: %s", exc
)
break
await self._process_response_create(message)
except Exception as exc:
verbose_logger.exception("ManagedResponsesWS: unexpected error: %s", exc)
await self._send_error(f"Internal server error: {exc}")