Initial commit: lesson-highlights generator
This commit is contained in:
@@ -0,0 +1,518 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pipeline - 核心业务逻辑
|
||||
|
||||
统一管理从视频提取到最终输出的完整流程
|
||||
UI和CLI共用同一套逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Optional, List, Dict, Any
|
||||
|
||||
from .video import extract_clip, merge_clips, burn_dual_subtitles
|
||||
from .subtitle import SubtitlePipeline
|
||||
from .llm import LLMClient
|
||||
from .corrections import apply_all_corrections, load_term_corrections_from_config
|
||||
from .utils import ensure_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
精华视频生成流水线
|
||||
|
||||
使用方法:
|
||||
# CLI模式
|
||||
pipeline = Pipeline(config)
|
||||
pipeline.run()
|
||||
|
||||
# UI模式 (带回调)
|
||||
pipeline = Pipeline(config, progress_callback=my_callback)
|
||||
pipeline.run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
progress_callback: Optional[Callable[[str, int, str], None]] = None,
|
||||
step_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""
|
||||
初始化流水线
|
||||
|
||||
Args:
|
||||
config: 配置字典,包含:
|
||||
- video_src: 视频路径
|
||||
- clips: [{title, start, end}, ...]
|
||||
- output_dir: 输出目录
|
||||
- api_key: LLM API密钥
|
||||
- api_host: LLM API地址
|
||||
- whisper_model_path: Whisper模型路径
|
||||
- term_corrections: 术语纠正字典
|
||||
- video_params: 视频参数
|
||||
progress_callback: 进度回调 (step, percent, message)
|
||||
step_callback: 步骤开始/完成回调 (step_name)
|
||||
"""
|
||||
self.config = config
|
||||
self.progress_callback = progress_callback if progress_callback else (lambda s, p, m: logger.info(f"[{s}] {p}%: {m}"))
|
||||
self.step_callback = step_callback if step_callback else (lambda s: None)
|
||||
|
||||
# 路径
|
||||
self.output_dir = config.get('output_dir', './output')
|
||||
self.inter_dir = ensure_dir(os.path.join(self.output_dir, 'intermediates'))
|
||||
self.subs_dir = ensure_dir(os.path.join(self.output_dir, 'subs'))
|
||||
|
||||
# 配置
|
||||
self.clips = config.get('clips', [])
|
||||
self.video_src = config.get('video_src')
|
||||
self.video_params = config.get('video_params', {})
|
||||
self.fade_duration = self.video_params.get('fade_duration', 1)
|
||||
|
||||
# LLM客户端 (延迟初始化)
|
||||
self._llm_client = None
|
||||
|
||||
# 字幕处理
|
||||
self._subtitle_pipeline = None
|
||||
|
||||
# 术语纠正
|
||||
self.term_corrections = load_term_corrections_from_config(config)
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if self._llm_client is None:
|
||||
self._llm_client = LLMClient(
|
||||
api_key=self.config.get('api_key'),
|
||||
api_host=self.config.get('api_host')
|
||||
)
|
||||
return self._llm_client
|
||||
|
||||
@property
|
||||
def subtitle_pipeline(self) -> SubtitlePipeline:
|
||||
if self._subtitle_pipeline is None:
|
||||
self._subtitle_pipeline = SubtitlePipeline(self.config, self.output_dir)
|
||||
return self._subtitle_pipeline
|
||||
|
||||
# ==================== 步骤方法 ====================
|
||||
|
||||
def step_extract(self) -> List[str]:
|
||||
"""
|
||||
Step 1: 提取视频片段
|
||||
|
||||
Returns:
|
||||
clip_paths: 提取的片段路径列表
|
||||
"""
|
||||
self.step_callback('extracting')
|
||||
self.progress_callback('extracting', 0, "开始提取片段...")
|
||||
|
||||
if not self.clips:
|
||||
raise ValueError("No clips configured")
|
||||
if not self.video_src or not os.path.exists(self.video_src):
|
||||
raise ValueError(f"Video file not found: {self.video_src}")
|
||||
|
||||
clip_paths = []
|
||||
total = len(self.clips)
|
||||
|
||||
for i, clip in enumerate(self.clips, 1):
|
||||
clip_path = os.path.join(self.inter_dir, f"clip{i}.mp4")
|
||||
fade_path = os.path.join(self.inter_dir, f"clip{i}_fade.mp4")
|
||||
|
||||
# 提取片段
|
||||
success = extract_clip(
|
||||
self.video_src,
|
||||
clip['start'],
|
||||
clip['end'],
|
||||
clip_path,
|
||||
fade_duration=0 # 先不添加淡出
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to extract clip {i}")
|
||||
continue
|
||||
|
||||
# 如果需要淡入淡出
|
||||
if self.fade_duration > 0:
|
||||
duration = clip['end'] - clip['start']
|
||||
fade_out_start = max(0, duration - self.fade_duration)
|
||||
|
||||
from .constants import FFMPEG_CMD
|
||||
from .utils import run_cmd
|
||||
|
||||
cmd = f'"{FFMPEG_CMD}" -y -i "{clip_path}" '
|
||||
cmd += f'-vf "fade=t=in:st=0:d={self.fade_duration},fade=t=out:st={fade_out_start}:d={self.fade_duration}" '
|
||||
cmd += f'-c:v libx264 -crf 20 -c:a aac -y "{fade_path}"'
|
||||
|
||||
if run_cmd(cmd):
|
||||
clip_paths.append(fade_path)
|
||||
else:
|
||||
clip_paths.append(clip_path)
|
||||
else:
|
||||
clip_paths.append(clip_path)
|
||||
|
||||
percent = int((i / total) * 100)
|
||||
self.progress_callback('extracting', percent, f"提取片段 {i}/{total}")
|
||||
|
||||
self.progress_callback('extracting', 100, f"提取完成,共 {len(clip_paths)} 个片段")
|
||||
self.step_callback('extracting')
|
||||
return clip_paths
|
||||
|
||||
def step_transcribe(self, clip_paths: List[str]) -> List[str]:
|
||||
"""
|
||||
Step 2: 转录片段
|
||||
|
||||
Args:
|
||||
clip_paths: 片段路径列表
|
||||
|
||||
Returns:
|
||||
json_paths: JSON转录文件路径列表
|
||||
"""
|
||||
self.step_callback('transcribing')
|
||||
self.progress_callback('transcribing', 0, "开始转录...")
|
||||
|
||||
# 延迟导入,避免没有faster-whisper时无法import
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
except ImportError:
|
||||
logger.warning("faster-whisper not available, skipping transcription")
|
||||
self.progress_callback('transcribing', 100, "faster-whisper未安装,跳过转录")
|
||||
self.step_callback('transcribing')
|
||||
return []
|
||||
|
||||
model_path = self.config.get('whisper_model_path')
|
||||
model_name = self.config.get('whisper_model', 'large')
|
||||
|
||||
# 加载模型
|
||||
self.progress_callback('transcribing', 5, "加载Whisper模型...")
|
||||
model = WhisperModel(model_path or model_name, compute_type="float16")
|
||||
|
||||
# 通过YAML配置hash检测配置是否改变,如果改变则删除所有旧JSON
|
||||
import hashlib
|
||||
config_str = str([(c['start'], c['end'], c.get('title', '')) for c in self.clips])
|
||||
config_hash = hashlib.md5(config_str.encode()).hexdigest()
|
||||
hash_file = os.path.join(self.inter_dir, '.config_hash')
|
||||
old_hash = None
|
||||
if os.path.exists(hash_file):
|
||||
with open(hash_file, 'r') as f:
|
||||
old_hash = f.read().strip()
|
||||
if old_hash != config_hash:
|
||||
# 配置变了,删除所有旧JSON
|
||||
for f in os.listdir(self.inter_dir):
|
||||
if f.startswith('clip') and f.endswith('.json'):
|
||||
os.remove(os.path.join(self.inter_dir, f))
|
||||
logger.info(f"清理旧JSON: {f} (配置已改变)")
|
||||
with open(hash_file, 'w') as f:
|
||||
f.write(config_hash)
|
||||
logger.info("配置已更新,清除所有旧JSON,重新转录")
|
||||
|
||||
json_paths = []
|
||||
total = len(clip_paths)
|
||||
|
||||
for i, clip_path in enumerate(clip_paths, 1):
|
||||
json_path = os.path.join(self.inter_dir, f"clip{i}.json")
|
||||
json_paths.append(json_path)
|
||||
|
||||
# 如果JSON已存在,跳过
|
||||
if os.path.exists(json_path):
|
||||
logger.info(f"Clip {i}: JSON exists, skipping")
|
||||
self.progress_callback('transcribing', int((i/total)*100), f"跳过片段 {i} (已存在)")
|
||||
continue
|
||||
|
||||
# 转录
|
||||
self.progress_callback('transcribing', int((i/total)*90), f"转录片段 {i}/{total}")
|
||||
|
||||
try:
|
||||
segments, _ = model.transcribe(clip_path, language='zh', beam_size=5)
|
||||
|
||||
# 保存转录结果
|
||||
segments_data = []
|
||||
for seg in segments:
|
||||
segments_data.append({
|
||||
'start': seg.start,
|
||||
'end': seg.end,
|
||||
'text': seg.text.strip()
|
||||
})
|
||||
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({'segments': segments_data}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"Transcribed clip {i}: {json_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to transcribe clip {i}: {e}")
|
||||
|
||||
# 不手动 del model —— CUDA 上下文在 Windows 下销毁时容易触发
|
||||
# Access Violation (0xC0000005),让进程自然释放即可。
|
||||
|
||||
self.progress_callback('transcribing', 100, "转录完成")
|
||||
self.step_callback('transcribing')
|
||||
return json_paths
|
||||
|
||||
def step_correct_titles(self, json_paths: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Step 3: LLM标题纠正
|
||||
|
||||
Args:
|
||||
json_paths: JSON文件路径列表
|
||||
|
||||
Returns:
|
||||
corrected_clips: 纠正后的片段配置列表
|
||||
"""
|
||||
self.step_callback('title_correcting')
|
||||
self.progress_callback('title_correcting', 0, "开始标题纠正...")
|
||||
|
||||
corrected_clips = []
|
||||
total = len(self.clips)
|
||||
|
||||
for i, (clip, json_path) in enumerate(zip(self.clips, json_paths), 1):
|
||||
original_title = clip.get('title', f'Clip {i}')
|
||||
|
||||
# 读取转录文本
|
||||
transcript_text = ''
|
||||
if json_path and os.path.exists(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
transcript_text = ' '.join(seg.get('text', '') for seg in data.get('segments', []))
|
||||
|
||||
# LLM纠正标题
|
||||
corrected_title = original_title
|
||||
if transcript_text and self.config.get('api_key'):
|
||||
try:
|
||||
corrected_title = self.llm_client.correct_title(
|
||||
transcript_text,
|
||||
original_title,
|
||||
[c.get('title', '') for c in self.clips]
|
||||
) or original_title
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM title correction failed for clip {i}: {e}")
|
||||
|
||||
corrected_clip = {
|
||||
'index': i - 1,
|
||||
'title': corrected_title,
|
||||
'original_title': original_title,
|
||||
'start': clip['start'],
|
||||
'end': clip['end'],
|
||||
}
|
||||
corrected_clips.append(corrected_clip)
|
||||
|
||||
percent = int((i / total) * 100)
|
||||
self.progress_callback('title_correcting', percent, f"纠正标题 {i}/{total}")
|
||||
|
||||
self.progress_callback('title_correcting', 100, "标题纠正完成")
|
||||
self.step_callback('title_correcting')
|
||||
return corrected_clips
|
||||
|
||||
def step_generate_subtitles(self, corrected_clips: List[Dict], json_paths: List[str]) -> tuple:
|
||||
"""
|
||||
Step 4: 生成字幕
|
||||
|
||||
Args:
|
||||
corrected_clips: 纠正后的片段配置
|
||||
json_paths: JSON文件路径列表
|
||||
|
||||
Returns:
|
||||
(title_path, content_path): 字幕文件路径
|
||||
"""
|
||||
self.step_callback('generating_subtitles')
|
||||
self.progress_callback('generating_subtitles', 0, "开始生成字幕...")
|
||||
|
||||
# 准备clip配置
|
||||
clip_configs = []
|
||||
valid_json_paths = []
|
||||
|
||||
for i, (clip, json_path) in enumerate(zip(corrected_clips, json_paths), 1):
|
||||
clip_config = {
|
||||
'index': i - 1,
|
||||
'start': clip['start'],
|
||||
'end': clip['end'],
|
||||
'title': clip.get('title', clip.get('original_title', '')),
|
||||
}
|
||||
clip_configs.append(clip_config)
|
||||
|
||||
if json_path and os.path.exists(json_path):
|
||||
valid_json_paths.append(json_path)
|
||||
else:
|
||||
valid_json_path = os.path.join(self.inter_dir, f"clip{i}.json")
|
||||
if os.path.exists(valid_json_path):
|
||||
valid_json_paths.append(valid_json_path)
|
||||
|
||||
if not valid_json_paths:
|
||||
raise ValueError("No valid JSON files for subtitle generation")
|
||||
|
||||
# 纠错函数
|
||||
def correct(text):
|
||||
return apply_all_corrections(text, self.term_corrections)
|
||||
|
||||
self.progress_callback('generating_subtitles', 50, "生成字幕轨道...")
|
||||
|
||||
# 生成字幕
|
||||
_, _, title_path, content_path = self.subtitle_pipeline.generate_from_clips(
|
||||
clip_configs,
|
||||
valid_json_paths,
|
||||
apply_corrections=correct
|
||||
)
|
||||
|
||||
self.progress_callback('generating_subtitles', 100, "字幕生成完成")
|
||||
self.step_callback('generating_subtitles')
|
||||
return title_path, content_path
|
||||
|
||||
def step_merge(self, clip_paths: List[str]) -> str:
|
||||
"""
|
||||
Step 5: 合并视频
|
||||
|
||||
Args:
|
||||
clip_paths: 片段路径列表
|
||||
|
||||
Returns:
|
||||
merged_path: 合并后的视频路径
|
||||
"""
|
||||
self.step_callback('merging')
|
||||
self.progress_callback('merging', 0, "开始合并视频...")
|
||||
|
||||
if not clip_paths:
|
||||
raise ValueError("No clips to merge")
|
||||
|
||||
merged_path = os.path.join(self.output_dir, "concat_merged.mp4")
|
||||
|
||||
success = merge_clips(clip_paths, merged_path, self.inter_dir)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Failed to merge clips")
|
||||
|
||||
self.progress_callback('merging', 100, f"合并完成: {merged_path}")
|
||||
self.step_callback('merging')
|
||||
return merged_path
|
||||
|
||||
def step_burn(self, merged_path: str, title_path: str, content_path: str) -> str:
|
||||
"""
|
||||
Step 6: 烧录字幕
|
||||
|
||||
Args:
|
||||
merged_path: 合并后的视频路径
|
||||
title_path: 标题字幕路径
|
||||
content_path: 正文字幕路径
|
||||
|
||||
Returns:
|
||||
final_path: 最终视频路径
|
||||
"""
|
||||
self.step_callback('burning')
|
||||
self.progress_callback('burning', 0, "开始烧录字幕...")
|
||||
|
||||
if not os.path.exists(merged_path):
|
||||
raise ValueError(f"Merged video not found: {merged_path}")
|
||||
|
||||
final_path = os.path.join(self.output_dir, "final.mp4")
|
||||
|
||||
video_params = self.config.get('video_params', {})
|
||||
|
||||
success = burn_dual_subtitles(
|
||||
merged_path,
|
||||
title_path,
|
||||
content_path,
|
||||
final_path,
|
||||
title_fontsize=video_params.get('title_fontsize', 90),
|
||||
title_color=video_params.get('title_color', 'FFFF00'),
|
||||
subtitle_fontsize=video_params.get('subtitle_fontsize', 24),
|
||||
subtitle_color=video_params.get('subtitle_color', 'FFFFFF')
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Failed to burn subtitles")
|
||||
|
||||
self.progress_callback('burning', 100, f"完成: {final_path}")
|
||||
self.step_callback('burning')
|
||||
return final_path
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
运行完整流水线
|
||||
|
||||
Returns:
|
||||
final_path: 最终视频路径
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 处理失败
|
||||
"""
|
||||
logger.info(f"Pipeline starting: {len(self.clips)} clips, output: {self.output_dir}")
|
||||
|
||||
# Step 1: 提取
|
||||
clip_paths = self.step_extract()
|
||||
if not clip_paths:
|
||||
raise RuntimeError("No clips extracted")
|
||||
|
||||
# Step 2: 转录
|
||||
json_paths = self.step_transcribe(clip_paths)
|
||||
|
||||
# Step 3: 标题纠正
|
||||
corrected_clips = self.step_correct_titles(json_paths)
|
||||
|
||||
# Step 4: 生成字幕
|
||||
title_path, content_path = self.step_generate_subtitles(corrected_clips, json_paths)
|
||||
|
||||
# Step 5: 合并
|
||||
merged_path = self.step_merge(clip_paths)
|
||||
|
||||
# Step 6: 烧录
|
||||
final_path = self.step_burn(merged_path, title_path, content_path)
|
||||
|
||||
logger.info(f"Pipeline completed: {final_path}")
|
||||
return final_path
|
||||
|
||||
def run_with_user_confirm(self, confirmed_titles: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
运行流水线,在标题纠正后等待用户确认
|
||||
|
||||
Args:
|
||||
confirmed_titles: 用户确认后的标题列表
|
||||
|
||||
Returns:
|
||||
final_path: 最终视频路径
|
||||
"""
|
||||
logger.info(f"Pipeline starting with user confirmation: {len(self.clips)} clips")
|
||||
|
||||
# Step 1-3: 同上
|
||||
clip_paths = self.step_extract()
|
||||
if not clip_paths:
|
||||
raise RuntimeError("No clips extracted")
|
||||
|
||||
json_paths = self.step_transcribe(clip_paths)
|
||||
corrected_clips = self.step_correct_titles(json_paths)
|
||||
|
||||
# 应用用户确认的标题
|
||||
for i, confirmed in enumerate(confirmed_titles):
|
||||
if i < len(corrected_clips):
|
||||
corrected_clips[i]['title'] = confirmed.get('title', corrected_clips[i]['title'])
|
||||
|
||||
# Step 4-6: 同上
|
||||
title_path, content_path = self.step_generate_subtitles(corrected_clips, json_paths)
|
||||
merged_path = self.step_merge(clip_paths)
|
||||
final_path = self.step_burn(merged_path, title_path, content_path)
|
||||
|
||||
logger.info(f"Pipeline completed: {final_path}")
|
||||
return final_path
|
||||
|
||||
|
||||
def create_pipeline_from_yaml(config_path: str, **kwargs) -> Pipeline:
|
||||
"""
|
||||
从YAML配置文件创建Pipeline
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
**kwargs: 额外配置参数
|
||||
|
||||
Returns:
|
||||
Pipeline实例
|
||||
"""
|
||||
import yaml
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 合并额外参数
|
||||
config.update(kwargs)
|
||||
|
||||
return Pipeline(config, **kwargs)
|
||||
Reference in New Issue
Block a user