Files
MoFin/venv/lib/python3.12/site-packages/litellm/proxy/auth/auth_utils.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

1566 lines
57 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import re
import sys
from functools import lru_cache
from logging import Logger
from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union
from fastapi import HTTPException, Request, status
import litellm
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger
from litellm.constants import STANDARD_CUSTOMER_ID_HEADERS
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.litellm_core_utils.url_utils import SSRFError, validate_url
from litellm.proxy._types import *
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
from litellm.types.utils import CustomPricingLiteLLMParams
def _get_request_ip_address(
request: Request, use_x_forwarded_for: Optional[bool] = False
) -> Optional[str]:
client_ip = None
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
client_ip = request.headers["x-forwarded-for"]
elif request.client is not None:
client_ip = request.client.host
else:
client_ip = ""
return client_ip
def _check_valid_ip(
allowed_ips: Optional[List[str]],
request: Request,
use_x_forwarded_for: Optional[bool] = False,
) -> Tuple[bool, Optional[str]]:
"""
Returns if ip is allowed or not
"""
if allowed_ips is None: # if not set, assume true
return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = _get_request_ip_address(
request=request, use_x_forwarded_for=use_x_forwarded_for
)
# Check if IP address is allowed
if client_ip not in allowed_ips:
return False, client_ip
return True, client_ip
def check_complete_credentials(request_body: dict) -> bool:
"""
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
Supplying an ``api_key`` is necessary but not sufficient: even with
credentials supplied, an ``api_base`` / ``base_url`` that resolves to a
private/internal/cloud-metadata address would still allow the proxy to
be used as an SSRF pivot. Validate any URL fields here so the gate
can't be bypassed with ``api_key=anything`` plus a malicious target.
"""
given_model: Optional[str] = None
given_model = request_body.get("model")
if given_model is None:
return False
if (
"sagemaker" in given_model
or "bedrock" in given_model
or "vertex_ai" in given_model
or "vertex_ai_beta" in given_model
):
# complex credentials - easier to make a malicious request
return False
api_key_value = request_body.get("api_key")
if not (api_key_value and isinstance(api_key_value, str) and api_key_value.strip()):
return False
# ``validate_url`` itself doesn't consult the toggle; ``safe_get`` /
# ``async_safe_get`` do. Mirror that here so admins who explicitly
# disabled URL validation (e.g. for an internal Ollama endpoint they
# accept the SSRF risk for) aren't blocked at the proxy boundary.
if getattr(litellm, "user_url_validation", False):
for url_field in ("api_base", "base_url"):
url_value = request_body.get(url_field)
if not url_value or not isinstance(url_value, str):
continue
try:
validate_url(url_value)
except SSRFError as e:
raise ValueError(
f"Rejected request: client-side {url_field}={url_value!r} "
f"is rejected by the SSRF guard ({e})."
)
return True
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
"""
Check if request_body_value matches the regex_str or is equal to param
"""
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
return True
return False
def _is_param_allowed(
param: str,
request_body_value: Any,
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
) -> bool:
"""
Check if param is a str or dict and if request_body_value is in the list of allowed values
"""
if configurable_clientside_auth_params is None:
return False
for item in configurable_clientside_auth_params:
if isinstance(item, str) and param == item:
return True
elif isinstance(item, Dict):
if param == "api_base" and check_regex_or_str_match(
request_body_value=request_body_value,
regex_str=item["api_base"],
): # assume param is a regex
return True
return False
def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
- get matching model
- check if 'clientside_configurable_parameters' is set for model
-
"""
if llm_router is None:
return False
# check if model is set
model_info = llm_router.get_model_group_info(model_group=model)
if model_info is None:
# check if wildcard model is set
if model.split("/", 1)[0] in provider_list:
model_info = llm_router.get_model_group_info(
model_group=model.split("/", 1)[0]
)
if model_info is None:
return False
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
return _is_param_allowed(
param=param,
request_body_value=request_body_value,
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
)
# Config dicts whose entries are spread as ``**dict`` into outbound LLM
# API calls. ``litellm_embedding_config`` is consumed by the Milvus
# vector store transformer. ``extra_body`` is the OpenAI-SDK passthrough
# container: provider modules pull provider-auth fields out of it
# (e.g. Azure's ``extra_body.azure_ad_token``, Bedrock's
# ``extra_body.aws_web_identity_token``) without re-validating, so the
# banned-key check has to descend into it the same way it descends into
# ``litellm_embedding_config``.
_NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config", "extra_body")
# Metadata containers that carry per-request configuration consumed by the
# observability callbacks. The same banned-param list applies — a value
# under ``metadata.langfuse_host`` redirects the same Langfuse client and
# leaks the same credentials as the root-level ``langfuse_host``, but the
# original check only walked the request-body root, so the metadata path
# was an unintentional bypass.
_NESTED_METADATA_KEYS: Tuple[str, ...] = ("metadata", "litellm_metadata")
# Banned request-body params. The same list applies to every entry in
# ``_NESTED_CONFIG_KEYS`` (dicts spread as ``**kwargs`` into outbound
# calls) and ``_NESTED_METADATA_KEYS`` (dicts read directly by integration
# callbacks), so a single banned name is enforced wherever the field can
# reach the call path from.
# Per-request observability params that are SAFE to accept from clients.
# These describe the request being logged (prompt version, sampling rate)
# without choosing the destination or the credentials, so they don't
# contribute to the data-exfil primitive that the rest of
# ``_supported_callback_params`` does.
_SAFE_CLIENT_CALLBACK_PARAMS: FrozenSet[str] = frozenset(
{
"langfuse_prompt_version",
"langsmith_sampling_rate",
}
)
# Observability fields that integrations read from the request body or
# metadata but that are not (yet) listed in ``_supported_callback_params``.
# Listed here so the proxy bans them today; the long-term cleanup is to
# fold these into the canonical allowlist so they share one source of
# truth with the rest.
_EXTRA_BANNED_OBSERVABILITY_PARAMS: FrozenSet[str] = frozenset(
{
"posthog_api_url",
"phoenix_project_name",
"phoenix_project_name_override",
# Server-reserved: written exclusively by add_user_api_key_auth_to_request_metadata
# from the authenticated key's database record. A caller-supplied value
# would survive the server merge and let an authenticated user redirect
# their Arize/Phoenix telemetry into arbitrary projects.
"user_api_key_auth_metadata",
"wandb_api_key",
"weave_project_id",
}
)
def _build_banned_observability_params() -> FrozenSet[str]:
"""Derive the observability ban list from the canonical allowlist.
``_supported_callback_params`` and ``_request_blocked_callback_params`` in
``litellm/litellm_core_utils/initialize_dynamic_callback_params.py`` is
the single place that enumerates every observability field integrations
resolve from kwargs/metadata, plus fields that integration code explicitly
blocks from request-supplied callback params. Subtract the small set of
informational fields (``_SAFE_CLIENT_CALLBACK_PARAMS``) and union with the
extras the canonical allowlist hasn't caught up to yet. New integrations
added to the canonical allowlist are banned by default, which is the safe
failure mode.
"""
from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
_request_blocked_callback_params,
_supported_callback_params,
)
return (
(frozenset(_supported_callback_params) - _SAFE_CLIENT_CALLBACK_PARAMS)
| frozenset(_request_blocked_callback_params)
| _EXTRA_BANNED_OBSERVABILITY_PARAMS
)
_BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
"api_base",
"base_url",
"user_config",
"aws_sts_endpoint",
"aws_web_identity_token",
"aws_role_name",
"vertex_credentials",
# Azure managed-identity / federated-auth token. The Azure provider
# transformer reads ``azure_ad_token`` (top-level or via
# ``extra_body``) and resolves it through ``get_secret`` before
# passing it as the bearer token to the Azure endpoint, so a
# caller-supplied value is the same exfil shape as
# ``aws_web_identity_token`` on the Bedrock path.
"azure_ad_token",
# Endpoint-targeting fields that retarget the outbound request or
# an observability callback. An attacker-controlled value either
# exfiltrates the request payload (incl. messages + admin-set
# tokens) to the attacker's host, or coerces the proxy into
# authenticating against the attacker's host with admin secrets.
"aws_bedrock_runtime_endpoint",
# Bedrock project/workspace association. Deployments pin this to
# enforce a data-retention policy, so a caller-supplied value would
# re-route the request's retention and accounting to any project
# reachable with the deployment's shared AWS credentials.
"aws_bedrock_project_id",
# Provider-specific endpoint overrides that flow into the outbound
# request via ``optional_params``. Same threat as ``api_base``:
# ``s3_endpoint_url`` redirects Bedrock file uploads to attacker
# S3; ``sagemaker_base_url`` redirects all SageMaker traffic;
# ``deployment_url`` redirects SAP deployments.
"s3_endpoint_url",
"sagemaker_base_url",
"deployment_url",
# Observability credentials, hosts, and project identifiers: derived
# from the canonical ``_supported_callback_params`` allowlist so new
# integrations are covered automatically. Sorted for stable iteration
# order and reviewable diffs.
*sorted(_build_banned_observability_params()),
*sorted(CustomPricingLiteLLMParams.model_fields.keys()),
)
def _check_banned_params(
body: dict,
general_settings: dict,
llm_router: Optional[Router],
model: str,
) -> None:
"""Raise ``ValueError`` if ``body`` carries a banned param without admin opt-in.
Shared between the root-level check and the nested-config check so a
new banned param only needs to be added in one place.
"""
for param in _BANNED_REQUEST_BODY_PARAMS:
if param not in body:
continue
if general_settings.get("allow_client_side_credentials") is True:
# Proxy-wide opt-in: every banned param is permitted, exit
# entirely so the rest of the loop doesn't waste work.
return
if (
_allow_model_level_clientside_configurable_parameters(
model=model,
param=param,
request_body_value=body[param],
llm_router=llm_router,
)
is True
):
# Per-param opt-in: only THIS param is permitted by the
# deployment's ``configurable_clientside_auth_params``. Skip
# to the next banned param so a body that pairs an allowed
# ``api_base`` with an unallowed ``langfuse_host`` is still
# rejected for the second field.
continue
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Clientside passthrough requires explicit admin opt-in via "
"either `general_settings.allow_client_side_credentials = true` "
"(proxy-wide) or `configurable_clientside_auth_params` on the "
"deployment in your proxy config.yaml. "
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
def is_request_body_safe(
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
) -> bool:
"""
Check if the request body is safe.
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
The blocklist is enforced unconditionally. Legitimate clientside
credential / endpoint passthrough goes through one of the two
explicit admin opt-ins (``general_settings.allow_client_side_credentials``
proxy-wide or ``configurable_clientside_auth_params`` per deployment).
Historically there was a third, *implicit*, *caller-controlled* path:
``check_complete_credentials`` returned True when the caller supplied
any non-empty ``api_key``, which made the entire blocklist a no-op.
That bypass turned every missing entry on the blocklist into an
exploitable SSRF / credential-exfil hole — see GHSA-jh89-88fc-qrfp,
GHSA-3frq-6r6h-7j64, and the chain of veria-admin findings (Dv_m860l,
b_yRJeQ5, stN90yjP, LBlyOAc8, U2TD78kg). Removed: the blocklist now
has a single, predictable failure mode for missing entries (a 400),
not a credential leak.
Iterative single-level descent into ``_NESTED_CONFIG_KEYS`` (rather
than recursion) covers nested-config attacks like Milvus's
``litellm_embedding_config.api_base`` (VERIA-6) without exposing a
recursion-depth DoS surface.
"""
_check_banned_params(request_body, general_settings, llm_router, model)
for nested_key in _NESTED_CONFIG_KEYS:
nested = _coerce_metadata_to_dict(request_body.get(nested_key))
if nested is not None:
_check_banned_params(nested, general_settings, llm_router, model)
for metadata_key in _NESTED_METADATA_KEYS:
metadata = _coerce_metadata_to_dict(request_body.get(metadata_key))
if metadata is not None:
_check_banned_params(metadata, general_settings, llm_router, model)
return True
def _coerce_metadata_to_dict(value: Any) -> Optional[Dict[str, Any]]:
"""Return ``value`` as a dict, parsing it from JSON if delivered as a string.
Multipart/form-data and ``extra_body`` callers send ``litellm_metadata``
as a JSON-encoded string; the proxy parses it into a dict later in
``add_litellm_data_to_request``, but the auth-time bouncer runs first
and would otherwise miss the banned-param check on a still-stringified
metadata blob.
"""
if isinstance(value, dict):
return value
if isinstance(value, str):
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
parsed = safe_json_loads(value)
if isinstance(parsed, dict):
return parsed
return None
async def pre_db_read_auth_checks(
request: Request,
request_data: dict,
route: str,
):
"""
1. Checks if request size is under max_request_size_mb (if set)
2. Check if request body is safe (example user has not set api_base in request body)
3. Check if IP address is allowed (if set)
4. Check if request route is an allowed route on the proxy (if set)
Returns:
- True
Raises:
- HTTPException if request fails initial auth checks
"""
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
# Check 1. request size
await check_if_request_size_is_safe(request=request)
# Check 2. Request body is safe
is_request_body_safe(
request_body=request_data,
general_settings=general_settings,
llm_router=llm_router,
model=request_data.get(
"model", ""
), # [TODO] use model passed in url as well (azure openai routes)
)
# Check 3. Check if IP address is allowed
is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request,
)
if not is_valid_ip:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
)
# Check 4. Check if request route is an allowed route on the proxy
if "allowed_routes" in general_settings:
_allowed_routes = general_settings["allowed_routes"]
if premium_user is not True:
verbose_proxy_logger.error(
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
if route not in _allowed_routes:
verbose_proxy_logger.error(
f"Route {route} not in allowed_routes={_allowed_routes}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: Route {route} not allowed",
)
def route_in_additonal_public_routes(current_route: str):
"""
Helper to check if the user defined public_routes on config.yaml
Parameters:
- current_route: str - the route the user is trying to call
Returns:
- bool - True if the route is defined in public_routes
- bool - False if the route is not defined in public_routes
Supports wildcard patterns (e.g., "/api/*" matches "/api/users", "/api/users/123")
In order to use this the litellm config.yaml should have the following in general_settings:
```yaml
general_settings:
master_key: sk-1234
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate", "/api/*"]
```
"""
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.proxy_server import general_settings, premium_user
try:
if premium_user is not True:
return False
if general_settings is None:
return False
routes_defined = general_settings.get("public_routes", [])
# Check exact match first
if current_route in routes_defined:
return True
# Check wildcard patterns
for route_pattern in routes_defined:
if RouteChecks._route_matches_wildcard_pattern(
route=current_route, pattern=route_pattern
):
return True
return False
except Exception as e:
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
return False
def get_request_route(request: Request) -> str:
"""
Resolve the request route from the ASGI scope, with ``root_path`` stripped.
Prefer this over ``request.url.path`` for any auth, ACL, routing, or
audit-log decision: Starlette reconstructs ``url.path`` by interpolating
the Host header into a URL string and re-parsing with ``urlsplit``, so a
malformed Host (e.g. ``localhost/?x=1``) collapses ``url.path`` to ``"/"``
while FastAPI continues to dispatch on ``scope["path"]``. ``scope["path"]``
is uvicorn's parse of the HTTP request line and matches the actual
handler, so it's the authoritative route.
Also normalizes sub-path deployments by stripping ``scope["root_path"]``
e.g. ``/genai/chat/completions`` -> ``/chat/completions``.
"""
try:
scope = request.scope
if not isinstance(scope, dict):
return str(request.url.path)
raw_path: str = str(scope.get("path", request.url.path))
root_path: str = str(
scope.get("app_root_path", scope.get("root_path", ""))
).rstrip("/")
if not isinstance(raw_path, str):
return str(request.url.path)
# Strip root_path only when it matches whole path segments — guarding
# against sibling paths like "/apifoo" being truncated under
# root_path="/api". Trailing slashes on root_path are stripped above,
# so bare "/" or "/prefix/" still leave the leading "/" intact.
if root_path and (
raw_path == root_path or raw_path.startswith(root_path + "/")
):
stripped = raw_path[len(root_path) :]
return stripped or "/"
return raw_path
except Exception as e:
verbose_proxy_logger.debug(
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
)
return str(request.url.path)
def get_request_route_template(request: Request) -> Optional[str]:
"""
Return the low-cardinality route template, e.g.
``/v1/threads/{thread_id}/runs`` (vs. the literal path from
``get_request_route``). FastAPI sets ``scope["route"]`` before endpoint
dependencies run. Returns None if unavailable (unmatched path, Mount).
"""
try:
scope = request.scope
if not isinstance(scope, dict):
return None
route = scope.get("route")
template = getattr(route, "path", None)
return template if isinstance(template, str) and template else None
except Exception as e:
verbose_proxy_logger.debug(f"error on get_request_route_template: {str(e)}")
return None
@lru_cache(maxsize=256)
def normalize_request_route(route: str) -> str:
"""
Normalize request routes by replacing dynamic path parameters with placeholders.
This prevents high cardinality in Prometheus metrics by collapsing routes like:
- /v1/responses/1234567890 -> /v1/responses/{response_id}
- /v1/threads/thread_123 -> /v1/threads/{thread_id}
Args:
route: The request route path
Returns:
Normalized route with dynamic parameters replaced by placeholders
Examples:
>>> normalize_request_route("/v1/responses/abc123")
'/v1/responses/{response_id}'
>>> normalize_request_route("/v1/responses/abc123/cancel")
'/v1/responses/{response_id}/cancel'
>>> normalize_request_route("/chat/completions")
'/chat/completions'
"""
# Define patterns for routes with dynamic IDs
# Format: (regex_pattern, replacement_template)
patterns = [
# Responses API - must come before generic patterns
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
(r"^(/(?:openai/)?v1/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
(r"^(/(?:openai/)?v1/responses)/([^/]+)$", r"\1/{response_id}"),
(r"^(/responses)/([^/]+)(/input_items)$", r"\1/{response_id}\3"),
(r"^(/responses)/([^/]+)(/cancel)$", r"\1/{response_id}\3"),
(r"^(/responses)/([^/]+)$", r"\1/{response_id}"),
# Threads API
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)/([^/]+)$",
r"\1/{thread_id}\3/{run_id}\5/{step_id}",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/steps)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/cancel)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)(/submit_tool_outputs)$",
r"\1/{thread_id}\3/{run_id}\5",
),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)/([^/]+)$",
r"\1/{thread_id}\3/{run_id}",
),
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/runs)$", r"\1/{thread_id}\3"),
(
r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)/([^/]+)$",
r"\1/{thread_id}\3/{message_id}",
),
(r"^(/(?:openai/)?v1/threads)/([^/]+)(/messages)$", r"\1/{thread_id}\3"),
(r"^(/(?:openai/)?v1/threads)/([^/]+)$", r"\1/{thread_id}"),
# Vector Stores API
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)/([^/]+)$",
r"\1/{vector_store_id}\3/{file_id}",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/files)$",
r"\1/{vector_store_id}\3",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)/([^/]+)$",
r"\1/{vector_store_id}\3/{batch_id}",
),
(
r"^(/(?:openai/)?v1/vector_stores)/([^/]+)(/file_batches)$",
r"\1/{vector_store_id}\3",
),
(r"^(/(?:openai/)?v1/vector_stores)/([^/]+)$", r"\1/{vector_store_id}"),
# Assistants API
(r"^(/(?:openai/)?v1/assistants)/([^/]+)$", r"\1/{assistant_id}"),
# Files API
(r"^(/(?:openai/)?v1/files)/([^/]+)(/content)$", r"\1/{file_id}\3"),
(r"^(/(?:openai/)?v1/files)/([^/]+)$", r"\1/{file_id}"),
# Batches API
(r"^(/(?:openai/)?v1/batches)/([^/]+)(/cancel)$", r"\1/{batch_id}\3"),
(r"^(/(?:openai/)?v1/batches)/([^/]+)$", r"\1/{batch_id}"),
# Fine-tuning API
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/events)$",
r"\1/{fine_tuning_job_id}\3",
),
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/cancel)$",
r"\1/{fine_tuning_job_id}\3",
),
(
r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)(/checkpoints)$",
r"\1/{fine_tuning_job_id}\3",
),
(r"^(/(?:openai/)?v1/fine_tuning/jobs)/([^/]+)$", r"\1/{fine_tuning_job_id}"),
# Models API
(r"^(/(?:openai/)?v1/models)/([^/]+)$", r"\1/{model}"),
]
# Apply patterns in order
for pattern, replacement in patterns:
normalized = re.sub(pattern, replacement, route)
if normalized != route:
return normalized
# Return original route if no pattern matched
return route
async def check_if_request_size_is_safe(request: Request) -> bool:
"""
Enterprise Only:
- Checks if the request size is within the limit
Args:
request (Request): The incoming request.
Returns:
bool: True if the request size is within the limit
Raises:
ProxyException: If the request size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_request_size_mb = general_settings.get("max_request_size_mb", None)
if max_request_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
# Get the request body
content_length = request.headers.get("content-length")
if content_length:
header_size = int(content_length)
header_size_mb = bytes_to_mb(bytes_value=header_size)
verbose_proxy_logger.debug(
f"content_length request size in MB={header_size_mb}"
)
if header_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
else:
# If Content-Length is not available, read the body
body = await request.body()
body_size = len(body)
request_size_mb = bytes_to_mb(bytes_value=body_size)
verbose_proxy_logger.debug(
f"request body request size in MB={request_size_mb}"
)
if request_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
async def check_response_size_is_safe(response: Any) -> bool:
"""
Enterprise Only:
- Checks if the response size is within the limit
Args:
response (Any): The response to check.
Returns:
bool: True if the response size is within the limit
Raises:
ProxyException: If the response size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_response_size_mb = general_settings.get("max_response_size_mb", None)
if max_response_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
if response_size_mb > max_response_size_mb:
raise ProxyException(
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
def bytes_to_mb(bytes_value: int):
"""
Helper to convert bytes to MB
"""
return bytes_value / (1024 * 1024)
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
def _get_deployment_default_limit(model_name: str, field: str) -> Optional[int]:
"""
Return the minimum value of `field` across all deployments for model_name,
or None if no deployment has the field set.
When multiple deployments share the same model name, taking the minimum is
the safest choice for load-balanced setups: it ensures no deployment is
over-consumed regardless of which one actually serves a given request.
"""
from litellm.proxy.proxy_server import llm_router
if llm_router is None:
return None
deployments = llm_router.get_model_list(model_name=model_name)
if not deployments:
return None
limits = []
for deployment in deployments:
raw = deployment.get("litellm_params", {}).get(field)
if raw is not None:
try:
if isinstance(raw, (int, float, str, bytes, bytearray)):
limits.append(int(raw))
except (ValueError, TypeError):
pass
return min(limits) if limits else None
def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]:
return _get_deployment_default_limit(model_name, "default_api_key_rpm_limit")
def _get_deployment_default_tpm_limit(model_name: str) -> Optional[int]:
return _get_deployment_default_limit(model_name, "default_api_key_tpm_limit")
def get_key_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
model_name: Optional[str] = None,
) -> Optional[Dict[str, int]]:
"""
Get the model rpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (model_rpm_limit)
2. Key model_max_budget (rpm_limit per model)
3. Team metadata (model_rpm_limit)
4. Deployment default_api_key_rpm_limit (when model_name is provided)
"""
# 1. Check key metadata first (takes priority)
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("model_rpm_limit")
if result:
return result
# 2. Check model_max_budget
if user_api_key_dict.model_max_budget:
model_rpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict) and budget.get("rpm_limit") is not None:
model_rpm_limit[model] = budget["rpm_limit"]
if model_rpm_limit:
return model_rpm_limit
# 3. Fallback to team metadata
if user_api_key_dict.team_metadata:
team_limit = user_api_key_dict.team_metadata.get("model_rpm_limit")
if team_limit is not None:
return team_limit
# 4. Fallback to deployment default_api_key_rpm_limit
if model_name is not None:
default_limit = _get_deployment_default_rpm_limit(model_name)
if default_limit is not None:
return {model_name: default_limit}
return None
def get_key_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
model_name: Optional[str] = None,
) -> Optional[Dict[str, int]]:
"""
Get the model tpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (model_tpm_limit)
2. Key model_max_budget (tpm_limit per model)
3. Team metadata (model_tpm_limit)
4. Deployment default_api_key_tpm_limit (when model_name is provided)
"""
# 1. Check key metadata first (takes priority)
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("model_tpm_limit")
if result:
return result
# 2. Check model_max_budget (iterate per-model like RPM does)
if user_api_key_dict.model_max_budget:
model_tpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if isinstance(budget, dict) and budget.get("tpm_limit") is not None:
model_tpm_limit[model] = budget["tpm_limit"]
if model_tpm_limit:
return model_tpm_limit
# 3. Fallback to team metadata
if user_api_key_dict.team_metadata:
team_limit = user_api_key_dict.team_metadata.get("model_tpm_limit")
if team_limit is not None:
return team_limit
# 4. Fallback to deployment default_api_key_tpm_limit
if model_name is not None:
default_limit = _get_deployment_default_tpm_limit(model_name)
if default_limit is not None:
return {model_name: default_limit}
return None
def get_model_rate_limit_from_metadata(
user_api_key_dict: UserAPIKeyAuth,
metadata_accessor_key: Literal[
"team_metadata", "organization_metadata", "project_metadata"
],
rate_limit_key: Literal["model_rpm_limit", "model_tpm_limit"],
) -> Optional[Dict[str, int]]:
if getattr(user_api_key_dict, metadata_accessor_key):
return getattr(user_api_key_dict, metadata_accessor_key).get(rate_limit_key)
return None
def get_team_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_rpm_limit")
return None
def get_team_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("model_tpm_limit")
return None
def get_key_mcp_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
"""
Get the per-MCP-server rpm limit for a given api key.
Priority order (returns first found):
1. Key metadata (mcp_rpm_limit)
2. Team metadata (mcp_rpm_limit)
The returned dict is keyed by MCP server name (alias if set, else the
configured server name).
"""
if user_api_key_dict.metadata:
result = user_api_key_dict.metadata.get("mcp_rpm_limit")
if result is not None:
return result
if user_api_key_dict.team_metadata:
team_limit = user_api_key_dict.team_metadata.get("mcp_rpm_limit")
if team_limit is not None:
return team_limit
return None
def get_team_mcp_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.team_metadata:
return user_api_key_dict.team_metadata.get("mcp_rpm_limit")
return None
def get_project_model_rpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.project_metadata:
return user_api_key_dict.project_metadata.get("model_rpm_limit")
return None
def get_project_model_tpm_limit(
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[Dict[str, int]]:
if user_api_key_dict.project_metadata:
return user_api_key_dict.project_metadata.get("model_tpm_limit")
return None
def custom_auth_common_checks_warning(
*,
custom_auth_configured: bool,
run_common_checks: bool,
) -> str | None:
if not custom_auth_configured or run_common_checks:
return None
return (
"custom_auth is configured but 'custom_auth_run_common_checks' is not set. "
"Problem: budgets, model-access allowlists, and per-model rate limits configured "
"on your DB team/project records will NOT be enforced for custom-auth requests "
"(rate limits set directly on the returned UserAPIKeyAuth still apply). "
"Fix: set 'general_settings.custom_auth_run_common_checks: true'. "
"Docs: https://docs.litellm.ai/docs/proxy/custom_auth"
)
_custom_auth_common_checks_warning_emitted = False
def warn_once_if_custom_auth_skips_common_checks(
*,
custom_auth_configured: bool,
run_common_checks: bool,
logger: Logger = verbose_proxy_logger,
) -> None:
global _custom_auth_common_checks_warning_emitted
if _custom_auth_common_checks_warning_emitted:
return
message = custom_auth_common_checks_warning(
custom_auth_configured=custom_auth_configured,
run_common_checks=run_common_checks,
)
if message is None:
return
logger.warning(message)
_custom_auth_common_checks_warning_emitted = True
def is_pass_through_provider_route(route: str) -> bool:
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
"vertex-ai",
]
# check if any of the prefixes are in the route
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
if prefix in route:
return True
return False
def _has_user_setup_sso():
"""
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
Returns a boolean indicating whether SSO has been set up.
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
sso_setup = (
(microsoft_client_id is not None)
or (google_client_id is not None)
or (generic_client_id is not None)
)
return sso_setup
def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[list]:
"""Return the header_name mapped to CUSTOMER role, if any (dict-based)."""
if not user_id_mapping:
return None
items = user_id_mapping if isinstance(user_id_mapping, list) else [user_id_mapping]
customer_headers_mappings = []
for item in items:
if not isinstance(item, dict):
continue
role = item.get("litellm_user_role")
header_name = item.get("header_name")
if role is None or not header_name:
continue
if str(role).lower() == str(LitellmUserRoles.CUSTOMER).lower():
customer_headers_mappings.append(header_name.lower())
if customer_headers_mappings:
return customer_headers_mappings
return None
def _get_customer_id_from_standard_headers(
request_headers: Optional[dict],
) -> Optional[str]:
"""
Check standard customer ID headers for a customer/end-user ID.
This enables tools like Claude Code to pass customer IDs via ANTHROPIC_CUSTOM_HEADERS.
No configuration required - these headers are always checked.
Args:
request_headers: The request headers dict
Returns:
The customer ID if found in standard headers, None otherwise
"""
if request_headers is None:
return None
for standard_header in STANDARD_CUSTOMER_ID_HEADERS:
for header_name, header_value in request_headers.items():
if header_name.lower() == standard_header.lower():
user_id_str = _coerce_user_id_to_str(header_value)
if user_id_str:
return user_id_str
return None
def _coerce_user_id_to_str(value: Any) -> Optional[str]:
"""Return a usable end-user identifier string, or None if the value isn't one.
Always drops non-string structured values (dict/list/tuple/set) because
stringifying them produces garbage spend-log rows like
``"{'device_id': ...}"``. Strings that *decode* to a structured payload
are only rejected when ``litellm.validate_end_user_id_in_db`` is enabled
— operators who currently pass JSON-encoded identifiers keep their
existing behavior until they opt in. See
auth_utils.py:get_end_user_id_from_request_body for the extraction chain.
"""
if value is None:
return None
if isinstance(value, bool):
# bool is an int subclass; handle explicitly to avoid "True"/"False".
return None
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return None
# Reject strings that decode to a structured payload (JSON object/array)
# only when the operator has opted into end-user validation. Gating
# behind the flag preserves backwards compatibility for deployments
# that intentionally pass JSON-encoded user identifiers.
if litellm.validate_end_user_id_in_db and stripped[:1] in ("{", "["):
parsed = safe_json_loads(stripped)
if isinstance(parsed, (dict, list)):
return None
return stripped
# dict, list, tuple, set, arbitrary objects -> drop.
return None
def get_end_user_id_from_request_body(
request_body: dict, request_headers: Optional[dict] = None
) -> Optional[str]:
# Import general_settings here to avoid potential circular import issues at module level
# and to ensure it's fetched at runtime.
from litellm.proxy.proxy_server import general_settings
# Check 1: Standard customer ID headers (always checked, no configuration required)
customer_id = _get_customer_id_from_standard_headers(
request_headers=request_headers
)
if customer_id is not None:
return customer_id
# Check 2: Follow the user header mappings feature, if not found, then check for deprecated user_header_name (only if request_headers is provided)
# User query: "system not respecting user_header_name property"
# This implies the key in general_settings is 'user_header_name'.
if request_headers is not None:
custom_header_name_to_check: Optional[Union[list, str]] = None
# Prefer user mappings (new behavior)
user_id_mapping = general_settings.get("user_header_mappings", None)
if user_id_mapping:
custom_header_name_to_check = get_customer_user_header_from_mapping(
user_id_mapping
)
# Fallback to deprecated user_header_name if mapping did not specify
if not custom_header_name_to_check:
user_id_header_config_key = "user_header_name"
value = general_settings.get(user_id_header_config_key)
if isinstance(value, str) and value.strip() != "":
custom_header_name_to_check = value
# If we have a header name to check, try to read it from request headers
if isinstance(custom_header_name_to_check, list):
headers_lower = {k.lower(): v for k, v in request_headers.items()}
for expected_header in custom_header_name_to_check:
user_id_str = _coerce_user_id_to_str(headers_lower.get(expected_header))
if user_id_str:
return user_id_str
elif isinstance(custom_header_name_to_check, str):
for header_name, header_value in request_headers.items():
if header_name.lower() == custom_header_name_to_check.lower():
user_id_str = _coerce_user_id_to_str(header_value)
if user_id_str:
return user_id_str
# Check 3: 'user' field in request_body (commonly OpenAI)
if "user" in request_body:
user_id_str = _coerce_user_id_to_str(request_body["user"])
if user_id_str:
return user_id_str
def _as_dict(value: Any) -> dict:
# metadata / litellm_metadata can arrive as JSON strings from
# multipart/form-data or extra_body; coerce so string-encoded
# payloads can't evade end-user attribution.
if isinstance(value, dict):
return value
if isinstance(value, str):
parsed = safe_json_loads(value)
return parsed if isinstance(parsed, dict) else {}
return {}
# Check 4: 'litellm_metadata.user' in request_body (commonly Anthropic)
litellm_metadata = _as_dict(request_body.get("litellm_metadata"))
user_id_str = _coerce_user_id_to_str(litellm_metadata.get("user"))
if user_id_str:
return user_id_str
# Check 5: 'metadata.user_id' in request_body (another common pattern)
metadata_dict = _as_dict(request_body.get("metadata"))
user_id_str = _coerce_user_id_to_str(metadata_dict.get("user_id"))
if user_id_str:
return user_id_str
# Check 6: 'safety_identifier' in request body (OpenAI Responses API parameter)
# SECURITY NOTE: safety_identifier can be set by any caller in the request body.
# Only use this for end-user identification in trusted environments where you control
# the calling application. For untrusted callers, prefer using headers or server-side
# middleware to set the end_user_id to prevent impersonation.
user_id_str = _coerce_user_id_to_str(request_body.get("safety_identifier"))
if user_id_str:
return user_id_str
return None
MODEL_ROUTING_HEADER_NAME = "x-litellm-model"
_MODEL_ROUTING_ROUTE_MARKERS = (
"/files",
"/batches",
"/vector_stores",
"/skills",
"/evals",
"/fine_tuning",
"/videos",
)
_MODEL_ROUTING_HEADER_OR_QUERY_ROUTE_MARKERS = (
"/files",
"/batches",
"/skills",
"/evals",
)
_MODEL_ROUTING_QUERY_TARGET_MODEL_ROUTE_MARKERS = (
"/files",
"/batches",
"/fine_tuning",
)
_MODEL_ROUTING_BODY_TARGET_MODEL_ROUTE_MARKERS = (
"/files",
"/batches",
"/vector_stores",
)
_MODEL_ROUTING_COMPLETION_MODEL_ROUTE_MARKERS = ("/evals",)
# Realtime WebRTC routes carry the effective model inside the nested
# ``session.model`` field (see realtime_endpoints.endpoints), so the model the
# request will actually use is not present at the top level. Extract it here so
# can_key_call_model() validates the real target model.
_MODEL_ROUTING_SESSION_MODEL_ROUTE_MARKERS = (
"/realtime/client_secrets",
"/realtime/calls",
)
_MODEL_ROUTING_ID_FIELDS = (
"file_id",
"input_file_id",
"output_file_id",
"error_file_id",
"batch_id",
"fine_tuning_job_id",
"training_file",
"validation_file",
"vector_store_id",
"video_id",
"character_id",
)
def _append_model_candidates(candidates: List[str], value: Any) -> None:
if value is None:
return
values = value if isinstance(value, (list, tuple, set)) else [value]
for item in values:
if item is None:
continue
if isinstance(item, str):
model_names = [model.strip() for model in item.split(",")]
else:
model_names = [str(item).strip()]
candidates.extend(model for model in model_names if model)
def _dedupe_model_candidates(candidates: List[str]) -> List[str]:
deduped: List[str] = []
for model in candidates:
if model not in deduped:
deduped.append(model)
return deduped
def _get_case_insensitive_mapping_value(
mapping: Optional[Mapping[str, Any]], key: str
) -> Any:
if not mapping:
return None
if key in mapping:
return mapping[key]
key_lower = key.lower()
for mapping_key, value in mapping.items():
if str(mapping_key).lower() == key_lower:
return value
return None
def _route_matches_any_marker(route: str, markers: Tuple[str, ...]) -> bool:
normalized_route = route.lower()
return any(marker in normalized_route for marker in markers)
def _route_uses_model_routing_sources(route: str) -> bool:
return _route_matches_any_marker(route=route, markers=_MODEL_ROUTING_ROUTE_MARKERS)
def _extract_models_from_managed_resource_id(
resource_id: Any,
resource_id_field: Optional[str] = None,
llm_router: Optional[Router] = None,
) -> List[str]:
if not isinstance(resource_id, str) or not resource_id:
return []
candidates: List[str] = []
try:
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
decode_model_from_file_id,
get_model_id_from_unified_batch_id,
get_models_from_unified_file_id,
)
_append_model_candidates(
candidates=candidates, value=decode_model_from_file_id(resource_id)
)
unified_file_id = _is_base64_encoded_unified_file_id(resource_id)
if unified_file_id:
_append_model_candidates(
candidates=candidates,
value=get_models_from_unified_file_id(unified_file_id),
)
_append_model_candidates(
candidates=candidates,
value=get_model_id_from_unified_batch_id(unified_file_id),
)
except Exception as e:
verbose_proxy_logger.debug(
"Unable to extract model from managed file/batch ID: %s", str(e)
)
try:
from litellm.llms.base_llm.managed_resources.utils import parse_unified_id
parsed_id = parse_unified_id(resource_id)
if parsed_id:
_append_model_candidates(
candidates=candidates, value=parsed_id.get("model_id")
)
_append_model_candidates(
candidates=candidates, value=parsed_id.get("target_model_names")
)
except Exception as e:
verbose_proxy_logger.debug(
"Unable to extract model from unified managed resource ID: %s", str(e)
)
if resource_id_field in ("video_id", "character_id"):
try:
from litellm.types.videos.utils import (
decode_character_id_with_provider,
decode_video_id_with_provider,
)
if resource_id_field == "video_id":
model_id = decode_video_id_with_provider(resource_id).get("model_id")
_append_model_candidates(
candidates=candidates,
value=_resolve_model_id_with_router(model_id, llm_router),
)
else:
model_id = decode_character_id_with_provider(resource_id).get(
"model_id"
)
_append_model_candidates(
candidates=candidates,
value=_resolve_model_id_with_router(model_id, llm_router),
)
except Exception as e:
verbose_proxy_logger.debug(
"Unable to extract model from managed video/character ID: %s", str(e)
)
return _dedupe_model_candidates(candidates)
def _resolve_model_id_with_router(
model_id: Optional[str], llm_router: Optional[Router]
) -> Optional[str]:
if model_id is None or llm_router is None:
return model_id
try:
return llm_router.resolve_model_name_from_model_id(model_id) or model_id
except Exception as e:
verbose_proxy_logger.debug(
"Unable to resolve model_id from managed resource ID: %s", str(e)
)
return model_id
def _extract_model_candidates_from_request(
request_data: dict,
route: str,
request_headers: Optional[Mapping[str, Any]] = None,
request_query_params: Optional[Mapping[str, Any]] = None,
llm_router: Optional[Router] = None,
) -> List[str]:
candidates: List[str] = []
uses_model_routing_sources = _route_uses_model_routing_sources(route=route)
uses_header_or_query_model_sources = _route_matches_any_marker(
route=route, markers=_MODEL_ROUTING_HEADER_OR_QUERY_ROUTE_MARKERS
)
uses_query_target_model_sources = _route_matches_any_marker(
route=route, markers=_MODEL_ROUTING_QUERY_TARGET_MODEL_ROUTE_MARKERS
)
uses_body_target_model_sources = _route_matches_any_marker(
route=route, markers=_MODEL_ROUTING_BODY_TARGET_MODEL_ROUTE_MARKERS
)
uses_completion_model_sources = _route_matches_any_marker(
route=route, markers=_MODEL_ROUTING_COMPLETION_MODEL_ROUTE_MARKERS
)
body_model = request_data.get("model")
_append_model_candidates(candidates, body_model)
if uses_body_target_model_sources or not body_model:
_append_model_candidates(candidates, request_data.get("target_model_names"))
if _route_matches_any_marker(
route=route, markers=_MODEL_ROUTING_SESSION_MODEL_ROUTE_MARKERS
):
session = request_data.get("session")
if isinstance(session, dict):
_append_model_candidates(candidates, session.get("model"))
if uses_completion_model_sources and isinstance(
request_data.get("completion"), dict
):
_append_model_candidates(candidates, request_data["completion"].get("model"))
if uses_model_routing_sources:
if uses_header_or_query_model_sources:
_append_model_candidates(
candidates,
_get_case_insensitive_mapping_value(request_query_params, "model"),
)
_append_model_candidates(
candidates,
_get_case_insensitive_mapping_value(
request_headers, MODEL_ROUTING_HEADER_NAME
),
)
if uses_query_target_model_sources:
_append_model_candidates(
candidates,
_get_case_insensitive_mapping_value(
request_query_params, "target_model_names"
),
)
for field in _MODEL_ROUTING_ID_FIELDS:
_append_model_candidates(
candidates,
_extract_models_from_managed_resource_id(
request_data.get(field),
resource_id_field=field,
llm_router=llm_router,
),
)
return _dedupe_model_candidates(candidates)
def _format_model_candidates(
candidates: List[str],
) -> Optional[Union[str, List[str]]]:
if not candidates:
return None
if len(candidates) == 1:
return candidates[0]
return candidates
def get_model_from_request(
request_data: dict,
route: str,
request_headers: Optional[Mapping[str, Any]] = None,
request_query_params: Optional[Mapping[str, Any]] = None,
llm_router: Optional[Router] = None,
) -> Optional[Union[str, List[str]]]:
candidates = _extract_model_candidates_from_request(
request_data=request_data,
route=route,
request_headers=request_headers,
request_query_params=request_query_params,
llm_router=llm_router,
)
model = _format_model_candidates(candidates)
# If no explicit model was found, try to extract from route
if model is None:
# Parse model from route that follows the pattern /openai/deployments/{model}/*
match = re.match(r"/openai/deployments/([^/]+)", route)
if match:
model = match.group(1)
# If still not found, extract model from Google generateContent-style routes.
# These routes put the model in the path and allow "/" inside the model id.
# Examples:
# - /v1beta/models/gemini-2.0-flash:generateContent
# - /v1beta/models/bedrock/claude-sonnet-3.7:generateContent
# - /models/custom/ns/model:streamGenerateContent
if model is None and not route.lower().startswith("/vertex"):
google_match = re.search(r"/(?:v1beta|beta)/models/([^:]+):", route)
if google_match:
model = google_match.group(1)
if model is None and not route.lower().startswith("/vertex"):
google_match = re.search(r"^/models/([^:]+):", route)
if google_match:
model = google_match.group(1)
# If still not found, extract from Vertex AI passthrough route
# Pattern: /vertex_ai/.../models/{model_id}:*
# Example: /vertex_ai/v1/.../models/gemini-1.5-pro:generateContent
if model is None and route.lower().startswith("/vertex"):
vertex_match = re.search(r"/models/([^:]+)", route)
if vertex_match:
model = vertex_match.group(1)
return model
def abbreviate_api_key(api_key: str) -> str:
return f"sk-...{api_key[-4:]}"