diff --git a/app/routes/students.py b/app/routes/students.py index 8dba774..eceafca 100644 --- a/app/routes/students.py +++ b/app/routes/students.py @@ -380,11 +380,37 @@ def import_students(): @main_bp.route("/api/statistics/students") def get_student_statistics(): """获取学员统计数据""" - # 全体学员数量 - total_students = Student.query.count() + # 获取筛选参数 + mine_filter = request.args.get("mine", "false") == "true" + class_id_filter = request.args.get("class_id", type=int) + current_user_id = session.get("user_id") - # 全体问题记录 - problems = StudentProblem.query.all() + # 基础查询:班级 + class_query = Class.query.filter_by(active=True) + if mine_filter and current_user_id: + class_query = class_query.filter_by(teacher_id=current_user_id) + if class_id_filter: + class_query = class_query.filter_by(id=class_id_filter) + classes = class_query.all() + + # 获取班级IDs + class_ids = [c.id for c in classes] + + # 学员过滤 + student_query = Student.query + if class_ids: + student_query = student_query.filter(Student.class_id.in_(class_ids)) + else: + student_query = student_query.filter(Student.id == -1) # 无班级时返回空 + students = student_query.all() + student_ids = [s.id for s in students] + + # 全体学员数量 + total_students = len(students) + + # 问题记录过滤 + problem_query = StudentProblem.query.filter(StudentProblem.student_id.in_(student_ids)) if student_ids else StudentProblem.query.filter(StudentProblem.id == -1) + problems = problem_query.all() # 问题级别分布(来自 StudentProblem.level) levels = ["启蒙", "入门", "进阶", "熟练", "精通"] @@ -400,7 +426,6 @@ def get_student_statistics(): severity_dist[p.severity] += 1 # 各班级学员数量 - classes = Class.query.filter_by(active=True).all() class_student_count = [] for c in classes: class_student_count.append({ diff --git a/app/templates/statistics.html b/app/templates/statistics.html index d700e81..c91be16 100644 --- a/app/templates/statistics.html +++ b/app/templates/statistics.html @@ -6,6 +6,20 @@