Files
MoFin/venv/lib/python3.12/site-packages/alphasift/ranker.py
T
知微 fa45d8aa5f fix: 小果地址统一node122(兼容LAN+EasyTier)
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
  Privoxy对node122:18003返回500,直连正常
2026-06-30 02:56:35 +08:00

936 lines
33 KiB
Python

# -*- coding: utf-8 -*-
"""L2 LLM ranker — relative ranking of shortlisted candidates."""
import json
import logging
import os
from dataclasses import dataclass
from alphasift.models import Pick
from alphasift.normalize import (
bounded_float as _bounded_float,
normalize_code,
safe_string_list as _safe_string_list,
safe_text,
)
def _normalize_code(value: object) -> str:
# Candidate codes and LLM ranking JSON code fields are structured, so
# US tickers may pass through (see normalize_code docstring).
return normalize_code(value, allow_ticker=True)
logger = logging.getLogger(__name__)
_DEFAULT_RANKING_PROMPT_MAX_CHARS = 24_000
_PROMPT_TRIM_MARKER = "[prompt_trimmed]"
@dataclass
class RankingParseResult:
picks: list[Pick]
coverage: float
errors: list[str]
market_view: str = ""
selection_logic: str = ""
portfolio_risk: str = ""
@dataclass
class LLMRankingResult:
picks: list[Pick]
ranked: bool = False
market_view: str = ""
selection_logic: str = ""
portfolio_risk: str = ""
coverage: float = 0.0
errors: list[str] | None = None
def __post_init__(self) -> None:
if self.errors is None:
self.errors = []
def rank_candidates(
candidates: list[Pick],
ranking_hints: str,
llm_api_key: str,
llm_model: str,
llm_base_url: str = "",
*,
context: str = "",
rank_weight: float = 0.40,
max_retries: int = 1,
min_coverage: float = 0.60,
fallback_models: list[str] | None = None,
temperature: float = 0.2,
json_mode: bool = True,
silent: bool = True,
channels: list[dict[str, object]] | None = None,
config_path: str = "",
timeout_sec: float = 60.0,
max_prompt_chars: int | None = _DEFAULT_RANKING_PROMPT_MAX_CHARS,
max_tokens: int | None = 2048,
) -> list[Pick]:
"""Use LLM to re-rank candidates and add ranking_reason / risk_summary.
Falls back to screen_score order if LLM call fails.
"""
return rank_candidates_with_metadata(
candidates,
ranking_hints,
llm_api_key,
llm_model,
llm_base_url,
context=context,
rank_weight=rank_weight,
max_retries=max_retries,
min_coverage=min_coverage,
fallback_models=fallback_models,
temperature=temperature,
json_mode=json_mode,
silent=silent,
channels=channels,
config_path=config_path,
timeout_sec=timeout_sec,
max_prompt_chars=max_prompt_chars,
max_tokens=max_tokens,
).picks
def rank_candidates_with_metadata(
candidates: list[Pick],
ranking_hints: str,
llm_api_key: str,
llm_model: str,
llm_base_url: str = "",
*,
context: str = "",
rank_weight: float = 0.40,
max_retries: int = 1,
min_coverage: float = 0.60,
fallback_models: list[str] | None = None,
temperature: float = 0.2,
json_mode: bool = True,
silent: bool = True,
channels: list[dict[str, object]] | None = None,
config_path: str = "",
timeout_sec: float = 60.0,
max_prompt_chars: int | None = _DEFAULT_RANKING_PROMPT_MAX_CHARS,
degradation: list[str] | None = None,
max_tokens: int | None = 2048,
) -> LLMRankingResult:
"""Use LLM to re-rank candidates and return global research metadata."""
if not candidates:
return LLMRankingResult(picks=candidates)
prompt = _build_ranking_prompt(
candidates,
ranking_hints,
context,
max_chars=max_prompt_chars,
degradation=degradation,
)
try:
last_errors: list[str] = []
parsed: RankingParseResult | None = None
for attempt in range(max_retries + 1):
attempt_prompt = prompt
if attempt:
attempt_prompt += (
"\n\n上一次输出没有满足结构化覆盖率要求。"
"请重新返回严格 JSON,并覆盖尽可能多的候选代码。"
)
response = _call_llm(
attempt_prompt,
llm_api_key,
llm_model,
llm_base_url,
fallback_models=fallback_models or [],
temperature=temperature,
json_mode=json_mode,
silent=silent,
channels=channels or [],
config_path=config_path,
timeout_sec=timeout_sec,
max_tokens=max_tokens,
)
parsed = _parse_ranking_response_detail(response, candidates)
last_errors = parsed.errors
if parsed.coverage >= min_coverage:
break
if parsed is None or parsed.coverage < min_coverage:
raise ValueError(
"LLM ranking response coverage below threshold: "
f"{0 if parsed is None else parsed.coverage:.2f}; "
f"errors={last_errors}"
)
ranked = parsed.picks
for i, pick in enumerate(ranked):
pick.rank = i + 1
if pick.llm_score is None:
pick.llm_score = 100.0 - i * (100.0 / max(len(ranked), 1))
weight = min(max(rank_weight, 0.0), 1.0)
pick.final_score = pick.screen_score * (1 - weight) + (pick.llm_score or 0) * weight
ranked.sort(key=lambda item: item.final_score, reverse=True)
for i, pick in enumerate(ranked, start=1):
pick.rank = i
return LLMRankingResult(
picks=ranked,
ranked=True,
market_view=parsed.market_view,
selection_logic=parsed.selection_logic,
portfolio_risk=parsed.portfolio_risk,
coverage=parsed.coverage,
errors=parsed.errors,
)
except Exception as e:
logger.warning("LLM ranking failed, falling back to screen_score: %s", e)
return LLMRankingResult(picks=candidates, errors=[str(e)])
def _build_ranking_prompt(
candidates: list[Pick],
hints: str,
context: str = "",
*,
max_chars: int | None = _DEFAULT_RANKING_PROMPT_MAX_CHARS,
degradation: list[str] | None = None,
) -> str:
hints_text = hints.strip() or "无额外排序提示。"
context_text = context.strip() or "无额外上下文。只能基于候选池结构化数据和策略偏好判断。"
candidates_text = "\n".join(_format_candidate_for_prompt(p) for p in candidates)
prompt = _render_ranking_prompt(hints_text, context_text, candidates_text)
if max_chars is None or len(prompt) <= max_chars:
return prompt
return _build_bounded_ranking_prompt(
candidates,
hints_text,
context_text,
max_chars=max_chars,
degradation=degradation,
)
def _render_ranking_prompt(hints: str, context: str, candidates_text: str) -> str:
return f"""你是一个专业的股票研究员,任务是在“已经由代码硬筛过”的候选池内做相对排序。
你不能推荐候选池外股票,不能修改硬筛条件,不能给目标价或承诺收益。你的价值在于:
1. 结合策略偏好,对候选之间做跨股票比较;
2. 识别结构化数据暴露不出的潜在催化、风格匹配和风险点;
3. 对行业/概念热度和 DSA 补充的行情、基本面、新闻做语义归因,但不能把单日热度当作唯一买入理由;
4. 给出简短、可审计、可复核的排序理由。
## 排序依据
{hints}
## 市场/情报上下文
{context}
## 候选列表
{candidates_text}
## 输出要求
只返回 JSON,不要 Markdown,不要解释 JSON 以外的文本。
格式:
{{
"market_view": "一句话概括当前候选池和市场背景是否适合该策略",
"selection_logic": "说明本次排序最主要的2-3个判断维度",
"portfolio_risk": "说明最终名单可能存在的集中风险或共同风险",
"ranked": [
{{
"code": "股票代码",
"llm_score": 0-100,
"confidence": 0-1,
"sector": "行业/主题短标签,优先参考候选的 industry/concepts,并尽量统一,如 券商、银行、医药、AI算力",
"theme": "主要交易逻辑或主题",
"thesis": "该候选入选的核心投资假设",
"reason": "一句话排序理由",
"risk": "一句话主要风险",
"catalysts": ["潜在催化1", "潜在催化2"],
"risk_flags": ["风险标签1"],
"tags": ["价值", "趋势", "防守", "事件", "流动性"],
"style_fit": "与策略风格的匹配度说明",
"watch_items": ["后续应跟踪的数据或事件"],
"invalidators": ["会推翻该候选逻辑的观察点"]
}}
]
}}
"""
def _build_bounded_ranking_prompt(
candidates: list[Pick],
hints: str,
context: str,
*,
max_chars: int,
degradation: list[str] | None,
) -> str:
trimmed: list[str] = []
identity_text = "\n".join(_format_candidate_for_prompt(p, detail="identity") for p in candidates)
base_min = _render_ranking_prompt(
_truncate_prompt_text(hints, 900, "hints", trimmed),
"",
identity_text,
)
context_budget = max(int(max_chars) - len(base_min) - 80, 0)
context_text = _truncate_prompt_text(context, context_budget, "context", trimmed)
prompt_without_candidates = _render_ranking_prompt(
_truncate_prompt_text(hints, 900, "hints", trimmed),
context_text,
"",
)
candidate_budget = max(int(max_chars) - len(prompt_without_candidates), 0)
candidates_text = _fit_candidate_prompt_lines(candidates, candidate_budget, trimmed)
prompt = _render_ranking_prompt(
_truncate_prompt_text(hints, 900, "hints", trimmed),
context_text,
candidates_text,
)
if len(prompt) > max_chars:
overflow = len(prompt) - int(max_chars)
context_text = _truncate_prompt_text(
context_text,
max(len(context_text) - overflow - 80, 0),
"context",
trimmed,
)
prompt_without_candidates = _render_ranking_prompt(
_truncate_prompt_text(hints, 600, "hints", trimmed),
context_text,
"",
)
candidate_budget = max(int(max_chars) - len(prompt_without_candidates), 0)
candidates_text = _fit_candidate_prompt_lines(candidates, candidate_budget, trimmed)
prompt = _render_ranking_prompt(
_truncate_prompt_text(hints, 600, "hints", trimmed),
context_text,
candidates_text,
)
if len(prompt) > max_chars:
marker = f"\n...{_PROMPT_TRIM_MARKER}:hard_cap"
prompt = prompt[: max(int(max_chars) - len(marker), 0)].rstrip() + marker
trimmed.append("hard_cap")
if trimmed and degradation is not None:
labels = ",".join(dict.fromkeys(trimmed))
degradation.append(f"LLM ranking prompt truncated: trimmed={labels}")
return prompt[:max_chars]
def _format_candidate_for_prompt(p: Pick, *, detail: str = "full") -> str:
if detail == "identity":
return (
f"- {p.code} {p.name}: rank={p.rank}, "
f"screen_score={p.screen_score:.1f}, final_score={p.final_score:.1f}"
)
if detail == "compact":
return (
f"- {p.code} {p.name}: rank={p.rank}, price={p.price}, "
f"change_pct={p.change_pct}%, amount={p.amount:.0f}, "
f"screen_score={p.screen_score:.1f}, industry={p.industry or 'unknown'}, "
f"concepts={p.concepts or 'unknown'}, board_heat_score={p.board_heat_score}, "
f"signal_score={p.signal_score}, dsa_context={_format_dsa_context_for_prompt(p)}"
)
return (
f"- {p.code} {p.name}: price={p.price}, change_pct={p.change_pct}%, "
f"amount={p.amount:.0f}, turnover={p.turnover_rate}, volume_ratio={p.volume_ratio}, "
f"total_mv={p.total_mv}, PE={p.pe_ratio}, PB={p.pb_ratio}, "
f"industry={p.industry or 'unknown'}, concepts={p.concepts or 'unknown'}, "
f"industry_rank={p.industry_rank}, industry_change_pct={p.industry_change_pct}, "
f"board_heat_score={p.board_heat_score}, board_heat_summary={p.board_heat_summary or 'unknown'}, "
f"board_heat_latest_score={p.board_heat_latest_score}, "
f"board_heat_trend_score={p.board_heat_trend_score}, "
f"board_heat_persistence_score={p.board_heat_persistence_score}, "
f"board_heat_cooling_score={p.board_heat_cooling_score}, "
f"board_heat_observations={p.board_heat_observations}, "
f"board_heat_state={p.board_heat_state or 'unknown'}, "
f"change_60d={p.change_60d}, signal_score={p.signal_score}, "
f"macd={p.macd_status}, rsi={p.rsi_status}, "
f"breakout_20d_pct={p.breakout_20d_pct}, range_20d_pct={p.range_20d_pct}, "
f"volume_ratio_20d={p.volume_ratio_20d}, body_pct={p.body_pct}, "
f"pullback_to_ma20_pct={p.pullback_to_ma20_pct}, "
f"consolidation_days_20d={p.consolidation_days_20d}, "
f"screen_score={p.screen_score:.1f}, factor_scores={p.factor_scores}, "
f"dsa_context={_format_dsa_context_for_prompt(p)}"
)
def _fit_candidate_prompt_lines(
candidates: list[Pick],
budget: int,
trimmed: list[str],
) -> str:
marker = f"...{_PROMPT_TRIM_MARKER}:candidate_details"
full_text = "\n".join(_format_candidate_for_prompt(p) for p in candidates)
if len(full_text) <= budget:
return full_text
available = max(int(budget) - len(marker) - 1, 0)
if available <= 0:
trimmed.append("candidate_details")
return marker[:budget]
identity_lines = [_format_candidate_for_prompt(p, detail="identity") for p in candidates]
lines: list[str] = []
used = 0
omitted = 0
for line in identity_lines:
extra = len(line) + (1 if lines else 0)
if used + extra > available:
omitted += 1
continue
lines.append(line)
used += extra
if omitted == 0:
for idx, pick in enumerate(candidates):
for detail in ("full", "compact"):
replacement = _format_candidate_for_prompt(pick, detail=detail)
delta = len(replacement) - len(lines[idx])
if used + delta <= available:
lines[idx] = replacement
used += delta
break
if omitted:
trimmed.append("candidate_omitted")
lines.append(f"...{_PROMPT_TRIM_MARKER}:candidate_omitted={omitted}")
else:
trimmed.append("candidate_details")
lines.append(marker)
return "\n".join(lines)
def _truncate_prompt_text(text: str, limit: int, label: str, trimmed: list[str]) -> str:
text = text.strip()
if len(text) <= limit:
return text
marker = f"\n...{_PROMPT_TRIM_MARKER}:{label}"
trimmed.append(label)
if limit <= len(marker) + 8:
return marker[:limit]
return text[: max(limit - len(marker), 0)].rstrip() + marker
def _format_dsa_context_for_prompt(p: Pick) -> str:
parts: list[str] = []
if p.dsa_analysis_summary:
parts.append(f"summary={_truncate_text(p.dsa_analysis_summary, 240)}")
context = p.dsa_context if isinstance(p.dsa_context, dict) else {}
quote = context.get("quote") if isinstance(context.get("quote"), dict) else {}
if quote:
parts.append(
"quote="
f"price:{quote.get('price')},change_pct:{quote.get('change_pct')},"
f"amount:{quote.get('amount')}"
)
fundamentals = context.get("fundamentals") if isinstance(context.get("fundamentals"), dict) else {}
coverage = fundamentals.get("coverage") if isinstance(fundamentals.get("coverage"), dict) else {}
if coverage:
available = [
str(key)
for key, value in coverage.items()
if str(value).lower() in {"available", "partial"}
]
if available:
parts.append(f"fundamental_coverage={','.join(available[:5])}")
news_items = p.dsa_news
if not news_items:
news_payload = context.get("news") if isinstance(context.get("news"), dict) else {}
raw_results = news_payload.get("results") if isinstance(news_payload, dict) else []
if isinstance(raw_results, list):
news_items = [item for item in raw_results if isinstance(item, dict)]
titles = [
_truncate_text(str(item.get("title") or "").strip(), 80)
for item in news_items[:3]
if isinstance(item, dict) and item.get("title")
]
if titles:
parts.append(f"news_titles={';'.join(titles)}")
warnings = context.get("warnings") if isinstance(context.get("warnings"), list) else []
warning_text = [str(item) for item in warnings[:3] if item]
if warning_text:
parts.append(f"warnings={';'.join(warning_text)}")
return "; ".join(parts) if parts else "none"
def _truncate_text(value: str, limit: int) -> str:
text = " ".join(value.split())
if len(text) <= limit:
return text
return text[: max(limit - 1, 0)] + ""
def _call_llm(
prompt: str,
api_key: str,
model: str,
base_url: str,
*,
fallback_models: list[str] | None = None,
temperature: float = 0.2,
json_mode: bool = True,
silent: bool = True,
channels: list[dict[str, object]] | None = None,
config_path: str = "",
timeout_sec: float = 60.0,
max_tokens: int | None = 2048,
) -> str:
"""Call LLM via litellm with fallback models and channel configs."""
import litellm
if silent:
_silence_litellm_logs(litellm)
messages = [{"role": "user", "content": prompt}]
model_chain = _dedupe([model, *(fallback_models or [])])
last_error: Exception | None = None
if config_path:
router_result = _call_litellm_router(
litellm,
config_path=config_path,
model_chain=model_chain,
messages=messages,
temperature=temperature,
json_mode=json_mode,
timeout_sec=timeout_sec,
max_tokens=max_tokens,
)
if router_result is not None:
return router_result
for candidate_model in model_chain:
for kwargs in _build_litellm_attempts(
candidate_model,
api_key=api_key,
base_url=base_url,
channels=channels or [],
):
kwargs["messages"] = messages
kwargs["temperature"] = temperature
kwargs["timeout"] = timeout_sec
kwargs["num_retries"] = 0
if max_tokens is not None and int(max_tokens) > 0:
kwargs["max_tokens"] = int(max_tokens)
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
try:
response = litellm.completion(**kwargs)
return response.choices[0].message.content or ""
except Exception as exc:
last_error = exc
if _is_timeout_error(exc):
raise
if json_mode and "response_format" in kwargs and _is_json_mode_unsupported(exc):
# Some providers do not support JSON mode. Retry the same
# request without it before moving to fallback models. Do
# not do this for timeout/connection failures: a local
# OpenAI-compatible server may keep generating after the
# client timeout, so a blind retry can duplicate expensive
# work while the first request is still running.
retry_kwargs = dict(kwargs)
retry_kwargs.pop("response_format", None)
try:
response = litellm.completion(**retry_kwargs)
return response.choices[0].message.content or ""
except Exception as retry_exc:
last_error = retry_exc
continue
continue
if last_error is not None:
raise last_error
raise RuntimeError("No LLM model configured")
def _is_json_mode_unsupported(exc: Exception) -> bool:
"""Return True only for provider errors that clearly reject JSON mode."""
if _is_timeout_error(exc):
return False
text = str(exc).lower()
return (
"response_format" in text
or "json mode" in text
or "json_object" in text
or ("not support" in text and "json" in text)
or ("unsupported" in text and "json" in text)
)
def _is_timeout_error(exc: Exception) -> bool:
text = str(exc).lower()
timeout_markers = ("timeout", "timed out", "readtimeout", "apitimeout")
return any(marker in text for marker in timeout_markers)
def _parse_ranking_response(response: str, candidates: list[Pick]) -> list[Pick]:
"""Parse LLM response and reorder candidates."""
return _parse_ranking_response_detail(response, candidates).picks
def _parse_ranking_response_detail(response: str, candidates: list[Pick]) -> RankingParseResult:
"""Parse LLM response and return diagnostics."""
errors: list[str] = []
if not response or not response.strip():
errors.append("empty_response")
logger.warning("Empty LLM ranking response")
return RankingParseResult(candidates, 0.0, errors)
parsed = _extract_ranking_json(response, errors)
if parsed is None:
errors.append("no_json_found")
logger.warning("No JSON object or array found in LLM response")
return RankingParseResult(candidates, 0.0, errors)
if isinstance(parsed, dict):
items = parsed.get("ranked", [])
market_view = _safe_str(parsed.get("market_view"), max_len=260)
selection_logic = _safe_str(parsed.get("selection_logic"), max_len=360)
portfolio_risk = _safe_str(parsed.get("portfolio_risk"), max_len=360)
else:
items = parsed
market_view = ""
selection_logic = ""
portfolio_risk = ""
if not isinstance(items, list):
errors.append("ranked_not_list")
logger.warning("LLM ranking JSON has no ranked list")
return RankingParseResult(candidates, 0.0, errors)
code_to_pick = {_normalize_code(p.code): p for p in candidates if _normalize_code(p.code)}
ranked = []
matched = 0
seen_codes = set()
for item in items:
if not isinstance(item, dict):
errors.append("non_object_item")
continue
code = _normalize_code(item.get("code", ""))
if code in seen_codes:
errors.append(f"duplicate_code:{code}")
continue
seen_codes.add(code)
if code in code_to_pick:
pick = code_to_pick.pop(code)
pick.ranking_reason = _safe_str(item.get("reason"), max_len=180)
pick.risk_summary = _safe_str(item.get("risk"), max_len=180)
pick.llm_score = _bounded_float(item.get("llm_score"), low=0, high=100)
pick.llm_confidence = _bounded_float(item.get("confidence"), low=0, high=1)
pick.llm_sector = _safe_str(
item.get("sector") or item.get("industry") or item.get("sector_label"),
max_len=40,
)
pick.llm_theme = _safe_str(item.get("theme"), max_len=100)
pick.llm_thesis = _safe_str(item.get("thesis"), max_len=220)
pick.llm_catalysts = _safe_string_list(item.get("catalysts"))
pick.llm_invalidators = _safe_string_list(item.get("invalidators"))
pick.llm_style_fit = _safe_str(item.get("style_fit"), max_len=120)
pick.llm_watch_items = _safe_string_list(item.get("watch_items"))
pick.llm_risks = _safe_string_list(item.get("risk_flags"))
pick.llm_tags = _safe_string_list(item.get("tags"))
if pick.llm_sector:
pick.llm_tags = _dedupe([*pick.llm_tags, f"sector:{pick.llm_sector}"])
if pick.llm_theme:
pick.llm_tags = _dedupe([*pick.llm_tags, f"theme:{pick.llm_theme}"])
if pick.llm_style_fit:
pick.llm_tags = _dedupe([*pick.llm_tags, f"style_fit:{pick.llm_style_fit}"])
ranked.append(pick)
matched += 1
elif code:
errors.append(f"unknown_code:{code}")
# Append any candidates not mentioned by LLM
ranked.extend(code_to_pick.values())
coverage = matched / max(len(candidates), 1)
return RankingParseResult(
ranked,
coverage,
errors,
market_view=market_view,
selection_logic=selection_logic,
portfolio_risk=portfolio_risk,
)
def _safe_str(value, *, max_len: int) -> str:
return safe_text(value, max_len=max_len)
def _try_parse_json_lenient(raw: str, errors: list[str]):
"""Attempt to parse LLM JSON output, tolerating common formatting drift.
Steps applied in order: strict parse → strip trailing commas → balance
truncated brackets → return None if all fail. Any repair that succeeds is
recorded in ``errors`` for diagnostics.
"""
import re
try:
return json.loads(raw)
except json.JSONDecodeError as exc:
first_error = exc
# Repair 1: remove trailing commas before } or ].
repaired = re.sub(r",(\s*[}\]])", r"\1", raw)
if repaired != raw:
try:
result = json.loads(repaired)
errors.append("json_repaired:trailing_comma")
return result
except json.JSONDecodeError:
pass
# Repair 2: close unbalanced brackets caused by truncated output.
open_curly = repaired.count("{") - repaired.count("}")
open_square = repaired.count("[") - repaired.count("]")
if open_curly > 0 or open_square > 0:
patched = repaired + ("]" * max(open_square, 0)) + ("}" * max(open_curly, 0))
try:
result = json.loads(patched)
errors.append("json_repaired:closed_brackets")
return result
except json.JSONDecodeError:
pass
errors.append(f"json_decode_error:{first_error}")
logger.warning("Failed to parse LLM ranking JSON: %s", first_error)
return None
def _extract_ranking_json(response: str, errors: list[str]):
"""Extract ranking JSON from common LLM response shapes."""
for raw in _iter_json_payloads(response):
parsed = _try_parse_json_lenient(raw, errors)
if _looks_like_ranking_payload(parsed):
return parsed
partial = _extract_partial_ranking_array(response, errors)
if partial is not None:
return partial
return None
def _looks_like_ranking_payload(value: object) -> bool:
if isinstance(value, dict):
return isinstance(value.get("ranked"), list)
if isinstance(value, list):
return any(isinstance(item, dict) and "code" in item for item in value)
return False
def _iter_json_payloads(response: str):
"""Yield likely JSON payload substrings in priority order."""
import re
yielded: set[str] = set()
fence_pattern = re.compile(r"```(?:json|JSON)?\s*(.*?)```", re.DOTALL)
for match in fence_pattern.finditer(response):
payload = match.group(1).strip()
if payload and payload not in yielded:
yielded.add(payload)
yield payload
cleaned = fence_pattern.sub(lambda match: match.group(1), response)
for payload in _balanced_json_values(cleaned):
if payload not in yielded:
yielded.add(payload)
yield payload
def _balanced_json_values(text: str) -> list[str]:
"""Return balanced top-level JSON object/array substrings."""
values: list[str] = []
stack: list[str] = []
start: int | None = None
in_string = False
escaped = False
pairs = {"{": "}", "[": "]"}
for index, char in enumerate(text):
if in_string:
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
continue
if char in pairs:
if not stack:
start = index
stack.append(pairs[char])
continue
if char in ("}", "]") and stack:
expected = stack.pop()
if char != expected:
stack.clear()
start = None
continue
if not stack and start is not None:
values.append(text[start : index + 1])
start = None
return values
def _extract_partial_ranking_array(response: str, errors: list[str]):
"""Recover a ranked list from multiple JSON objects in a noisy response."""
items = []
item_errors: list[str] = []
for raw in _balanced_json_values(response):
parsed = _try_parse_json_lenient(raw, item_errors)
if isinstance(parsed, dict) and "code" in parsed:
items.append(parsed)
if not items:
return None
errors.append("json_repaired:partial_array")
return {"ranked": items}
def _call_litellm_router(
litellm,
*,
config_path: str,
model_chain: list[str],
messages: list[dict[str, str]],
temperature: float,
json_mode: bool,
timeout_sec: float,
max_tokens: int | None = 2048,
) -> str | None:
try:
import yaml
with open(config_path, "r", encoding="utf-8") as handle:
data = yaml.safe_load(handle) or {}
model_list = data.get("model_list")
if not isinstance(model_list, list) or not model_list:
return None
router = litellm.Router(model_list=model_list)
for model in model_chain:
kwargs = {
"model": model,
"messages": messages,
"temperature": temperature,
"timeout": timeout_sec,
"num_retries": 0,
}
if max_tokens is not None and int(max_tokens) > 0:
kwargs["max_tokens"] = int(max_tokens)
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
try:
response = router.completion(**kwargs)
return response.choices[0].message.content or ""
except Exception as exc:
if "response_format" in kwargs and _is_json_mode_unsupported(exc):
kwargs.pop("response_format", None)
response = router.completion(**kwargs)
return response.choices[0].message.content or ""
raise
except Exception as exc:
if _is_timeout_error(exc):
raise
logger.warning("LiteLLM router config failed, falling back to direct calls: %s", exc)
return None
def _silence_litellm_logs(litellm) -> None:
os.environ.setdefault("LITELLM_LOG", "ERROR")
try:
litellm.set_verbose = False
litellm.suppress_debug_info = True
except Exception:
pass
for logger_name in ("LiteLLM", "litellm"):
logging.getLogger(logger_name).setLevel(logging.WARNING)
def _build_litellm_attempts(
model: str,
*,
api_key: str,
base_url: str,
channels: list[dict[str, object]],
) -> list[dict[str, object]]:
attempts = []
for channel in channels:
if not _channel_matches_model(channel, model):
continue
api_keys = channel.get("api_keys", [])
if not isinstance(api_keys, list) or not api_keys:
api_keys = [api_key] if api_key else [""]
for channel_key in api_keys:
attempts.append(_completion_kwargs(
model,
api_key=str(channel_key or ""),
base_url=str(channel.get("base_url", "") or base_url or ""),
))
attempts.append(_completion_kwargs(model, api_key=api_key, base_url=base_url))
return _unique_attempts(attempts)
def _completion_kwargs(model: str, *, api_key: str, base_url: str) -> dict[str, object]:
kwargs: dict[str, object] = {"model": model}
if api_key:
kwargs["api_key"] = api_key
if base_url:
kwargs["api_base"] = base_url
return kwargs
def _channel_matches_model(channel: dict[str, object], model: str) -> bool:
models = channel.get("models", [])
if not isinstance(models, list) or not models:
return False
normalized = {_normalize_model_name(str(item), str(channel.get("protocol", "openai"))) for item in models}
return model in normalized or model.split("/", 1)[-1] in {item.split("/", 1)[-1] for item in normalized}
def _normalize_model_name(model: str, protocol: str) -> str:
model = model.strip()
if "/" in model:
return model
if protocol == "ollama":
return f"ollama/{model}"
if protocol == "gemini":
return f"gemini/{model}"
if protocol == "deepseek":
return f"deepseek/{model}"
return f"openai/{model}"
def _unique_attempts(items: list[dict[str, object]]) -> list[dict[str, object]]:
seen = set()
result = []
for item in items:
key = (item.get("model"), item.get("api_key"), item.get("api_base"))
if key not in seen:
seen.add(key)
result.append(item)
return result
def _dedupe(items: list[str]) -> list[str]:
seen = set()
result = []
for item in items:
key = str(item).strip()
if key and key not in seen:
seen.add(key)
result.append(key)
return result