Files
MoFin/venv/lib/python3.12/site-packages/litellm/proxy/auth/model_checks.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

432 lines
15 KiB
Python

# What is this?
## Common checks for /v1/models and `/model/info`
from typing import Dict, List, Optional, Set
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.router import Router
from litellm.router_utils.fallback_event_handlers import get_fallback_model_group
from litellm.types.router import CredentialLiteLLMParams, LiteLLM_Params
from litellm.types.utils import LlmProviders
from litellm.utils import get_valid_models
_CREDENTIAL_LITELLM_PARAM_FIELDS = set(CredentialLiteLLMParams.model_fields)
_CREDENTIAL_LITELLM_PARAM_FIELDS = set(CredentialLiteLLMParams.model_fields)
def _check_wildcard_routing(model: str) -> bool:
"""
Returns True if a model is a provider wildcard.
eg:
- anthropic/*
- openai/*
- *
"""
if "*" in model:
return True
return False
def get_provider_models(
provider: str, litellm_params: Optional[LiteLLM_Params] = None
) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models(litellm_params=litellm_params)
if provider in litellm.models_by_provider:
provider_models = get_valid_models(
custom_llm_provider=provider, litellm_params=litellm_params
)
return provider_models
return None
def _get_models_from_access_groups(
model_access_groups: Dict[str, List[str]],
all_models: List[str],
include_model_access_groups: Optional[bool] = False,
) -> List[str]:
idx_to_remove = []
new_models = []
for idx, model in enumerate(all_models):
if model in model_access_groups:
if (
not include_model_access_groups
): # remove access group, unless requested - e.g. when creating a key
idx_to_remove.append(idx)
new_models.extend(model_access_groups[model])
for idx in sorted(idx_to_remove, reverse=True):
all_models.pop(idx)
all_models.extend(new_models)
return all_models
async def get_mcp_server_ids(
user_api_key_dict: UserAPIKeyAuth,
) -> List[str]:
"""
Returns the list of MCP server ids for a given key by querying the object_permission table
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
return []
if user_api_key_dict.object_permission_id is None:
return []
# Make a direct SQL query to get just the mcp_servers
try:
result = await ObjectPermissionRepository(prisma_client).table.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
if result and result.mcp_servers:
return result.mcp_servers
return []
except Exception:
return []
def get_key_models(
user_api_key_dict: UserAPIKeyAuth,
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
"""
all_models: List[str] = []
if len(user_api_key_dict.models) > 0:
all_models = list(
user_api_key_dict.models
) # copy to avoid mutating cached objects
if (
SpecialModelNames.all_team_models.value in all_models
and user_api_key_dict.team_id is not None
):
all_models = list(
user_api_key_dict.team_models
) # copy to avoid mutating cached objects
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = list(proxy_model_list) # copy to avoid mutating caller's list
if include_model_access_groups:
all_models.extend(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
all_models=all_models,
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
team_models: List[str],
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
"""
all_models_set: Set[str] = set()
if len(team_models) > 0:
all_models_set.update(team_models)
if SpecialModelNames.all_team_models.value in all_models_set:
all_models_set.update(team_models)
if SpecialModelNames.all_proxy_models.value in all_models_set:
all_models_set.update(proxy_model_list)
if include_model_access_groups:
all_models_set.update(model_access_groups.keys())
all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups,
all_models=list(all_models_set),
include_model_access_groups=include_model_access_groups,
)
# deduplicate while preserving order
all_models = list(dict.fromkeys(all_models))
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
model_access_groups: Dict[str, List[str]] = {},
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
team_id: Optional[str] = None,
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
If list contains wildcard -> return known provider models
"""
unique_models = []
def append_unique(models):
for model in models:
if model not in unique_models:
unique_models.append(model)
if key_models:
append_unique(key_models)
elif team_models:
append_unique(team_models)
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
if user_model:
append_unique([user_model])
if infer_model_from_keys:
valid_models = get_valid_models()
append_unique(valid_models)
if only_model_access_groups:
model_access_groups_to_return: List[str] = []
for model in unique_models:
if model in model_access_groups:
model_access_groups_to_return.append(model)
return model_access_groups_to_return
all_wildcard_models = _get_wildcard_models(
unique_models=unique_models,
return_wildcard_routes=return_wildcard_routes,
llm_router=llm_router,
team_id=team_id,
)
complete_model_list = unique_models + all_wildcard_models
return complete_model_list
def _hydrate_litellm_credential_name(
litellm_params: Optional[LiteLLM_Params],
) -> Optional[LiteLLM_Params]:
if litellm_params is None or litellm_params.litellm_credential_name is None:
return litellm_params
credential_values = CredentialAccessor.get_credential_values(
litellm_params.litellm_credential_name
)
if not credential_values:
return litellm_params
litellm_params = litellm_params.model_copy()
for key, value in credential_values.items():
if (
key in _CREDENTIAL_LITELLM_PARAM_FIELDS
and getattr(litellm_params, key, None) is None
):
setattr(litellm_params, key, value)
litellm_params.litellm_credential_name = None
return litellm_params
def get_known_models_from_wildcard(
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
) -> List[str]:
try:
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# Use provider from litellm_params when available, otherwise from wildcard prefix
# (e.g., "openai" from "openai/*" - needed for BYOK where wildcard isn't in router)
if litellm_params is not None:
try:
provider = litellm_params.model.split("/", 1)[0]
except ValueError:
provider = wildcard_provider_prefix
else:
provider = wildcard_provider_prefix
litellm_params = _hydrate_litellm_credential_name(litellm_params)
wildcard_models = get_provider_models(
provider=provider, litellm_params=litellm_params
)
if wildcard_models is None:
return []
if wildcard_suffix != "*":
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
model_prefix = wildcard_suffix.replace("*", "")
is_partial_filter = any(
wc_model.startswith(model_prefix) for wc_model in wildcard_models
)
if is_partial_filter:
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.startswith(model_prefix)
]
wildcard_models = filtered_wildcard_models
else:
# add model prefix to wildcard models
wildcard_models = [f"{model_prefix}{model}" for model in wildcard_models]
known_providers = {provider.value for provider in LlmProviders}
suffix_appended_wildcard_models = []
for model in wildcard_models:
if not model.startswith(wildcard_provider_prefix):
# `get_provider_models` returns provider-prefixed ids (e.g. "ollama/gemma3:1b").
# When the wildcard uses a custom prefix (e.g. "ollama_server1/*" to distinguish
# multiple instances), replace that existing provider prefix instead of stacking
# both, which would otherwise yield an uncallable "ollama_server1/ollama/gemma3:1b".
# Only strip the leading segment when it is a known provider, so ids whose first
# segment is an org rather than a provider (e.g. "meta-llama/Llama-3-8B") keep it.
leading, sep, model_suffix = model.partition("/")
if sep and leading in known_providers:
model = f"{wildcard_provider_prefix}/{model_suffix}"
else:
model = f"{wildcard_provider_prefix}/{model}"
suffix_appended_wildcard_models.append(model)
return suffix_appended_wildcard_models or []
def _get_wildcard_models(
unique_models: List[str],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
team_id: Optional[str] = None,
) -> List[str]:
models_to_remove = set()
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
if (
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)
## get litellm params from model
if llm_router is not None:
model_list = llm_router.get_model_list(
model_name=model, team_id=team_id
)
if model_list:
for router_model in model_list:
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model,
litellm_params=LiteLLM_Params(
**router_model["litellm_params"] # type: ignore
),
)
all_wildcard_models.extend(wildcard_models)
else:
# Router has no deployment for this wildcard (e.g., BYOK team models)
# Fall back to expanding from known provider models
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model, litellm_params=None
)
if wildcard_models:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
else:
# get all known provider models
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model, litellm_params=None
)
if wildcard_models:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
for model in models_to_remove:
unique_models.remove(model)
return all_wildcard_models
def get_all_fallbacks(
model: str,
llm_router: Optional[Router] = None,
fallback_type: str = "general",
) -> List[str]:
"""
Get all fallbacks for a given model from the router's fallback configuration.
Args:
model: The model name to get fallbacks for
llm_router: The LiteLLM router instance
fallback_type: Type of fallback ("general", "context_window", "content_policy")
Returns:
List of fallback model names. Empty list if no fallbacks found.
"""
if llm_router is None:
return []
# Get the appropriate fallback list based on type
fallbacks_config: list = []
if fallback_type == "general":
fallbacks_config = getattr(llm_router, "fallbacks", [])
elif fallback_type == "context_window":
fallbacks_config = getattr(llm_router, "context_window_fallbacks", [])
elif fallback_type == "content_policy":
fallbacks_config = getattr(llm_router, "content_policy_fallbacks", [])
else:
verbose_proxy_logger.warning(f"Unknown fallback_type: {fallback_type}")
return []
if not fallbacks_config:
return []
try:
# Use existing function to get fallback model group
fallback_model_group, _ = get_fallback_model_group(
fallbacks=fallbacks_config, model_group=model
)
if fallback_model_group is None:
return []
return fallback_model_group
except Exception as e:
verbose_proxy_logger.error(f"Error getting fallbacks for model {model}: {e}")
return []