Files
MoFin/venv/lib/python3.12/site-packages/openai/auth/_workload.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

310 lines
12 KiB
Python

from __future__ import annotations
import time
import threading
from typing import Any, Callable, TypedDict, cast
from pathlib import Path
from typing_extensions import Literal, NotRequired
import httpx
from .._exceptions import OAuthError, OpenAIError, SubjectTokenProviderError
from .._utils._sync import to_thread
TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
DEFAULT_TOKEN_EXCHANGE_URL = "https://auth.openai.com/oauth/token"
DEFAULT_REFRESH_BUFFER_SECONDS = 1200
SUBJECT_TOKEN_TYPES = {
"jwt": "urn:ietf:params:oauth:token-type:jwt",
"id": "urn:ietf:params:oauth:token-type:id_token",
}
class SubjectTokenProvider(TypedDict):
token_type: Literal["jwt", "id"]
get_token: Callable[[], str]
class WorkloadIdentity(TypedDict):
"""Identity provider resource id in WIFAPI."""
identity_provider_id: str
"""Service account id to bind the verified external identity to."""
service_account_id: str
"""The provider configuration for obtaining the subject token."""
provider: SubjectTokenProvider
"""Optional buffer time in seconds to refresh the OpenAI token before it expires. Defaults to 1200 seconds (20 minutes)."""
refresh_buffer_seconds: NotRequired[float]
def k8s_service_account_token_provider(
token_file_path: str | Path = "/var/run/secrets/kubernetes.io/serviceaccount/token",
) -> SubjectTokenProvider:
"""
Get a subject token provider for Kubernetes clusters with Workload Identity configured.
Cloud providers typically mount the subject token as a file in the container.
Args:
token_file_path: path to the mounted service account token file. Defaults to `/var/run/secrets/kubernetes.io/serviceaccount/token`.
"""
def get_token() -> str:
try:
with open(token_file_path, "r") as f:
token = f.read().strip()
if not token:
raise SubjectTokenProviderError(f"The token file at {token_file_path} is empty.")
return token
except Exception as e:
raise SubjectTokenProviderError(f"Failed to read the token file at {token_file_path}: {e}") from e
return {"token_type": "jwt", "get_token": get_token}
def azure_managed_identity_token_provider(
resource: str = "https://management.azure.com/",
*,
object_id: str | None = None,
client_id: str | None = None,
msi_res_id: str | None = None,
api_version: str = "2018-02-01",
timeout: float = 10.0,
http_client: httpx.Client | None = None,
) -> SubjectTokenProvider:
"""
Get a subject token provider for Azure Managed Identities.
See: https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
Args:
resource: the resource URI to request a token for. Defaults to `https://management.azure.com/` (Azure Resource Manager).
object_id: the object ID of the managed identity to use, when multiple are assigned.
client_id: the client ID of the managed identity to use, when multiple are assigned.
msi_res_id: the ARM resource ID of the managed identity to use, when multiple are assigned.
api_version: the Azure IMDS API version. Defaults to `2018-02-01`.
timeout: the request timeout in seconds. Defaults to 10.0.
http_client: optional httpx.Client instance to use for requests. If not provided, a new client will be created for each request.
"""
def get_token() -> str:
try:
url = "http://169.254.169.254/metadata/identity/oauth2/token"
params: dict[str, str] = {"api-version": api_version, "resource": resource}
if object_id is not None:
params["object_id"] = object_id
if client_id is not None:
params["client_id"] = client_id
if msi_res_id is not None:
params["msi_res_id"] = msi_res_id
if http_client is not None:
response = http_client.get(url, params=params, headers={"Metadata": "true"}, timeout=timeout)
else:
with httpx.Client() as client:
response = client.get(url, params=params, headers={"Metadata": "true"}, timeout=timeout)
if response.is_error:
raise SubjectTokenProviderError(
f"Failed to fetch Azure subject token from IMDS: HTTP {response.status_code}",
response=response,
)
data = response.json()
token = data.get("access_token")
if not token:
raise SubjectTokenProviderError(
"Azure IMDS response did not include an access_token", response=response
)
return cast(str, token)
except Exception as e:
raise SubjectTokenProviderError(f"Failed to fetch Azure subject token from IMDS: {e}") from e
return {"token_type": "jwt", "get_token": get_token}
def gcp_id_token_provider(
audience: str = "https://api.openai.com/v1",
*,
timeout: float = 10.0,
http_client: httpx.Client | None = None,
) -> SubjectTokenProvider:
"""
Get a subject token provider for GCP VM instances using the instance metadata server.
See: https://cloud.google.com/compute/docs/instances/verifying-instance-identity
Args:
audience: the unique URI agreed upon by both the instance and the system verifying
the instance's identity. Defaults to `https://api.openai.com/v1`.
timeout: the request timeout in seconds. Defaults to 10.0.
http_client: optional httpx.Client instance to use for requests. If not provided, a new client will be created for each request.
"""
def get_token() -> str:
try:
url = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity"
params = {"audience": audience}
if http_client is not None:
response = http_client.get(url, params=params, headers={"Metadata-Flavor": "Google"}, timeout=timeout)
else:
with httpx.Client() as client:
response = client.get(url, params=params, headers={"Metadata-Flavor": "Google"}, timeout=timeout)
if response.is_error:
raise SubjectTokenProviderError(
f"Failed to fetch GCP subject token from metadata server: HTTP {response.status_code}",
response=response,
)
token = response.text.strip()
if not token:
raise SubjectTokenProviderError("GCP metadata server returned an empty token", response=response)
return token
except Exception as e:
raise SubjectTokenProviderError(f"Failed to fetch GCP subject token from metadata server: {e}") from e
return {"token_type": "id", "get_token": get_token}
class WorkloadIdentityAuth:
def __init__(
self,
*,
workload_identity: WorkloadIdentity,
token_exchange_url: str = DEFAULT_TOKEN_EXCHANGE_URL,
):
self.workload_identity = workload_identity
self.token_exchange_url = token_exchange_url
self._cached_token: str | None = None
self._cached_token_expires_at_monotonic: float | None = None
self._cached_token_refresh_at_monotonic: float | None = None
self._refreshing: bool = False
self._lock = threading.Lock()
self._condition = threading.Condition(self._lock)
def get_token(self) -> str:
with self._lock:
while self._refreshing and self._token_unusable():
self._condition.wait()
if not self._token_unusable() and not self._needs_refresh():
return cast(str, self._cached_token)
if self._refreshing:
while self._refreshing:
self._condition.wait()
token = self._cached_token # type: ignore[unreachable]
if self._token_unusable():
raise RuntimeError("Token is unusable after refresh completed")
return cast(str, token)
self._refreshing = True
try:
self._perform_refresh()
with self._lock:
if self._token_unusable():
raise RuntimeError("Token is unusable after refresh completed")
return cast(str, self._cached_token)
finally:
with self._lock:
self._refreshing = False
self._condition.notify_all()
async def get_token_async(self) -> str:
return await to_thread(self.get_token)
def invalidate_token(self) -> None:
with self._lock:
self._cached_token = None
self._cached_token_expires_at_monotonic = None
self._cached_token_refresh_at_monotonic = None
def _perform_refresh(self) -> None:
token_data = self._fetch_token_from_exchange()
now = time.monotonic()
expires_in = token_data["expires_in"]
with self._lock:
self._cached_token = token_data["access_token"]
self._cached_token_expires_at_monotonic = now + expires_in
self._cached_token_refresh_at_monotonic = now + self._refresh_delay_seconds(expires_in)
def _fetch_token_from_exchange(self) -> dict[str, Any]:
subject_token = self._get_subject_token()
token_type = self.workload_identity["provider"]["token_type"]
subject_token_type = SUBJECT_TOKEN_TYPES.get(token_type)
if subject_token_type is None:
raise OpenAIError(
f"Unsupported token type: {token_type!r}. Supported types: {', '.join(SUBJECT_TOKEN_TYPES.keys())}"
)
with httpx.Client() as client:
response = client.post(
self.token_exchange_url,
json={
"grant_type": TOKEN_EXCHANGE_GRANT_TYPE,
"subject_token": subject_token,
"subject_token_type": subject_token_type,
"identity_provider_id": self.workload_identity["identity_provider_id"],
"service_account_id": self.workload_identity["service_account_id"],
},
timeout=10.0,
)
return self._handle_token_response(response)
def _handle_token_response(self, response: httpx.Response) -> dict[str, Any]:
try:
body = response.json() if response.content else None
except ValueError:
body = None
if response.status_code in (400, 401, 403):
raise OAuthError(response=response, body=body)
if response.is_success:
if body is None:
raise OpenAIError("Token exchange succeeded but response body was empty")
access_token = body.get("access_token")
expires_in = body.get("expires_in")
if not isinstance(access_token, str) or not access_token:
raise OpenAIError("Token exchange response did not include a valid access_token")
if not isinstance(expires_in, (int, float)):
raise OpenAIError("Token exchange response did not include a valid expires_in")
return {"access_token": access_token, "expires_in": float(expires_in)}
raise OpenAIError(
f"Token exchange failed with status {response.status_code}",
)
def _get_subject_token(self) -> str:
provider = self.workload_identity["provider"]
subject_token = provider["get_token"]()
if not subject_token:
raise OpenAIError("The workload identity provider returned an empty subject token")
return subject_token
def _token_unusable(self) -> bool:
return self._cached_token is None or self._token_expired()
def _token_expired(self) -> bool:
if self._cached_token_expires_at_monotonic is None:
return True
return time.monotonic() >= self._cached_token_expires_at_monotonic
def _needs_refresh(self) -> bool:
if self._cached_token_refresh_at_monotonic is None:
return False
return time.monotonic() >= self._cached_token_refresh_at_monotonic
def _refresh_delay_seconds(self, expires_in: float) -> float:
configured_buffer = self.workload_identity.get("refresh_buffer_seconds", DEFAULT_REFRESH_BUFFER_SECONDS)
effective_buffer = min(configured_buffer, expires_in / 2)
return max(expires_in - effective_buffer, 0.0)