From 0b1335fbadc1384eb43b92b78beebd948220199d Mon Sep 17 00:00:00 2001 From: LeonspaceX Date: Tue, 16 Dec 2025 22:04:06 +0800 Subject: [PATCH] =?UTF-8?q?Feat:=E5=A2=9E=E5=8A=A0SSE=E6=96=B0=E5=B8=96?= =?UTF-8?q?=E6=8E=A8=E9=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_server.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/api_server.py b/api_server.py index 7d89ea9..bbb6f85 100644 --- a/api_server.py +++ b/api_server.py @@ -1,4 +1,4 @@ -from flask import Flask, request, jsonify, abort +from flask import Flask, request, jsonify, abort, Response from flask_sqlalchemy import SQLAlchemy from datetime import datetime, timezone from flask_cors import CORS @@ -9,6 +9,9 @@ from werkzeug.utils import secure_filename import os import shutil import hashlib +import queue +import time +import threading # === Flask 初始化 === app = Flask(__name__) @@ -135,6 +138,19 @@ CONFIG = {} INIT = False NEED_AUDIT = False +# === SSE 相关变量 === +sse_clients = [] # 存储所有 SSE 客户端的队列列表 +sse_lock = threading.Lock() # 保护 sse_clients 的线程锁 + +def notify_new_post(): + """通知所有 SSE 客户端有新的审核通过的投稿""" + with sse_lock: + for client_queue in sse_clients: + try: + client_queue.put("new_post") + except Exception: + pass + # 运行时使用的变量,初始为默认值 ADMIN_TOKEN_HASH = DEFAULT_ADMIN_TOKEN_HASH UPLOAD_FOLDER = DEFAULT_UPLOAD_FOLDER @@ -400,6 +416,36 @@ def require_admin(func): return wrapper # === 路由 === +@app.route('/stream', methods=['GET']) +def stream(): + """SSE 端点:推送新投稿通知和心跳""" + def event_stream(): + # 为当前客户端创建一个队列 + client_queue = queue.Queue() + + # 将队列注册到全局客户端列表 + with sse_lock: + sse_clients.append(client_queue) + + try: + while True: + try: + # 尝试从队列获取消息,超时 15 秒(心跳间隔) + message = client_queue.get(timeout=15) + # 发送新投稿通知 + yield f"data: {message}\n\n" + except queue.Empty: + # 超时则发送心跳 + yield "data: heartbeat\n\n" + except GeneratorExit: + # 客户端断开连接时清理 + with sse_lock: + if client_queue in sse_clients: + sse_clients.remove(client_queue) + + return Response(event_stream(), mimetype='text/event-stream') + + @app.route('/post', methods=['POST']) def submit_post(): guard = guard_rate_limit() @@ -429,6 +475,10 @@ def submit_post(): db.session.add(submission) db.session.commit() + # 如果直接通过(关闭审核),通知 SSE 客户端 + if status == "Pass": + notify_new_post() + return jsonify({"id": submission.id, "status": submission.status}), 201 @app.route('/up', methods=['POST']) @@ -864,6 +914,8 @@ def admin_approve(): return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400 success, reason = admin_change_status(data["id"], "Pending", "Pass") if success: + # 审核通过,通知 SSE 客户端 + notify_new_post() return jsonify({"status": "OK"}) else: return jsonify({"status": "Fail", "reason": reason})