增加Rate Limit功能
This commit is contained in:
@@ -8,6 +8,7 @@ from flask import send_file
|
|||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import hashlib
|
||||||
|
|
||||||
# === Flask 初始化 ===
|
# === Flask 初始化 ===
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@@ -97,6 +98,7 @@ DEFAULT_ADMIN_TOKEN = "Sycamore_whisper"
|
|||||||
DEFAULT_UPLOAD_FOLDER = "img"
|
DEFAULT_UPLOAD_FOLDER = "img"
|
||||||
DEFAULT_ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "webp"}
|
DEFAULT_ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "gif", "webp"}
|
||||||
DEFAULT_MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
DEFAULT_MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||||
|
DEFAULT_RATE_LIMIT = 10 # 次/分钟,0为无限制
|
||||||
|
|
||||||
CONFIG = {}
|
CONFIG = {}
|
||||||
INIT = False
|
INIT = False
|
||||||
@@ -108,6 +110,7 @@ UPLOAD_FOLDER = DEFAULT_UPLOAD_FOLDER
|
|||||||
ALLOWED_EXTENSIONS = set(DEFAULT_ALLOWED_EXTENSIONS)
|
ALLOWED_EXTENSIONS = set(DEFAULT_ALLOWED_EXTENSIONS)
|
||||||
MAX_FILE_SIZE = DEFAULT_MAX_FILE_SIZE
|
MAX_FILE_SIZE = DEFAULT_MAX_FILE_SIZE
|
||||||
BANNED_KEYWORDS = list(DEFAULT_BANNED_KEYWORDS)
|
BANNED_KEYWORDS = list(DEFAULT_BANNED_KEYWORDS)
|
||||||
|
RATE_LIMIT = DEFAULT_RATE_LIMIT
|
||||||
|
|
||||||
DB_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instance', 'database.db')
|
DB_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instance', 'database.db')
|
||||||
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
|
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
|
||||||
@@ -120,12 +123,13 @@ def allowed_backup_file(filename):
|
|||||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_BACKUP_EXTENSIONS
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_BACKUP_EXTENSIONS
|
||||||
|
|
||||||
def apply_config_to_globals():
|
def apply_config_to_globals():
|
||||||
global ADMIN_TOKEN, UPLOAD_FOLDER, ALLOWED_EXTENSIONS, MAX_FILE_SIZE, IMG_FOLDER, BANNED_KEYWORDS
|
global ADMIN_TOKEN, UPLOAD_FOLDER, ALLOWED_EXTENSIONS, MAX_FILE_SIZE, IMG_FOLDER, BANNED_KEYWORDS, RATE_LIMIT
|
||||||
ADMIN_TOKEN = CONFIG.get('ADMIN_TOKEN', DEFAULT_ADMIN_TOKEN)
|
ADMIN_TOKEN = CONFIG.get('ADMIN_TOKEN', DEFAULT_ADMIN_TOKEN)
|
||||||
UPLOAD_FOLDER = CONFIG.get('UPLOAD_FOLDER', DEFAULT_UPLOAD_FOLDER)
|
UPLOAD_FOLDER = CONFIG.get('UPLOAD_FOLDER', DEFAULT_UPLOAD_FOLDER)
|
||||||
ALLOWED_EXTENSIONS = set(CONFIG.get('ALLOWED_EXTENSIONS', DEFAULT_ALLOWED_EXTENSIONS))
|
ALLOWED_EXTENSIONS = set(CONFIG.get('ALLOWED_EXTENSIONS', DEFAULT_ALLOWED_EXTENSIONS))
|
||||||
MAX_FILE_SIZE = int(CONFIG.get('MAX_FILE_SIZE', DEFAULT_MAX_FILE_SIZE))
|
MAX_FILE_SIZE = int(CONFIG.get('MAX_FILE_SIZE', DEFAULT_MAX_FILE_SIZE))
|
||||||
BANNED_KEYWORDS = list(CONFIG.get('BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS))
|
BANNED_KEYWORDS = list(CONFIG.get('BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS))
|
||||||
|
RATE_LIMIT = int(CONFIG.get('RATE_LIMIT', DEFAULT_RATE_LIMIT))
|
||||||
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
|
IMG_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), UPLOAD_FOLDER)
|
||||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||||
|
|
||||||
@@ -144,6 +148,7 @@ def load_config():
|
|||||||
'ALLOWED_EXTENSIONS': set(getattr(cfg, 'ALLOWED_EXTENSIONS')),
|
'ALLOWED_EXTENSIONS': set(getattr(cfg, 'ALLOWED_EXTENSIONS')),
|
||||||
'MAX_FILE_SIZE': int(getattr(cfg, 'MAX_FILE_SIZE')),
|
'MAX_FILE_SIZE': int(getattr(cfg, 'MAX_FILE_SIZE')),
|
||||||
'BANNED_KEYWORDS': list(getattr(cfg, 'BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS)),
|
'BANNED_KEYWORDS': list(getattr(cfg, 'BANNED_KEYWORDS', DEFAULT_BANNED_KEYWORDS)),
|
||||||
|
'RATE_LIMIT': int(getattr(cfg, 'RATE_LIMIT', DEFAULT_RATE_LIMIT)),
|
||||||
}
|
}
|
||||||
INIT = True
|
INIT = True
|
||||||
apply_config_to_globals()
|
apply_config_to_globals()
|
||||||
@@ -169,7 +174,7 @@ def gate_uninitialized():
|
|||||||
if not INIT:
|
if not INIT:
|
||||||
return jsonify({"status": "Fail", "reason": "Uninitialized"}), 503
|
return jsonify({"status": "Fail", "reason": "Uninitialized"}), 503
|
||||||
|
|
||||||
def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords=None):
|
def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords=None, rate_limit=DEFAULT_RATE_LIMIT):
|
||||||
# 归一化扩展名为小写且唯一
|
# 归一化扩展名为小写且唯一
|
||||||
exts = sorted(set(str(e).strip().lower() for e in allowed_exts if str(e).strip()))
|
exts = sorted(set(str(e).strip().lower() for e in allowed_exts if str(e).strip()))
|
||||||
# 归一化敏感词为去空格的字符串列表
|
# 归一化敏感词为去空格的字符串列表
|
||||||
@@ -182,6 +187,7 @@ def write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_ke
|
|||||||
f"ALLOWED_EXTENSIONS = {repr(exts)}\n"
|
f"ALLOWED_EXTENSIONS = {repr(exts)}\n"
|
||||||
f"MAX_FILE_SIZE = {int(max_file_size)}\n"
|
f"MAX_FILE_SIZE = {int(max_file_size)}\n"
|
||||||
f"BANNED_KEYWORDS = {repr(banned)}\n"
|
f"BANNED_KEYWORDS = {repr(banned)}\n"
|
||||||
|
f"RATE_LIMIT = {int(rate_limit)}\n"
|
||||||
)
|
)
|
||||||
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.py')
|
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.py')
|
||||||
with open(config_path, 'w', encoding='utf-8') as f:
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
@@ -193,7 +199,7 @@ def init_service():
|
|||||||
if INIT:
|
if INIT:
|
||||||
return jsonify({"status": "Fail", "reason": "Already initialized"}), 403
|
return jsonify({"status": "Fail", "reason": "Already initialized"}), 403
|
||||||
data = request.get_json() or {}
|
data = request.get_json() or {}
|
||||||
required = ["ADMIN_TOKEN", "UPLOAD_FOLDER", "ALLOWED_EXTENSIONS", "MAX_FILE_SIZE"]
|
required = ["ADMIN_TOKEN", "UPLOAD_FOLDER", "ALLOWED_EXTENSIONS", "MAX_FILE_SIZE", "RATE_LIMIT"]
|
||||||
missing = [k for k in required if k not in data]
|
missing = [k for k in required if k not in data]
|
||||||
if missing:
|
if missing:
|
||||||
return jsonify({"status": "Fail", "reason": f"Missing fields: {', '.join(missing)}"}), 400
|
return jsonify({"status": "Fail", "reason": f"Missing fields: {', '.join(missing)}"}), 400
|
||||||
@@ -214,6 +220,14 @@ def init_service():
|
|||||||
except Exception:
|
except Exception:
|
||||||
return jsonify({"status": "Fail", "reason": "MAX_FILE_SIZE must be int"}), 400
|
return jsonify({"status": "Fail", "reason": "MAX_FILE_SIZE must be int"}), 400
|
||||||
|
|
||||||
|
# 必填的 RATE_LIMIT(次/分钟,0为无限制)
|
||||||
|
try:
|
||||||
|
rate_limit = int(data["RATE_LIMIT"])
|
||||||
|
if rate_limit < 0:
|
||||||
|
return jsonify({"status": "Fail", "reason": "RATE_LIMIT must be >= 0"}), 400
|
||||||
|
except Exception:
|
||||||
|
return jsonify({"status": "Fail", "reason": "RATE_LIMIT must be int"}), 400
|
||||||
|
|
||||||
# 可选的 BANNED_KEYWORDS
|
# 可选的 BANNED_KEYWORDS
|
||||||
bk = data.get("BANNED_KEYWORDS", DEFAULT_BANNED_KEYWORDS)
|
bk = data.get("BANNED_KEYWORDS", DEFAULT_BANNED_KEYWORDS)
|
||||||
if isinstance(bk, str):
|
if isinstance(bk, str):
|
||||||
@@ -224,7 +238,7 @@ def init_service():
|
|||||||
return jsonify({"status": "Fail", "reason": "BANNED_KEYWORDS must be list or comma string"}), 400
|
return jsonify({"status": "Fail", "reason": "BANNED_KEYWORDS must be list or comma string"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords)
|
write_config_py(token, upload_folder, allowed_exts, max_file_size, banned_keywords, rate_limit)
|
||||||
load_config()
|
load_config()
|
||||||
initialize_database()
|
initialize_database()
|
||||||
try:
|
try:
|
||||||
@@ -236,6 +250,53 @@ def init_service():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"status": "Fail", "reason": str(e)}), 500
|
return jsonify({"status": "Fail", "reason": str(e)}), 500
|
||||||
|
|
||||||
|
# === 限流(Rate Limit)实现 ===
|
||||||
|
RATE_LIMIT_STORE = {}
|
||||||
|
|
||||||
|
def get_client_ip():
|
||||||
|
"""在反向代理后正确获取客户端 IP。
|
||||||
|
优先级:CF-Connecting-IP > X-Forwarded-For(首个) > X-Real-IP > remote_addr
|
||||||
|
"""
|
||||||
|
ip = (
|
||||||
|
request.headers.get('CF-Connecting-IP')
|
||||||
|
or request.headers.get('X-Forwarded-For')
|
||||||
|
or request.headers.get('X-Real-IP')
|
||||||
|
or request.remote_addr
|
||||||
|
or '127.0.0.1'
|
||||||
|
)
|
||||||
|
if isinstance(ip, str):
|
||||||
|
# X-Forwarded-For 可能包含多个 IP,取第一个
|
||||||
|
if ',' in ip:
|
||||||
|
ip = ip.split(',')[0].strip()
|
||||||
|
ip = ip.strip()
|
||||||
|
return ip
|
||||||
|
|
||||||
|
def rate_limit_exceeded() -> bool:
|
||||||
|
"""返回是否超过限流。0 表示无限制。窗口从首次请求开始,持续 60 秒。"""
|
||||||
|
if RATE_LIMIT == 0:
|
||||||
|
return False
|
||||||
|
ip = get_client_ip()
|
||||||
|
ip_hash = hashlib.sha256(ip.encode('utf-8')).hexdigest()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
rec = RATE_LIMIT_STORE.get(ip_hash)
|
||||||
|
if rec is None:
|
||||||
|
RATE_LIMIT_STORE[ip_hash] = {'count': 1, 'start': now}
|
||||||
|
return False
|
||||||
|
# 窗口超过 60 秒则重置
|
||||||
|
if (now - rec['start']).total_seconds() >= 60:
|
||||||
|
rec['count'] = 1
|
||||||
|
rec['start'] = now
|
||||||
|
return False
|
||||||
|
# 累加计数并判断是否超过
|
||||||
|
rec['count'] += 1
|
||||||
|
return rec['count'] > RATE_LIMIT
|
||||||
|
|
||||||
|
def guard_rate_limit():
|
||||||
|
"""超过限流则返回 403,否则返回 None。"""
|
||||||
|
if rate_limit_exceeded():
|
||||||
|
return jsonify({"status": "Fail", "reason": "Rate Limit Exceeded"}), 403
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# 在服务收到请求且已配置后,确保数据库表创建并加载审核状态
|
# 在服务收到请求且已配置后,确保数据库表创建并加载审核状态
|
||||||
@app.before_request
|
@app.before_request
|
||||||
@@ -285,6 +346,9 @@ def require_admin(func):
|
|||||||
# === 路由 ===
|
# === 路由 ===
|
||||||
@app.route('/post', methods=['POST'])
|
@app.route('/post', methods=['POST'])
|
||||||
def submit_post():
|
def submit_post():
|
||||||
|
guard = guard_rate_limit()
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or "content" not in data:
|
if not data or "content" not in data:
|
||||||
return jsonify({"error": "Content not found"}), 400
|
return jsonify({"error": "Content not found"}), 400
|
||||||
@@ -313,6 +377,9 @@ def submit_post():
|
|||||||
|
|
||||||
@app.route('/up', methods=['POST'])
|
@app.route('/up', methods=['POST'])
|
||||||
def upvote():
|
def upvote():
|
||||||
|
guard = guard_rate_limit()
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or "id" not in data:
|
if not data or "id" not in data:
|
||||||
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
||||||
@@ -328,6 +395,9 @@ def upvote():
|
|||||||
|
|
||||||
@app.route('/down', methods=['POST'])
|
@app.route('/down', methods=['POST'])
|
||||||
def downvote():
|
def downvote():
|
||||||
|
guard = guard_rate_limit()
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or "id" not in data:
|
if not data or "id" not in data:
|
||||||
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
return jsonify({"status": "Fail", "reason": "Value ID not found"}), 400
|
||||||
@@ -342,6 +412,9 @@ def downvote():
|
|||||||
|
|
||||||
@app.route('/comment', methods=['POST'])
|
@app.route('/comment', methods=['POST'])
|
||||||
def post_comment():
|
def post_comment():
|
||||||
|
guard = guard_rate_limit()
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
required_fields = ["content", "submission_id", "parent_comment_id", "nickname"]
|
required_fields = ["content", "submission_id", "parent_comment_id", "nickname"]
|
||||||
if not all(field in data for field in required_fields):
|
if not all(field in data for field in required_fields):
|
||||||
@@ -389,6 +462,9 @@ def random_string(length=5):
|
|||||||
|
|
||||||
@app.route('/upload_pic', methods=['POST'])
|
@app.route('/upload_pic', methods=['POST'])
|
||||||
def upload_pic():
|
def upload_pic():
|
||||||
|
guard = guard_rate_limit()
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return jsonify({"status": "Fail", "url": None}), 400
|
return jsonify({"status": "Fail", "url": None}), 400
|
||||||
|
|
||||||
@@ -427,6 +503,9 @@ def serve_image(filename):
|
|||||||
|
|
||||||
@app.route('/report', methods=['POST'])
|
@app.route('/report', methods=['POST'])
|
||||||
def submit_report():
|
def submit_report():
|
||||||
|
guard = guard_rate_limit()
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data:
|
if not data:
|
||||||
return jsonify({"status": "Fail", "reason": "No data provided"}), 400
|
return jsonify({"status": "Fail", "reason": "No data provided"}), 400
|
||||||
|
|||||||
Reference in New Issue
Block a user