fa45d8aa5f
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
Privoxy对node122:18003返回500,直连正常
741 lines
30 KiB
Python
741 lines
30 KiB
Python
"""
|
|
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import random
|
|
import signal
|
|
import subprocess
|
|
import time
|
|
import urllib
|
|
import urllib.parse
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Callable, Union
|
|
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.secret_managers.main import str_to_bool
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class IAMEndpoint:
|
|
"""Static parts of an RDS IAM-authenticated Postgres connection.
|
|
|
|
The IAM token rotates every ~15 minutes; everything else (host, port, user,
|
|
database name, schema) stays fixed. We capture the static fields once so
|
|
refresh just regenerates the token and reassembles the URL.
|
|
"""
|
|
|
|
host: str
|
|
port: str
|
|
user: str
|
|
name: str
|
|
schema: str | None = None
|
|
|
|
def build_url(self, token: str) -> str:
|
|
url = f"postgresql://{self.user}:{token}@{self.host}:{self.port}/{self.name}"
|
|
if self.schema:
|
|
url += f"?schema={self.schema}"
|
|
return url
|
|
|
|
|
|
def parse_iam_endpoint_from_url(url: str) -> IAMEndpoint:
|
|
"""Parse an IAMEndpoint from a Postgres URL.
|
|
|
|
Used so a reader URL can drive its own IAM refresh without requiring
|
|
callers to set parallel DATABASE_HOST_READ_REPLICA / etc. env vars.
|
|
"""
|
|
parsed = urllib.parse.urlparse(url)
|
|
if not parsed.hostname or not parsed.username:
|
|
raise ValueError("Cannot parse IAM endpoint from URL: missing host or username")
|
|
name = (parsed.path or "/").lstrip("/")
|
|
if not name:
|
|
raise ValueError("Cannot parse IAM endpoint from URL: missing database name")
|
|
port = str(parsed.port) if parsed.port else "5432"
|
|
schema: str | None = None
|
|
if parsed.query:
|
|
qs = urllib.parse.parse_qs(parsed.query)
|
|
schema_vals = qs.get("schema")
|
|
if schema_vals:
|
|
schema = schema_vals[0]
|
|
return IAMEndpoint(
|
|
host=parsed.hostname,
|
|
port=port,
|
|
user=parsed.username,
|
|
name=name,
|
|
schema=schema,
|
|
)
|
|
|
|
|
|
class PrismaWrapper:
|
|
"""
|
|
Wrapper around Prisma client that handles RDS IAM token authentication.
|
|
|
|
When iam_token_db_auth is enabled, this wrapper:
|
|
1. Proactively refreshes IAM tokens before they expire (background task)
|
|
2. Falls back to synchronous refresh if a token is found expired
|
|
3. Uses proper locking to prevent race conditions during reconnection
|
|
|
|
RDS IAM tokens are valid for 15 minutes. This wrapper refreshes them
|
|
3 minutes before expiration to ensure uninterrupted database connectivity.
|
|
"""
|
|
|
|
# Buffer time in seconds before token expiration to trigger refresh
|
|
# Refresh 3 minutes (180 seconds) before the token expires
|
|
TOKEN_REFRESH_BUFFER_SECONDS = 180
|
|
|
|
# Fallback refresh interval if token parsing fails (10 minutes)
|
|
FALLBACK_REFRESH_INTERVAL_SECONDS = 600
|
|
|
|
def __init__(
|
|
self,
|
|
original_prisma: Any,
|
|
iam_token_db_auth: bool,
|
|
*,
|
|
db_url_env_var: str = "DATABASE_URL",
|
|
iam_endpoint: IAMEndpoint | None = None,
|
|
recreate_uses_datasource: bool = False,
|
|
log_prefix: str = "",
|
|
):
|
|
self._original_prisma = original_prisma
|
|
self.iam_token_db_auth = iam_token_db_auth
|
|
|
|
# Per-connection knobs so the same wrapper can be used for the writer
|
|
# (defaults: DATABASE_URL env, IAM endpoint from DATABASE_HOST/etc.,
|
|
# recreate via env reload) or for a reader (DATABASE_URL_READ_REPLICA
|
|
# env, IAM endpoint parsed from that URL, recreate via datasource
|
|
# override since Prisma only auto-reads DATABASE_URL).
|
|
self._db_url_env_var = db_url_env_var
|
|
self._iam_endpoint = iam_endpoint
|
|
self._recreate_uses_datasource = recreate_uses_datasource
|
|
# Tag every log line emitted by this wrapper instance so writer and
|
|
# reader can be told apart in interleaved output (e.g. "[writer] RDS
|
|
# IAM token refresh scheduled in 720 seconds"). Empty string (default)
|
|
# keeps backward-compatible logs for the single-DB case.
|
|
self._log_prefix = f"{log_prefix} " if log_prefix else ""
|
|
|
|
# Background token refresh task management
|
|
self._token_refresh_task: asyncio.Task | None = None
|
|
self._reconnection_lock = asyncio.Lock()
|
|
self._last_refresh_time: datetime | None = None
|
|
|
|
# Coordination for planned engine restarts (issue #29176). Every
|
|
# `recreate_prisma_client` SIGTERMs the running query-engine on
|
|
# purpose. The engine-death watcher (in `PrismaClient`) must be able
|
|
# to tell that planned kill apart from a real crash, otherwise it
|
|
# triggers its own reconnect and kills the freshly-spawned engine.
|
|
# - `_expected_engine_deaths`: PIDs we intentionally killed; the
|
|
# watcher consumes these instead of reconnecting.
|
|
# - `_engine_generation`: monotonic counter bumped on every
|
|
# successful recreate, used by callers as an optimistic-lock token
|
|
# so racing/cascading recreates collapse into a single restart.
|
|
# - `on_engine_replaced`: optional callback fired after a recreate so
|
|
# the owner (PrismaClient) can re-arm its watcher on the new PID.
|
|
self._expected_engine_deaths: set[int] = set()
|
|
self._engine_generation: int = 0
|
|
self.on_engine_replaced: Callable[[], None] | None = None
|
|
|
|
def _get_engine_pid(self) -> int:
|
|
"""Get the PID of the current Prisma engine subprocess, or 0 if unavailable."""
|
|
try:
|
|
engine = self._original_prisma._engine
|
|
process = getattr(engine, "process", None) if engine is not None else None
|
|
if process is not None:
|
|
pid = process.pid
|
|
if isinstance(pid, int):
|
|
return pid
|
|
except (AttributeError, TypeError):
|
|
pass
|
|
return 0
|
|
|
|
@staticmethod
|
|
async def _kill_engine_process(pid: int) -> None:
|
|
"""Force-kill the engine subprocess to prevent DB connection pool leaks.
|
|
|
|
Called on every reconnect (in `recreate_prisma_client`) to retire the
|
|
old query-engine subprocess without invoking prisma-client-py's
|
|
synchronous `disconnect()` — which blocks the asyncio event loop on
|
|
`subprocess.Popen.wait()` for 30-120+ seconds when the engine is
|
|
stuck on TCP close.
|
|
|
|
Sends SIGTERM for graceful shutdown, waits briefly, then SIGKILL as
|
|
a backstop.
|
|
"""
|
|
if pid <= 0:
|
|
return
|
|
try:
|
|
os.kill(pid, signal.SIGTERM)
|
|
except (ProcessLookupError, PermissionError, OSError):
|
|
return # Already dead or inaccessible
|
|
verbose_proxy_logger.warning(
|
|
"Sent SIGTERM to prisma-query-engine PID %s during reconnect.",
|
|
pid,
|
|
)
|
|
# Brief wait for graceful shutdown, then force-kill
|
|
await asyncio.sleep(0.5)
|
|
try:
|
|
os.kill(pid, getattr(signal, "SIGKILL", signal.SIGTERM))
|
|
verbose_proxy_logger.warning(
|
|
"Sent SIGKILL to prisma-query-engine PID %s (did not exit after SIGTERM).",
|
|
pid,
|
|
)
|
|
except (ProcessLookupError, PermissionError, OSError):
|
|
pass # Exited after SIGTERM — expected
|
|
|
|
def _extract_token_from_db_url(self, db_url: str | None) -> str | None:
|
|
"""
|
|
Extract the token (password) from the DATABASE_URL.
|
|
|
|
The token contains the AWS signature with X-Amz-Date and X-Amz-Expires parameters.
|
|
|
|
Important: We must parse the URL while it's still encoded to preserve structure,
|
|
then decode the password portion. Otherwise the '?' in the token breaks URL parsing.
|
|
"""
|
|
if db_url is None:
|
|
return None
|
|
try:
|
|
# Parse URL while still encoded to preserve structure
|
|
parsed = urllib.parse.urlparse(db_url)
|
|
if parsed.password:
|
|
# Now decode just the password/token
|
|
return urllib.parse.unquote(parsed.password)
|
|
return None
|
|
except Exception:
|
|
return None
|
|
|
|
def _parse_token_expiration(self, token: str | None) -> datetime | None:
|
|
"""
|
|
Parse the token to extract its expiration time.
|
|
|
|
Returns the datetime when the token expires, or None if parsing fails.
|
|
"""
|
|
if token is None:
|
|
return None
|
|
|
|
try:
|
|
# Token format: ...?X-Amz-Date=YYYYMMDDTHHMMSSZ&X-Amz-Expires=900&...
|
|
if "?" not in token:
|
|
return None
|
|
|
|
query_string = token.split("?", 1)[1]
|
|
params = urllib.parse.parse_qs(query_string)
|
|
|
|
expires_str = params.get("X-Amz-Expires", [None])[0]
|
|
date_str = params.get("X-Amz-Date", [None])[0]
|
|
|
|
if not expires_str or not date_str:
|
|
return None
|
|
|
|
token_created = datetime.strptime(date_str, "%Y%m%dT%H%M%SZ")
|
|
expires_in = int(expires_str)
|
|
|
|
return token_created + timedelta(seconds=expires_in)
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(f"Failed to parse token expiration: {e}")
|
|
return None
|
|
|
|
def _calculate_seconds_until_refresh(self) -> float:
|
|
"""
|
|
Calculate exactly how many seconds until we need to refresh the token.
|
|
|
|
Uses precise timing: sleeps until (token_expiration - buffer_seconds).
|
|
For a 15-minute (900s) token with 180s buffer, this returns ~720s (12 min).
|
|
|
|
Returns:
|
|
Number of seconds to sleep before the next refresh.
|
|
Returns 0 if token should be refreshed immediately.
|
|
Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails.
|
|
"""
|
|
db_url = os.getenv(self._db_url_env_var)
|
|
token = self._extract_token_from_db_url(db_url)
|
|
expiration_time = self._parse_token_expiration(token)
|
|
|
|
if expiration_time is None:
|
|
# If we can't parse the token, use fallback interval
|
|
verbose_proxy_logger.debug(
|
|
f"Could not parse token expiration, using fallback interval of "
|
|
f"{self.FALLBACK_REFRESH_INTERVAL_SECONDS}s"
|
|
)
|
|
return self.FALLBACK_REFRESH_INTERVAL_SECONDS
|
|
|
|
# Calculate when we should refresh (expiration - buffer)
|
|
refresh_at = expiration_time - timedelta(
|
|
seconds=self.TOKEN_REFRESH_BUFFER_SECONDS
|
|
)
|
|
|
|
# How long until refresh time?
|
|
now = datetime.utcnow()
|
|
seconds_until_refresh = (refresh_at - now).total_seconds()
|
|
|
|
# If already past refresh time, return 0 (refresh immediately)
|
|
return max(0, seconds_until_refresh)
|
|
|
|
def is_token_expired(self, token_url: str | None) -> bool:
|
|
"""Check if the token in the given URL is expired."""
|
|
if token_url is None:
|
|
return True
|
|
|
|
token = self._extract_token_from_db_url(token_url)
|
|
expiration_time = self._parse_token_expiration(token)
|
|
|
|
if expiration_time is None:
|
|
# If we can't parse the token, assume it's expired to trigger refresh
|
|
verbose_proxy_logger.debug(
|
|
"Could not parse token expiration, treating as expired"
|
|
)
|
|
return True
|
|
|
|
return datetime.utcnow() > expiration_time
|
|
|
|
def get_rds_iam_token(self) -> str | None:
|
|
"""Generate a new RDS IAM token and update the configured DB URL env var.
|
|
|
|
When the wrapper was constructed with an explicit `iam_endpoint`
|
|
(typical for a reader wrapper whose host/port/user came from a parsed
|
|
URL), use that. Otherwise fall back to the legacy DATABASE_HOST/PORT/
|
|
USER/NAME/SCHEMA env vars (writer behavior).
|
|
"""
|
|
if not self.iam_token_db_auth:
|
|
return None
|
|
|
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
|
|
|
if self._iam_endpoint is not None:
|
|
endpoint = self._iam_endpoint
|
|
token = generate_iam_auth_token(
|
|
db_host=endpoint.host, db_port=endpoint.port, db_user=endpoint.user
|
|
)
|
|
_db_url = endpoint.build_url(token)
|
|
else:
|
|
db_host = os.getenv("DATABASE_HOST")
|
|
# Default to the Postgres standard port; passing None to
|
|
# `generate_iam_auth_token` makes botocore embed the literal
|
|
# string "None" in the presigned URL, which then fails to parse.
|
|
db_port = os.getenv("DATABASE_PORT", "5432")
|
|
db_user = os.getenv("DATABASE_USER")
|
|
db_name = os.getenv("DATABASE_NAME")
|
|
db_schema = os.getenv("DATABASE_SCHEMA")
|
|
|
|
token = generate_iam_auth_token(
|
|
db_host=db_host, db_port=db_port, db_user=db_user
|
|
)
|
|
|
|
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
|
|
if db_schema:
|
|
_db_url += f"?schema={db_schema}"
|
|
|
|
os.environ[self._db_url_env_var] = _db_url
|
|
return _db_url
|
|
|
|
async def recreate_prisma_client(
|
|
self,
|
|
new_db_url: str,
|
|
http_client: Any | None = None,
|
|
*,
|
|
expected_generation: int | None = None,
|
|
) -> bool:
|
|
"""Disconnect and reconnect the Prisma client with a new database URL.
|
|
|
|
Kills the old engine subprocess directly (SIGTERM → SIGKILL) rather than
|
|
calling `disconnect()`. prisma-client-py's `disconnect()` calls a
|
|
synchronous `subprocess.Popen.wait()` that can freeze the asyncio event
|
|
loop for 30-120+ seconds when the engine is stuck on TCP close,
|
|
breaking `/health/liveliness` and causing Kubernetes pod restarts.
|
|
|
|
The writer wrapper relies on Prisma re-reading `DATABASE_URL` from env;
|
|
the reader wrapper opts into `recreate_uses_datasource=True` so the
|
|
new URL is passed explicitly via `datasource={"url": ...}` (Prisma
|
|
does not auto-read alternate env vars like DATABASE_URL_READ_REPLICA).
|
|
|
|
Serializes all recreations through `self._reconnection_lock` so the
|
|
IAM-refresh path and the engine-death/transport-error reconnect paths
|
|
cannot recreate concurrently (issue #29176). `expected_generation`, if
|
|
given, is an optimistic-lock token: when it no longer matches
|
|
`self._engine_generation` once the lock is held, another path already
|
|
replaced the engine, so this call is a no-op and returns ``False``.
|
|
|
|
Returns:
|
|
bool: ``True`` if the client was actually recreated, ``False`` if
|
|
the recreate was skipped because the engine generation moved on.
|
|
"""
|
|
async with self._reconnection_lock:
|
|
return await self._recreate_prisma_client_locked(
|
|
new_db_url,
|
|
http_client=http_client,
|
|
expected_generation=expected_generation,
|
|
)
|
|
|
|
async def _recreate_prisma_client_locked(
|
|
self,
|
|
new_db_url: str,
|
|
http_client: Any | None = None,
|
|
*,
|
|
expected_generation: int | None = None,
|
|
) -> bool:
|
|
"""Core recreate logic. Caller MUST hold `self._reconnection_lock`.
|
|
|
|
Split out so callers that already hold the lock (e.g.
|
|
`_safe_refresh_token`, which double-checks token freshness under the
|
|
lock) don't re-acquire it — `asyncio.Lock` is not reentrant.
|
|
"""
|
|
from prisma import Prisma # type: ignore
|
|
|
|
if (
|
|
expected_generation is not None
|
|
and expected_generation != self._engine_generation
|
|
):
|
|
verbose_proxy_logger.info(
|
|
"%sSkipping Prisma client recreate: engine already replaced "
|
|
"(generation %s != expected %s).",
|
|
self._log_prefix,
|
|
self._engine_generation,
|
|
expected_generation,
|
|
)
|
|
return False
|
|
|
|
old_engine_pid = self._get_engine_pid()
|
|
if old_engine_pid > 0:
|
|
# Record BEFORE the kill so the engine-death watcher, which may
|
|
# fire the instant the process dies, recognizes this as a planned
|
|
# restart and does not launch its own reconnect.
|
|
#
|
|
# A stale entry can linger when the watcher re-arms on the new PID
|
|
# before the old PID's death callback runs (the callback then
|
|
# early-returns on PID mismatch without consuming it). Such entries
|
|
# are harmless but would accumulate on a long-running proxy (~one
|
|
# per IAM refresh), so cap the set — those old PIDs are long dead.
|
|
if len(self._expected_engine_deaths) >= 64:
|
|
self._expected_engine_deaths.clear()
|
|
self._expected_engine_deaths.add(old_engine_pid)
|
|
await self._kill_engine_process(old_engine_pid)
|
|
|
|
kwargs: dict[str, Any] = {}
|
|
if http_client is not None:
|
|
kwargs["http"] = http_client
|
|
if self._recreate_uses_datasource:
|
|
kwargs["datasource"] = {"url": new_db_url}
|
|
self._original_prisma = Prisma(**kwargs)
|
|
|
|
await self._original_prisma.connect()
|
|
self._engine_generation += 1
|
|
|
|
# Let the owner (PrismaClient) re-arm its engine-death watcher on the
|
|
# newly-spawned engine PID. Scheduled, never awaited, so a slow watcher
|
|
# can't stall the refresh while we hold the reconnection lock.
|
|
if self.on_engine_replaced is not None:
|
|
self.on_engine_replaced()
|
|
|
|
return True
|
|
|
|
async def start_token_refresh_task(self) -> None:
|
|
"""
|
|
Start the background token refresh task.
|
|
|
|
This task proactively refreshes RDS IAM tokens before they expire,
|
|
preventing connection failures. Should be called after the initial
|
|
Prisma client connection is established.
|
|
"""
|
|
if not self.iam_token_db_auth:
|
|
verbose_proxy_logger.debug(
|
|
"IAM token auth not enabled, skipping token refresh task"
|
|
)
|
|
return
|
|
|
|
if self._token_refresh_task is not None:
|
|
verbose_proxy_logger.debug("Token refresh task already running")
|
|
return
|
|
|
|
self._token_refresh_task = asyncio.create_task(self._token_refresh_loop())
|
|
verbose_proxy_logger.info(
|
|
"%sStarted RDS IAM token proactive refresh background task",
|
|
self._log_prefix,
|
|
)
|
|
|
|
async def stop_token_refresh_task(self) -> None:
|
|
"""
|
|
Stop the background token refresh task gracefully.
|
|
|
|
Should be called during application shutdown to clean up resources.
|
|
"""
|
|
if self._token_refresh_task is None:
|
|
return
|
|
|
|
self._token_refresh_task.cancel()
|
|
try:
|
|
await self._token_refresh_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._token_refresh_task = None
|
|
verbose_proxy_logger.info(
|
|
"%sStopped RDS IAM token refresh background task", self._log_prefix
|
|
)
|
|
|
|
async def _token_refresh_loop(self) -> None:
|
|
"""
|
|
Background loop that proactively refreshes RDS IAM tokens before expiration.
|
|
|
|
Uses precise timing: calculates the exact sleep duration until the token
|
|
needs to be refreshed (expiration - 3 minute buffer), then refreshes.
|
|
This is more efficient than polling, requiring only 1 wake-up per token cycle.
|
|
"""
|
|
verbose_proxy_logger.info(
|
|
f"{self._log_prefix}RDS IAM token refresh loop started. "
|
|
f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration."
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
# Calculate exactly how long to sleep until next refresh
|
|
sleep_seconds = self._calculate_seconds_until_refresh()
|
|
|
|
if sleep_seconds > 0:
|
|
verbose_proxy_logger.info(
|
|
f"{self._log_prefix}RDS IAM token refresh scheduled in "
|
|
f"{sleep_seconds:.0f} seconds ({sleep_seconds / 60:.1f} minutes)"
|
|
)
|
|
await asyncio.sleep(sleep_seconds)
|
|
|
|
# Refresh the token
|
|
verbose_proxy_logger.info(
|
|
"%sProactively refreshing RDS IAM token...", self._log_prefix
|
|
)
|
|
await self._safe_refresh_token()
|
|
|
|
except asyncio.CancelledError:
|
|
verbose_proxy_logger.info(
|
|
"%sRDS IAM token refresh loop cancelled", self._log_prefix
|
|
)
|
|
break
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"{self._log_prefix}Error in RDS IAM token refresh loop: {e}. "
|
|
f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..."
|
|
)
|
|
# On error, wait before retrying to avoid tight error loops
|
|
try:
|
|
await asyncio.sleep(self.FALLBACK_REFRESH_INTERVAL_SECONDS)
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
async def _safe_refresh_token(self) -> None:
|
|
"""
|
|
Refresh the RDS IAM token with proper locking to prevent race conditions.
|
|
|
|
Uses an asyncio lock to ensure only one refresh operation happens at a time,
|
|
preventing multiple concurrent reconnection attempts.
|
|
"""
|
|
async with self._reconnection_lock:
|
|
# Double-checked under the lock: another trigger (e.g. the
|
|
# proactive loop racing a __getattr__ fallback) may have already
|
|
# refreshed while we waited. Recreating again would needlessly kill
|
|
# the engine that refresh just spawned (issue #29176), so coalesce
|
|
# by skipping when the current token still has comfortable runway.
|
|
if self._token_refresh_not_needed(os.getenv(self._db_url_env_var)):
|
|
verbose_proxy_logger.debug(
|
|
"%sRDS IAM token still fresh; skipping redundant refresh.",
|
|
self._log_prefix,
|
|
)
|
|
return
|
|
|
|
new_db_url = self.get_rds_iam_token()
|
|
if new_db_url:
|
|
# We already hold `_reconnection_lock`; call the locked core
|
|
# directly (the public method would re-acquire and deadlock).
|
|
await self._recreate_prisma_client_locked(new_db_url)
|
|
self._last_refresh_time = datetime.utcnow()
|
|
verbose_proxy_logger.info(
|
|
"%sRDS IAM token refreshed successfully. New token valid for ~15 minutes.",
|
|
self._log_prefix,
|
|
)
|
|
else:
|
|
verbose_proxy_logger.error(
|
|
"%sFailed to generate new RDS IAM token during proactive refresh",
|
|
self._log_prefix,
|
|
)
|
|
|
|
def _token_refresh_not_needed(self, token_url: str | None) -> bool:
|
|
"""True iff the token in ``token_url`` has more than the refresh buffer
|
|
of runway left, so a refresh would be redundant.
|
|
|
|
Used to coalesce stacked refresh triggers. Deliberately mirrors the
|
|
proactive loop's schedule (refresh at ``expiration - buffer``): a token
|
|
with exactly ``buffer`` seconds left is NOT considered fresh, so the
|
|
legitimate proactive refresh still fires. Unparseable tokens return
|
|
``False`` (refresh) — skipping them would mean never refreshing.
|
|
"""
|
|
token = self._extract_token_from_db_url(token_url)
|
|
expiration_time = self._parse_token_expiration(token)
|
|
if expiration_time is None:
|
|
return False
|
|
seconds_left = (expiration_time - datetime.utcnow()).total_seconds()
|
|
return seconds_left > self.TOKEN_REFRESH_BUFFER_SECONDS
|
|
|
|
def __getattr__(self, name: str):
|
|
"""
|
|
Proxy attribute access to the underlying Prisma client.
|
|
|
|
If IAM token auth is enabled and the token is found expired here, the
|
|
proactive refresh task has missed its window. Behavior depends on
|
|
whether we're called from inside a running event loop:
|
|
|
|
- Inside the loop (typical: from a coroutine): schedule a refresh as a
|
|
background task and return the (stale) attribute. The caller's await
|
|
will likely fail with a connection error and be retried by upper
|
|
layers (`call_with_db_reconnect_retry`); by that time the refresh
|
|
has either completed or escalated to the proactive loop's error
|
|
path. We CANNOT block here — `run_coroutine_threadsafe(...)` +
|
|
`future.result()` from inside the same loop deadlocks the loop
|
|
(loop thread is blocked, scheduled coroutine never runs, 30s timeout).
|
|
|
|
- No running loop (sync caller, mostly tests): run the refresh in a
|
|
fresh loop and re-fetch the attribute.
|
|
"""
|
|
original_attr = getattr(self._original_prisma, name)
|
|
|
|
if self.iam_token_db_auth:
|
|
db_url = os.getenv(self._db_url_env_var)
|
|
|
|
# Check if token is expired (should be rare if background task is running)
|
|
if self.is_token_expired(db_url):
|
|
try:
|
|
running_loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
running_loop = None
|
|
|
|
if running_loop is not None:
|
|
verbose_proxy_logger.warning(
|
|
"%sRDS IAM token expired in __getattr__ — proactive refresh "
|
|
"may have failed. Scheduling async refresh; the current "
|
|
"request may fail and be retried with the fresh token.",
|
|
self._log_prefix,
|
|
)
|
|
# Non-blocking: schedule the locked refresh on the
|
|
# running loop. The reconnection lock inside
|
|
# `_safe_refresh_token` coalesces concurrent triggers.
|
|
running_loop.create_task(self._safe_refresh_token())
|
|
else:
|
|
verbose_proxy_logger.warning(
|
|
"%sRDS IAM token expired in __getattr__ — proactive refresh "
|
|
"may have failed. Triggering synchronous fallback refresh...",
|
|
self._log_prefix,
|
|
)
|
|
new_db_url = self.get_rds_iam_token()
|
|
if new_db_url:
|
|
asyncio.run(self.recreate_prisma_client(new_db_url))
|
|
# Re-fetch attribute against the recreated Prisma instance.
|
|
original_attr = getattr(self._original_prisma, name)
|
|
verbose_proxy_logger.info(
|
|
"%sSynchronous token refresh completed successfully",
|
|
self._log_prefix,
|
|
)
|
|
else:
|
|
raise ValueError("Failed to get RDS IAM token")
|
|
|
|
return original_attr
|
|
|
|
|
|
class PrismaManager:
|
|
@staticmethod
|
|
def _get_prisma_dir() -> str:
|
|
"""Get the path to the migrations directory"""
|
|
abspath = os.path.abspath(__file__)
|
|
dname = os.path.dirname(os.path.dirname(abspath))
|
|
return dname
|
|
|
|
@staticmethod
|
|
def setup_database(
|
|
use_migrate: bool = False, use_v2_resolver: bool = False
|
|
) -> bool:
|
|
"""
|
|
Set up the database using either prisma migrate or prisma db push
|
|
|
|
Args:
|
|
use_migrate: Use `prisma migrate deploy` instead of `db push`.
|
|
use_v2_resolver: Opt into the v2 migration resolver that avoids
|
|
the diff-and-force recovery behavior (which caused schema
|
|
thrashing during rolling deploys). Defaults to False.
|
|
|
|
Returns:
|
|
bool: True if setup was successful, False otherwise
|
|
"""
|
|
|
|
for attempt in range(4):
|
|
original_dir = os.getcwd()
|
|
prisma_dir = PrismaManager._get_prisma_dir()
|
|
os.chdir(prisma_dir)
|
|
try:
|
|
if use_migrate:
|
|
try:
|
|
from litellm_proxy_extras.utils import ProxyExtrasDBManager
|
|
except ImportError as e:
|
|
verbose_proxy_logger.error(
|
|
f"\033[1;31mLiteLLM: Failed to import proxy extras. Got {e}\033[0m"
|
|
)
|
|
return False
|
|
|
|
prisma_dir = PrismaManager._get_prisma_dir()
|
|
|
|
return ProxyExtrasDBManager.setup_database(
|
|
use_migrate=use_migrate,
|
|
use_v2_resolver=use_v2_resolver,
|
|
)
|
|
else:
|
|
# Use prisma db push with increased timeout
|
|
subprocess.run(
|
|
[
|
|
"prisma",
|
|
"db",
|
|
"push",
|
|
"--accept-data-loss",
|
|
"--skip-generate",
|
|
],
|
|
timeout=60,
|
|
check=True,
|
|
)
|
|
return True
|
|
except subprocess.TimeoutExpired:
|
|
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
|
|
time.sleep(random.randrange(5, 15))
|
|
except subprocess.CalledProcessError as e:
|
|
attempts_left = 3 - attempt
|
|
retry_msg = (
|
|
f" Retrying... ({attempts_left} attempts left)"
|
|
if attempts_left > 0
|
|
else ""
|
|
)
|
|
verbose_proxy_logger.warning(
|
|
f"The process failed to execute. Details: {e}.{retry_msg}"
|
|
)
|
|
time.sleep(random.randrange(5, 15))
|
|
finally:
|
|
os.chdir(original_dir)
|
|
return False
|
|
|
|
|
|
def should_update_prisma_schema(
|
|
disable_updates: Union[bool, str] | None = None,
|
|
) -> bool:
|
|
"""
|
|
Determines if Prisma Schema updates should be applied during startup.
|
|
|
|
Args:
|
|
disable_updates: Controls whether schema updates are disabled.
|
|
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
|
|
|
|
Returns:
|
|
bool: True if schema updates should be applied, False if updates are disabled.
|
|
|
|
Examples:
|
|
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
|
|
>>> should_update_prisma_schema(True) # Explicitly disable updates
|
|
>>> should_update_prisma_schema("false") # Enable updates using string
|
|
"""
|
|
if disable_updates is None:
|
|
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
|
|
|
|
if isinstance(disable_updates, str):
|
|
disable_updates = str_to_bool(disable_updates)
|
|
|
|
return not bool(disable_updates)
|