Feat:增加SSE新帖推送
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from flask import Flask, request, jsonify, abort
|
||||
from flask import Flask, request, jsonify, abort, Response
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from datetime import datetime, timezone
|
||||
from flask_cors import CORS
|
||||
@@ -9,6 +9,9 @@ from werkzeug.utils import secure_filename
|
||||
import os
|
||||
import shutil
|
||||
import hashlib
|
||||
import queue
|
||||
import time
|
||||
import threading
|
||||
|
||||
# === Flask 初始化 ===
|
||||
app = Flask(__name__)
|
||||
@@ -135,6 +138,19 @@ CONFIG = {}
|
||||
INIT = False
|
||||
NEED_AUDIT = False
|
||||
|
||||
# === SSE 相关变量 ===
|
||||
sse_clients = [] # 存储所有 SSE 客户端的队列列表
|
||||
sse_lock = threading.Lock() # 保护 sse_clients 的线程锁
|
||||
|
||||
def notify_new_post():
|
||||
"""通知所有 SSE 客户端有新的审核通过的投稿"""
|
||||
with sse_lock:
|
||||
for client_queue in sse_clients:
|
||||
try:
|
||||
client_queue.put("new_post")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 运行时使用的变量,初始为默认值
|
||||
ADMIN_TOKEN_HASH = DEFAULT_ADMIN_TOKEN_HASH
|
||||
UPLOAD_FOLDER = DEFAULT_UPLOAD_FOLDER
|
||||
@@ -400,6 +416,36 @@ def require_admin(func):
|
||||
return wrapper
|
||||
|
||||
# === 路由 ===
|
||||
@app.route('/stream', methods=['GET'])
|
||||
def stream():
|
||||
"""SSE 端点:推送新投稿通知和心跳"""
|
||||
def event_stream():
|
||||
# 为当前客户端创建一个队列
|
||||
client_queue = queue.Queue()
|
||||
|
||||
# 将队列注册到全局客户端列表
|
||||
with sse_lock:
|
||||
sse_clients.append(client_queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 尝试从队列获取消息,超时 15 秒(心跳间隔)
|
||||
message = client_queue.get(timeout=15)
|
||||
# 发送新投稿通知
|
||||
yield f"data: {message}\n\n"
|
||||
except queue.Empty:
|
||||
# 超时则发送心跳
|
||||
yield "data: heartbeat\n\n"
|
||||
except GeneratorExit:
|
||||
# 客户端断开连接时清理
|
||||
with sse_lock:
|
||||
if client_queue in sse_clients:
|
||||
sse_clients.remove(client_queue)
|
||||
|
||||
return Response(event_stream(), mimetype='text/event-stream')
|
||||
|
||||
|
||||
@app.route('/post', methods=['POST'])
|
||||
def submit_post():
|
||||
guard = guard_rate_limit()
|
||||
@@ -429,6 +475,10 @@ def submit_post():
|
||||
db.session.add(submission)
|
||||
db.session.commit()
|
||||
|
||||
# 如果直接通过(关闭审核),通知 SSE 客户端
|
||||
if status == "Pass":
|
||||
notify_new_post()
|
||||
|
||||
return jsonify({"id": submission.id, "status": submission.status}), 201
|
||||
|
||||
@app.route('/up', methods=['POST'])
|
||||
@@ -864,6 +914,8 @@ def admin_approve():
|
||||
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
||||
success, reason = admin_change_status(data["id"], "Pending", "Pass")
|
||||
if success:
|
||||
# 审核通过,通知 SSE 客户端
|
||||
notify_new_post()
|
||||
return jsonify({"status": "OK"})
|
||||
else:
|
||||
return jsonify({"status": "Fail", "reason": reason})
|
||||
|
||||
Reference in New Issue
Block a user