Move request limiter to middleware.

This commit is contained in:
2025-07-09 17:00:39 +08:00
parent 0021a73951
commit b568b234c4
11 changed files with 73 additions and 71 deletions

View File

@@ -2,17 +2,19 @@ package api
import ( import (
"net/url" "net/url"
"nysoure/server/middleware"
"nysoure/server/model" "nysoure/server/model"
"nysoure/server/service" "nysoure/server/service"
"strconv" "strconv"
"time"
"github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3"
) )
func AddCommentRoutes(router fiber.Router) { func AddCommentRoutes(router fiber.Router) {
api := router.Group("/comments") api := router.Group("/comments")
api.Post("/resource/:resourceID", createResourceComment) api.Use(middleware.NewRequestLimiter(500, 24*time.Hour)).Post("/resource/:resourceID", createResourceComment)
api.Post("/reply/:commentID", createReplyComment) api.Use(middleware.NewRequestLimiter(500, 24*time.Hour)).Post("/reply/:commentID", createReplyComment)
api.Get("/resource/:resourceID", listResourceComments) api.Get("/resource/:resourceID", listResourceComments)
api.Get("/reply/:commentID", listReplyComments) api.Get("/reply/:commentID", listReplyComments)
api.Get("/user/:username", listCommentsByUser) api.Get("/user/:username", listCommentsByUser)

View File

@@ -4,11 +4,14 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"nysoure/server/config"
"nysoure/server/middleware"
"nysoure/server/model" "nysoure/server/model"
"nysoure/server/service" "nysoure/server/service"
"nysoure/server/utils" "nysoure/server/utils"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3"
) )
@@ -16,17 +19,17 @@ import (
func AddFileRoutes(router fiber.Router) { func AddFileRoutes(router fiber.Router) {
fileGroup := router.Group("/files") fileGroup := router.Group("/files")
{ {
fileGroup.Post("/upload/init", initUpload) fileGroup.Use(middleware.NewRequestLimiter(10, time.Hour)).Post("/upload/init", initUpload)
fileGroup.Post("/upload/block/:id/:index", uploadBlock) fileGroup.Post("/upload/block/:id/:index", uploadBlock)
fileGroup.Post("/upload/finish/:id", finishUpload) fileGroup.Post("/upload/finish/:id", finishUpload)
fileGroup.Post("/upload/cancel/:id", cancelUpload) fileGroup.Post("/upload/cancel/:id", cancelUpload)
fileGroup.Post("/redirect", createRedirectFile) fileGroup.Use(middleware.NewRequestLimiter(50, time.Hour)).Post("/redirect", createRedirectFile)
fileGroup.Post("/upload/url", createServerDownloadTask) fileGroup.Post("/upload/url", createServerDownloadTask)
fileGroup.Get("/:id", getFile) fileGroup.Get("/:id", getFile)
fileGroup.Put("/:id", updateFile) fileGroup.Put("/:id", updateFile)
fileGroup.Delete("/:id", deleteFile) fileGroup.Delete("/:id", deleteFile)
fileGroup.Get("/download/local", downloadLocalFile) fileGroup.Get("/download/local", downloadLocalFile)
fileGroup.Get("/download/:id", downloadFile) fileGroup.Use(middleware.NewDynamicRequestLimiter(config.MaxDownloadsPerDayForSingleIP, 24*time.Hour)).Get("/download/:id", downloadFile)
} }
} }
@@ -201,8 +204,7 @@ func deleteFile(c fiber.Ctx) error {
func downloadFile(c fiber.Ctx) error { func downloadFile(c fiber.Ctx) error {
cfToken := c.Query("cf_token") cfToken := c.Query("cf_token")
ip := c.IP() s, filename, err := service.DownloadFile(c.Params("id"), cfToken)
s, filename, err := service.DownloadFile(ip, c.Params("id"), cfToken)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,12 +1,15 @@
package api package api
import ( import (
"github.com/gofiber/fiber/v3"
"net/http" "net/http"
"nysoure/server/middleware"
"nysoure/server/model" "nysoure/server/model"
"nysoure/server/service" "nysoure/server/service"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/gofiber/fiber/v3"
) )
func handleUploadImage(c fiber.Ctx) error { func handleUploadImage(c fiber.Ctx) error {
@@ -94,7 +97,7 @@ func handleGetResampledImage(c fiber.Ctx) error {
func AddImageRoutes(api fiber.Router) { func AddImageRoutes(api fiber.Router) {
image := api.Group("/image") image := api.Group("/image")
{ {
image.Put("/", handleUploadImage) image.Use(middleware.NewRequestLimiter(50, time.Hour)).Put("/", handleUploadImage)
image.Get("/resampled/:id", handleGetResampledImage) image.Get("/resampled/:id", handleGetResampledImage)
image.Get("/:id", handleGetImage) image.Get("/:id", handleGetImage)
image.Delete("/:id", handleDeleteImage) image.Delete("/:id", handleDeleteImage)

View File

@@ -4,9 +4,11 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"nysoure/server/middleware"
"nysoure/server/model" "nysoure/server/model"
"nysoure/server/service" "nysoure/server/service"
"strconv" "strconv"
"time"
"github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3"
) )
@@ -342,7 +344,7 @@ func handleGetMe(c fiber.Ctx) error {
func AddUserRoutes(r fiber.Router) { func AddUserRoutes(r fiber.Router) {
u := r.Group("user") u := r.Group("user")
u.Post("/register", handleUserRegister) u.Use(middleware.NewRequestLimiter(5, time.Hour)).Post("/register", handleUserRegister)
u.Post("/login", handleUserLogin) u.Post("/login", handleUserLogin)
u.Put("/avatar", handleUserChangeAvatar) u.Put("/avatar", handleUserChangeAvatar)
u.Post("/password", handleUserChangePassword) u.Post("/password", handleUserChangePassword)

View File

@@ -2,9 +2,10 @@ package middleware
import ( import (
"errors" "errors"
"github.com/gofiber/fiber/v3/log"
"nysoure/server/model" "nysoure/server/model"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3"
) )
@@ -14,6 +15,7 @@ func ErrorHandler(c fiber.Ctx) error {
var requestErr *model.RequestError var requestErr *model.RequestError
var unauthorizedErr *model.UnAuthorizedError var unauthorizedErr *model.UnAuthorizedError
var notFoundErr *model.NotFoundError var notFoundErr *model.NotFoundError
var fiberErr *fiber.Error
if errors.As(err, &requestErr) { if errors.As(err, &requestErr) {
log.Error("Request Error: ", err) log.Error("Request Error: ", err)
return c.Status(fiber.StatusBadRequest).JSON(model.Response[any]{ return c.Status(fiber.StatusBadRequest).JSON(model.Response[any]{
@@ -47,6 +49,12 @@ func ErrorHandler(c fiber.Ctx) error {
Data: nil, Data: nil,
Message: "Method not allowed", Message: "Method not allowed",
}) })
} else if errors.As(err, &fiberErr) && fiberErr.Message != "" {
return c.Status(fiberErr.Code).JSON(model.Response[any]{
Success: false,
Data: nil,
Message: fiberErr.Message,
})
} else { } else {
var fiberErr *fiber.Error var fiberErr *fiber.Error
if errors.As(err, &fiberErr) { if errors.As(err, &fiberErr) {

View File

@@ -0,0 +1,34 @@
package middleware
import (
"nysoure/server/utils"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)
func NewRequestLimiter(maxRequests int, duration time.Duration) func(c fiber.Ctx) error {
limiter := utils.NewRequestLimiter(func() int {
return maxRequests
}, duration)
return func(c fiber.Ctx) error {
if !limiter.AllowRequest(c.IP()) {
log.Warnf("IP %s has exceeded the request limit of %d requests in %s", c.IP(), maxRequests, duration)
return fiber.NewError(fiber.StatusTooManyRequests, "Too many requests")
}
return c.Next()
}
}
func NewDynamicRequestLimiter(maxRequestsFunc func() int, duration time.Duration) func(c fiber.Ctx) error {
limiter := utils.NewRequestLimiter(maxRequestsFunc, duration)
return func(c fiber.Ctx) error {
if !limiter.AllowRequest(c.IP()) {
return fiber.NewError(fiber.StatusTooManyRequests, "Too many requests")
}
return c.Next()
}
}

View File

@@ -3,11 +3,9 @@ package service
import ( import (
"nysoure/server/dao" "nysoure/server/dao"
"nysoure/server/model" "nysoure/server/model"
"nysoure/server/utils"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/log"
) )
@@ -19,10 +17,6 @@ const (
maxCommentBriefLength = 256 // Maximum length of a comment brief maxCommentBriefLength = 256 // Maximum length of a comment brief
) )
var (
commentsLimiter = utils.NewRequestLimiter(maxCommentsPerIP, 24*time.Hour)
)
type CommentRequest struct { type CommentRequest struct {
Content string `json:"content"` // markdown Content string `json:"content"` // markdown
// Images []uint `json:"images"` // Unrequired after new design // Images []uint `json:"images"` // Unrequired after new design
@@ -62,11 +56,6 @@ func findImagesInContent(content string, host string) []uint {
} }
func CreateComment(req CommentRequest, userID uint, refID uint, ip string, cType model.CommentType, host string) (*model.CommentView, error) { func CreateComment(req CommentRequest, userID uint, refID uint, ip string, cType model.CommentType, host string) (*model.CommentView, error) {
if !commentsLimiter.AllowRequest(ip) {
log.Warnf("IP %s has exceeded the comment limit of %d comments per day", ip, maxCommentsPerIP)
return nil, model.NewRequestError("Too many comments from this IP address, please try again later")
}
if len(req.Content) == 0 { if len(req.Content) == 0 {
return nil, model.NewRequestError("Content cannot be empty") return nil, model.NewRequestError("Content cannot be empty")
} }

View File

@@ -13,7 +13,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/log"
@@ -25,23 +24,6 @@ const (
storageKeyUnavailable = "storage_key_unavailable" // Placeholder for unavailable storage key storageKeyUnavailable = "storage_key_unavailable" // Placeholder for unavailable storage key
) )
var (
ipDownloads = sync.Map{}
)
func init() {
go func() {
for {
// Clean up old IP download records every 24 hours
time.Sleep(24 * time.Hour)
ipDownloads.Range(func(key, value interface{}) bool {
ipDownloads.Delete(key)
return true
})
}
}()
}
func getUploadingSize() int64 { func getUploadingSize() int64 {
return dao.GetStatistic("uploading_size") return dao.GetStatistic("uploading_size")
} }
@@ -405,7 +387,7 @@ func GetFile(fid string) (*model.FileView, error) {
} }
// DownloadFile handles the file download request. Return a presigned URL or a direct file path. // DownloadFile handles the file download request. Return a presigned URL or a direct file path.
func DownloadFile(ip, fid, cfToken string) (string, string, error) { func DownloadFile(fid, cfToken string) (string, string, error) {
passed, err := verifyCfToken(cfToken) passed, err := verifyCfToken(cfToken)
if err != nil { if err != nil {
log.Error("failed to verify cf token: ", err) log.Error("failed to verify cf token: ", err)
@@ -415,17 +397,6 @@ func DownloadFile(ip, fid, cfToken string) (string, string, error) {
log.Info("cf token verification failed") log.Info("cf token verification failed")
return "", "", model.NewRequestError("cf token verification failed") return "", "", model.NewRequestError("cf token verification failed")
} }
log.Info("File download request from: " + ip)
downloads, _ := ipDownloads.Load(ip)
if downloads == nil {
ipDownloads.Store(ip, 1)
} else {
count := downloads.(int)
if count >= config.MaxDownloadsPerDayForSingleIP() {
return "", "", model.NewRequestError("Too many requests, please try again later")
}
ipDownloads.Store(ip, count+1)
}
file, err := dao.GetFile(fid) file, err := dao.GetFile(fid)
if err != nil { if err != nil {
log.Error("failed to get file: ", err) log.Error("failed to get file: ", err)

View File

@@ -3,7 +3,6 @@ package service
import ( import (
"bytes" "bytes"
"errors" "errors"
"github.com/disintegration/imaging"
"image" "image"
"math" "math"
"net/http" "net/http"
@@ -14,14 +13,17 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/disintegration/imaging"
"github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/log"
"github.com/google/uuid" "github.com/google/uuid"
_ "golang.org/x/image/bmp"
_ "image/gif" _ "image/gif"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
"github.com/chai2010/webp" "github.com/chai2010/webp"
@@ -54,24 +56,12 @@ func init() {
}() }()
} }
var (
imageLimiter = utils.NewRequestLimiter(maxUploadsPerIP, 24*time.Hour)
)
const maxUploadsPerIP = 100
func CreateImage(uid uint, ip string, data []byte) (uint, error) { func CreateImage(uid uint, ip string, data []byte) (uint, error) {
canUpload, err := checkUserCanUpload(uid) canUpload, err := checkUserCanUpload(uid)
if err != nil { if err != nil {
log.Error("Error checking user upload permission:", err) log.Error("Error checking user upload permission:", err)
return 0, model.NewInternalServerError("Error checking user upload permission") return 0, model.NewInternalServerError("Error checking user upload permission")
} }
if !canUpload {
// For a normal user, check the IP upload limit
if !imageLimiter.AllowRequest(ip) {
return 0, model.NewUnAuthorizedError("You have reached the maximum upload limit")
}
}
if len(data) == 0 { if len(data) == 0 {
return 0, model.NewRequestError("Image data is empty") return 0, model.NewRequestError("Image data is empty")

View File

@@ -3,7 +3,6 @@ package service
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gofiber/fiber/v3/log"
"nysoure/server/config" "nysoure/server/config"
"nysoure/server/dao" "nysoure/server/dao"
"nysoure/server/model" "nysoure/server/model"
@@ -15,6 +14,8 @@ import (
"time" "time"
"unicode" "unicode"
"github.com/gofiber/fiber/v3/log"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -376,7 +377,7 @@ func validateUsername(username string) error {
if usernameLen < 3 || usernameLen > 20 { if usernameLen < 3 || usernameLen > 20 {
return model.NewRequestError("Username must be between 3 and 20 characters") return model.NewRequestError("Username must be between 3 and 20 characters")
} }
for _, r := range []rune(username) { for _, r := range username {
if r == ' ' || r == '\n' || r == '\r' || r == '\t' || r == '\v' || r == '\f' { if r == ' ' || r == '\n' || r == '\r' || r == '\t' || r == '\v' || r == '\f' {
return model.NewRequestError("Username cannot contain whitespace characters") return model.NewRequestError("Username cannot contain whitespace characters")
} }

View File

@@ -6,12 +6,12 @@ import (
) )
type RequestLimiter struct { type RequestLimiter struct {
limit int limit func() int
requestsByIP map[string]int requestsByIP map[string]int
mu sync.Mutex mu sync.Mutex
} }
func NewRequestLimiter(limit int, duration time.Duration) *RequestLimiter { func NewRequestLimiter(limit func() int, duration time.Duration) *RequestLimiter {
l := &RequestLimiter{ l := &RequestLimiter{
limit: limit, limit: limit,
requestsByIP: make(map[string]int), requestsByIP: make(map[string]int),
@@ -38,7 +38,7 @@ func (rl *RequestLimiter) AllowRequest(ip string) bool {
count = 0 count = 0
} }
if count >= rl.limit { if count >= rl.limit() {
return false // Exceeded request limit for this IP return false // Exceeded request limit for this IP
} }