# -*- coding: utf-8 -*- """L2 LLM ranker — relative ranking of shortlisted candidates.""" import json import logging import os from dataclasses import dataclass from alphasift.models import Pick from alphasift.normalize import ( bounded_float as _bounded_float, normalize_code, safe_string_list as _safe_string_list, safe_text, ) def _normalize_code(value: object) -> str: # Candidate codes and LLM ranking JSON code fields are structured, so # US tickers may pass through (see normalize_code docstring). return normalize_code(value, allow_ticker=True) logger = logging.getLogger(__name__) _DEFAULT_RANKING_PROMPT_MAX_CHARS = 24_000 _PROMPT_TRIM_MARKER = "[prompt_trimmed]" @dataclass class RankingParseResult: picks: list[Pick] coverage: float errors: list[str] market_view: str = "" selection_logic: str = "" portfolio_risk: str = "" @dataclass class LLMRankingResult: picks: list[Pick] ranked: bool = False market_view: str = "" selection_logic: str = "" portfolio_risk: str = "" coverage: float = 0.0 errors: list[str] | None = None def __post_init__(self) -> None: if self.errors is None: self.errors = [] def rank_candidates( candidates: list[Pick], ranking_hints: str, llm_api_key: str, llm_model: str, llm_base_url: str = "", *, context: str = "", rank_weight: float = 0.40, max_retries: int = 1, min_coverage: float = 0.60, fallback_models: list[str] | None = None, temperature: float = 0.2, json_mode: bool = True, silent: bool = True, channels: list[dict[str, object]] | None = None, config_path: str = "", timeout_sec: float = 60.0, max_prompt_chars: int | None = _DEFAULT_RANKING_PROMPT_MAX_CHARS, max_tokens: int | None = 2048, ) -> list[Pick]: """Use LLM to re-rank candidates and add ranking_reason / risk_summary. Falls back to screen_score order if LLM call fails. """ return rank_candidates_with_metadata( candidates, ranking_hints, llm_api_key, llm_model, llm_base_url, context=context, rank_weight=rank_weight, max_retries=max_retries, min_coverage=min_coverage, fallback_models=fallback_models, temperature=temperature, json_mode=json_mode, silent=silent, channels=channels, config_path=config_path, timeout_sec=timeout_sec, max_prompt_chars=max_prompt_chars, max_tokens=max_tokens, ).picks def rank_candidates_with_metadata( candidates: list[Pick], ranking_hints: str, llm_api_key: str, llm_model: str, llm_base_url: str = "", *, context: str = "", rank_weight: float = 0.40, max_retries: int = 1, min_coverage: float = 0.60, fallback_models: list[str] | None = None, temperature: float = 0.2, json_mode: bool = True, silent: bool = True, channels: list[dict[str, object]] | None = None, config_path: str = "", timeout_sec: float = 60.0, max_prompt_chars: int | None = _DEFAULT_RANKING_PROMPT_MAX_CHARS, degradation: list[str] | None = None, max_tokens: int | None = 2048, ) -> LLMRankingResult: """Use LLM to re-rank candidates and return global research metadata.""" if not candidates: return LLMRankingResult(picks=candidates) prompt = _build_ranking_prompt( candidates, ranking_hints, context, max_chars=max_prompt_chars, degradation=degradation, ) try: last_errors: list[str] = [] parsed: RankingParseResult | None = None for attempt in range(max_retries + 1): attempt_prompt = prompt if attempt: attempt_prompt += ( "\n\n上一次输出没有满足结构化覆盖率要求。" "请重新返回严格 JSON,并覆盖尽可能多的候选代码。" ) response = _call_llm( attempt_prompt, llm_api_key, llm_model, llm_base_url, fallback_models=fallback_models or [], temperature=temperature, json_mode=json_mode, silent=silent, channels=channels or [], config_path=config_path, timeout_sec=timeout_sec, max_tokens=max_tokens, ) parsed = _parse_ranking_response_detail(response, candidates) last_errors = parsed.errors if parsed.coverage >= min_coverage: break if parsed is None or parsed.coverage < min_coverage: raise ValueError( "LLM ranking response coverage below threshold: " f"{0 if parsed is None else parsed.coverage:.2f}; " f"errors={last_errors}" ) ranked = parsed.picks for i, pick in enumerate(ranked): pick.rank = i + 1 if pick.llm_score is None: pick.llm_score = 100.0 - i * (100.0 / max(len(ranked), 1)) weight = min(max(rank_weight, 0.0), 1.0) pick.final_score = pick.screen_score * (1 - weight) + (pick.llm_score or 0) * weight ranked.sort(key=lambda item: item.final_score, reverse=True) for i, pick in enumerate(ranked, start=1): pick.rank = i return LLMRankingResult( picks=ranked, ranked=True, market_view=parsed.market_view, selection_logic=parsed.selection_logic, portfolio_risk=parsed.portfolio_risk, coverage=parsed.coverage, errors=parsed.errors, ) except Exception as e: logger.warning("LLM ranking failed, falling back to screen_score: %s", e) return LLMRankingResult(picks=candidates, errors=[str(e)]) def _build_ranking_prompt( candidates: list[Pick], hints: str, context: str = "", *, max_chars: int | None = _DEFAULT_RANKING_PROMPT_MAX_CHARS, degradation: list[str] | None = None, ) -> str: hints_text = hints.strip() or "无额外排序提示。" context_text = context.strip() or "无额外上下文。只能基于候选池结构化数据和策略偏好判断。" candidates_text = "\n".join(_format_candidate_for_prompt(p) for p in candidates) prompt = _render_ranking_prompt(hints_text, context_text, candidates_text) if max_chars is None or len(prompt) <= max_chars: return prompt return _build_bounded_ranking_prompt( candidates, hints_text, context_text, max_chars=max_chars, degradation=degradation, ) def _render_ranking_prompt(hints: str, context: str, candidates_text: str) -> str: return f"""你是一个专业的股票研究员,任务是在“已经由代码硬筛过”的候选池内做相对排序。 你不能推荐候选池外股票,不能修改硬筛条件,不能给目标价或承诺收益。你的价值在于: 1. 结合策略偏好,对候选之间做跨股票比较; 2. 识别结构化数据暴露不出的潜在催化、风格匹配和风险点; 3. 对行业/概念热度和 DSA 补充的行情、基本面、新闻做语义归因,但不能把单日热度当作唯一买入理由; 4. 给出简短、可审计、可复核的排序理由。 ## 排序依据 {hints} ## 市场/情报上下文 {context} ## 候选列表 {candidates_text} ## 输出要求 只返回 JSON,不要 Markdown,不要解释 JSON 以外的文本。 格式: {{ "market_view": "一句话概括当前候选池和市场背景是否适合该策略", "selection_logic": "说明本次排序最主要的2-3个判断维度", "portfolio_risk": "说明最终名单可能存在的集中风险或共同风险", "ranked": [ {{ "code": "股票代码", "llm_score": 0-100, "confidence": 0-1, "sector": "行业/主题短标签,优先参考候选的 industry/concepts,并尽量统一,如 券商、银行、医药、AI算力", "theme": "主要交易逻辑或主题", "thesis": "该候选入选的核心投资假设", "reason": "一句话排序理由", "risk": "一句话主要风险", "catalysts": ["潜在催化1", "潜在催化2"], "risk_flags": ["风险标签1"], "tags": ["价值", "趋势", "防守", "事件", "流动性"], "style_fit": "与策略风格的匹配度说明", "watch_items": ["后续应跟踪的数据或事件"], "invalidators": ["会推翻该候选逻辑的观察点"] }} ] }} """ def _build_bounded_ranking_prompt( candidates: list[Pick], hints: str, context: str, *, max_chars: int, degradation: list[str] | None, ) -> str: trimmed: list[str] = [] identity_text = "\n".join(_format_candidate_for_prompt(p, detail="identity") for p in candidates) base_min = _render_ranking_prompt( _truncate_prompt_text(hints, 900, "hints", trimmed), "", identity_text, ) context_budget = max(int(max_chars) - len(base_min) - 80, 0) context_text = _truncate_prompt_text(context, context_budget, "context", trimmed) prompt_without_candidates = _render_ranking_prompt( _truncate_prompt_text(hints, 900, "hints", trimmed), context_text, "", ) candidate_budget = max(int(max_chars) - len(prompt_without_candidates), 0) candidates_text = _fit_candidate_prompt_lines(candidates, candidate_budget, trimmed) prompt = _render_ranking_prompt( _truncate_prompt_text(hints, 900, "hints", trimmed), context_text, candidates_text, ) if len(prompt) > max_chars: overflow = len(prompt) - int(max_chars) context_text = _truncate_prompt_text( context_text, max(len(context_text) - overflow - 80, 0), "context", trimmed, ) prompt_without_candidates = _render_ranking_prompt( _truncate_prompt_text(hints, 600, "hints", trimmed), context_text, "", ) candidate_budget = max(int(max_chars) - len(prompt_without_candidates), 0) candidates_text = _fit_candidate_prompt_lines(candidates, candidate_budget, trimmed) prompt = _render_ranking_prompt( _truncate_prompt_text(hints, 600, "hints", trimmed), context_text, candidates_text, ) if len(prompt) > max_chars: marker = f"\n...{_PROMPT_TRIM_MARKER}:hard_cap" prompt = prompt[: max(int(max_chars) - len(marker), 0)].rstrip() + marker trimmed.append("hard_cap") if trimmed and degradation is not None: labels = ",".join(dict.fromkeys(trimmed)) degradation.append(f"LLM ranking prompt truncated: trimmed={labels}") return prompt[:max_chars] def _format_candidate_for_prompt(p: Pick, *, detail: str = "full") -> str: if detail == "identity": return ( f"- {p.code} {p.name}: rank={p.rank}, " f"screen_score={p.screen_score:.1f}, final_score={p.final_score:.1f}" ) if detail == "compact": return ( f"- {p.code} {p.name}: rank={p.rank}, price={p.price}, " f"change_pct={p.change_pct}%, amount={p.amount:.0f}, " f"screen_score={p.screen_score:.1f}, industry={p.industry or 'unknown'}, " f"concepts={p.concepts or 'unknown'}, board_heat_score={p.board_heat_score}, " f"signal_score={p.signal_score}, dsa_context={_format_dsa_context_for_prompt(p)}" ) return ( f"- {p.code} {p.name}: price={p.price}, change_pct={p.change_pct}%, " f"amount={p.amount:.0f}, turnover={p.turnover_rate}, volume_ratio={p.volume_ratio}, " f"total_mv={p.total_mv}, PE={p.pe_ratio}, PB={p.pb_ratio}, " f"industry={p.industry or 'unknown'}, concepts={p.concepts or 'unknown'}, " f"industry_rank={p.industry_rank}, industry_change_pct={p.industry_change_pct}, " f"board_heat_score={p.board_heat_score}, board_heat_summary={p.board_heat_summary or 'unknown'}, " f"board_heat_latest_score={p.board_heat_latest_score}, " f"board_heat_trend_score={p.board_heat_trend_score}, " f"board_heat_persistence_score={p.board_heat_persistence_score}, " f"board_heat_cooling_score={p.board_heat_cooling_score}, " f"board_heat_observations={p.board_heat_observations}, " f"board_heat_state={p.board_heat_state or 'unknown'}, " f"change_60d={p.change_60d}, signal_score={p.signal_score}, " f"macd={p.macd_status}, rsi={p.rsi_status}, " f"breakout_20d_pct={p.breakout_20d_pct}, range_20d_pct={p.range_20d_pct}, " f"volume_ratio_20d={p.volume_ratio_20d}, body_pct={p.body_pct}, " f"pullback_to_ma20_pct={p.pullback_to_ma20_pct}, " f"consolidation_days_20d={p.consolidation_days_20d}, " f"screen_score={p.screen_score:.1f}, factor_scores={p.factor_scores}, " f"dsa_context={_format_dsa_context_for_prompt(p)}" ) def _fit_candidate_prompt_lines( candidates: list[Pick], budget: int, trimmed: list[str], ) -> str: marker = f"...{_PROMPT_TRIM_MARKER}:candidate_details" full_text = "\n".join(_format_candidate_for_prompt(p) for p in candidates) if len(full_text) <= budget: return full_text available = max(int(budget) - len(marker) - 1, 0) if available <= 0: trimmed.append("candidate_details") return marker[:budget] identity_lines = [_format_candidate_for_prompt(p, detail="identity") for p in candidates] lines: list[str] = [] used = 0 omitted = 0 for line in identity_lines: extra = len(line) + (1 if lines else 0) if used + extra > available: omitted += 1 continue lines.append(line) used += extra if omitted == 0: for idx, pick in enumerate(candidates): for detail in ("full", "compact"): replacement = _format_candidate_for_prompt(pick, detail=detail) delta = len(replacement) - len(lines[idx]) if used + delta <= available: lines[idx] = replacement used += delta break if omitted: trimmed.append("candidate_omitted") lines.append(f"...{_PROMPT_TRIM_MARKER}:candidate_omitted={omitted}") else: trimmed.append("candidate_details") lines.append(marker) return "\n".join(lines) def _truncate_prompt_text(text: str, limit: int, label: str, trimmed: list[str]) -> str: text = text.strip() if len(text) <= limit: return text marker = f"\n...{_PROMPT_TRIM_MARKER}:{label}" trimmed.append(label) if limit <= len(marker) + 8: return marker[:limit] return text[: max(limit - len(marker), 0)].rstrip() + marker def _format_dsa_context_for_prompt(p: Pick) -> str: parts: list[str] = [] if p.dsa_analysis_summary: parts.append(f"summary={_truncate_text(p.dsa_analysis_summary, 240)}") context = p.dsa_context if isinstance(p.dsa_context, dict) else {} quote = context.get("quote") if isinstance(context.get("quote"), dict) else {} if quote: parts.append( "quote=" f"price:{quote.get('price')},change_pct:{quote.get('change_pct')}," f"amount:{quote.get('amount')}" ) fundamentals = context.get("fundamentals") if isinstance(context.get("fundamentals"), dict) else {} coverage = fundamentals.get("coverage") if isinstance(fundamentals.get("coverage"), dict) else {} if coverage: available = [ str(key) for key, value in coverage.items() if str(value).lower() in {"available", "partial"} ] if available: parts.append(f"fundamental_coverage={','.join(available[:5])}") news_items = p.dsa_news if not news_items: news_payload = context.get("news") if isinstance(context.get("news"), dict) else {} raw_results = news_payload.get("results") if isinstance(news_payload, dict) else [] if isinstance(raw_results, list): news_items = [item for item in raw_results if isinstance(item, dict)] titles = [ _truncate_text(str(item.get("title") or "").strip(), 80) for item in news_items[:3] if isinstance(item, dict) and item.get("title") ] if titles: parts.append(f"news_titles={';'.join(titles)}") warnings = context.get("warnings") if isinstance(context.get("warnings"), list) else [] warning_text = [str(item) for item in warnings[:3] if item] if warning_text: parts.append(f"warnings={';'.join(warning_text)}") return "; ".join(parts) if parts else "none" def _truncate_text(value: str, limit: int) -> str: text = " ".join(value.split()) if len(text) <= limit: return text return text[: max(limit - 1, 0)] + "…" def _call_llm( prompt: str, api_key: str, model: str, base_url: str, *, fallback_models: list[str] | None = None, temperature: float = 0.2, json_mode: bool = True, silent: bool = True, channels: list[dict[str, object]] | None = None, config_path: str = "", timeout_sec: float = 60.0, max_tokens: int | None = 2048, ) -> str: """Call LLM via litellm with fallback models and channel configs.""" import litellm if silent: _silence_litellm_logs(litellm) messages = [{"role": "user", "content": prompt}] model_chain = _dedupe([model, *(fallback_models or [])]) last_error: Exception | None = None if config_path: router_result = _call_litellm_router( litellm, config_path=config_path, model_chain=model_chain, messages=messages, temperature=temperature, json_mode=json_mode, timeout_sec=timeout_sec, max_tokens=max_tokens, ) if router_result is not None: return router_result for candidate_model in model_chain: for kwargs in _build_litellm_attempts( candidate_model, api_key=api_key, base_url=base_url, channels=channels or [], ): kwargs["messages"] = messages kwargs["temperature"] = temperature kwargs["timeout"] = timeout_sec kwargs["num_retries"] = 0 if max_tokens is not None and int(max_tokens) > 0: kwargs["max_tokens"] = int(max_tokens) if json_mode: kwargs["response_format"] = {"type": "json_object"} try: response = litellm.completion(**kwargs) return response.choices[0].message.content or "" except Exception as exc: last_error = exc if _is_timeout_error(exc): raise if json_mode and "response_format" in kwargs and _is_json_mode_unsupported(exc): # Some providers do not support JSON mode. Retry the same # request without it before moving to fallback models. Do # not do this for timeout/connection failures: a local # OpenAI-compatible server may keep generating after the # client timeout, so a blind retry can duplicate expensive # work while the first request is still running. retry_kwargs = dict(kwargs) retry_kwargs.pop("response_format", None) try: response = litellm.completion(**retry_kwargs) return response.choices[0].message.content or "" except Exception as retry_exc: last_error = retry_exc continue continue if last_error is not None: raise last_error raise RuntimeError("No LLM model configured") def _is_json_mode_unsupported(exc: Exception) -> bool: """Return True only for provider errors that clearly reject JSON mode.""" if _is_timeout_error(exc): return False text = str(exc).lower() return ( "response_format" in text or "json mode" in text or "json_object" in text or ("not support" in text and "json" in text) or ("unsupported" in text and "json" in text) ) def _is_timeout_error(exc: Exception) -> bool: text = str(exc).lower() timeout_markers = ("timeout", "timed out", "readtimeout", "apitimeout") return any(marker in text for marker in timeout_markers) def _parse_ranking_response(response: str, candidates: list[Pick]) -> list[Pick]: """Parse LLM response and reorder candidates.""" return _parse_ranking_response_detail(response, candidates).picks def _parse_ranking_response_detail(response: str, candidates: list[Pick]) -> RankingParseResult: """Parse LLM response and return diagnostics.""" errors: list[str] = [] if not response or not response.strip(): errors.append("empty_response") logger.warning("Empty LLM ranking response") return RankingParseResult(candidates, 0.0, errors) parsed = _extract_ranking_json(response, errors) if parsed is None: errors.append("no_json_found") logger.warning("No JSON object or array found in LLM response") return RankingParseResult(candidates, 0.0, errors) if isinstance(parsed, dict): items = parsed.get("ranked", []) market_view = _safe_str(parsed.get("market_view"), max_len=260) selection_logic = _safe_str(parsed.get("selection_logic"), max_len=360) portfolio_risk = _safe_str(parsed.get("portfolio_risk"), max_len=360) else: items = parsed market_view = "" selection_logic = "" portfolio_risk = "" if not isinstance(items, list): errors.append("ranked_not_list") logger.warning("LLM ranking JSON has no ranked list") return RankingParseResult(candidates, 0.0, errors) code_to_pick = {_normalize_code(p.code): p for p in candidates if _normalize_code(p.code)} ranked = [] matched = 0 seen_codes = set() for item in items: if not isinstance(item, dict): errors.append("non_object_item") continue code = _normalize_code(item.get("code", "")) if code in seen_codes: errors.append(f"duplicate_code:{code}") continue seen_codes.add(code) if code in code_to_pick: pick = code_to_pick.pop(code) pick.ranking_reason = _safe_str(item.get("reason"), max_len=180) pick.risk_summary = _safe_str(item.get("risk"), max_len=180) pick.llm_score = _bounded_float(item.get("llm_score"), low=0, high=100) pick.llm_confidence = _bounded_float(item.get("confidence"), low=0, high=1) pick.llm_sector = _safe_str( item.get("sector") or item.get("industry") or item.get("sector_label"), max_len=40, ) pick.llm_theme = _safe_str(item.get("theme"), max_len=100) pick.llm_thesis = _safe_str(item.get("thesis"), max_len=220) pick.llm_catalysts = _safe_string_list(item.get("catalysts")) pick.llm_invalidators = _safe_string_list(item.get("invalidators")) pick.llm_style_fit = _safe_str(item.get("style_fit"), max_len=120) pick.llm_watch_items = _safe_string_list(item.get("watch_items")) pick.llm_risks = _safe_string_list(item.get("risk_flags")) pick.llm_tags = _safe_string_list(item.get("tags")) if pick.llm_sector: pick.llm_tags = _dedupe([*pick.llm_tags, f"sector:{pick.llm_sector}"]) if pick.llm_theme: pick.llm_tags = _dedupe([*pick.llm_tags, f"theme:{pick.llm_theme}"]) if pick.llm_style_fit: pick.llm_tags = _dedupe([*pick.llm_tags, f"style_fit:{pick.llm_style_fit}"]) ranked.append(pick) matched += 1 elif code: errors.append(f"unknown_code:{code}") # Append any candidates not mentioned by LLM ranked.extend(code_to_pick.values()) coverage = matched / max(len(candidates), 1) return RankingParseResult( ranked, coverage, errors, market_view=market_view, selection_logic=selection_logic, portfolio_risk=portfolio_risk, ) def _safe_str(value, *, max_len: int) -> str: return safe_text(value, max_len=max_len) def _try_parse_json_lenient(raw: str, errors: list[str]): """Attempt to parse LLM JSON output, tolerating common formatting drift. Steps applied in order: strict parse → strip trailing commas → balance truncated brackets → return None if all fail. Any repair that succeeds is recorded in ``errors`` for diagnostics. """ import re try: return json.loads(raw) except json.JSONDecodeError as exc: first_error = exc # Repair 1: remove trailing commas before } or ]. repaired = re.sub(r",(\s*[}\]])", r"\1", raw) if repaired != raw: try: result = json.loads(repaired) errors.append("json_repaired:trailing_comma") return result except json.JSONDecodeError: pass # Repair 2: close unbalanced brackets caused by truncated output. open_curly = repaired.count("{") - repaired.count("}") open_square = repaired.count("[") - repaired.count("]") if open_curly > 0 or open_square > 0: patched = repaired + ("]" * max(open_square, 0)) + ("}" * max(open_curly, 0)) try: result = json.loads(patched) errors.append("json_repaired:closed_brackets") return result except json.JSONDecodeError: pass errors.append(f"json_decode_error:{first_error}") logger.warning("Failed to parse LLM ranking JSON: %s", first_error) return None def _extract_ranking_json(response: str, errors: list[str]): """Extract ranking JSON from common LLM response shapes.""" for raw in _iter_json_payloads(response): parsed = _try_parse_json_lenient(raw, errors) if _looks_like_ranking_payload(parsed): return parsed partial = _extract_partial_ranking_array(response, errors) if partial is not None: return partial return None def _looks_like_ranking_payload(value: object) -> bool: if isinstance(value, dict): return isinstance(value.get("ranked"), list) if isinstance(value, list): return any(isinstance(item, dict) and "code" in item for item in value) return False def _iter_json_payloads(response: str): """Yield likely JSON payload substrings in priority order.""" import re yielded: set[str] = set() fence_pattern = re.compile(r"```(?:json|JSON)?\s*(.*?)```", re.DOTALL) for match in fence_pattern.finditer(response): payload = match.group(1).strip() if payload and payload not in yielded: yielded.add(payload) yield payload cleaned = fence_pattern.sub(lambda match: match.group(1), response) for payload in _balanced_json_values(cleaned): if payload not in yielded: yielded.add(payload) yield payload def _balanced_json_values(text: str) -> list[str]: """Return balanced top-level JSON object/array substrings.""" values: list[str] = [] stack: list[str] = [] start: int | None = None in_string = False escaped = False pairs = {"{": "}", "[": "]"} for index, char in enumerate(text): if in_string: if escaped: escaped = False elif char == "\\": escaped = True elif char == '"': in_string = False continue if char == '"': in_string = True continue if char in pairs: if not stack: start = index stack.append(pairs[char]) continue if char in ("}", "]") and stack: expected = stack.pop() if char != expected: stack.clear() start = None continue if not stack and start is not None: values.append(text[start : index + 1]) start = None return values def _extract_partial_ranking_array(response: str, errors: list[str]): """Recover a ranked list from multiple JSON objects in a noisy response.""" items = [] item_errors: list[str] = [] for raw in _balanced_json_values(response): parsed = _try_parse_json_lenient(raw, item_errors) if isinstance(parsed, dict) and "code" in parsed: items.append(parsed) if not items: return None errors.append("json_repaired:partial_array") return {"ranked": items} def _call_litellm_router( litellm, *, config_path: str, model_chain: list[str], messages: list[dict[str, str]], temperature: float, json_mode: bool, timeout_sec: float, max_tokens: int | None = 2048, ) -> str | None: try: import yaml with open(config_path, "r", encoding="utf-8") as handle: data = yaml.safe_load(handle) or {} model_list = data.get("model_list") if not isinstance(model_list, list) or not model_list: return None router = litellm.Router(model_list=model_list) for model in model_chain: kwargs = { "model": model, "messages": messages, "temperature": temperature, "timeout": timeout_sec, "num_retries": 0, } if max_tokens is not None and int(max_tokens) > 0: kwargs["max_tokens"] = int(max_tokens) if json_mode: kwargs["response_format"] = {"type": "json_object"} try: response = router.completion(**kwargs) return response.choices[0].message.content or "" except Exception as exc: if "response_format" in kwargs and _is_json_mode_unsupported(exc): kwargs.pop("response_format", None) response = router.completion(**kwargs) return response.choices[0].message.content or "" raise except Exception as exc: if _is_timeout_error(exc): raise logger.warning("LiteLLM router config failed, falling back to direct calls: %s", exc) return None def _silence_litellm_logs(litellm) -> None: os.environ.setdefault("LITELLM_LOG", "ERROR") try: litellm.set_verbose = False litellm.suppress_debug_info = True except Exception: pass for logger_name in ("LiteLLM", "litellm"): logging.getLogger(logger_name).setLevel(logging.WARNING) def _build_litellm_attempts( model: str, *, api_key: str, base_url: str, channels: list[dict[str, object]], ) -> list[dict[str, object]]: attempts = [] for channel in channels: if not _channel_matches_model(channel, model): continue api_keys = channel.get("api_keys", []) if not isinstance(api_keys, list) or not api_keys: api_keys = [api_key] if api_key else [""] for channel_key in api_keys: attempts.append(_completion_kwargs( model, api_key=str(channel_key or ""), base_url=str(channel.get("base_url", "") or base_url or ""), )) attempts.append(_completion_kwargs(model, api_key=api_key, base_url=base_url)) return _unique_attempts(attempts) def _completion_kwargs(model: str, *, api_key: str, base_url: str) -> dict[str, object]: kwargs: dict[str, object] = {"model": model} if api_key: kwargs["api_key"] = api_key if base_url: kwargs["api_base"] = base_url return kwargs def _channel_matches_model(channel: dict[str, object], model: str) -> bool: models = channel.get("models", []) if not isinstance(models, list) or not models: return False normalized = {_normalize_model_name(str(item), str(channel.get("protocol", "openai"))) for item in models} return model in normalized or model.split("/", 1)[-1] in {item.split("/", 1)[-1] for item in normalized} def _normalize_model_name(model: str, protocol: str) -> str: model = model.strip() if "/" in model: return model if protocol == "ollama": return f"ollama/{model}" if protocol == "gemini": return f"gemini/{model}" if protocol == "deepseek": return f"deepseek/{model}" return f"openai/{model}" def _unique_attempts(items: list[dict[str, object]]) -> list[dict[str, object]]: seen = set() result = [] for item in items: key = (item.get("model"), item.get("api_key"), item.get("api_base")) if key not in seen: seen.add(key) result.append(item) return result def _dedupe(items: list[str]) -> list[str]: seen = set() result = [] for item in items: key = str(item).strip() if key and key not in seen: seen.add(key) result.append(key) return result