fa45d8aa5f
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
Privoxy对node122:18003返回500,直连正常
371 lines
14 KiB
Python
371 lines
14 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
import httpx
|
|
|
|
from .. import constants
|
|
from . import hf_raise_for_status, http_backoff, validate_hf_hub_args
|
|
|
|
|
|
XET_CONNECTION_INFO_SAFETY_PERIOD = 60 # seconds
|
|
XET_CONNECTION_INFO_CACHE_SIZE = 1_000
|
|
XET_CONNECTION_INFO_CACHE: dict[str, "XetConnectionInfo"] = {}
|
|
|
|
|
|
class XetTokenType(str, Enum):
|
|
READ = "read"
|
|
WRITE = "write"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class XetFileData:
|
|
file_hash: str
|
|
refresh_route: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class XetConnectionInfo:
|
|
access_token: str
|
|
expiration_unix_epoch: int
|
|
endpoint: str
|
|
|
|
|
|
def parse_xet_file_data_from_response(response: httpx.Response, endpoint: str | None = None) -> XetFileData | None:
|
|
"""
|
|
Parse XET file metadata from an HTTP response.
|
|
|
|
This function extracts XET file metadata from the HTTP headers or HTTP links
|
|
of a given response object. If the required metadata is not found, it returns `None`.
|
|
|
|
Args:
|
|
response (`httpx.Response`):
|
|
The HTTP response object containing headers dict and links dict to extract the XET metadata from.
|
|
Returns:
|
|
`Optional[XetFileData]`:
|
|
An instance of `XetFileData` containing the file hash and refresh route if the metadata
|
|
is found. Returns `None` if the required metadata is missing.
|
|
"""
|
|
if response is None:
|
|
return None
|
|
try:
|
|
file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
|
|
|
|
if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
|
|
refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
|
|
else:
|
|
refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
|
|
except KeyError:
|
|
return None
|
|
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
|
if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
|
|
refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
|
|
return XetFileData(
|
|
file_hash=file_hash,
|
|
refresh_route=refresh_route,
|
|
)
|
|
|
|
|
|
def parse_xet_connection_info_from_headers(headers: dict[str, str]) -> XetConnectionInfo | None:
|
|
"""
|
|
Parse XET connection info from the HTTP headers or return None if not found.
|
|
Args:
|
|
headers (`dict`):
|
|
HTTP headers to extract the XET metadata from.
|
|
Returns:
|
|
`XetConnectionInfo` or `None`:
|
|
The information needed to connect to the XET storage service.
|
|
Returns `None` if the headers do not contain the XET connection info.
|
|
"""
|
|
try:
|
|
endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
|
|
access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
|
|
expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
|
|
except (KeyError, ValueError, TypeError):
|
|
return None
|
|
|
|
return XetConnectionInfo(
|
|
endpoint=endpoint,
|
|
access_token=access_token,
|
|
expiration_unix_epoch=expiration_unix_epoch,
|
|
)
|
|
|
|
|
|
@validate_hf_hub_args
|
|
def refresh_xet_connection_info(
|
|
*,
|
|
file_data: XetFileData,
|
|
headers: dict[str, str],
|
|
) -> XetConnectionInfo:
|
|
"""
|
|
Utilizes the information in the parsed metadata to request the Hub xet connection information.
|
|
This includes the access token, expiration, and XET service URL.
|
|
Args:
|
|
file_data: (`XetFileData`):
|
|
The file data needed to refresh the xet connection information.
|
|
headers (`dict[str, str]`):
|
|
Headers to use for the request, including authorization headers and user agent.
|
|
Returns:
|
|
`XetConnectionInfo`:
|
|
The connection information needed to make the request to the xet storage service.
|
|
Raises:
|
|
[`~utils.HfHubHTTPError`]
|
|
If the Hub API returned an error.
|
|
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
If the Hub API response is improperly formatted.
|
|
"""
|
|
if file_data.refresh_route is None:
|
|
raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
|
|
return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers)
|
|
|
|
|
|
@validate_hf_hub_args
|
|
def xet_connection_info_refresh_url(
|
|
*,
|
|
token_type: XetTokenType,
|
|
repo_id: str,
|
|
repo_type: str,
|
|
revision: str | None = None,
|
|
endpoint: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Build the URL used to fetch or refresh a Xet access token for a given repo.
|
|
Args:
|
|
token_type (`XetTokenType`):
|
|
Type of the token to request: `"read"` or `"write"`.
|
|
repo_id (`str`):
|
|
A namespace (user or an organization) and a repo name separated by a `/`.
|
|
repo_type (`str`):
|
|
Type of the repo (e.g. `"model"`, `"dataset"`, `"space"`, `"bucket"`).
|
|
revision (`str`, `optional`):
|
|
The revision of the repo to get the token for.
|
|
endpoint (`str`, `optional`):
|
|
The endpoint to use for the request. Defaults to the Hub endpoint.
|
|
Returns:
|
|
`str`:
|
|
The fully-qualified URL of the token refresh endpoint.
|
|
"""
|
|
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
|
|
url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token"
|
|
if repo_type != "bucket" or revision is not None:
|
|
# On "bucket" repo type, the revision never needed => don't use it
|
|
# Otherwise, use the revision.
|
|
# Note: when creating a PR on a git-based repo, user needs write access but they don't know the revision in advance.
|
|
# => pass "/None" in URL and server will return a token for PR refs.
|
|
url += f"/{revision}"
|
|
return url
|
|
|
|
|
|
@validate_hf_hub_args
|
|
def fetch_xet_connection_info_from_repo_info(
|
|
*,
|
|
token_type: XetTokenType,
|
|
repo_id: str,
|
|
repo_type: str,
|
|
revision: str | None = None,
|
|
headers: dict[str, str],
|
|
endpoint: str | None = None,
|
|
params: dict[str, str] | None = None,
|
|
) -> XetConnectionInfo:
|
|
"""
|
|
Uses the repo info to request a xet access token from Hub.
|
|
Args:
|
|
token_type (`XetTokenType`):
|
|
Type of the token to request: `"read"` or `"write"`.
|
|
repo_id (`str`):
|
|
A namespace (user or an organization) and a repo name separated by a `/`.
|
|
repo_type (`str`):
|
|
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
|
revision (`str`, `optional`):
|
|
The revision of the repo to get the token for.
|
|
headers (`dict[str, str]`):
|
|
Headers to use for the request, including authorization headers and user agent.
|
|
endpoint (`str`, `optional`):
|
|
The endpoint to use for the request. Defaults to the Hub endpoint.
|
|
params (`dict[str, str]`, `optional`):
|
|
Additional parameters to pass with the request.
|
|
Returns:
|
|
`XetConnectionInfo`:
|
|
The connection information needed to make the request to the xet storage service.
|
|
Raises:
|
|
[`~utils.HfHubHTTPError`]
|
|
If the Hub API returned an error.
|
|
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
If the Hub API response is improperly formatted.
|
|
"""
|
|
url = xet_connection_info_refresh_url(
|
|
token_type=token_type,
|
|
repo_id=repo_id,
|
|
repo_type=repo_type,
|
|
revision=revision,
|
|
endpoint=endpoint,
|
|
)
|
|
return _fetch_xet_connection_info_with_url(url, headers, params, cache_key_prefix=f"{repo_type}-{repo_id}")
|
|
|
|
|
|
@validate_hf_hub_args
|
|
def _fetch_xet_connection_info_with_url(
|
|
url: str,
|
|
headers: dict[str, str],
|
|
params: dict[str, str] | None = None,
|
|
cache_key_prefix: str | None = None,
|
|
) -> XetConnectionInfo:
|
|
"""
|
|
Requests the xet connection info from the supplied URL. This includes the
|
|
access token, expiration time, and endpoint to use for the xet storage service.
|
|
|
|
Result is cached to avoid redundant requests.
|
|
|
|
Args:
|
|
url: (`str`):
|
|
The access token endpoint URL.
|
|
headers (`dict[str, str]`):
|
|
Headers to use for the request, including authorization headers and user agent.
|
|
params (`dict[str, str]`, `optional`):
|
|
Additional parameters to pass with the request.
|
|
Returns:
|
|
`XetConnectionInfo`:
|
|
The connection information needed to make the request to the xet storage service.
|
|
Raises:
|
|
[`~utils.HfHubHTTPError`]
|
|
If the Hub API returned an error.
|
|
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
If the Hub API response is improperly formatted.
|
|
"""
|
|
# Check cache first
|
|
cache_key = _cache_key(url, headers, params, prefix=cache_key_prefix)
|
|
cached_info = XET_CONNECTION_INFO_CACHE.get(cache_key)
|
|
if cached_info is not None:
|
|
if not _is_expired(cached_info):
|
|
return cached_info
|
|
|
|
# Fetch from server
|
|
resp = http_backoff("GET", url, headers=headers, params=params)
|
|
hf_raise_for_status(resp)
|
|
|
|
metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
|
|
if metadata is None:
|
|
raise ValueError("Xet headers have not been correctly set by the server.")
|
|
|
|
# Delete expired cache entries
|
|
for k, v in list(XET_CONNECTION_INFO_CACHE.items()):
|
|
if _is_expired(v):
|
|
XET_CONNECTION_INFO_CACHE.pop(k, None)
|
|
|
|
# Enforce cache size limit
|
|
if len(XET_CONNECTION_INFO_CACHE) >= XET_CONNECTION_INFO_CACHE_SIZE:
|
|
XET_CONNECTION_INFO_CACHE.pop(next(iter(XET_CONNECTION_INFO_CACHE)))
|
|
|
|
# Update cache
|
|
XET_CONNECTION_INFO_CACHE[cache_key] = metadata
|
|
|
|
return metadata
|
|
|
|
|
|
def reset_xet_connection_info_cache_for_repo(repo_type: str | None, repo_id: str) -> None:
|
|
"""Reset the XET connection info cache for the given repo type and repo id.
|
|
|
|
Used when a repo is deleted.
|
|
"""
|
|
if repo_type is None:
|
|
repo_type = constants.REPO_TYPE_MODEL
|
|
prefix = f"{repo_type}-{repo_id}|"
|
|
for k in list(XET_CONNECTION_INFO_CACHE.keys()):
|
|
if k.startswith(prefix):
|
|
XET_CONNECTION_INFO_CACHE.pop(k, None)
|
|
|
|
|
|
def _cache_key(url: str, headers: dict[str, str], params: dict[str, str] | None, prefix: str | None = None) -> str:
|
|
"""Return a unique cache key for the given request parameters."""
|
|
lower_headers = {k.lower(): v for k, v in headers.items()} # casing is not guaranteed here
|
|
auth_header = lower_headers.get("authorization", "")
|
|
params_str = "&".join(f"{k}={v}" for k, v in sorted((params or {}).items(), key=lambda x: x[0]))
|
|
return f"{prefix}|{url}|{auth_header}|{params_str}"
|
|
|
|
|
|
def _is_expired(connection_info: XetConnectionInfo) -> bool:
|
|
"""Check if the given XET connection info is expired."""
|
|
return connection_info.expiration_unix_epoch <= int(time.time()) + XET_CONNECTION_INFO_SAFETY_PERIOD
|
|
|
|
|
|
class XetSessionHolder:
|
|
"""Holds an optional XetSession; supports safe re-creation after sigint_abort or fork.
|
|
|
|
Thread-safe: a ``threading.Lock`` guards all state mutations, which matters
|
|
for free-threaded Python (3.14t) where multiple threads can race on ``get()``
|
|
or ``sigint_abort()`` without the GIL serialising them.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._lock = threading.Lock()
|
|
self._session = None
|
|
self._session_pid: int | None = None
|
|
|
|
def get(self):
|
|
"""Return the current session, creating one if needed.
|
|
|
|
Fork-safe: if the current process PID differs from the PID that created
|
|
the session (i.e. we are in a forked child), the old session is discarded
|
|
and a fresh session is created for this process.
|
|
"""
|
|
with self._lock:
|
|
current_pid = os.getpid()
|
|
|
|
if self._session is not None and self._session_pid != current_pid:
|
|
# Fork detected. Discard the parent's session; the Rust Drop will
|
|
# call discard_runtime() (std::mem::forget) rather than the normal
|
|
# shutdown path, so this returns immediately without blocking.
|
|
self._session = None
|
|
|
|
if self._session is None:
|
|
from hf_xet import XetSession
|
|
|
|
self._session = XetSession()
|
|
self._session_pid = current_pid
|
|
|
|
return self._session
|
|
|
|
def sigint_abort(self):
|
|
"""Abort the current session and clear it so the next get() creates a fresh one."""
|
|
with self._lock:
|
|
if self._session is not None:
|
|
try:
|
|
self._session.sigint_abort()
|
|
except Exception:
|
|
pass
|
|
self._session = None
|
|
self._session_pid = None
|
|
|
|
|
|
_GLOBAL_XET_HOLDER = XetSessionHolder()
|
|
|
|
|
|
def get_xet_session():
|
|
"""Return the global :class:`hf_xet.XetSession`, creating it on first call.
|
|
|
|
The session is shared across all calls within a process, just as the HTTP
|
|
client returned by :func:`~huggingface_hub.utils._http.get_session` is shared.
|
|
It is created lazily and is fork-safe and thread-safe.
|
|
"""
|
|
return _GLOBAL_XET_HOLDER.get()
|
|
|
|
|
|
def xet_headers_without_auth(headers: dict[str, str]) -> dict[str, str]:
|
|
"""Return a copy of headers with the authorization header removed.
|
|
|
|
Xet storage requests use a short-lived xet access token for auth, so the
|
|
Hub authorization header must not be forwarded to xet storage endpoints.
|
|
"""
|
|
return {key: value for key, value in headers.items() if key.lower() != "authorization"}
|
|
|
|
|
|
def abort_xet_session():
|
|
"""Abort the global xet session after a KeyboardInterrupt.
|
|
|
|
Cancels any in-flight Rust operation and clears the session so the next
|
|
call to :func:`get_xet_session` starts fresh (notebook-friendly).
|
|
"""
|
|
_GLOBAL_XET_HOLDER.sigint_abort()
|