Files
MoFin/venv/lib/python3.12/site-packages/litellm/llms/sap/credentials.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

516 lines
17 KiB
Python

from __future__ import annotations
from typing import Any, Callable, Dict, Final, List, Optional, Sequence, Tuple, Union
from datetime import datetime, timedelta, timezone
from threading import Lock
from pathlib import Path
from dataclasses import dataclass
import json
import os
import tempfile
import httpx
from litellm.llms.custom_httpx.http_handler import _get_httpx_client, HTTPHandler
from litellm._logging import verbose_logger
import litellm
AUTH_ENDPOINT_SUFFIX = "/oauth/token"
CONFIG_FILE_ENV_VAR = "AICORE_CONFIG"
HOME_PATH_ENV_VAR = "AICORE_HOME"
PROFILE_ENV_VAR = "AICORE_PROFILE"
VCAP_SERVICES_ENV_VAR = "VCAP_SERVICES"
VCAP_AICORE_SERVICE_NAME = "aicore"
SERVICE_KEY_ENV_VAR = "AICORE_SERVICE_KEY"
DEFAULT_HOME_PATH = os.path.join(os.path.expanduser("~"), ".aicore")
def _get_home() -> str:
return os.getenv(HOME_PATH_ENV_VAR, DEFAULT_HOME_PATH)
def _get_nested(d: Union[Dict[str, Any], str], path: Sequence[str]) -> Any:
cur: Any = d
if isinstance(cur, str):
# This shouldn't happen if service keys are pre-parsed correctly
try:
cur = json.loads(cur)
except json.JSONDecodeError:
verbose_logger.warning(
"SAP service key or VCAP service is a string but not valid JSON."
)
return None
for k in path:
if not isinstance(cur, dict):
verbose_logger.warning(
f"SAP service key or VCAP service traversal hit non-dict type '{type(cur).__name__}' at key '{k}'."
)
return None
if k not in cur:
return None
cur = cur[k]
return cur
def _load_json_env(var_name: str) -> Optional[Dict[str, Any]]:
raw = os.environ.get(var_name)
if not raw:
return None
try:
return json.loads(raw)
except json.JSONDecodeError:
return None
def _str_or_none(value) -> Optional[str]:
try:
return str(value) if value is not None else None
except Exception:
return None
def _load_vcap() -> Dict[str, Any]:
return _load_json_env(VCAP_SERVICES_ENV_VAR) or {}
def _get_vcap_service(label: str) -> Optional[Dict[str, Any]]:
for services in _load_vcap().values():
for svc in services:
if svc.get("label") == label:
return svc
return None
@dataclass
class Source:
name: str
get: Callable[[CredentialsValue], Optional[str]]
@dataclass(frozen=True)
class CredentialsValue:
name: str
vcap_key: Optional[Tuple[str, ...]] = None
default: Optional[str] = None
transform_fn: Optional[Callable[[str], str]] = None
CREDENTIAL_VALUES: Final[List[CredentialsValue]] = [
CredentialsValue("client_id", ("clientid",)),
CredentialsValue("client_secret", ("clientsecret",)),
CredentialsValue(
"auth_url",
("url",),
transform_fn=lambda url: url.rstrip("/")
+ ("" if url.endswith(AUTH_ENDPOINT_SUFFIX) else AUTH_ENDPOINT_SUFFIX),
),
CredentialsValue(
"base_url",
("serviceurls", "AI_API_URL"),
transform_fn=lambda url: url.rstrip("/")
+ ("" if url.endswith("/v2") else "/v2"),
),
CredentialsValue(
"cert_url",
("certurl",),
transform_fn=lambda url: url.rstrip("/")
+ ("" if url.endswith(AUTH_ENDPOINT_SUFFIX) else AUTH_ENDPOINT_SUFFIX),
),
# file paths (kept for config compatibility)
CredentialsValue("cert_file_path"),
CredentialsValue("key_file_path"),
# inline PEMs from VCAP
CredentialsValue(
"cert_str", ("certificate",), transform_fn=lambda s: s.replace("\\n", "\n")
),
CredentialsValue(
"key_str", ("key",), transform_fn=lambda s: s.replace("\\n", "\n")
),
]
def init_conf(profile: Optional[str] = None) -> Dict[str, Any]:
"""
Loads config JSON from:
1) $AICORE_CONFIG if set, otherwise
2) $AICORE_HOME/config.json (or config_<profile>.json when profile is given/not default)
Returns {} when nothing is found.
"""
home = Path(_get_home())
profile = profile or os.environ.get(PROFILE_ENV_VAR)
cfg_env = os.getenv(CONFIG_FILE_ENV_VAR)
cfg_path = (
Path(cfg_env)
if cfg_env
else (
home
/ (
"config.json"
if profile in (None, "", "default")
else f"config_{profile}.json"
)
)
)
if cfg_path and cfg_path.exists():
try:
with cfg_path.open(encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
raise KeyError(f"{cfg_path} is not valid JSON. Please fix or remove it!")
# If an explicit non-default profile was requested but not found, raise.
if cfg_env or (profile not in (None, "", "default")):
raise FileNotFoundError(
f"Unable to locate profile config file at '{cfg_path}' in AICORE_HOME '{home}'"
)
return {}
def _env_name(name: str) -> str:
return f"AICORE_{name.upper()}"
def extract_credentials(source: Source) -> Dict[str, str]:
"""Extract all credentials from a source."""
credentials = {}
for cv in CREDENTIAL_VALUES:
value = source.get(cv)
if value is not None:
credentials[cv.name] = cv.transform_fn(value) if cv.transform_fn else value
return credentials
def resolve_credentials(sources: List[Source]) -> Dict[str, str]:
"""Extract credentials from the first source that has any defined."""
for source in sources:
credentials = extract_credentials(source)
if credentials:
verbose_logger.debug(f"Resolved SAP credentials from source {source.name}")
return credentials
raise ValueError("No credentials found in any source")
def resolve_resource_group(sources: List[Source]) -> Optional[str]:
"""Find resource_group from the first source that defines it."""
rg_cred = CredentialsValue("resource_group", default="default")
for source in sources:
value = source.get(rg_cred)
if value is not None:
verbose_logger.debug(
f"Resolved GEN AI Hub resource_group from source {source.name}"
)
return value
return rg_cred.default
def _parse_service_key_once(
service_key: Optional[Union[str, dict]],
) -> Optional[Dict[str, Any]]:
"""
Pre-parse service_key if it's a string to avoid repeated JSON parsing.
Returns None if parsing fails (other credential sources may still work).
"""
if service_key is None:
return None
if isinstance(service_key, dict):
return service_key
if isinstance(service_key, str):
try:
return json.loads(service_key)
except json.JSONDecodeError:
verbose_logger.warning(
"SAP service key is a string but not valid JSON. Skipping this source."
)
return None
verbose_logger.warning(
f"SAP service key has unexpected type '{type(service_key).__name__}'. Expected str or dict. Ignoring."
)
return None
def _resolve_credential_from_service_key(
service_key: Optional[Union[str, dict]], cv: CredentialsValue
) -> Optional[str]:
if service_key is None:
return None
val = _str_or_none(
_get_nested(
service_key, (("credentials",) + cv.vcap_key) if cv.vcap_key else (cv.name,)
)
)
if val is None:
return _str_or_none(
_get_nested(service_key, cv.vcap_key if cv.vcap_key else (cv.name,))
)
return val
def fetch_credentials(
service_key: Optional[Union[str, dict]] = None,
profile: Optional[str] = None,
**kwargs,
) -> Dict[str, str]:
"""
Resolution order (first-source-wins):
Sources are checked in this order:
kwargs
> service key
> env (AICORE_<NAME>)
> config (AICORE_<NAME> or plain <name>)
> vcap service key
> default
Important:
- Credentials are extracted from the FIRST source that provides any credential value.
- Values are NOT merged per key across sources. Except resource_group, which is merged.
Warning:
- This function does NOT validate the returned credentials just parsed it from the sources.
- Callers MUST explicitly call validate_credentials() on the returned dict
"""
config = init_conf(profile)
service_key = _parse_service_key_once(
service_key or litellm.sap_service_key or os.environ.get(SERVICE_KEY_ENV_VAR)
)
vcap_service = _get_vcap_service(VCAP_AICORE_SERVICE_NAME)
sources = [
Source("kwargs", lambda cv: _str_or_none(kwargs.get(cv.name))),
Source(
"service key",
lambda cv: _resolve_credential_from_service_key(service_key, cv),
),
Source(
"environment variables",
lambda cv: _str_or_none(os.environ.get(f"AICORE_{cv.name.upper()}")),
),
Source(
"config file",
lambda cv: _str_or_none(
config.get(f"AICORE_{cv.name.upper()}")
if config.get(f"AICORE_{cv.name.upper()}") is not None
else config.get(cv.name)
),
),
Source(
"VCAP service",
lambda cv: (
_str_or_none(
_get_nested(
vcap_service,
(("credentials",) + cv.vcap_key) if cv.vcap_key else (cv.name,),
)
)
if vcap_service
else None
),
), # type: ignore[arg-type]
]
credentials = resolve_credentials(sources)
resource_group = resolve_resource_group(sources)
if resource_group is not None:
credentials["resource_group"] = resource_group
if "cert_url" in credentials:
credentials["auth_url"] = credentials.pop("cert_url")
return credentials
def validate_credentials(
auth_url: Optional[str] = None,
base_url: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
cert_str: Optional[str] = None,
key_str: Optional[str] = None,
cert_file_path: Optional[str] = None,
key_file_path: Optional[str] = None,
) -> None:
"""
Validate SAP AI Core credentials for completeness and consistency.
Args:
auth_url: OAuth2 token endpoint URL (required)
base_url: SAP AI Core API base URL (required)
client_id: OAuth2 client ID (required)
client_secret: OAuth2 client secret (for secret-based auth)
cert_str: PEM-encoded certificate string (for cert-based auth)
key_str: PEM-encoded private key string (for cert-based auth)
cert_file_path: Path to certificate file (for file-based cert auth)
key_file_path: Path to private key file (for file-based cert auth)
Raises:
ValueError: If required fields are missing or authentication mode is ambiguous.
Note:
- This function does NOT validate resource_group (resolved separately).
- Exactly one authentication method must be provided:
* client_secret, OR
* (cert_str AND key_str), OR
* (cert_file_path AND key_file_path)
"""
if not auth_url or not client_id or not base_url:
raise ValueError(
"SAP AI Core credentials not found. "
"Please provide credentials by setting appropriate environment variables "
"(e.g. AICORE_CLIENT_ID, AICORE_CLIENT_SECRET, etc.)"
)
modes = [
bool(client_secret),
bool(cert_str) and bool(key_str),
bool(cert_file_path) and bool(key_file_path),
]
if sum(bool(m) for m in modes) != 1:
raise ValueError(
"SAP AI Core credentials are incomplete. "
"Invalid credentials: provide exactly one of client_secret, "
"(cert_str & key_str), or (cert_file_path & key_file_path)."
)
def _request_token(
client_id: str, auth_url: str, timeout: float, cert_pair=None, client_secret=None
) -> tuple[str, datetime]:
data = {"grant_type": "client_credentials", "client_id": client_id}
if client_secret:
data["client_secret"] = client_secret
resp: Optional[httpx.Response] = None
try:
if cert_pair:
with httpx.Client(cert=cert_pair) as raw_client:
handler = HTTPHandler(client=raw_client)
resp = handler.post(auth_url, data=data, timeout=timeout) # type: ignore[arg-type]
payload = resp.json()
else:
handler = _get_httpx_client()
resp = handler.post(auth_url, data=data, timeout=timeout) # type: ignore[arg-type]
payload = resp.json()
access_token = payload["access_token"]
expires_in = int(payload.get("expires_in", 3600))
expiry_date = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
return f"Bearer {access_token}", expiry_date
except Exception as e:
msg = resp.text if resp is not None else getattr(e, "text", str(e))
raise RuntimeError(f"Token request failed: {msg}") from e
def get_token_creator(
service_key: Optional[Union[str, dict]] = None,
profile: Optional[str] = None,
*,
timeout: float = 30.0,
expiry_buffer_minutes: int = 60,
**overrides,
) -> Tuple[Callable[[], str], str, str]:
"""
Creates a callable that fetches and caches an OAuth2 bearer token
using credentials from `fetch_credentials()`.
The callable:
- Automatically loads credentials via fetch_credentials(profile, **overrides)
- Fetches a new token only if expired or near expiry
- Caches token thread-safely with a configurable refresh buffer
Args:
profile: Optional AICore profile name
timeout: Timeout for HTTP requests
expiry_buffer_minutes: Refresh the token this many minutes before expiry
overrides: Any explicit credential overrides (client_id, client_secret, etc.)
Returns:
Callable[[], str]: function returning a valid "Bearer <token>" string.
"""
# Resolve credentials using your helper
credentials: Dict[str, str] = fetch_credentials(
service_key=service_key, profile=profile, **overrides
)
auth_url = credentials.get("auth_url")
base_url = credentials.get("base_url")
client_id = credentials.get("client_id")
client_secret = credentials.get("client_secret")
cert_str = credentials.get("cert_str")
key_str = credentials.get("key_str")
cert_file_path = credentials.get("cert_file_path")
key_file_path = credentials.get("key_file_path")
# Sanity check
validate_credentials(
auth_url,
base_url,
client_id,
client_secret,
cert_str,
key_str,
cert_file_path,
key_file_path,
)
lock = Lock()
token: Optional[str] = None
token_expiry: Optional[datetime] = None
def _fetch_token() -> tuple[str, datetime]:
# Case 1: secret-based auth
if client_secret:
return _request_token(
auth_url=auth_url, # type: ignore[arg-type]
client_id=client_id, # type: ignore[arg-type]
timeout=timeout,
client_secret=client_secret,
)
# Case 2: cert/key strings
if cert_str and key_str:
cert_str_fixed = cert_str.replace("\\n", "\n")
key_str_fixed = key_str.replace("\\n", "\n")
with tempfile.TemporaryDirectory() as tmp:
cert_path = os.path.join(tmp, "cert.pem")
key_path = os.path.join(tmp, "key.pem")
with open(cert_path, "w") as f:
f.write(cert_str_fixed)
with open(key_path, "w") as f:
f.write(key_str_fixed)
return _request_token(
auth_url=auth_url, # type: ignore[arg-type]
client_id=client_id, # type: ignore[arg-type]
timeout=timeout,
cert_pair=(cert_path, key_path),
)
# Case 3: file-based cert/key
if cert_file_path is not None and key_file_path is not None:
return _request_token(
auth_url=auth_url, # type: ignore[arg-type]
client_id=client_id, # type: ignore[arg-type]
timeout=timeout,
cert_pair=(cert_file_path, key_file_path),
)
# Defensive guard: should never reach here due to validate_credentials()
raise ValueError(
"Invalid authentication configuration: no valid credentials found. "
)
def get_token() -> str:
nonlocal token, token_expiry
with lock:
now = datetime.now(timezone.utc)
if (
token is None
or token_expiry is None
or token_expiry - now < timedelta(minutes=expiry_buffer_minutes)
):
token, token_expiry = _fetch_token()
return token
return get_token, credentials["base_url"], credentials["resource_group"]