Files
piano-plan/app/config.py
T

163 lines
5.5 KiB
Python

# Flask 应用配置
import os
import pathlib
# 检测是否在Docker容器中
IS_DOCKER = os.environ.get("FLASK_ENV") == "production"
# 默认提供商列表
DEFAULT_PROVIDERS = {
"minimax": {
"name": "MiniMax",
"endpoint": "https://api.minimaxi.com/anthropic/v1",
"models": ["MiniMax-M2.7-highspeed"],
},
"volcengine": {
"name": "火山引擎",
"endpoint": "https://ark.cn-beijing.volces.com/api/coding/v3",
"models": ["doubao-seed-2.0-pro", "doubao-seed-code", "doubao-seed-2.0-lite"],
},
"deepseek": {
"name": "DeepSeek",
"endpoint": "https://api.deepseek.com/v1",
"models": ["deepseek-chat"],
},
"openai": {
"name": "OpenAI",
"endpoint": "https://api.openai.com/v1",
"models": ["gpt-4o-mini", "gpt-4o"],
},
"openrouter": {
"name": "OpenRouter",
"endpoint": "https://openrouter.ai/api/v1",
"models": ["anthropic/claude-3-haiku"],
},
"opencodego": {
"name": "OpenCode Go",
"endpoint": "https://opencode.ai/zen/go/v1",
"models": ["deepseek-v4-pro", "deepseek-v4-flash", "qwen3.6-plus"],
},
}
# 兼容旧配置:从 default_model 迁移到 models
def _migrate_provider_models(providers):
"""确保每个 provider 都有 models 数组"""
for pid, pdata in list(providers.items()):
if "models" not in pdata:
if "default_model" in pdata and pdata["default_model"]:
pdata["models"] = [pdata.pop("default_model")]
else:
pdata["models"] = []
return providers
DEFAULT_PROMPT_TEMPLATE = """你是一位专业的钢琴教师,请为学员生成一份简洁的个性化练习方案报告。
## 学员信息
- 姓名:{student_name}
- 微信昵称:{wechat_nickname}
- 每日练习时间:{practice_time}
## 问题详情
{problems}
请生成一份简洁的练习方案报告,包含:
1. 方案概述
2. 每日练习安排
3. 针对每个问题的核心练习建议
4. 重点注意事项
语言要专业、简洁、有鼓励性。使用Markdown格式。"""
def _get_config_path(app_config=None):
if app_config and app_config.get("API_CONFIG_FILE"):
return pathlib.Path(app_config["API_CONFIG_FILE"])
return pathlib.Path(__file__).resolve().parent.parent / "config" / "api_config.json"
def _build_default_config():
return {
"provider": "volcengine",
"providers": dict(DEFAULT_PROVIDERS),
"api_key": "",
"base_url": "https://ark.cn-beijing.volces.com/api/coding/v3",
"model": "doubao-seed-2.0-pro",
"temperature": 0.7,
"watermark_text": "",
"prompt_template": DEFAULT_PROMPT_TEMPLATE,
}
def load_api_config(app_config=None):
"""加载API配置"""
import json
config_file = _get_config_path(app_config)
default_config = _build_default_config()
if config_file.exists():
try:
with open(config_file, "r", encoding="utf-8") as f:
loaded_config = json.load(f)
# 确保 providers 字段存在(兼容旧配置)
if "providers" not in loaded_config:
loaded_config["providers"] = dict(DEFAULT_PROVIDERS)
# 迁移旧 default_model → models
loaded_config["providers"] = _migrate_provider_models(loaded_config["providers"])
# 合并缺失的默认 provider
for pid, pdata in DEFAULT_PROVIDERS.items():
if pid not in loaded_config["providers"]:
loaded_config["providers"][pid] = pdata
# 如果 api_keys 映射存在,根据当前 provider 自动设置 api_key
if "api_keys" in loaded_config:
provider = loaded_config.get("provider", "volcengine")
loaded_config["api_key"] = loaded_config["api_keys"].get(provider, "")
return loaded_config
except Exception:
pass
return default_config
def save_api_config(config, app_config=None):
"""保存API配置"""
import json
config_file = _get_config_path(app_config)
config_file.parent.mkdir(parents=True, exist_ok=True)
existing_config = {}
if config_file.exists():
try:
with open(config_file, "r", encoding="utf-8") as f:
existing_config = json.load(f)
except Exception:
pass
# 确保 providers 存在
if "providers" not in existing_config:
existing_config["providers"] = dict(DEFAULT_PROVIDERS)
existing_config["providers"] = _migrate_provider_models(existing_config["providers"])
# 如果 config 中有 providers,完整替换
if "providers" in config:
existing_config["providers"] = config["providers"]
# 保存 api_key 到 api_keys 映射
if "api_key" in config and config["api_key"]:
if "api_keys" not in existing_config:
existing_config["api_keys"] = {}
provider = config.get("provider", "volcengine")
existing_config["api_keys"][provider] = config["api_key"]
# 更新当前选中的配置
for key in ["provider", "base_url", "model", "temperature", "prompt_template", "watermark_text"]:
if key in config:
existing_config[key] = config[key]
if config.get("api_key"):
existing_config["api_key"] = config["api_key"]
with open(config_file, "w", encoding="utf-8") as f:
json.dump(existing_config, f, ensure_ascii=False, indent=2)