Tighten upload validation and transactions

This commit is contained in:
LeonspaceX
2026-01-31 17:28:05 +08:00
parent 3a30271fe6
commit b788c04f1d

View File

@@ -3,9 +3,11 @@
from flask import Flask, jsonify, request, abort, send_from_directory from flask import Flask, jsonify, request, abort, send_from_directory
from flask_cors import CORS from flask_cors import CORS
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import foreign
import os import os
import uuid import uuid
import json import json
import imghdr
from datetime import datetime from datetime import datetime
app = Flask(__name__) app = Flask(__name__)
@@ -62,9 +64,10 @@ class Submission(db.Model):
comments = db.relationship('Comment', backref='submission', lazy=True, cascade='all, delete-orphan') comments = db.relationship('Comment', backref='submission', lazy=True, cascade='all, delete-orphan')
hashtags = db.relationship( hashtags = db.relationship(
'Hashtag', 'Hashtag',
primaryjoin="and_(Hashtag.type==0, Hashtag.target_id==Submission.id)", primaryjoin="and_(Hashtag.type==0, foreign(Hashtag.target_id)==Submission.id)",
cascade='all, delete-orphan', cascade='all, delete-orphan',
lazy=True lazy=True,
overlaps="hashtags"
) )
class Comment(db.Model): class Comment(db.Model):
@@ -78,9 +81,10 @@ class Comment(db.Model):
parent_comment_id = db.Column(db.Integer, db.ForeignKey('comments.id'), nullable=True) parent_comment_id = db.Column(db.Integer, db.ForeignKey('comments.id'), nullable=True)
hashtags = db.relationship( hashtags = db.relationship(
'Hashtag', 'Hashtag',
primaryjoin="and_(Hashtag.type==1, Hashtag.target_id==Comment.id)", primaryjoin="and_(Hashtag.type==1, foreign(Hashtag.target_id)==Comment.id)",
cascade='all, delete-orphan', cascade='all, delete-orphan',
lazy=True lazy=True,
overlaps="hashtags"
) )
class Report(db.Model): class Report(db.Model):
@@ -207,7 +211,6 @@ def save_hashtags(tag_type, target_id, hashtopic):
name=tag name=tag
) )
db.session.add(new_tag) db.session.add(new_tag)
db.session.commit()
# --- 用户普通api端点 --- # --- 用户普通api端点 ---
@app.route('/api/settings', methods=['GET']) @app.route('/api/settings', methods=['GET'])
@@ -345,14 +348,16 @@ def submit_post():
updated_at=now, updated_at=now,
) )
db.session.add(new_post) db.session.add(new_post)
db.session.commit() db.session.flush()
# 保存 Hashtags # 保存 Hashtags
save_hashtags(0, new_post.id, hashtopic) save_hashtags(0, new_post.id, hashtopic)
db.session.commit()
code = 1002 if new_post.status == 'Pending' else 1001 code = 1002 if new_post.status == 'Pending' else 1001
return jsonify({"code": code, "data": {"id": new_post.id}}) return jsonify({"code": code, "data": {"id": new_post.id}})
except Exception as e: except Exception as e:
db.session.rollback()
return jsonify({"code": 2003, "data": f"投稿失败: {str(e)}"}) return jsonify({"code": 2003, "data": f"投稿失败: {str(e)}"})
@app.route('/api/comment', methods=['POST']) @app.route('/api/comment', methods=['POST'])
@@ -398,13 +403,15 @@ def submit_comment():
parent_comment_id=None if parent_comment_id == 0 else parent_comment_id parent_comment_id=None if parent_comment_id == 0 else parent_comment_id
) )
db.session.add(new_comment) db.session.add(new_comment)
db.session.commit() db.session.flush()
# 保存 Hashtags # 保存 Hashtags
save_hashtags(1, new_comment.id, hashtopic) save_hashtags(1, new_comment.id, hashtopic)
db.session.commit()
return jsonify({"code": 1001, "data": {"id": new_comment.id}}) return jsonify({"code": 1001, "data": {"id": new_comment.id}})
except Exception as e: except Exception as e:
db.session.rollback()
return jsonify({"code": 2003, "data": f"评论失败: {str(e)}"}) return jsonify({"code": 2003, "data": f"评论失败: {str(e)}"})
@app.route('/api/report', methods=['POST']) @app.route('/api/report', methods=['POST'])
@@ -498,6 +505,20 @@ def upload_pic():
if not ext or (FILE_FORMATS and ext not in FILE_FORMATS): if not ext or (FILE_FORMATS and ext not in FILE_FORMATS):
return jsonify({"code": 2007, "data": "上传的文件类型不支持"}) return jsonify({"code": 2007, "data": "上传的文件类型不支持"})
header = file.read(512)
file.seek(0)
detected = imghdr.what(None, header)
ext_map = {
'jpg': 'jpeg',
'jpeg': 'jpeg',
'png': 'png',
'gif': 'gif',
'webp': 'webp',
}
expected = ext_map.get(ext)
if not expected or detected != expected:
return jsonify({"code": 2007, "data": "上传的文件类型不支持"})
filename = f"{uuid.uuid4().hex}.{ext}" filename = f"{uuid.uuid4().hex}.{ext}"
filepath = os.path.join(IMG_DIR, filename) filepath = os.path.join(IMG_DIR, filename)
file.save(filepath) file.save(filepath)