增加Rate Limit功能
This commit is contained in:
@@ -8,6 +8,7 @@ from flask import send_file
|
||||
from werkzeug.utils import secure_filename
|
||||
import os
|
||||
import shutil
|
||||
import hashlib
|
||||
|
||||
# === Flask 初始化 ===
|
||||
app = Flask(__name__)
|
||||
@@ -97,6 +98,7 @@ DEFAULT_ADMIN_TOKEN = "Sycamore_whisper"
|
||||
DEFAULT_UPLOAD_FOLDER = "img"
|
||||
DEFAULT_ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "webp"}
|
||||
DEFAULT_MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||
DEFAULT_RATE_LIMIT = 10 # 次/分钟,0为无限制
|
||||
|
||||
CONFIG = {}
|
||||
INIT = False
|
||||
@@ -108,6 +110,7 @@ UPLOAD_FOLDER = DEFAULT_UPLOAD_FOLDER
|
||||
ALLOWED_EXTENSIONS = set(DEFAULT_ALLOWED_EXTENSIONS)
|
||||
MAX_FILE_SIZE = DEFAULT_MAX_FILE_SIZE
|
||||
BANNED_KEYWORDS = list(DEFAULT_BANNED_KEYWORDS)
|
||||
RATE_LIMIT = DEFAULT_RATE_LIMIT
|
||||
|
||||
DB_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instance', 'database.db')
|
||||
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
|
||||
@@ -120,12 +123,13 @@ def allowed_backup_file(filename):
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_BACKUP_EXTENSIONS
|
||||
|
||||
def apply_config_to_globals():
|
||||
global ADMIN_TOKEN, UPLOAD_FOLDER, ALLOWED_EXTENSIONS, MAX_FILE_SIZE, IMG_FOLDER, BANNED_KEYWORDS
|
||||
global ADMIN_TOKEN, UPLOAD_FOLDER, ALLOWED_EXTENSIONS, MAX_FILE_SIZE, IMG_FOLDER, BANNED_KEYWORDS, RATE_LIMIT
|
||||
ADMIN_TOKEN = CONFIG.get('ADMIN_TOKEN', DEFAULT_ADMIN_TOKEN)
|
||||
UPLOAD_FOLDER = CONFIG.get('UPLOAD_FOLDER', DEFAULT_UPLOAD_FOLDER)
|
||||
ALLOWED_EXTENSIONS = set(CONFIG.get('ALLOWED_EXTENSIONS', DEFAULT_ALLOWED_EXTENSIONS))
|
||||
MAX_FILE_SIZE = int(CONFIG.get('MAX_FILE_SIZE', DEFAULT_MAX_FILE_SIZE))
|
||||
BANNED_KEYWORDS = list(CONFIG.get('BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS))
|
||||
RATE_LIMIT = int(CONFIG.get('RATE_LIMIT', DEFAULT_RATE_LIMIT))
|
||||
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
|
||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
@@ -144,6 +148,7 @@ def load_config():
|
||||
'ALLOWED_EXTENSIONS': set(getattr(cfg, 'ALLOWED_EXTENSIONS')),
|
||||
'MAX_FILE_SIZE': int(getattr(cfg, 'MAX_FILE_SIZE')),
|
||||
'BANNED_KEYWORDS': list(getattr(cfg, 'BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS)),
|
||||
'RATE_LIMIT': int(getattr(cfg, 'RATE_LIMIT', DEFAULT_RATE_LIMIT)),
|
||||
}
|
||||
INIT = True
|
||||
apply_config_to_globals()
|
||||
@@ -169,7 +174,7 @@ def gate_uninitialized():
|
||||
if not INIT:
|
||||
return jsonify({"status": "Fail", "reason": "Uninitialized"}), 503
|
||||
|
||||
def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords=None):
|
||||
def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords=None, rate_limit=DEFAULT_RATE_LIMIT):
|
||||
# 归一化扩展名为小写且唯一
|
||||
exts = sorted(set(str(e).strip().lower() for e in allowed_exts if str(e).strip()))
|
||||
# 归一化敏感词为去空格的字符串列表
|
||||
@@ -182,6 +187,7 @@ def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_ke
|
||||
f"ALLOWED_EXTENSIONS = {repr(exts)}\n"
|
||||
f"MAX_FILE_SIZE = {int(max_file_size)}\n"
|
||||
f"BANNED_KEYWORDS = {repr(banned)}\n"
|
||||
f"RATE_LIMIT = {int(rate_limit)}\n"
|
||||
)
|
||||
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.py')
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
@@ -193,7 +199,7 @@ def init_service():
|
||||
if INIT:
|
||||
return jsonify({"status": "Fail", "reason": "Already initialized"}), 403
|
||||
data = request.get_json() or {}
|
||||
required = ["ADMIN_TOKEN", "UPLOAD_FOLDER", "ALLOWED_EXTENSIONS", "MAX_FILE_SIZE"]
|
||||
required = ["ADMIN_TOKEN", "UPLOAD_FOLDER", "ALLOWED_EXTENSIONS", "MAX_FILE_SIZE", "RATE_LIMIT"]
|
||||
missing = [k for k in required if k not in data]
|
||||
if missing:
|
||||
return jsonify({"status": "Fail", "reason": f"Missing fields: {', '.join(missing)}"}), 400
|
||||
@@ -214,6 +220,14 @@ def init_service():
|
||||
except Exception:
|
||||
return jsonify({"status": "Fail", "reason": "MAX_FILE_SIZE must be int"}), 400
|
||||
|
||||
# 必填的 RATE_LIMIT(次/分钟,0为无限制)
|
||||
try:
|
||||
rate_limit = int(data["RATE_LIMIT"])
|
||||
if rate_limit < 0:
|
||||
return jsonify({"status": "Fail", "reason": "RATE_LIMIT must be >= 0"}), 400
|
||||
except Exception:
|
||||
return jsonify({"status": "Fail", "reason": "RATE_LIMIT must be int"}), 400
|
||||
|
||||
# 可选的 BANNED_KEYWORDS
|
||||
bk = data.get("BANNED_KEYWORDS", DEFAULT_BANNED_KEYWORDS)
|
||||
if isinstance(bk, str):
|
||||
@@ -224,7 +238,7 @@ def init_service():
|
||||
return jsonify({"status": "Fail", "reason": "BANNED_KEYWORDS must be list or comma string"}), 400
|
||||
|
||||
try:
|
||||
write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords)
|
||||
write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords, rate_limit)
|
||||
load_config()
|
||||
initialize_database()
|
||||
try:
|
||||
@@ -236,6 +250,53 @@ def init_service():
|
||||
except Exception as e:
|
||||
return jsonify({"status": "Fail", "reason": str(e)}), 500
|
||||
|
||||
# === 限流(Rate Limit)实现 ===
|
||||
RATE_LIMIT_STORE = {}
|
||||
|
||||
def get_client_ip():
|
||||
"""在反向代理后正确获取客户端 IP。
|
||||
优先级:CF-Connecting-IP > X-Forwarded-For(首个) > X-Real-IP > remote_addr
|
||||
"""
|
||||
ip = (
|
||||
request.headers.get('CF-Connecting-IP')
|
||||
or request.headers.get('X-Forwarded-For')
|
||||
or request.headers.get('X-Real-IP')
|
||||
or request.remote_addr
|
||||
or '127.0.0.1'
|
||||
)
|
||||
if isinstance(ip, str):
|
||||
# X-Forwarded-For 可能包含多个 IP,取第一个
|
||||
if ',' in ip:
|
||||
ip = ip.split(',')[0].strip()
|
||||
ip = ip.strip()
|
||||
return ip
|
||||
|
||||
def rate_limit_exceeded() -> bool:
|
||||
"""返回是否超过限流。0 表示无限制。窗口从首次请求开始,持续 60 秒。"""
|
||||
if RATE_LIMIT == 0:
|
||||
return False
|
||||
ip = get_client_ip()
|
||||
ip_hash = hashlib.sha256(ip.encode('utf-8')).hexdigest()
|
||||
now = datetime.now(timezone.utc)
|
||||
rec = RATE_LIMIT_STORE.get(ip_hash)
|
||||
if rec is None:
|
||||
RATE_LIMIT_STORE[ip_hash] = {'count': 1, 'start': now}
|
||||
return False
|
||||
# 窗口超过 60 秒则重置
|
||||
if (now - rec['start']).total_seconds() >= 60:
|
||||
rec['count'] = 1
|
||||
rec['start'] = now
|
||||
return False
|
||||
# 累加计数并判断是否超过
|
||||
rec['count'] += 1
|
||||
return rec['count'] > RATE_LIMIT
|
||||
|
||||
def guard_rate_limit():
|
||||
"""超过限流则返回 403,否则返回 None。"""
|
||||
if rate_limit_exceeded():
|
||||
return jsonify({"status": "Fail", "reason": "Rate Limit Exceeded"}), 403
|
||||
return None
|
||||
|
||||
|
||||
# 在服务收到请求且已配置后,确保数据库表创建并加载审核状态
|
||||
@app.before_request
|
||||
@@ -285,6 +346,9 @@ def require_admin(func):
|
||||
# === 路由 ===
|
||||
@app.route('/post', methods=['POST'])
|
||||
def submit_post():
|
||||
guard = guard_rate_limit()
|
||||
if guard is not None:
|
||||
return guard
|
||||
data = request.get_json()
|
||||
if not data or "content" not in data:
|
||||
return jsonify({"error": "Content not found"}), 400
|
||||
@@ -313,6 +377,9 @@ def submit_post():
|
||||
|
||||
@app.route('/up', methods=['POST'])
|
||||
def upvote():
|
||||
guard = guard_rate_limit()
|
||||
if guard is not None:
|
||||
return guard
|
||||
data = request.get_json()
|
||||
if not data or "id" not in data:
|
||||
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
||||
@@ -328,6 +395,9 @@ def upvote():
|
||||
|
||||
@app.route('/down', methods=['POST'])
|
||||
def downvote():
|
||||
guard = guard_rate_limit()
|
||||
if guard is not None:
|
||||
return guard
|
||||
data = request.get_json()
|
||||
if not data or "id" not in data:
|
||||
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
||||
@@ -342,6 +412,9 @@ def downvote():
|
||||
|
||||
@app.route('/comment', methods=['POST'])
|
||||
def post_comment():
|
||||
guard = guard_rate_limit()
|
||||
if guard is not None:
|
||||
return guard
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "submission_id", "parent_comment_id", "nickname"]
|
||||
if not all(field in data for field in required_fields):
|
||||
@@ -389,6 +462,9 @@ def random_string(length=5):
|
||||
|
||||
@app.route('/upload_pic', methods=['POST'])
|
||||
def upload_pic():
|
||||
guard = guard_rate_limit()
|
||||
if guard is not None:
|
||||
return guard
|
||||
if 'file' not in request.files:
|
||||
return jsonify({"status": "Fail", "url": None}), 400
|
||||
|
||||
@@ -427,6 +503,9 @@ def serve_image(filename):
|
||||
|
||||
@app.route('/report', methods=['POST'])
|
||||
def submit_report():
|
||||
guard = guard_rate_limit()
|
||||
if guard is not None:
|
||||
return guard
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return jsonify({"status": "Fail", "reason": "No data provided"}), 400
|
||||
|
||||
Reference in New Issue
Block a user