fa45d8aa5f
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
Privoxy对node122:18003返回500,直连正常
310 lines
12 KiB
Python
310 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
import threading
|
|
from typing import Any, Callable, TypedDict, cast
|
|
from pathlib import Path
|
|
from typing_extensions import Literal, NotRequired
|
|
|
|
import httpx
|
|
|
|
from .._exceptions import OAuthError, OpenAIError, SubjectTokenProviderError
|
|
from .._utils._sync import to_thread
|
|
|
|
TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
|
|
DEFAULT_TOKEN_EXCHANGE_URL = "https://auth.openai.com/oauth/token"
|
|
DEFAULT_REFRESH_BUFFER_SECONDS = 1200
|
|
|
|
SUBJECT_TOKEN_TYPES = {
|
|
"jwt": "urn:ietf:params:oauth:token-type:jwt",
|
|
"id": "urn:ietf:params:oauth:token-type:id_token",
|
|
}
|
|
|
|
|
|
class SubjectTokenProvider(TypedDict):
|
|
token_type: Literal["jwt", "id"]
|
|
get_token: Callable[[], str]
|
|
|
|
|
|
class WorkloadIdentity(TypedDict):
|
|
"""Identity provider resource id in WIFAPI."""
|
|
|
|
identity_provider_id: str
|
|
|
|
"""Service account id to bind the verified external identity to."""
|
|
service_account_id: str
|
|
|
|
"""The provider configuration for obtaining the subject token."""
|
|
provider: SubjectTokenProvider
|
|
|
|
"""Optional buffer time in seconds to refresh the OpenAI token before it expires. Defaults to 1200 seconds (20 minutes)."""
|
|
refresh_buffer_seconds: NotRequired[float]
|
|
|
|
|
|
def k8s_service_account_token_provider(
|
|
token_file_path: str | Path = "/var/run/secrets/kubernetes.io/serviceaccount/token",
|
|
) -> SubjectTokenProvider:
|
|
"""
|
|
Get a subject token provider for Kubernetes clusters with Workload Identity configured.
|
|
|
|
Cloud providers typically mount the subject token as a file in the container.
|
|
|
|
Args:
|
|
token_file_path: path to the mounted service account token file. Defaults to `/var/run/secrets/kubernetes.io/serviceaccount/token`.
|
|
"""
|
|
|
|
def get_token() -> str:
|
|
try:
|
|
with open(token_file_path, "r") as f:
|
|
token = f.read().strip()
|
|
if not token:
|
|
raise SubjectTokenProviderError(f"The token file at {token_file_path} is empty.")
|
|
return token
|
|
except Exception as e:
|
|
raise SubjectTokenProviderError(f"Failed to read the token file at {token_file_path}: {e}") from e
|
|
|
|
return {"token_type": "jwt", "get_token": get_token}
|
|
|
|
|
|
def azure_managed_identity_token_provider(
|
|
resource: str = "https://management.azure.com/",
|
|
*,
|
|
object_id: str | None = None,
|
|
client_id: str | None = None,
|
|
msi_res_id: str | None = None,
|
|
api_version: str = "2018-02-01",
|
|
timeout: float = 10.0,
|
|
http_client: httpx.Client | None = None,
|
|
) -> SubjectTokenProvider:
|
|
"""
|
|
Get a subject token provider for Azure Managed Identities.
|
|
|
|
See: https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
|
|
|
|
Args:
|
|
resource: the resource URI to request a token for. Defaults to `https://management.azure.com/` (Azure Resource Manager).
|
|
object_id: the object ID of the managed identity to use, when multiple are assigned.
|
|
client_id: the client ID of the managed identity to use, when multiple are assigned.
|
|
msi_res_id: the ARM resource ID of the managed identity to use, when multiple are assigned.
|
|
api_version: the Azure IMDS API version. Defaults to `2018-02-01`.
|
|
timeout: the request timeout in seconds. Defaults to 10.0.
|
|
http_client: optional httpx.Client instance to use for requests. If not provided, a new client will be created for each request.
|
|
"""
|
|
|
|
def get_token() -> str:
|
|
try:
|
|
url = "http://169.254.169.254/metadata/identity/oauth2/token"
|
|
params: dict[str, str] = {"api-version": api_version, "resource": resource}
|
|
if object_id is not None:
|
|
params["object_id"] = object_id
|
|
if client_id is not None:
|
|
params["client_id"] = client_id
|
|
if msi_res_id is not None:
|
|
params["msi_res_id"] = msi_res_id
|
|
|
|
if http_client is not None:
|
|
response = http_client.get(url, params=params, headers={"Metadata": "true"}, timeout=timeout)
|
|
else:
|
|
with httpx.Client() as client:
|
|
response = client.get(url, params=params, headers={"Metadata": "true"}, timeout=timeout)
|
|
|
|
if response.is_error:
|
|
raise SubjectTokenProviderError(
|
|
f"Failed to fetch Azure subject token from IMDS: HTTP {response.status_code}",
|
|
response=response,
|
|
)
|
|
data = response.json()
|
|
token = data.get("access_token")
|
|
if not token:
|
|
raise SubjectTokenProviderError(
|
|
"Azure IMDS response did not include an access_token", response=response
|
|
)
|
|
return cast(str, token)
|
|
except Exception as e:
|
|
raise SubjectTokenProviderError(f"Failed to fetch Azure subject token from IMDS: {e}") from e
|
|
|
|
return {"token_type": "jwt", "get_token": get_token}
|
|
|
|
|
|
def gcp_id_token_provider(
|
|
audience: str = "https://api.openai.com/v1",
|
|
*,
|
|
timeout: float = 10.0,
|
|
http_client: httpx.Client | None = None,
|
|
) -> SubjectTokenProvider:
|
|
"""
|
|
Get a subject token provider for GCP VM instances using the instance metadata server.
|
|
|
|
See: https://cloud.google.com/compute/docs/instances/verifying-instance-identity
|
|
|
|
Args:
|
|
audience: the unique URI agreed upon by both the instance and the system verifying
|
|
the instance's identity. Defaults to `https://api.openai.com/v1`.
|
|
timeout: the request timeout in seconds. Defaults to 10.0.
|
|
http_client: optional httpx.Client instance to use for requests. If not provided, a new client will be created for each request.
|
|
"""
|
|
|
|
def get_token() -> str:
|
|
try:
|
|
url = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity"
|
|
params = {"audience": audience}
|
|
|
|
if http_client is not None:
|
|
response = http_client.get(url, params=params, headers={"Metadata-Flavor": "Google"}, timeout=timeout)
|
|
else:
|
|
with httpx.Client() as client:
|
|
response = client.get(url, params=params, headers={"Metadata-Flavor": "Google"}, timeout=timeout)
|
|
|
|
if response.is_error:
|
|
raise SubjectTokenProviderError(
|
|
f"Failed to fetch GCP subject token from metadata server: HTTP {response.status_code}",
|
|
response=response,
|
|
)
|
|
token = response.text.strip()
|
|
if not token:
|
|
raise SubjectTokenProviderError("GCP metadata server returned an empty token", response=response)
|
|
return token
|
|
except Exception as e:
|
|
raise SubjectTokenProviderError(f"Failed to fetch GCP subject token from metadata server: {e}") from e
|
|
|
|
return {"token_type": "id", "get_token": get_token}
|
|
|
|
|
|
class WorkloadIdentityAuth:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
workload_identity: WorkloadIdentity,
|
|
token_exchange_url: str = DEFAULT_TOKEN_EXCHANGE_URL,
|
|
):
|
|
self.workload_identity = workload_identity
|
|
self.token_exchange_url = token_exchange_url
|
|
|
|
self._cached_token: str | None = None
|
|
self._cached_token_expires_at_monotonic: float | None = None
|
|
self._cached_token_refresh_at_monotonic: float | None = None
|
|
self._refreshing: bool = False
|
|
self._lock = threading.Lock()
|
|
self._condition = threading.Condition(self._lock)
|
|
|
|
def get_token(self) -> str:
|
|
with self._lock:
|
|
while self._refreshing and self._token_unusable():
|
|
self._condition.wait()
|
|
|
|
if not self._token_unusable() and not self._needs_refresh():
|
|
return cast(str, self._cached_token)
|
|
|
|
if self._refreshing:
|
|
while self._refreshing:
|
|
self._condition.wait()
|
|
token = self._cached_token # type: ignore[unreachable]
|
|
if self._token_unusable():
|
|
raise RuntimeError("Token is unusable after refresh completed")
|
|
return cast(str, token)
|
|
|
|
self._refreshing = True
|
|
|
|
try:
|
|
self._perform_refresh()
|
|
with self._lock:
|
|
if self._token_unusable():
|
|
raise RuntimeError("Token is unusable after refresh completed")
|
|
return cast(str, self._cached_token)
|
|
finally:
|
|
with self._lock:
|
|
self._refreshing = False
|
|
self._condition.notify_all()
|
|
|
|
async def get_token_async(self) -> str:
|
|
return await to_thread(self.get_token)
|
|
|
|
def invalidate_token(self) -> None:
|
|
with self._lock:
|
|
self._cached_token = None
|
|
self._cached_token_expires_at_monotonic = None
|
|
self._cached_token_refresh_at_monotonic = None
|
|
|
|
def _perform_refresh(self) -> None:
|
|
token_data = self._fetch_token_from_exchange()
|
|
now = time.monotonic()
|
|
expires_in = token_data["expires_in"]
|
|
|
|
with self._lock:
|
|
self._cached_token = token_data["access_token"]
|
|
self._cached_token_expires_at_monotonic = now + expires_in
|
|
self._cached_token_refresh_at_monotonic = now + self._refresh_delay_seconds(expires_in)
|
|
|
|
def _fetch_token_from_exchange(self) -> dict[str, Any]:
|
|
subject_token = self._get_subject_token()
|
|
|
|
token_type = self.workload_identity["provider"]["token_type"]
|
|
subject_token_type = SUBJECT_TOKEN_TYPES.get(token_type)
|
|
if subject_token_type is None:
|
|
raise OpenAIError(
|
|
f"Unsupported token type: {token_type!r}. Supported types: {', '.join(SUBJECT_TOKEN_TYPES.keys())}"
|
|
)
|
|
|
|
with httpx.Client() as client:
|
|
response = client.post(
|
|
self.token_exchange_url,
|
|
json={
|
|
"grant_type": TOKEN_EXCHANGE_GRANT_TYPE,
|
|
"subject_token": subject_token,
|
|
"subject_token_type": subject_token_type,
|
|
"identity_provider_id": self.workload_identity["identity_provider_id"],
|
|
"service_account_id": self.workload_identity["service_account_id"],
|
|
},
|
|
timeout=10.0,
|
|
)
|
|
return self._handle_token_response(response)
|
|
|
|
def _handle_token_response(self, response: httpx.Response) -> dict[str, Any]:
|
|
try:
|
|
body = response.json() if response.content else None
|
|
except ValueError:
|
|
body = None
|
|
|
|
if response.status_code in (400, 401, 403):
|
|
raise OAuthError(response=response, body=body)
|
|
|
|
if response.is_success:
|
|
if body is None:
|
|
raise OpenAIError("Token exchange succeeded but response body was empty")
|
|
access_token = body.get("access_token")
|
|
expires_in = body.get("expires_in")
|
|
if not isinstance(access_token, str) or not access_token:
|
|
raise OpenAIError("Token exchange response did not include a valid access_token")
|
|
if not isinstance(expires_in, (int, float)):
|
|
raise OpenAIError("Token exchange response did not include a valid expires_in")
|
|
return {"access_token": access_token, "expires_in": float(expires_in)}
|
|
|
|
raise OpenAIError(
|
|
f"Token exchange failed with status {response.status_code}",
|
|
)
|
|
|
|
def _get_subject_token(self) -> str:
|
|
provider = self.workload_identity["provider"]
|
|
subject_token = provider["get_token"]()
|
|
if not subject_token:
|
|
raise OpenAIError("The workload identity provider returned an empty subject token")
|
|
return subject_token
|
|
|
|
def _token_unusable(self) -> bool:
|
|
return self._cached_token is None or self._token_expired()
|
|
|
|
def _token_expired(self) -> bool:
|
|
if self._cached_token_expires_at_monotonic is None:
|
|
return True
|
|
return time.monotonic() >= self._cached_token_expires_at_monotonic
|
|
|
|
def _needs_refresh(self) -> bool:
|
|
if self._cached_token_refresh_at_monotonic is None:
|
|
return False
|
|
return time.monotonic() >= self._cached_token_refresh_at_monotonic
|
|
|
|
def _refresh_delay_seconds(self, expires_in: float) -> float:
|
|
configured_buffer = self.workload_identity.get("refresh_buffer_seconds", DEFAULT_REFRESH_BUFFER_SECONDS)
|
|
effective_buffer = min(configured_buffer, expires_in / 2)
|
|
return max(expires_in - effective_buffer, 0.0)
|