# -*- coding: utf-8 -*- """Market snapshot fetcher. Fetches full-market real-time snapshots for screening. This is separate from single-stock realtime quotes. """ import logging import json import os import threading import time from datetime import date, datetime, timezone, timedelta from pathlib import Path import pandas as pd import requests logger = logging.getLogger(__name__) _SNAPSHOT_CACHE_VERSION = 1 _DEFAULT_TUSHARE_HTTP_URL = "http://api.waditu.com" _EM_REQUEST_MIN_INTERVAL_SECONDS = 0.25 _SOURCE_HEALTH_FAILURE_THRESHOLD = 3 _SOURCE_HEALTH_COOLDOWN_SECONDS = 5 * 60 _EM_SESSION: requests.Session | None = None _EM_LAST_REQUEST_AT = 0.0 _EM_LOCK = threading.Lock() _SOURCE_HEALTH: dict[str, dict[str, float]] = {} _SOURCE_HEALTH_LOCK = threading.Lock() def fetch_cn_snapshot(source: str = "efinance") -> pd.DataFrame: """Fetch A-share full-market snapshot. Returns a DataFrame with columns: code, name, price, change_pct, amount, total_mv, circ_mv, pe_ratio, pb_ratio, volume_ratio, turnover_rate Raises RuntimeError if the source is unavailable. """ if source == "sina": return _fetch_sina() elif source == "efinance": return _fetch_efinance() elif source == "akshare_em": return _fetch_akshare_em() elif source == "em_datacenter": return _fetch_em_datacenter() elif source == "tushare": return _fetch_tushare() else: raise ValueError(f"Unknown snapshot source: {source}") def fetch_snapshot_with_fallback( sources: list[str], *, required_columns: list[str] | None = None, fallback_snapshot_path: str | Path | None = None, fallback_max_age_hours: float | None = None, market: str = "cn", ) -> pd.DataFrame: """Try live sources, optionally falling back to the last-good snapshot.""" if market == "us": return _fetch_us_snapshot_with_fallback(required_columns) errors = [] required = required_columns or [] for source in sources: disabled_reason = _source_disabled_reason(source) if disabled_reason: errors.append(f"{source}: {disabled_reason}") continue try: df = fetch_cn_snapshot(source) if not df.empty: missing = _missing_required_columns(df, required) if missing: errors.append( f"{source}: missing required columns {','.join(missing)}" ) continue df.attrs.setdefault("snapshot_source", source) df.attrs["source_errors"] = list(errors) df.attrs["fallback_used"] = False df.attrs["stale"] = False df.attrs["stale_age_hours"] = None _write_last_good_snapshot(fallback_snapshot_path, df) _record_source_success(source) logger.info("Snapshot fetched from %s: %d rows", source, len(df)) return df errors.append(f"{source}: returned empty data") _record_source_failure(source) except Exception as e: errors.append(f"{source}: {e}") _record_source_failure(source) logger.warning("Snapshot source %s failed: %s", source, e) cached = _read_last_good_snapshot( fallback_snapshot_path, required_columns=required, source_errors=errors, max_age_hours=fallback_max_age_hours, ) if cached is not None: return cached raise RuntimeError(f"All snapshot sources failed: {'; '.join(errors)}") def _fetch_us_snapshot_with_fallback( required_columns: list[str] | None = None, ) -> pd.DataFrame: """Fetch US equity snapshot via yfinance adapter.""" from alphasift.snapshot_us import fetch_us_snapshot df = fetch_us_snapshot() missing = _missing_required_columns(df, required_columns or []) if missing: logger.warning("US snapshot missing columns: %s", ",".join(missing)) return df def _missing_required_columns(df: pd.DataFrame, required_columns: list[str]) -> list[str]: missing: list[str] = [] for col in required_columns: if col not in df.columns: missing.append(col) continue if df[col].dropna().empty: missing.append(col) return missing def _source_disabled_reason(source: str) -> str | None: now = time.monotonic() with _SOURCE_HEALTH_LOCK: state = _SOURCE_HEALTH.get(source) if not state: return None disabled_until = float(state.get("disabled_until", 0.0)) if disabled_until <= now: if disabled_until: state["disabled_until"] = 0.0 return None return f"temporarily disabled for {disabled_until - now:.1f}s after repeated failures" def _record_source_success(source: str) -> None: with _SOURCE_HEALTH_LOCK: _SOURCE_HEALTH.pop(source, None) def _record_source_failure(source: str) -> None: now = time.monotonic() with _SOURCE_HEALTH_LOCK: state = _SOURCE_HEALTH.setdefault(source, {"failures": 0.0, "disabled_until": 0.0}) failures = float(state.get("failures", 0.0)) + 1.0 state["failures"] = failures state["total_failures"] = float(state.get("total_failures", 0.0)) + 1.0 state["last_failure_at"] = time.time() if failures >= _SOURCE_HEALTH_FAILURE_THRESHOLD: state["disabled_until"] = now + _SOURCE_HEALTH_COOLDOWN_SECONDS def snapshot_source_health_snapshot( sources: list[str] | tuple[str, ...] | None = None, ) -> dict[str, dict[str, float | bool]]: """Return in-process snapshot-source health without exposing credentials.""" now = time.monotonic() requested = tuple(sources or tuple(_SOURCE_HEALTH)) with _SOURCE_HEALTH_LOCK: snapshot: dict[str, dict[str, float | bool]] = {} for source in requested: state = dict(_SOURCE_HEALTH.get(source, {})) disabled_until = float(state.get("disabled_until", 0.0)) snapshot[source] = { "successes": 0.0, "failures": float(state.get("failures", 0.0)), "total_failures": float(state.get("total_failures", 0.0)), "last_rows": 0.0, "disabled": disabled_until > now, } return snapshot def _write_last_good_snapshot( path_like: str | Path | None, df: pd.DataFrame, ) -> None: if path_like is None: return path = Path(path_like) try: path.parent.mkdir(parents=True, exist_ok=True) payload = { "version": _SNAPSHOT_CACHE_VERSION, "created_at": datetime.now(timezone.utc).isoformat(), "metadata": { "snapshot_source": str(df.attrs.get("snapshot_source", "")), "row_count": int(len(df)), "columns": list(df.columns), }, "frame": json.loads( df.to_json(orient="split", date_format="iso", force_ascii=False) ), } tmp_path = path.with_name(f".{path.name}.{time.time_ns()}.tmp") tmp_path.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") tmp_path.replace(path) except Exception as exc: # noqa: BLE001 - live snapshot should remain usable. logger.warning("Failed to write last-good snapshot cache %s: %s", path, exc) def _read_last_good_snapshot( path_like: str | Path | None, *, required_columns: list[str], source_errors: list[str], max_age_hours: float | None = None, ) -> pd.DataFrame | None: if path_like is None: return None path = Path(path_like) try: stat = path.stat() except FileNotFoundError: return None try: payload = json.loads(path.read_text(encoding="utf-8")) if payload.get("version") != _SNAPSHOT_CACHE_VERSION: raise ValueError("unsupported cache version") frame = payload.get("frame") if not isinstance(frame, dict): raise ValueError("missing cached frame") columns = frame.get("columns") data = frame.get("data") if not isinstance(columns, list) or not isinstance(data, list): raise ValueError("malformed cached frame") stale_age_hours = _cache_stale_age_hours( stat.st_mtime, created_at=str(payload.get("created_at", "")), ) if max_age_hours is not None and max_age_hours >= 0 and stale_age_hours > max_age_hours: raise ValueError( f"cache stale_age_hours={stale_age_hours:.4g} exceeds max_age_hours={max_age_hours:.4g}" ) cached = pd.DataFrame(data, columns=columns) if cached.empty: raise ValueError("cached snapshot is empty") missing = _missing_required_columns(cached, required_columns) if missing: raise ValueError(f"missing required columns {','.join(missing)}") except Exception as exc: # noqa: BLE001 - invalid cache should not mask live errors. source_errors.append(f"last_good_cache: {exc}") return None cached.attrs["snapshot_source"] = "last_good_cache" cached.attrs["fallback_used"] = True cached.attrs["stale"] = True cached.attrs["stale_age_hours"] = stale_age_hours cached.attrs["source_errors"] = list(source_errors) metadata = payload.get("metadata") if isinstance(metadata, dict): cached.attrs["last_good_snapshot_source"] = str( metadata.get("snapshot_source", "") ) cached.attrs["last_good_created_at"] = str(payload.get("created_at", "")) logger.warning( "Using last-good snapshot cache %s after live source failures: %s", path, "; ".join(source_errors), ) return cached def _cache_stale_age_hours(mtime: float, *, created_at: str = "") -> float: modified = _parse_created_at(created_at) or datetime.fromtimestamp(mtime, tz=timezone.utc) age_hours = (datetime.now(timezone.utc) - modified).total_seconds() / 3600.0 return round(max(age_hours, 0.0), 4) def _parse_created_at(value: str) -> datetime | None: if not value: return None try: parsed = datetime.fromisoformat(value.replace("Z", "+00:00")) except ValueError: return None if parsed.tzinfo is None: return parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc) def _fetch_efinance() -> pd.DataFrame: """Fetch via efinance.""" import efinance as ef df = ef.stock.get_realtime_quotes() if df is None or df.empty: raise RuntimeError("efinance returned empty data") return _normalize(df, source="efinance") def _fetch_akshare_em() -> pd.DataFrame: """Fetch via akshare (eastmoney).""" import akshare as ak df = ak.stock_zh_a_spot_em() if df is None or df.empty: raise RuntimeError("akshare returned empty data") return _normalize(df, source="akshare_em") def _fetch_sina() -> pd.DataFrame: """Fetch A-share full-market snapshot directly from Sina Finance. Sina's market-center endpoint is a lightweight direct HTTP source with PE, PB, turnover and market-cap fields. It gives AlphaSift another non-wrapper, non-Eastmoney-first snapshot option before falling back to Eastmoney-heavy sources. """ url = "https://vip.stock.finance.sina.com.cn/quotes_service/api/json_v2.php/Market_Center.getHQNodeData" page = 1 page_size = 80 all_items = [] while True: resp = requests.get( url, params={ "page": page, "num": page_size, "sort": "symbol", "asc": 1, "node": "hs_a", "symbol": "", "_s_r_a": "page", }, headers={ "User-Agent": "Mozilla/5.0", "Referer": "https://vip.stock.finance.sina.com.cn/mkt/", }, timeout=15, ) resp.raise_for_status() items = resp.json() if not isinstance(items, list): raise RuntimeError("sina snapshot returned malformed data") if not items: break all_items.extend(items) if len(items) < page_size: break page += 1 if not all_items: raise RuntimeError("sina returned empty data") df = pd.DataFrame(all_items) for col in ("mktcap", "nmc"): if col in df.columns: # Sina exposes market caps in ten-thousand yuan; normalize to yuan. df[col] = pd.to_numeric(df[col], errors="coerce") * 10000 return _normalize(df, source="sina") def _fetch_em_datacenter() -> pd.DataFrame: """Fetch via eastmoney datacenter xuangu API. This works even on weekends (returns last trading day data). """ url = "https://data.eastmoney.com/dataapi/xuangu/list" all_items = [] page = 1 page_size = 500 while True: params = { "st": "SECURITY_CODE", "sr": "1", "ps": str(page_size), "p": str(page), "sty": "SECUCODE,SECURITY_CODE,SECURITY_NAME_ABBR,NEW_PRICE," "CHANGE_RATE,VOLUME_RATIO,DEAL_AMOUNT,TURNOVERRATE," "PE9,PBNEWMRQ,TOTAL_MARKET_CAP,CIRCULATION_MARKET_CAP", "filter": '(MARKET+in+("上交所主板","深交所主板","深交所创业板","上交所科创板","北交所"))', "source": "SELECT_SECURITIES", "client": "WEB", } headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Referer": "https://data.eastmoney.com/xuangu/", } resp = _eastmoney_get(url, params=params, headers=headers, timeout=30) data = resp.json() if not data.get("success"): raise RuntimeError(f"em_datacenter API error: {data.get('message', 'unknown')}") items = data["result"]["data"] all_items.extend(items) total_count = data["result"]["count"] if page * page_size >= total_count: break page += 1 if not all_items: raise RuntimeError("em_datacenter returned no data") df = pd.DataFrame(all_items) return _normalize(df, source="em_datacenter") def _eastmoney_get(url: str, **kwargs) -> requests.Response: """GET EastMoney endpoints through one throttled shared session. EastMoney endpoints are useful but more sensitive to bursty access than lightweight direct sources. Keeping all direct calls behind one session and a small process-wide interval reduces connection churn and accidental request bursts when snapshot fallbacks are exercised repeatedly. """ global _EM_LAST_REQUEST_AT, _EM_SESSION with _EM_LOCK: if _EM_SESSION is None: _EM_SESSION = requests.Session() elapsed = time.monotonic() - _EM_LAST_REQUEST_AT if elapsed < _EM_REQUEST_MIN_INTERVAL_SECONDS: time.sleep(_EM_REQUEST_MIN_INTERVAL_SECONDS - elapsed) response = _EM_SESSION.get(url, **kwargs) _EM_LAST_REQUEST_AT = time.monotonic() response.raise_for_status() return response def _fetch_tushare() -> pd.DataFrame: """Fetch latest available A-share snapshot via Tushare Pro. Tushare is not a real-time source here. It is used as a resilient fallback by joining the latest open trading day's daily quote and daily_basic data. """ token = ( os.getenv("TUSHARE_TOKEN", "").strip() or os.getenv("TUSHARE_API_TOKEN", "").strip() ) if not token: raise RuntimeError("tushare requires TUSHARE_TOKEN") import tushare as ts pro = ts.pro_api(token) _configure_tushare_client(pro, token=token) trade_date = _resolve_tushare_trade_date(pro) daily = pro.daily( trade_date=trade_date, fields="ts_code,trade_date,close,pct_chg,amount", ) daily_basic = pro.daily_basic( trade_date=trade_date, fields="ts_code,turnover_rate,volume_ratio,pe,pb,total_mv,circ_mv", ) stock_basic = pro.stock_basic( exchange="", list_status="L", fields="ts_code,symbol,name,industry", ) if daily is None or daily.empty: raise RuntimeError(f"tushare daily returned empty data for {trade_date}") if daily_basic is None or daily_basic.empty: raise RuntimeError(f"tushare daily_basic returned empty data for {trade_date}") return _prepare_tushare_snapshot(daily, daily_basic, stock_basic) def _configure_tushare_client(pro: object, *, token: str) -> None: try: setattr(pro, "_DataApi__token", token) except Exception: pass http_url = ( os.getenv("TUSHARE_API_URL", "").strip() or os.getenv("TUSHARE_HTTP_URL", "").strip() or _DEFAULT_TUSHARE_HTTP_URL ) try: setattr(pro, "_DataApi__http_url", http_url) except Exception: pass def _resolve_tushare_trade_date(pro) -> str: """Return the latest open trade date for Tushare requests.""" explicit = os.getenv("TUSHARE_TRADE_DATE", "").strip() if explicit: return explicit end = date.today() start = end - timedelta(days=30) calendar = pro.trade_cal( exchange="", start_date=start.strftime("%Y%m%d"), end_date=end.strftime("%Y%m%d"), is_open="1", fields="cal_date,is_open", ) if calendar is None or calendar.empty or "cal_date" not in calendar.columns: raise RuntimeError("tushare trade_cal returned no open trading days") return str(calendar["cal_date"].max()) def _prepare_tushare_snapshot( daily: pd.DataFrame, daily_basic: pd.DataFrame, stock_basic: pd.DataFrame | None, ) -> pd.DataFrame: """Join and unit-normalize Tushare tables into the common snapshot schema.""" merged = daily.merge(daily_basic, on="ts_code", how="left") if stock_basic is not None and not stock_basic.empty: merged = merged.merge(stock_basic, on="ts_code", how="left") if "symbol" not in merged.columns: merged["symbol"] = merged["ts_code"].astype(str).str.split(".").str[0] else: fallback_symbol = merged["ts_code"].astype(str).str.split(".").str[0] merged["symbol"] = merged["symbol"].fillna(fallback_symbol) # Tushare units: amount is thousand yuan; market caps are ten-thousand yuan. for col, multiplier in { "amount": 1000, "total_mv": 10000, "circ_mv": 10000, }.items(): if col in merged.columns: merged[col] = pd.to_numeric(merged[col], errors="coerce") * multiplier return _normalize(merged, source="tushare") def _normalize(df: pd.DataFrame, source: str) -> pd.DataFrame: """Normalize column names to a standard schema. Standard columns: code, name, price, change_pct, amount, total_mv, circ_mv, pe_ratio, pb_ratio, volume_ratio, turnover_rate """ df = df.copy() if source == "efinance": standard_cols = { "code": ["股票代码", "代码"], "name": ["股票名称", "名称"], "price": ["最新价"], "change_pct": ["涨跌幅"], "amount": ["成交额"], "total_mv": ["总市值"], "circ_mv": ["流通市值"], "pe_ratio": ["动态市盈率", "市盈率(动)"], "pb_ratio": ["市净率"], "volume_ratio": ["量比"], "turnover_rate": ["换手率"], "industry": ["行业", "所属行业", "行业板块"], "concepts": ["概念", "概念题材", "题材"], } elif source == "akshare_em": standard_cols = { "code": ["代码"], "name": ["名称"], "price": ["最新价"], "change_pct": ["涨跌幅"], "amount": ["成交额"], "total_mv": ["总市值"], "circ_mv": ["流通市值"], "pe_ratio": ["市盈率-动态", "市盈率(动)"], "pb_ratio": ["市净率"], "volume_ratio": ["量比"], "turnover_rate": ["换手率"], "industry": ["行业", "所属行业", "行业板块"], "concepts": ["概念", "概念题材", "题材"], } elif source == "sina": standard_cols = { "code": ["code"], "name": ["name"], "price": ["trade"], "change_pct": ["changepercent"], "amount": ["amount"], "total_mv": ["mktcap"], "circ_mv": ["nmc"], "pe_ratio": ["per"], "pb_ratio": ["pb"], "turnover_rate": ["turnoverratio"], } elif source == "em_datacenter": standard_cols = { "code": ["SECURITY_CODE"], "name": ["SECURITY_NAME_ABBR"], "price": ["NEW_PRICE"], "change_pct": ["CHANGE_RATE"], "amount": ["DEAL_AMOUNT"], "total_mv": ["TOTAL_MARKET_CAP"], "circ_mv": ["CIRCULATION_MARKET_CAP"], "pe_ratio": ["PE9"], "pb_ratio": ["PBNEWMRQ"], "volume_ratio": ["VOLUME_RATIO"], "turnover_rate": ["TURNOVERRATE"], "industry": ["INDUSTRY", "INDUSTRY_NAME", "BOARD_NAME"], "concepts": ["CONCEPT", "CONCEPT_NAME", "THEME_NAME"], } elif source == "tushare": standard_cols = { "code": ["symbol", "code"], "name": ["name"], "price": ["close"], "change_pct": ["pct_chg"], "amount": ["amount"], "total_mv": ["total_mv"], "circ_mv": ["circ_mv"], "pe_ratio": ["pe"], "pb_ratio": ["pb"], "volume_ratio": ["volume_ratio"], "turnover_rate": ["turnover_rate"], "industry": ["industry"], "concepts": ["concepts"], } else: standard_cols = {} df = _rename_standard_columns(df, standard_cols) # Coerce numeric columns numeric_cols = [ "price", "change_pct", "amount", "total_mv", "circ_mv", "pe_ratio", "pb_ratio", "volume_ratio", "turnover_rate", ] for col in numeric_cols: if col in df.columns: df[col] = pd.to_numeric(df[col], errors="coerce") # Drop rows without a valid price if "price" in df.columns: df = df.dropna(subset=["price"]) df = df[df["price"] > 0] df.attrs["snapshot_source"] = source return df def _rename_standard_columns( df: pd.DataFrame, standard_cols: dict[str, list[str]], ) -> pd.DataFrame: """Rename the first matching source column for each standard field.""" rename_map: dict[str, str] = {} for standard_name, candidates in standard_cols.items(): for candidate in candidates: if candidate in df.columns: rename_map[candidate] = standard_name break return df.rename(columns=rename_map)