""" Supports using JWT's for authenticating into the proxy. Currently only supports admin. JWT token must have 'litellm_proxy_admin' in scope. """ from __future__ import annotations import fnmatch import hashlib import os import re from typing import Any, List, Literal, Optional, Set, Tuple, Union, cast import jwt from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from fastapi import HTTPException, status from jwt.api_jwk import PyJWK from litellm._logging import verbose_proxy_logger from litellm.constants import DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import ( RBAC_ROLES, JWKKeyValue, JWTAuthBuilderResult, JWTIssuerConfig, JWTKeyItem, LiteLLM_EndUserTable, LiteLLM_JWTAuth, LiteLLM_OrganizationTable, LiteLLM_TeamMembership, LiteLLM_TeamTable, LiteLLM_UserTable, LitellmUserRoles, Member, ProxyErrorTypes, ProxyException, ScopeMapping, Span, TeamMemberAddRequest, UserAPIKeyAuth, ) from litellm.proxy.auth.auth_checks import can_team_access_model from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.repositories.user_repository import UserRepository from .auth_checks import ( _allowed_routes_check, allowed_routes_check, get_actual_routes, get_end_user_object, get_org_object, get_org_object_by_alias, get_role_based_models, get_role_based_routes, get_team_membership, get_team_object, get_team_object_by_alias, get_user_object, ) class NoMatchingJWTPublicKeyError(Exception): """Raised when a JWKS endpoint returns no key matching the requested ``kid``.""" class JWTHandler: """ - treat the sub id passed in as the user id - return an error if id making request doesn't exist in proxy user table - track spend against the user id - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets """ prisma_client: Optional[PrismaClient] user_api_key_cache: UserApiKeyCache # Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html # "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret # the key in different ways (e.g. HS* and RS*)." SUPPORTED_JWT_ALGORITHMS = [ "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES384", "ES512", "EdDSA", ] LITELLM_JWT_ISSUER_CLAIM = "_litellm_jwt_issuer" LITELLM_USER_ID_CLAIM = "_litellm_user_id" LITELLM_USER_EMAIL_CLAIM = "_litellm_user_email" LITELLM_TEAM_ID_CLAIM = "_litellm_team_id" LITELLM_TEAM_IDS_CLAIM = "_litellm_team_ids" LITELLM_ORG_ID_CLAIM = "_litellm_org_id" LITELLM_END_USER_ID_CLAIM = "_litellm_end_user_id" LITELLM_INTERNAL_CLAIMS = ( LITELLM_JWT_ISSUER_CLAIM, LITELLM_USER_ID_CLAIM, LITELLM_USER_EMAIL_CLAIM, LITELLM_TEAM_ID_CLAIM, LITELLM_TEAM_IDS_CLAIM, LITELLM_ORG_ID_CLAIM, LITELLM_END_USER_ID_CLAIM, ) def __init__( self, ) -> None: self.http_handler = HTTPHandler() self.leeway = 0 def update_environment( self, prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, litellm_jwtauth: LiteLLM_JWTAuth, leeway: int = 0, ) -> None: self.prisma_client = prisma_client self.user_api_key_cache = user_api_key_cache self.litellm_jwtauth = litellm_jwtauth self.leeway = leeway @staticmethod def is_jwt(token: Optional[str]) -> bool: if token is None: return False parts = token.split(".") return len(parts) == 3 @staticmethod def get_unverified_claims(token: str) -> Optional[dict]: """ Decode JWT claims without signature verification. Used for routing decisions before selecting validation path. """ if not JWTHandler.is_jwt(token): return None try: claims = jwt.decode( token, options={"verify_signature": False, "verify_aud": False}, algorithms=JWTHandler.SUPPORTED_JWT_ALGORITHMS, ) if isinstance(claims, dict): return claims return None except Exception as e: verbose_proxy_logger.debug( "Failed to decode unverified JWT claims for routing: %s", e ) return None def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]: """ Returns the RBAC role the token 'belongs' to based on role mappings. Args: token (dict): The JWT token containing role information Returns: Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists, None otherwise Note: The function handles both single string roles and lists of roles from the JWT. If multiple mappings match the JWT roles, the first matching mapping is returned. """ if self.litellm_jwtauth.role_mappings is None: return None jwt_role = self.get_jwt_role(token=token, default_value=None) if not jwt_role: return None jwt_role_set = set(jwt_role) for role_mapping in self.litellm_jwtauth.role_mappings: # Check if the mapping role matches any of the JWT roles if role_mapping.role in jwt_role_set: return role_mapping.internal_role return None def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: """ Returns the RBAC role the token 'belongs' to. RBAC roles allowed to make requests: - PROXY_ADMIN: can make requests to all routes - TEAM: can make requests to routes associated with a team - INTERNAL_USER: can make requests to routes associated with a user Resolves: https://github.com/BerriAI/litellm/issues/6793 Returns: - PROXY_ADMIN: if token is admin - TEAM: if token is associated with a team - INTERNAL_USER: if token is associated with a user - None: if token is not associated with a team or user """ scopes = self.get_scopes(token=token) is_admin = self.is_admin(scopes=scopes) user_roles = self.get_user_roles(token=token, default_value=None) if is_admin: return LitellmUserRoles.PROXY_ADMIN elif self.get_team_id(token=token, default_value=None) is not None: return LitellmUserRoles.TEAM elif self.get_user_id(token=token, default_value=None) is not None: return LitellmUserRoles.INTERNAL_USER elif user_roles is not None and self.is_allowed_user_role( user_roles=user_roles ): return LitellmUserRoles.INTERNAL_USER elif rbac_role := self._rbac_role_from_role_mapping(token=token): return rbac_role return None def is_admin(self, scopes: list) -> bool: if self.litellm_jwtauth.admin_jwt_scope in scopes: return True return False def _is_trusted_issuer_normalized_token(self, token: dict) -> bool: issuer = token.get(self.LITELLM_JWT_ISSUER_CLAIM) if not isinstance(issuer, str) or not issuer: return False litellm_jwtauth = getattr(self, "litellm_jwtauth", None) issuer_configs = getattr(litellm_jwtauth, "issuers", None) or [] return any(issuer_config.issuer == issuer for issuer_config in issuer_configs) def _has_trusted_issuer_normalized_claim(self, token: dict, claim: str) -> bool: return self._is_trusted_issuer_normalized_token(token=token) and claim in token def get_team_ids_from_jwt(self, token: dict) -> List[str]: if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_TEAM_IDS_CLAIM ): issuer_team_ids = token.get(self.LITELLM_TEAM_IDS_CLAIM) if isinstance(issuer_team_ids, list): return issuer_team_ids if isinstance(issuer_team_ids, str): return [issuer_team_ids] # Issuer-scoped claim exists but has an unexpected type # (e.g. int/dict from an unusual upstream mapping). Don't silently # fall through to the global ``team_ids_jwt_field`` path — that # would read a semantically unrelated claim on the same token. return [] if self.litellm_jwtauth.team_ids_jwt_field is not None: team_ids: Optional[List[str]] = get_nested_value( data=token, key_path=self.litellm_jwtauth.team_ids_jwt_field, default=[], ) return team_ids or [] return [] def get_all_jwt_team_ids(self, token: dict) -> List[str]: """ Return team IDs from both the plural ``team_ids_jwt_field`` and the singular ``team_id_jwt_field`` claim (string or list of strings), as a deduplicated list preserving plural-first order. Membership-reconciliation paths (SSO callback, JWT-bearer sync) need to consider both claim shapes. Reading only the plural field — as callers historically did — silently dropped users whose IdP populates the singular field, which is what Okta and Auth0 default to when a user has a single primary team. This intentionally does NOT consult ``team_id_default``: that fallback is a property of how the JWT-bearer auth flow resolves a single request-bound team, not of the token's claims. Callers that want the default-team behavior should still go through ``get_team_id``. """ team_ids: List[str] = list(self.get_team_ids_from_jwt(token)) singular: Any = None if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_TEAM_ID_CLAIM ): singular = token.get(self.LITELLM_TEAM_ID_CLAIM) elif self.litellm_jwtauth.team_id_jwt_field is not None: singular = get_nested_value( data=token, key_path=self.litellm_jwtauth.team_id_jwt_field, default=None, ) if singular is not None: if isinstance(singular, list): for item in singular: if item is None: continue sid = str(item) if sid and sid not in team_ids: team_ids.append(sid) elif singular and str(singular) not in team_ids: team_ids.append(str(singular)) return team_ids def get_end_user_id( self, token: dict, default_value: Optional[str] ) -> Optional[str]: if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_END_USER_ID_CLAIM ): return token.get(self.LITELLM_END_USER_ID_CLAIM) try: if self.litellm_jwtauth.end_user_id_jwt_field is not None: user_id = get_nested_value( data=token, key_path=self.litellm_jwtauth.end_user_id_jwt_field, default=default_value, ) else: user_id = None except KeyError: user_id = default_value return user_id def is_required_team_id(self) -> bool: """ Returns: - True: if 'team_id_jwt_field' or 'team_alias_jwt_field' is set - False: if neither is set """ if ( self.litellm_jwtauth.team_id_jwt_field is None and self.litellm_jwtauth.team_alias_jwt_field is None ): return False return True def is_enforced_email_domain(self) -> bool: """ Returns: - True: if 'user_allowed_email_domain' is set - False: if 'user_allowed_email_domain' is None """ if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance( self.litellm_jwtauth.user_allowed_email_domain, str ): return True return False def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_TEAM_ID_CLAIM ): team_id = token.get(self.LITELLM_TEAM_ID_CLAIM) if isinstance(team_id, list): return team_id[0] if team_id else default_value return team_id try: if self.litellm_jwtauth.team_id_jwt_field is not None: # Use a sentinel value to detect if the path actually exists sentinel = object() team_id = get_nested_value( data=token, key_path=self.litellm_jwtauth.team_id_jwt_field, default=sentinel, ) if team_id is sentinel: # Path doesn't exist, use team_id_default if available if self.litellm_jwtauth.team_id_default is not None: return self.litellm_jwtauth.team_id_default else: return default_value # AAD and other IdPs often send roles/groups as a list of strings. # team_id_jwt_field is singular, so take the first element when a list # is returned. This avoids "unhashable type: 'list'" errors downstream. if isinstance(team_id, list): if not team_id: return default_value verbose_proxy_logger.debug( f"JWT Auth: team_id_jwt_field '{self.litellm_jwtauth.team_id_jwt_field}' " f"returned a list {team_id}; using first element '{team_id[0]}' automatically." ) team_id = team_id[0] return team_id # type: ignore[return-value] elif self.litellm_jwtauth.team_id_default is not None: team_id = self.litellm_jwtauth.team_id_default else: team_id = None except KeyError: team_id = default_value return team_id def get_team_alias( self, token: dict, default_value: Optional[str] ) -> Optional[str]: """ Extract team name/alias from JWT token using the configured team_alias_jwt_field. Args: token: The decoded JWT token dictionary default_value: Default value to return if field not found Returns: The team alias from the token, or default_value if not found """ try: if self.litellm_jwtauth.team_alias_jwt_field is not None: team_alias = get_nested_value( data=token, key_path=self.litellm_jwtauth.team_alias_jwt_field, default=default_value, ) return team_alias else: team_alias = None except KeyError: team_alias = default_value return team_alias def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool: """ Returns: - True: if 'user_id_upsert' is set AND valid_user_email is not False - False: if not """ if valid_user_email is False: return False return self.litellm_jwtauth.user_id_upsert def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_USER_ID_CLAIM ): return token.get(self.LITELLM_USER_ID_CLAIM) try: if self.litellm_jwtauth.user_id_jwt_field is not None: user_id = get_nested_value( data=token, key_path=self.litellm_jwtauth.user_id_jwt_field, default=default_value, ) else: user_id = default_value except KeyError: user_id = default_value return user_id def get_user_roles( self, token: dict, default_value: Optional[List[str]] ) -> Optional[List[str]]: """ Returns the user role from the token. Set via 'user_roles_jwt_field' in the config. """ try: if self.litellm_jwtauth.user_roles_jwt_field is not None: user_roles = get_nested_value( data=token, key_path=self.litellm_jwtauth.user_roles_jwt_field, default=default_value, ) else: user_roles = default_value except KeyError: user_roles = default_value return user_roles def map_jwt_role_to_litellm_role(self, token: dict) -> Optional[LitellmUserRoles]: """Map roles from JWT to LiteLLM user roles""" if not self.litellm_jwtauth.jwt_litellm_role_map: return None jwt_roles = self.get_jwt_role(token=token, default_value=[]) if not jwt_roles: return None for mapping in self.litellm_jwtauth.jwt_litellm_role_map: for role in jwt_roles: if fnmatch.fnmatch(role, mapping.jwt_role): return mapping.litellm_role return None def get_jwt_role( self, token: dict, default_value: Optional[List[str]] ) -> Optional[List[str]]: """ Generic implementation of `get_user_roles` that can be used for both user and team roles. Returns the jwt role from the token. Set via 'roles_jwt_field' in the config. """ try: if self.litellm_jwtauth.roles_jwt_field is not None: user_roles = get_nested_value( data=token, key_path=self.litellm_jwtauth.roles_jwt_field, default=default_value, ) else: user_roles = default_value except KeyError: user_roles = default_value return user_roles def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: """ Returns the user role from the token. Set via 'user_allowed_roles' in the config. """ if ( user_roles is not None and self.litellm_jwtauth.user_allowed_roles is not None and any( role in self.litellm_jwtauth.user_allowed_roles for role in user_roles ) ): return True return False def get_user_email( self, token: dict, default_value: Optional[str] ) -> Optional[str]: if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_USER_EMAIL_CLAIM ): return token.get(self.LITELLM_USER_EMAIL_CLAIM) try: if self.litellm_jwtauth.user_email_jwt_field is not None: user_email = get_nested_value( data=token, key_path=self.litellm_jwtauth.user_email_jwt_field, default=default_value, ) else: user_email = None except KeyError: user_email = default_value return user_email def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: try: if self.litellm_jwtauth.object_id_jwt_field is not None: object_id = get_nested_value( data=token, key_path=self.litellm_jwtauth.object_id_jwt_field, default=default_value, ) else: object_id = default_value except KeyError: object_id = default_value return object_id def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: if self._has_trusted_issuer_normalized_claim( token=token, claim=self.LITELLM_ORG_ID_CLAIM ): return token.get(self.LITELLM_ORG_ID_CLAIM) try: if self.litellm_jwtauth.org_id_jwt_field is not None: org_id = get_nested_value( data=token, key_path=self.litellm_jwtauth.org_id_jwt_field, default=default_value, ) else: org_id = None except KeyError: org_id = default_value return org_id def get_org_alias(self, token: dict, default_value: Optional[str]) -> Optional[str]: """ Extract organization name/alias from JWT token using the configured org_alias_jwt_field. Args: token: The decoded JWT token dictionary default_value: Default value to return if field not found Returns: The organization alias from the token, or default_value if not found """ try: if self.litellm_jwtauth.org_alias_jwt_field is not None: org_alias = get_nested_value( data=token, key_path=self.litellm_jwtauth.org_alias_jwt_field, default=default_value, ) return org_alias else: org_alias = None except KeyError: org_alias = default_value return org_alias def get_scopes(self, token: dict) -> List[str]: try: if isinstance(token["scope"], str): # Assuming the scopes are stored in 'scope' claim and are space-separated scopes = token["scope"].split() elif isinstance(token["scope"], list): scopes = token["scope"] else: raise Exception( f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str." ) except KeyError: scopes = [] return scopes async def _resolve_jwks_url(self, url: str) -> str: """ If url points to an OIDC discovery document (*.well-known/openid-configuration), fetch it and return the jwks_uri contained within. Otherwise return url unchanged. This lets JWT_PUBLIC_KEY_URL be set to a well-known discovery endpoint instead of requiring operators to manually find the JWKS URL. """ if ".well-known/openid-configuration" not in url: return url cache_key = f"litellm_oidc_discovery_{url}" cached_jwks_uri = await self.user_api_key_cache.async_get_cache(cache_key) if cached_jwks_uri is not None: return cached_jwks_uri verbose_proxy_logger.debug( f"JWT Auth: Fetching OIDC discovery document from {url}" ) response = await self.http_handler.get(url) if response.status_code != 200: raise Exception( f"JWT Auth: OIDC discovery endpoint {url} returned status {response.status_code}: {response.text}" ) try: discovery = response.json() except Exception as e: raise Exception( f"JWT Auth: Failed to parse OIDC discovery document at {url}: {e}" ) jwks_uri = discovery.get("jwks_uri") if not jwks_uri: raise Exception( f"JWT Auth: OIDC discovery document at {url} does not contain a 'jwks_uri' field." ) verbose_proxy_logger.debug( f"JWT Auth: Resolved OIDC discovery {url} -> jwks_uri={jwks_uri}" ) await self.user_api_key_cache.async_set_cache( key=cache_key, value=jwks_uri, ttl=self._get_public_key_cache_ttl(), ) return jwks_uri def _get_public_key_cache_ttl(self) -> float: litellm_jwtauth = getattr(self, "litellm_jwtauth", None) if litellm_jwtauth is None: return 600 return litellm_jwtauth.public_key_ttl async def _get_public_key_from_jwks_url( self, jwks_url: str, kid: Optional[str] ) -> dict: resolved_jwks_url = await self._resolve_jwks_url(jwks_url) cache_key = f"litellm_jwt_auth_keys_{resolved_jwks_url}" cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) if cached_keys is None: response = await self.http_handler.get(resolved_jwks_url) try: response_json = response.json() except Exception as e: verbose_proxy_logger.error( f"Error parsing response: {e}. Original Response: {response.text}" ) raise Exception( f"Error parsing response: {e}. Check server logs for original response." ) if "keys" in response_json: keys: JWKKeyValue = response_json["keys"] else: keys = response_json await self.user_api_key_cache.async_set_cache( key=cache_key, value=keys, ttl=self._get_public_key_cache_ttl(), ) else: keys = cached_keys public_key = self.parse_keys(keys=keys, kid=kid) if public_key is not None: return cast(dict, public_key) raise NoMatchingJWTPublicKeyError( f"No matching public key found. keys={resolved_jwks_url}, kid={kid}" ) async def get_public_key(self, kid: Optional[str]) -> dict: keys_url = os.getenv("JWT_PUBLIC_KEY_URL") if keys_url is None: raise Exception("Missing JWT Public Key URL from environment.") keys_url_list = [url.strip() for url in keys_url.split(",") if url.strip()] for key_url in keys_url_list: try: return await self._get_public_key_from_jwks_url( jwks_url=key_url, kid=kid ) except NoMatchingJWTPublicKeyError as e: verbose_proxy_logger.debug( "JWT Auth: No matching public key found at %s: %s", key_url, e ) raise NoMatchingJWTPublicKeyError( f"No matching public key found. keys={keys_url_list}, kid={kid}" ) def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: public_key: Optional[JWTKeyItem] = None if len(keys) == 1: if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None): public_key = keys elif isinstance(keys, list) and ( keys[0].get("kid", None) == kid or kid is None ): public_key = keys[0] elif len(keys) > 1: for key in keys: if isinstance(key, dict): key_kid = key.get("kid", None) else: key_kid = None if ( kid is not None and isinstance(key, dict) and key_kid is not None and key_kid == kid ): public_key = key return public_key def is_allowed_domain(self, user_email: str) -> bool: if self.litellm_jwtauth.user_allowed_email_domain is None: return True email_domain = user_email.split("@")[-1] # Extract domain from email if email_domain == self.litellm_jwtauth.user_allowed_email_domain: return True else: return False async def get_oidc_userinfo(self, token: str) -> dict: """ Fetch user information from OIDC UserInfo endpoint. This follows the OpenID Connect protocol where an access token is sent to the identity provider's UserInfo endpoint to retrieve user identity information. Args: token: The access token to use for authentication Returns: dict: User information from the UserInfo endpoint Raises: Exception: If UserInfo endpoint is not configured or request fails """ if not self.litellm_jwtauth.oidc_userinfo_endpoint: raise Exception( "OIDC UserInfo endpoint not configured. Set 'oidc_userinfo_endpoint' in JWT auth config." ) # Check cache first cache_key = f"oidc_userinfo_{hashlib.sha256(token.encode()).hexdigest()}" cached_userinfo = await self.user_api_key_cache.async_get_cache(cache_key) if cached_userinfo is not None: verbose_proxy_logger.debug("Returning cached OIDC UserInfo") return cached_userinfo verbose_proxy_logger.debug( f"Calling OIDC UserInfo endpoint: {self.litellm_jwtauth.oidc_userinfo_endpoint}" ) try: # Call the UserInfo endpoint with the access token response = await self.http_handler.get( url=self.litellm_jwtauth.oidc_userinfo_endpoint, headers={ "Authorization": f"Bearer {token}", "Accept": "application/json", }, ) if response.status_code != 200: raise Exception( f"OIDC UserInfo endpoint returned status {response.status_code}: {response.text}" ) userinfo = response.json() verbose_proxy_logger.debug(f"Received OIDC UserInfo: {userinfo}") # Cache the userinfo response await self.user_api_key_cache.async_set_cache( key=cache_key, value=userinfo, ttl=self.litellm_jwtauth.oidc_userinfo_cache_ttl, ) return userinfo except Exception as e: verbose_proxy_logger.error(f"Error fetching OIDC UserInfo: {str(e)}") raise Exception(f"Failed to fetch OIDC UserInfo: {str(e)}") _unscoped_jwt_warning_emitted = False @classmethod def _build_decode_kwargs(cls) -> dict: """Build the audience/issuer/options kwargs for ``jwt.decode``. Setting ``JWT_AUDIENCE`` (and optionally ``JWT_ISSUER``) turns on the corresponding PyJWT verifications, blocking cross-tenant tokens minted by other applications that share the same IdP signing keys. When both are unset PyJWT only checks the signature and expiry, which is preserved for backward compatibility but logged once as a warning. The warning fires even in mixed deployments that also configure ``LiteLLM_JWTAuth.issuers``: tokens whose ``iss`` does not match any configured issuer fall through to this global path, and if env-var scoping is absent that fallback is itself unscoped. """ audience = os.getenv("JWT_AUDIENCE") issuer = os.getenv("JWT_ISSUER") if ( audience is None and issuer is None and not cls._unscoped_jwt_warning_emitted ): verbose_proxy_logger.warning( "JWT auth is enabled but neither JWT_AUDIENCE nor JWT_ISSUER " "is configured. Tokens minted by any application that shares " "the same IdP signing keys will be accepted. Set JWT_AUDIENCE " "(and ideally JWT_ISSUER) to scope this proxy." ) cls._unscoped_jwt_warning_emitted = True options: dict = {} if audience is None: options["verify_aud"] = False if issuer is None: options["verify_iss"] = False return { "audience": audience, "issuer": issuer, "options": options or None, } def _get_configured_issuer(self, token: str) -> Optional[JWTIssuerConfig]: litellm_jwtauth = getattr(self, "litellm_jwtauth", None) if litellm_jwtauth is None: return None issuer_configs = litellm_jwtauth.issuers if not issuer_configs: return None claims = self.get_unverified_claims(token=token) if claims is None: return None issuer = claims.get("iss") if not isinstance(issuer, str) or not issuer: return None for issuer_config in issuer_configs: if issuer_config.issuer == issuer: return issuer_config return None def _get_jwks_url_for_issuer(self, issuer_config: JWTIssuerConfig) -> str: if issuer_config.jwks_url: return issuer_config.jwks_url # _resolve_jwks_url fetches this OIDC discovery document and follows # its jwks_uri, matching JWTIssuerConfig.jwks_url's documented fallback. return f"{issuer_config.issuer.rstrip('/')}/.well-known/openid-configuration" def _get_claim_value_for_issuer_mapping(self, token: dict, claim_field: str) -> Any: """Resolve a mapped claim from ``token``. Returns ``None`` when the field is absent or empty so that mapped claims behave like the global ``litellm_jwtauth`` path — present claims override the normalised value, missing ones simply leave it ``None``. """ sentinel = object() claim_value = get_nested_value( data=token, key_path=claim_field, default=sentinel, ) if claim_value is sentinel or claim_value is None or claim_value == "": return None return claim_value def _apply_issuer_claim_mappings( self, token: dict, issuer_config: JWTIssuerConfig ) -> dict: normalized: dict = { k: v for k, v in token.items() if k not in self.LITELLM_INTERNAL_CLAIMS } normalized[self.LITELLM_JWT_ISSUER_CLAIM] = issuer_config.issuer claim_mappings = [ (issuer_config.user_id_jwt_field, self.LITELLM_USER_ID_CLAIM), (issuer_config.user_email_jwt_field, self.LITELLM_USER_EMAIL_CLAIM), (issuer_config.team_id_jwt_field, self.LITELLM_TEAM_ID_CLAIM), (issuer_config.team_ids_jwt_field, self.LITELLM_TEAM_IDS_CLAIM), (issuer_config.org_id_jwt_field, self.LITELLM_ORG_ID_CLAIM), (issuer_config.end_user_id_jwt_field, self.LITELLM_END_USER_ID_CLAIM), ] for source_claim, normalized_claim in claim_mappings: if source_claim is None: continue claim_value = self._get_claim_value_for_issuer_mapping( token=token, claim_field=source_claim, ) if claim_value is not None: normalized[normalized_claim] = claim_value return normalized def _get_jwk_from_public_key(self, public_key: dict) -> dict: jwk = {} for key in ["kty", "kid", "n", "e", "x", "y", "crv"]: if key in public_key: jwk[key] = public_key[key] return jwk def _get_decode_options( self, audience: Optional[Union[str, List[str]]], issuer: Optional[str] = None, disable_audience_validation: bool = False, ) -> Optional[dict]: # Disabling audience verification must be an explicit choice — never # an implicit consequence of ``audience`` being None. Otherwise a # caller that accidentally constructs a config with ``audience=None`` # (bypassing the model validator) would silently lose audience # validation. Require callers to opt in via # ``disable_audience_validation=True``. if audience is None and not disable_audience_validation: raise ValueError( "audience must be provided unless disable_audience_validation=True" ) options: dict = {} if audience is None: options["verify_aud"] = False if issuer is None: options["verify_iss"] = False return options or None def _decode_jwt_with_public_key( self, token: str, public_key: Union[dict, str], audience: Optional[Union[str, List[str]]], issuer: Optional[str] = None, options: Optional[dict] = None, disable_audience_validation: bool = False, ) -> dict: decode_options = ( options if options is not None else self._get_decode_options( audience=audience, issuer=issuer, disable_audience_validation=disable_audience_validation, ) ) if isinstance(public_key, dict): public_key_obj = PyJWK.from_dict( self._get_jwk_from_public_key(public_key=public_key) ).key return jwt.decode( token, public_key_obj, # type: ignore algorithms=self.SUPPORTED_JWT_ALGORITHMS, options=decode_options, # type: ignore[arg-type] audience=audience, issuer=issuer, leeway=self.leeway, ) cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend()) key = cert.public_key().public_bytes( serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo, ) return jwt.decode( token, key, algorithms=self.SUPPORTED_JWT_ALGORITHMS, audience=audience, issuer=issuer, options=decode_options, # type: ignore[arg-type] leeway=self.leeway, ) async def _auth_jwt_with_issuer( self, token: str, issuer_config: JWTIssuerConfig, kid: Optional[str] ) -> dict: public_key = await self._get_public_key_from_jwks_url( jwks_url=self._get_jwks_url_for_issuer(issuer_config=issuer_config), kid=kid, ) try: payload = self._decode_jwt_with_public_key( token=token, public_key=public_key, audience=issuer_config.audience, issuer=issuer_config.issuer, disable_audience_validation=issuer_config.disable_audience_validation, ) except jwt.ExpiredSignatureError: raise ProxyException( message="Token Expired", type=ProxyErrorTypes.expired_key, param=None, code=status.HTTP_401_UNAUTHORIZED, ) except Exception as e: raise Exception(f"Validation fails: {str(e)}") return self._apply_issuer_claim_mappings( token=payload, issuer_config=issuer_config, ) async def auth_jwt(self, token: str) -> dict: header = jwt.get_unverified_header(token) verbose_proxy_logger.debug("header: %s", header) kid = header.get("kid", None) issuer_config = self._get_configured_issuer(token=token) if issuer_config is not None: return await self._auth_jwt_with_issuer( token=token, issuer_config=issuer_config, kid=kid, ) decode_kwargs = self._build_decode_kwargs() public_key = await self.get_public_key(kid=kid) if public_key is not None: try: payload = self._decode_jwt_with_public_key( token=token, public_key=public_key, audience=decode_kwargs["audience"], issuer=decode_kwargs["issuer"], options=decode_kwargs["options"], ) return { k: v for k, v in payload.items() if k not in self.LITELLM_INTERNAL_CLAIMS } except jwt.ExpiredSignatureError: raise ProxyException( message="Token Expired", type=ProxyErrorTypes.expired_key, param=None, code=status.HTTP_401_UNAUTHORIZED, ) except Exception as e: raise Exception(f"Validation fails: {str(e)}") raise Exception("Invalid JWT Submitted") async def close(self): await self.http_handler.close() class JWTAuthManager: """Manages JWT authentication and authorization operations""" @staticmethod def can_rbac_role_call_route( rbac_role: RBAC_ROLES, general_settings: dict, route: str, ) -> Literal[True]: """ Checks if user is allowed to access the route, based on their role. """ role_based_routes = get_role_based_routes( rbac_role=rbac_role, general_settings=general_settings ) if role_based_routes is None or route is None: return True is_allowed = _allowed_routes_check( user_route=route, allowed_routes=role_based_routes, ) if not is_allowed: raise HTTPException( status_code=403, detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}", ) return True @staticmethod def can_rbac_role_call_model( rbac_role: RBAC_ROLES, general_settings: dict, model: Optional[str], ) -> Literal[True]: """ Checks if user is allowed to access the model, based on their role. """ role_based_models = get_role_based_models( rbac_role=rbac_role, general_settings=general_settings ) if role_based_models is None or model is None: return True if model not in role_based_models: raise HTTPException( status_code=403, detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", ) return True @staticmethod def check_scope_based_access( scope_mappings: List[ScopeMapping], scopes: List[str], request_data: dict, general_settings: dict, ) -> None: """ Check if scope allows access to the requested model """ if not scope_mappings: return None allowed_models = [] for sm in scope_mappings: if sm.scope in scopes and sm.models: allowed_models.extend(sm.models) requested_model = request_data.get("model") if not requested_model: return None if requested_model not in allowed_models: raise HTTPException( status_code=403, detail={ "error": "model={} not allowed. Allowed_models={}".format( requested_model, allowed_models ) }, ) return None @staticmethod async def check_rbac_role( jwt_handler: JWTHandler, jwt_valid_token: dict, general_settings: dict, request_data: dict, route: str, rbac_role: Optional[RBAC_ROLES], ) -> None: """Validate RBAC role and model access permissions""" if jwt_handler.litellm_jwtauth.enforce_rbac is True: if rbac_role is None: raise HTTPException( status_code=403, detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", ) JWTAuthManager.can_rbac_role_call_model( rbac_role=rbac_role, general_settings=general_settings, model=request_data.get("model"), ) JWTAuthManager.can_rbac_role_call_route( rbac_role=rbac_role, general_settings=general_settings, route=route, ) @staticmethod async def check_admin_access( jwt_handler: JWTHandler, scopes: list, route: str, user_id: Optional[str], org_id: Optional[str], api_key: str, jwt_valid_token: Optional[dict] = None, ) -> Optional[JWTAuthBuilderResult]: """Check admin status and route access permissions""" if not jwt_handler.is_admin(scopes=scopes): return None is_allowed = allowed_routes_check( user_role=LitellmUserRoles.PROXY_ADMIN, user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if not is_allowed: allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes actual_routes = get_actual_routes(allowed_routes=allowed_routes) raise Exception( f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" ) return JWTAuthBuilderResult( is_proxy_admin=True, team_object=None, user_object=None, end_user_object=None, org_object=None, token=api_key, team_id=None, user_id=user_id, end_user_id=None, org_id=org_id, team_membership=None, jwt_claims=jwt_valid_token or {}, ) @staticmethod async def find_and_validate_specific_team_id( jwt_handler: JWTHandler, jwt_valid_token: dict, prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: """Find and validate specific team ID from team_id_jwt_field or team_alias_jwt_field""" individual_team_id = jwt_handler.get_team_id( token=jwt_valid_token, default_value=None ) team_object: Optional[LiteLLM_TeamTable] = None # First try to get team by team_id if individual_team_id: try: team_object = await get_team_object( team_id=individual_team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, ) return individual_team_id, team_object except HTTPException as e: if ( e.status_code != 404 or not jwt_handler.litellm_jwtauth.team_claim_fallback ): raise # Claim doesn't map to a known team — defer to fallback. verbose_proxy_logger.debug( "JWT team_id claim '%s' did not resolve to a team: %s", individual_team_id, e.detail, ) return None, None # If no team_id found, try to resolve via team_alias_jwt_field team_alias = jwt_handler.get_team_alias( token=jwt_valid_token, default_value=None ) if team_alias: verbose_proxy_logger.info( f"JWT Auth: Resolving team by alias: '{team_alias}'" ) team_object = await get_team_object_by_alias( team_alias=team_alias, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) if team_object: individual_team_id = team_object.team_id verbose_proxy_logger.info( f"JWT Auth: Resolved team_alias='{team_alias}' to team_id='{individual_team_id}'" ) return individual_team_id, team_object # Check if team is required but not found if jwt_handler.is_required_team_id() is True: team_id_field = jwt_handler.litellm_jwtauth.team_id_jwt_field team_alias_field = jwt_handler.litellm_jwtauth.team_alias_jwt_field hint = "" if team_id_field: # "roles.0" — dot-notation numeric indexing is not supported if "." in team_id_field: parts = team_id_field.rsplit(".", 1) if parts[-1].isdigit(): base_field = parts[0] hint = ( f" Hint: dot-notation array indexing (e.g. '{team_id_field}') is not " f"supported. Use '{base_field}' instead — LiteLLM automatically " f"uses the first element when the field value is a list." ) # "roles[0]" — bracket-notation indexing is also not supported in get_nested_value elif "[" in team_id_field and team_id_field.endswith("]"): m = re.match(r"^(\w+)\[(\d+)\]$", team_id_field) if m: base_field = m.group(1) hint = ( f" Hint: array indexing (e.g. '{team_id_field}') is not supported " f"in team_id_jwt_field. Use '{base_field}' instead — LiteLLM " f"automatically uses the first element when the field value is a list." ) raise Exception( f"No team found in token. Checked team_id field '{team_id_field}' and team_alias field '{team_alias_field}'.{hint}" ) return individual_team_id, team_object @staticmethod def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]: """Get combined team IDs from groups and individual team_id""" team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token) all_team_ids = set(team_ids_from_groups) return all_team_ids @staticmethod def _team_has_passthrough_route_access( team_object: Optional[LiteLLM_TeamTable], route: str, request_method: Optional[str] = None, ) -> bool: normalized_request_method = ( request_method.upper() if isinstance(request_method, str) else None ) if not RouteChecks.is_auth_enforced_pass_through_route( route=route, method=normalized_request_method, ): return True # JWT team selection is team-scoped; key metadata is not available here, # so passthrough access is granted only by the selected team's metadata. return RouteChecks.check_passthrough_route_access( route=route, user_api_key_dict=UserAPIKeyAuth( team_metadata=(team_object.metadata or {}) if team_object else {} ), ) @staticmethod def _raise_team_passthrough_route_denial(route: str) -> None: raise HTTPException( status_code=403, detail=( f"Team not allowed to access passthrough route {route}. " "Configure `allowed_passthrough_routes` on the team." ), ) @staticmethod async def find_team_with_model_access( team_ids: Set[str], requested_model: Optional[str], route: str, jwt_handler: JWTHandler, prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, request_method: Optional[str] = None, ) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: """Find first team with access to the requested model""" from litellm.proxy.proxy_server import llm_router denied_auth_enforced_pass_through_route = False if not team_ids: if jwt_handler.litellm_jwtauth.enforce_team_based_model_access: raise HTTPException( status_code=403, detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.", ) return None, None any_claim_team_resolved = False for team_id in team_ids: try: team_object = await get_team_object( team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) if team_object is not None: any_claim_team_resolved = True if team_object and team_object.models is not None: team_models = team_object.models if isinstance(team_models, list) and ( not requested_model or await can_team_access_model( model=requested_model, team_object=team_object, llm_router=llm_router, team_model_aliases=None, ) ): is_allowed = allowed_routes_check( user_role=LitellmUserRoles.TEAM, user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if ( is_allowed and not JWTAuthManager._team_has_passthrough_route_access( team_object=team_object, route=route, request_method=request_method, ) ): is_allowed = False denied_auth_enforced_pass_through_route = True verbose_proxy_logger.debug( f"JWT team route check: team_id={team_id}, route={route}, is_allowed={is_allowed}" ) if is_allowed: return team_id, team_object except Exception: continue if denied_auth_enforced_pass_through_route: JWTAuthManager._raise_team_passthrough_route_denial(route=route) if requested_model and ( any_claim_team_resolved or not jwt_handler.litellm_jwtauth.team_claim_fallback ): # Claim resolved but no model access, or fallback disabled — deny. raise HTTPException( status_code=403, detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.", ) # No claim team resolved and fallback enabled — defer to fallback. return None, None @staticmethod async def get_user_info( jwt_handler: JWTHandler, jwt_valid_token: dict, ) -> Tuple[Optional[str], Optional[str], Optional[bool]]: """Get user email and validation status""" user_email = jwt_handler.get_user_email( token=jwt_valid_token, default_value=None ) valid_user_email = None if jwt_handler.is_enforced_email_domain(): valid_user_email = ( False if user_email is None else jwt_handler.is_allowed_domain(user_email=user_email) ) user_id = jwt_handler.get_user_id( token=jwt_valid_token, default_value=user_email ) return user_id, user_email, valid_user_email @staticmethod def _canonical_user_id_from_db( user_id: Optional[str], user_object: Optional[LiteLLM_UserTable], ) -> Optional[str]: """Id used for spend / team-membership attribution. JWT claim (often email) is only a lookup key. If fuzzy match in ``get_user_object`` resolved a legacy row with a different ``user_id``, use that row's id; otherwise keep the claim. GH #26789. """ if user_object is not None and user_object.user_id: return user_object.user_id return user_id @staticmethod async def get_objects( user_id: Optional[str], user_email: Optional[str], org_id: Optional[str], end_user_id: Optional[str], team_id: Optional[str], valid_user_email: Optional[bool], jwt_handler: JWTHandler, prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, route: str, org_alias: Optional[str] = None, ) -> Tuple[ Optional[LiteLLM_UserTable], Optional[LiteLLM_OrganizationTable], Optional[LiteLLM_EndUserTable], Optional[LiteLLM_TeamMembership], Optional[str], ]: """Get user, org, end-user, and team-membership objects. Returns ``(..., effective_user_id)``: JWT claim unless fuzzy lookup matched a legacy row (GH #26789). """ # Get org object - first try by ID, then by alias org_object: Optional[LiteLLM_OrganizationTable] = None if org_id: org_object = ( await get_org_object( org_id=org_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) if org_id else None ) elif org_alias: verbose_proxy_logger.info( f"JWT Auth: Resolving org by alias: '{org_alias}'" ) org_object = await get_org_object_by_alias( org_alias=org_alias, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) if org_object: verbose_proxy_logger.info( f"JWT Auth: Resolved org_alias='{org_alias}' to org_id='{org_object.organization_id}'" ) # Check if email domain is allowed before attempting to get/create user if valid_user_email is False: raise ProxyException( message=f"Email domain not allowed. User email: {user_email}. Allowed domain: {jwt_handler.litellm_jwtauth.user_allowed_email_domain}", type=ProxyErrorTypes.auth_error, param="user_email", code=403, ) user_object: Optional[LiteLLM_UserTable] = None if user_id: user_object = ( await get_user_object( user_id=user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, user_id_upsert=jwt_handler.is_upsert_user_id( valid_user_email=valid_user_email ), parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, user_email=user_email, sso_user_id=user_id, ) if user_id else None ) end_user_object: Optional[LiteLLM_EndUserTable] = None if end_user_id: end_user_object = ( await get_end_user_object( end_user_id=end_user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, route=route, ) if end_user_id else None ) # Rebind to resolved DB user_id for team_membership + auth_builder (GH #26789). effective_user_id = JWTAuthManager._canonical_user_id_from_db( user_id=user_id, user_object=user_object ) if effective_user_id != user_id: verbose_proxy_logger.debug( "JWT Auth: rebinding user_id %r -> DB user_id %r (email/sso match)", user_id, effective_user_id, ) user_id = effective_user_id team_membership_object: Optional[LiteLLM_TeamMembership] = None if user_id and team_id: team_membership_object = ( await get_team_membership( user_id=user_id, team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) if user_id and team_id else None ) return ( user_object, org_object, end_user_object, team_membership_object, user_id, ) @staticmethod def validate_object_id( user_id: Optional[str], team_id: Optional[str], enforce_rbac: bool, is_proxy_admin: bool, ) -> Literal[True]: """If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking""" if enforce_rbac and not is_proxy_admin and not user_id and not team_id: raise HTTPException( status_code=403, detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", ) return True @staticmethod def get_team_id_from_header( request_headers: Optional[dict], allowed_team_ids: Set[str], ) -> Optional[str]: """ Extract team_id from x-litellm-team-id header if present. Validates that the team is in the user's allowed teams from JWT. Args: request_headers: Dictionary of request headers allowed_team_ids: Set of team IDs the user is allowed to access (from JWT) Returns: The team_id from header if valid, None otherwise Raises: HTTPException: If team_id is provided but not in allowed_team_ids """ if not request_headers: return None # Normalize headers to lowercase for case-insensitive lookup normalized_headers = {k.lower(): v for k, v in request_headers.items()} header_team_id = normalized_headers.get("x-litellm-team-id") if not header_team_id: return None # Validate that the team_id is in the allowed teams if header_team_id not in allowed_team_ids: raise HTTPException( status_code=403, detail=f"Team '{header_team_id}' from x-litellm-team-id header is not in your JWT's allowed teams. Allowed teams: {list(allowed_team_ids)}", ) verbose_proxy_logger.debug( f"Using team_id from x-litellm-team-id header: {header_team_id}" ) return header_team_id @staticmethod async def map_user_to_teams( user_object: Optional[LiteLLM_UserTable], team_object: Optional[LiteLLM_TeamTable], ): """ Map user to teams. - If user is not in team, add them to the team - If user is in team, do nothing """ from litellm.proxy.management_endpoints.team_endpoints import team_member_add if not user_object: return None if not team_object: return None # check if user is in team for member in team_object.members_with_roles: if member.user_id and member.user_id == user_object.user_id: return None data = TeamMemberAddRequest( member=Member( user_id=user_object.user_id, role="user", # [TODO]: allow controlling role within team based on jwt token ), team_id=team_object.team_id, ) # add user to team - make this non-blocking to avoid authentication failures try: await team_member_add( data=data, user_api_key_dict=UserAPIKeyAuth( user_role=LitellmUserRoles.PROXY_ADMIN ), # [TODO]: expose an internal service role, for better tracking ) verbose_proxy_logger.debug( f"Successfully added user {user_object.user_id} to team {team_object.team_id}" ) except ProxyException as e: if e.type == ProxyErrorTypes.team_member_already_in_team: verbose_proxy_logger.debug( f"User {user_object.user_id} is already a member of team {team_object.team_id}" ) return None else: raise e return None @staticmethod async def sync_user_role_and_teams( jwt_handler: JWTHandler, jwt_valid_token: dict, user_object: Optional[LiteLLM_UserTable], prisma_client: Optional[PrismaClient], user_api_key_cache: Optional[UserApiKeyCache] = None, ) -> None: """ Sync user role and team memberships with JWT claims The goal of this method is to ensure: 1. The user role on LiteLLM DB is in sync with the IDP provider role 2. The user is a member of the teams specified in the JWT token This method is only called if sync_user_role_and_teams is set to True in the JWT config. """ if not jwt_handler.litellm_jwtauth.sync_user_role_and_teams: return None if user_object is None or prisma_client is None: return None # Update user role new_role = jwt_handler.map_jwt_role_to_litellm_role(jwt_valid_token) if new_role and user_object.user_role != new_role.value: await UserRepository(prisma_client).table.update( where={"user_id": user_object.user_id}, data={"user_role": new_role.value}, ) user_object.user_role = new_role.value if user_api_key_cache is not None: await user_api_key_cache.async_set_cache( key=user_object.user_id, value=user_object, model_type=LiteLLM_UserTable, ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, ) # Sync team memberships jwt_team_ids = set(jwt_handler.get_team_ids_from_jwt(jwt_valid_token)) existing_teams = set(user_object.teams or []) teams_to_add = jwt_team_ids - existing_teams teams_to_remove = existing_teams - jwt_team_ids if teams_to_add or teams_to_remove: from litellm.proxy.management_endpoints.scim.scim_v2 import ( patch_team_membership, ) await patch_team_membership( user_id=user_object.user_id, teams_ids_to_add_user_to=list(teams_to_add), teams_ids_to_remove_user_from=list(teams_to_remove), ) user_object.teams = list(jwt_team_ids) if user_api_key_cache is not None: await user_api_key_cache.async_set_cache( key=user_object.user_id, value=user_object, model_type=LiteLLM_UserTable, ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, ) return None @staticmethod async def _attach_team_from_header_for_admin( admin_result: JWTAuthBuilderResult, route: str, request_headers: Optional[dict], jwt_handler: JWTHandler, prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, ) -> None: """Attach team context from x-litellm-team-id to an admin result. Only applies on LLM API routes so team TPM/RPM limits and attribution are enforced when admins act on behalf of a team. Admin management routes ignore the header to preserve pre-existing bypass behavior. """ header_team_id = ( request_headers.get("x-litellm-team-id") if request_headers else None ) if not header_team_id or not RouteChecks.is_llm_api_route(route=route): return try: team_object = await get_team_object( team_id=header_team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, ) except Exception as e: # Fall back to pre-PR admin behavior: honor the admin's # authorization but skip team attribution/limits for this # request. Log so operators can find the misconfigured caller. verbose_proxy_logger.warning( "admin x-litellm-team-id=%r on route=%s could not be resolved (%s); " "proceeding with admin access, team context NOT attached.", header_team_id, route, e, ) return admin_result["team_id"] = header_team_id admin_result["team_object"] = team_object @staticmethod async def _resolve_single_team_fallback( user_object: Optional[LiteLLM_UserTable], user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, team_id_upsert: Optional[bool], ) -> tuple: """ If JWT did not resolve team_id, but the user belongs to exactly one team in LiteLLM, load that team (and membership when user_id is set) so that spend / metadata can be attributed correctly. Returns (team_id, team_object, team_membership_object). Any DB error is debug-logged and the tuple is (None, None, None) — no exception ever propagates from this helper. """ if user_object is None or not user_object.teams or len(user_object.teams) != 1: return None, None, None _tid = user_object.teams[0] try: team_row = await get_team_object( team_id=_tid, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, team_id_upsert=team_id_upsert, ) if team_row is None: return None, None, None if not user_id: return _tid, team_row, None team_membership = await get_team_membership( user_id=user_id, team_id=_tid, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) return _tid, team_row, team_membership except Exception: verbose_proxy_logger.debug( "JWT single-team fallback error, skipping. team_id=%s", _tid, exc_info=True, ) return None, None, None @staticmethod async def auth_builder( api_key: str, jwt_handler: JWTHandler, request_data: dict, general_settings: dict, route: str, prisma_client: Optional[PrismaClient], user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, request_headers: Optional[dict] = None, request_method: Optional[str] = None, ) -> JWTAuthBuilderResult: """Main authentication and authorization builder""" # Check if OIDC UserInfo endpoint is enabled, but fall back to standard # JWT auth if the token itself is a well-formed JWT (3-part structure). if ( jwt_handler.litellm_jwtauth.oidc_userinfo_enabled and not jwt_handler.is_jwt(token=api_key) ): verbose_proxy_logger.debug( "OIDC UserInfo is enabled. Fetching user info from UserInfo endpoint." ) # Use the access token to fetch user info from OIDC UserInfo endpoint jwt_valid_token: dict = await jwt_handler.get_oidc_userinfo(token=api_key) else: # Default behavior: decode and validate the JWT token jwt_valid_token = await jwt_handler.auth_jwt(token=api_key) # Check custom validate if jwt_handler.litellm_jwtauth.custom_validate: if not jwt_handler.litellm_jwtauth.custom_validate(jwt_valid_token): raise HTTPException( status_code=403, detail="Invalid JWT token", ) # Check RBAC rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token) await JWTAuthManager.check_rbac_role( jwt_handler, jwt_valid_token, general_settings, request_data, route, rbac_role, ) # Check Scope Based Access scopes = jwt_handler.get_scopes(token=jwt_valid_token) if ( jwt_handler.litellm_jwtauth.enforce_scope_based_access and jwt_handler.litellm_jwtauth.scope_mappings ): JWTAuthManager.check_scope_based_access( scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings, scopes=scopes, request_data=request_data, general_settings=general_settings, ) object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) # Get basic user info user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( jwt_handler, jwt_valid_token ) # Get IDs org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None) end_user_id = jwt_handler.get_end_user_id( token=jwt_valid_token, default_value=None ) team_id: Optional[str] = None team_object: Optional[LiteLLM_TeamTable] = None object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) if rbac_role and object_id: if rbac_role == LitellmUserRoles.TEAM: team_id = object_id elif rbac_role == LitellmUserRoles.INTERNAL_USER: user_id = object_id # Check admin access admin_result = await JWTAuthManager.check_admin_access( jwt_handler, scopes, route, user_id, org_id, api_key, jwt_valid_token ) if admin_result: await JWTAuthManager._attach_team_from_header_for_admin( admin_result=admin_result, route=route, request_headers=request_headers, jwt_handler=jwt_handler, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) return admin_result # Get team with model access ## Check if team_id is specified via x-litellm-team-id header all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token) specific_team_id = jwt_handler.get_team_id( token=jwt_valid_token, default_value=None ) if specific_team_id: all_team_ids.add(specific_team_id) header_team_id = JWTAuthManager.get_team_id_from_header( request_headers=request_headers, allowed_team_ids=all_team_ids, ) if header_team_id: team_id = header_team_id team_object = await get_team_object( team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, ) elif not team_id: ## SPECIFIC TEAM ID ( team_id, team_object, ) = await JWTAuthManager.find_and_validate_specific_team_id( jwt_handler, jwt_valid_token, prisma_client, user_api_key_cache, parent_otel_span, proxy_logging_obj, ) if not team_object and not team_id: ## CHECK USER GROUP ACCESS team_id, team_object = await JWTAuthManager.find_team_with_model_access( team_ids=all_team_ids, requested_model=request_data.get("model"), route=route, request_method=request_method, jwt_handler=jwt_handler, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) # The RBAC role-claim path (rbac_role == TEAM) sets team_id without # loading team_object, so fetch it here before gating an auth-enforced # passthrough route on the team's allowed_passthrough_routes. if ( team_id and team_object is None and RouteChecks.is_auth_enforced_pass_through_route( route=route, method=( request_method.upper() if isinstance(request_method, str) else None ), ) ): team_object = await get_team_object( team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, ) if team_id and not JWTAuthManager._team_has_passthrough_route_access( team_object=team_object, route=route, request_method=request_method, ): JWTAuthManager._raise_team_passthrough_route_denial(route=route) # Extract alias fields for resolution (if configured) org_alias = jwt_handler.get_org_alias(token=jwt_valid_token, default_value=None) # get_objects returns effective_user_id for downstream spend attribution (GH #26789). ( user_object, org_object, end_user_object, team_membership_object, user_id, ) = await JWTAuthManager.get_objects( user_id=user_id, user_email=user_email, org_id=org_id, end_user_id=end_user_id, team_id=team_id, valid_user_email=valid_user_email, jwt_handler=jwt_handler, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, route=route, org_alias=org_alias, ) # Derive org_id from org_object if resolved by alias resolved_org_id = org_object.organization_id if org_object else org_id await JWTAuthManager.sync_user_role_and_teams( jwt_handler=jwt_handler, jwt_valid_token=jwt_valid_token, user_object=user_object, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, ) # If JWT did not resolve team_id, attempt single-team DB fallback. if team_id is None: team_id, team_object, team_membership_object = ( await JWTAuthManager._resolve_single_team_fallback( user_object=user_object, user_id=user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert, ) ) ## MAP USER TO TEAMS await JWTAuthManager.map_user_to_teams( user_object=user_object, team_object=team_object, ) # Validate that a valid rbac id is returned for spend tracking JWTAuthManager.validate_object_id( user_id=user_id, team_id=team_id, enforce_rbac=general_settings.get("enforce_rbac", False), is_proxy_admin=False, ) # check if user is proxy admin is_proxy_admin = bool( user_object and user_object.user_role == LitellmUserRoles.PROXY_ADMIN ) return JWTAuthBuilderResult( is_proxy_admin=is_proxy_admin, team_id=team_id, team_object=team_object, user_id=user_id, user_object=user_object, org_id=resolved_org_id, # Use resolved org_id (from alias lookup if applicable) org_object=org_object, end_user_id=end_user_id, end_user_object=end_user_object, token=api_key, team_membership=team_membership_object, jwt_claims=jwt_valid_token, )