增加Rate Limit功能

This commit is contained in:
LeonspaceX
2025-11-20 18:44:32 +08:00
parent e674c321b5
commit c4fc6c0084

View File

@@ -8,6 +8,7 @@ from flask import send_file
from werkzeug.utils import secure_filename
import os
import shutil
import hashlib
# === Flask 初始化 ===
app = Flask(__name__)
@@ -97,6 +98,7 @@ DEFAULT_ADMIN_TOKEN = "Sycamore_whisper"
DEFAULT_UPLOAD_FOLDER = "img"
DEFAULT_ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "webp"}
DEFAULT_MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
DEFAULT_RATE_LIMIT = 10 # 次/分钟0为无限制
CONFIG = {}
INIT = False
@@ -108,6 +110,7 @@ UPLOAD_FOLDER = DEFAULT_UPLOAD_FOLDER
ALLOWED_EXTENSIONS = set(DEFAULT_ALLOWED_EXTENSIONS)
MAX_FILE_SIZE = DEFAULT_MAX_FILE_SIZE
BANNED_KEYWORDS = list(DEFAULT_BANNED_KEYWORDS)
RATE_LIMIT = DEFAULT_RATE_LIMIT
DB_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instance', 'database.db')
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
@@ -120,12 +123,13 @@ def allowed_backup_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_BACKUP_EXTENSIONS
def apply_config_to_globals():
global ADMIN_TOKEN, UPLOAD_FOLDER, ALLOWED_EXTENSIONS, MAX_FILE_SIZE, IMG_FOLDER, BANNED_KEYWORDS
global ADMIN_TOKEN, UPLOAD_FOLDER, ALLOWED_EXTENSIONS, MAX_FILE_SIZE, IMG_FOLDER, BANNED_KEYWORDS, RATE_LIMIT
ADMIN_TOKEN = CONFIG.get('ADMIN_TOKEN', DEFAULT_ADMIN_TOKEN)
UPLOAD_FOLDER = CONFIG.get('UPLOAD_FOLDER', DEFAULT_UPLOAD_FOLDER)
ALLOWED_EXTENSIONS = set(CONFIG.get('ALLOWED_EXTENSIONS', DEFAULT_ALLOWED_EXTENSIONS))
MAX_FILE_SIZE = int(CONFIG.get('MAX_FILE_SIZE', DEFAULT_MAX_FILE_SIZE))
BANNED_KEYWORDS = list(CONFIG.get('BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS))
RATE_LIMIT = int(CONFIG.get('RATE_LIMIT', DEFAULT_RATE_LIMIT))
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
@@ -144,6 +148,7 @@ def load_config():
'ALLOWED_EXTENSIONS': set(getattr(cfg, 'ALLOWED_EXTENSIONS')),
'MAX_FILE_SIZE': int(getattr(cfg, 'MAX_FILE_SIZE')),
'BANNED_KEYWORDS': list(getattr(cfg, 'BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS)),
'RATE_LIMIT': int(getattr(cfg, 'RATE_LIMIT', DEFAULT_RATE_LIMIT)),
}
INIT = True
apply_config_to_globals()
@@ -169,7 +174,7 @@ def gate_uninitialized():
if not INIT:
return jsonify({"status": "Fail", "reason": "Uninitialized"}), 503
def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords=None):
def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords=None, rate_limit=DEFAULT_RATE_LIMIT):
# 归一化扩展名为小写且唯一
exts = sorted(set(str(e).strip().lower() for e in allowed_exts if str(e).strip()))
# 归一化敏感词为去空格的字符串列表
@@ -182,6 +187,7 @@ def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_ke
f"ALLOWED_EXTENSIONS = {repr(exts)}\n"
f"MAX_FILE_SIZE = {int(max_file_size)}\n"
f"BANNED_KEYWORDS = {repr(banned)}\n"
f"RATE_LIMIT = {int(rate_limit)}\n"
)
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.py')
with open(config_path, 'w', encoding='utf-8') as f:
@@ -193,7 +199,7 @@ def init_service():
if INIT:
return jsonify({"status": "Fail", "reason": "Already initialized"}), 403
data = request.get_json() or {}
required = ["ADMIN_TOKEN", "UPLOAD_FOLDER", "ALLOWED_EXTENSIONS", "MAX_FILE_SIZE"]
required = ["ADMIN_TOKEN", "UPLOAD_FOLDER", "ALLOWED_EXTENSIONS", "MAX_FILE_SIZE", "RATE_LIMIT"]
missing = [k for k in required if k not in data]
if missing:
return jsonify({"status": "Fail", "reason": f"Missing fields: {', '.join(missing)}"}), 400
@@ -214,6 +220,14 @@ def init_service():
except Exception:
return jsonify({"status": "Fail", "reason": "MAX_FILE_SIZE must be int"}), 400
# 必填的 RATE_LIMIT次/分钟0为无限制
try:
rate_limit = int(data["RATE_LIMIT"])
if rate_limit < 0:
return jsonify({"status": "Fail", "reason": "RATE_LIMIT must be >= 0"}), 400
except Exception:
return jsonify({"status": "Fail", "reason": "RATE_LIMIT must be int"}), 400
# 可选的 BANNED_KEYWORDS
bk = data.get("BANNED_KEYWORDS", DEFAULT_BANNED_KEYWORDS)
if isinstance(bk, str):
@@ -224,7 +238,7 @@ def init_service():
return jsonify({"status": "Fail", "reason": "BANNED_KEYWORDS must be list or comma string"}), 400
try:
write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords)
write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords, rate_limit)
load_config()
initialize_database()
try:
@@ -236,6 +250,53 @@ def init_service():
except Exception as e:
return jsonify({"status": "Fail", "reason": str(e)}), 500
# === 限流Rate Limit实现 ===
RATE_LIMIT_STORE = {}
def get_client_ip():
"""在反向代理后正确获取客户端 IP。
优先级CF-Connecting-IP > X-Forwarded-For(首个) > X-Real-IP > remote_addr
"""
ip = (
request.headers.get('CF-Connecting-IP')
or request.headers.get('X-Forwarded-For')
or request.headers.get('X-Real-IP')
or request.remote_addr
or '127.0.0.1'
)
if isinstance(ip, str):
# X-Forwarded-For 可能包含多个 IP取第一个
if ',' in ip:
ip = ip.split(',')[0].strip()
ip = ip.strip()
return ip
def rate_limit_exceeded() -> bool:
"""返回是否超过限流。0 表示无限制。窗口从首次请求开始,持续 60 秒。"""
if RATE_LIMIT == 0:
return False
ip = get_client_ip()
ip_hash = hashlib.sha256(ip.encode('utf-8')).hexdigest()
now = datetime.now(timezone.utc)
rec = RATE_LIMIT_STORE.get(ip_hash)
if rec is None:
RATE_LIMIT_STORE[ip_hash] = {'count': 1, 'start': now}
return False
# 窗口超过 60 秒则重置
if (now - rec['start']).total_seconds() >= 60:
rec['count'] = 1
rec['start'] = now
return False
# 累加计数并判断是否超过
rec['count'] += 1
return rec['count'] > RATE_LIMIT
def guard_rate_limit():
"""超过限流则返回 403否则返回 None。"""
if rate_limit_exceeded():
return jsonify({"status": "Fail", "reason": "Rate Limit Exceeded"}), 403
return None
# 在服务收到请求且已配置后,确保数据库表创建并加载审核状态
@app.before_request
@@ -285,6 +346,9 @@ def require_admin(func):
# === 路由 ===
@app.route('/post', methods=['POST'])
def submit_post():
guard = guard_rate_limit()
if guard is not None:
return guard
data = request.get_json()
if not data or "content" not in data:
return jsonify({"error": "Content not found"}), 400
@@ -313,6 +377,9 @@ def submit_post():
@app.route('/up', methods=['POST'])
def upvote():
guard = guard_rate_limit()
if guard is not None:
return guard
data = request.get_json()
if not data or "id" not in data:
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
@@ -328,6 +395,9 @@ def upvote():
@app.route('/down', methods=['POST'])
def downvote():
guard = guard_rate_limit()
if guard is not None:
return guard
data = request.get_json()
if not data or "id" not in data:
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
@@ -342,6 +412,9 @@ def downvote():
@app.route('/comment', methods=['POST'])
def post_comment():
guard = guard_rate_limit()
if guard is not None:
return guard
data = request.get_json()
required_fields = ["content", "submission_id", "parent_comment_id", "nickname"]
if not all(field in data for field in required_fields):
@@ -389,6 +462,9 @@ def random_string(length=5):
@app.route('/upload_pic', methods=['POST'])
def upload_pic():
guard = guard_rate_limit()
if guard is not None:
return guard
if 'file' not in request.files:
return jsonify({"status": "Fail", "url": None}), 400
@@ -427,6 +503,9 @@ def serve_image(filename):
@app.route('/report', methods=['POST'])
def submit_report():
guard = guard_rate_limit()
if guard is not None:
return guard
data = request.get_json()
if not data:
return jsonify({"status": "Fail", "reason": "No data provided"}), 400