Files
piano-plan/app/config.py
T

102 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Flask 应用配置
import os
import pathlib
# 检测是否在Docker容器中
IS_DOCKER = os.environ.get("FLASK_ENV") == "production"
def load_api_config(app_config=None):
"""加载API配置"""
import json
# 优先从 app_config 获取路径,否则使用基于 __file__ 的绝对路径
if app_config and app_config.get("API_CONFIG_FILE"):
config_file = pathlib.Path(app_config["API_CONFIG_FILE"])
else:
config_file = pathlib.Path(__file__).resolve().parent.parent / "config" / "api_config.json"
default_config = {
"provider": "volcengine",
"api_key": "",
"base_url": "https://ark.cn-beijing.volces.com/api/coding/v3",
"model": "doubao-seed-2.0-pro",
"temperature": 0.7,
"prompt_template": """你是一位专业的钢琴教师,请为学员生成一份简洁的个性化练习方案报告。
## 学员信息
- 姓名:{student_name}
- 微信昵称:{wechat_nickname}
- 每日练习时间:{practice_time}
## 问题详情
{problems}
请生成一份简洁的练习方案报告,包含:
1. 方案概述
2. 每日练习安排
3. 针对每个问题的核心练习建议
4. 重点注意事项
语言要专业、简洁、有鼓励性。使用Markdown格式。""",
}
if config_file.exists():
try:
with open(config_file, "r", encoding="utf-8") as f:
loaded_config = json.load(f)
# 如果 api_keys 映射存在,根据当前 provider 自动设置 api_key
if "api_keys" in loaded_config:
provider = loaded_config.get("provider", "volcengine")
loaded_config["api_key"] = loaded_config["api_keys"].get(provider, "")
return loaded_config
except:
pass
# 返回默认配置
return default_config
def save_api_config(config, app_config=None):
"""保存API配置"""
import json
# 优先从 app_config 获取路径,否则使用基于 __file__ 的绝对路径
if app_config and app_config.get("API_CONFIG_FILE"):
config_file = pathlib.Path(app_config["API_CONFIG_FILE"])
else:
config_file = pathlib.Path(__file__).resolve().parent.parent / "config" / "api_config.json"
config_file = pathlib.Path(config_file)
config_file.parent.mkdir(parents=True, exist_ok=True)
# 先读取现有配置,保留 api_keys 映射
existing_config = {}
if config_file.exists():
try:
with open(config_file, "r", encoding="utf-8") as f:
existing_config = json.load(f)
except:
pass
# 如果 config 中有 api_key,保存到 api_keys 映射(保留其他 provider 的 key
if "api_key" in config and config["api_key"]:
if "api_keys" not in existing_config:
existing_config["api_keys"] = {}
provider = config.get("provider", "volcengine")
existing_config["api_keys"][provider] = config["api_key"]
# 合并配置:保留 api_keys,更新其他字段
existing_config["provider"] = config.get("provider", existing_config.get("provider", "volcengine"))
existing_config["base_url"] = config.get("base_url", existing_config.get("base_url", ""))
existing_config["model"] = config.get("model", existing_config.get("model", ""))
existing_config["temperature"] = config.get("temperature", existing_config.get("temperature", 0.7))
existing_config["prompt_template"] = config.get("prompt_template", existing_config.get("prompt_template", ""))
# 当前选中的 provider 的 key 直接存储(用于兼容旧逻辑)
if config.get("api_key"):
existing_config["api_key"] = config["api_key"]
with open(config_file, "w", encoding="utf-8") as f:
json.dump(existing_config, f, ensure_ascii=False, indent=2)