From b568b234c45a92cecf60e1fda7bae947cd5fb149 Mon Sep 17 00:00:00 2001 From: nyne Date: Wed, 9 Jul 2025 17:00:39 +0800 Subject: [PATCH] Move request limiter to middleware. --- server/api/comment.go | 6 +++-- server/api/file.go | 12 ++++++---- server/api/image.go | 7 ++++-- server/api/user.go | 4 +++- server/middleware/error_handler.go | 10 +++++++- server/middleware/request_limiter.go | 34 ++++++++++++++++++++++++++++ server/service/comment.go | 11 --------- server/service/file.go | 31 +------------------------ server/service/image.go | 18 ++++----------- server/service/user.go | 5 ++-- server/utils/request_limit.go | 6 ++--- 11 files changed, 73 insertions(+), 71 deletions(-) create mode 100644 server/middleware/request_limiter.go diff --git a/server/api/comment.go b/server/api/comment.go index c5a5fda..e1829bf 100644 --- a/server/api/comment.go +++ b/server/api/comment.go @@ -2,17 +2,19 @@ package api import ( "net/url" + "nysoure/server/middleware" "nysoure/server/model" "nysoure/server/service" "strconv" + "time" "github.com/gofiber/fiber/v3" ) func AddCommentRoutes(router fiber.Router) { api := router.Group("/comments") - api.Post("/resource/:resourceID", createResourceComment) - api.Post("/reply/:commentID", createReplyComment) + api.Use(middleware.NewRequestLimiter(500, 24*time.Hour)).Post("/resource/:resourceID", createResourceComment) + api.Use(middleware.NewRequestLimiter(500, 24*time.Hour)).Post("/reply/:commentID", createReplyComment) api.Get("/resource/:resourceID", listResourceComments) api.Get("/reply/:commentID", listReplyComments) api.Get("/user/:username", listCommentsByUser) diff --git a/server/api/file.go b/server/api/file.go index 8b60019..b6d3c4a 100644 --- a/server/api/file.go +++ b/server/api/file.go @@ -4,11 +4,14 @@ import ( "encoding/json" "fmt" "net/url" + "nysoure/server/config" + "nysoure/server/middleware" "nysoure/server/model" "nysoure/server/service" "nysoure/server/utils" "strconv" "strings" + "time" "github.com/gofiber/fiber/v3" ) @@ -16,17 +19,17 @@ import ( func AddFileRoutes(router fiber.Router) { fileGroup := router.Group("/files") { - fileGroup.Post("/upload/init", initUpload) + fileGroup.Use(middleware.NewRequestLimiter(10, time.Hour)).Post("/upload/init", initUpload) fileGroup.Post("/upload/block/:id/:index", uploadBlock) fileGroup.Post("/upload/finish/:id", finishUpload) fileGroup.Post("/upload/cancel/:id", cancelUpload) - fileGroup.Post("/redirect", createRedirectFile) + fileGroup.Use(middleware.NewRequestLimiter(50, time.Hour)).Post("/redirect", createRedirectFile) fileGroup.Post("/upload/url", createServerDownloadTask) fileGroup.Get("/:id", getFile) fileGroup.Put("/:id", updateFile) fileGroup.Delete("/:id", deleteFile) fileGroup.Get("/download/local", downloadLocalFile) - fileGroup.Get("/download/:id", downloadFile) + fileGroup.Use(middleware.NewDynamicRequestLimiter(config.MaxDownloadsPerDayForSingleIP, 24*time.Hour)).Get("/download/:id", downloadFile) } } @@ -201,8 +204,7 @@ func deleteFile(c fiber.Ctx) error { func downloadFile(c fiber.Ctx) error { cfToken := c.Query("cf_token") - ip := c.IP() - s, filename, err := service.DownloadFile(ip, c.Params("id"), cfToken) + s, filename, err := service.DownloadFile(c.Params("id"), cfToken) if err != nil { return err } diff --git a/server/api/image.go b/server/api/image.go index e392b87..8d0cf9f 100644 --- a/server/api/image.go +++ b/server/api/image.go @@ -1,12 +1,15 @@ package api import ( - "github.com/gofiber/fiber/v3" "net/http" + "nysoure/server/middleware" "nysoure/server/model" "nysoure/server/service" "strconv" "strings" + "time" + + "github.com/gofiber/fiber/v3" ) func handleUploadImage(c fiber.Ctx) error { @@ -94,7 +97,7 @@ func handleGetResampledImage(c fiber.Ctx) error { func AddImageRoutes(api fiber.Router) { image := api.Group("/image") { - image.Put("/", handleUploadImage) + image.Use(middleware.NewRequestLimiter(50, time.Hour)).Put("/", handleUploadImage) image.Get("/resampled/:id", handleGetResampledImage) image.Get("/:id", handleGetImage) image.Delete("/:id", handleDeleteImage) diff --git a/server/api/user.go b/server/api/user.go index a363d56..c46bb8c 100644 --- a/server/api/user.go +++ b/server/api/user.go @@ -4,9 +4,11 @@ import ( "io" "net/http" "net/url" + "nysoure/server/middleware" "nysoure/server/model" "nysoure/server/service" "strconv" + "time" "github.com/gofiber/fiber/v3" ) @@ -342,7 +344,7 @@ func handleGetMe(c fiber.Ctx) error { func AddUserRoutes(r fiber.Router) { u := r.Group("user") - u.Post("/register", handleUserRegister) + u.Use(middleware.NewRequestLimiter(5, time.Hour)).Post("/register", handleUserRegister) u.Post("/login", handleUserLogin) u.Put("/avatar", handleUserChangeAvatar) u.Post("/password", handleUserChangePassword) diff --git a/server/middleware/error_handler.go b/server/middleware/error_handler.go index 19b4e27..446a23d 100644 --- a/server/middleware/error_handler.go +++ b/server/middleware/error_handler.go @@ -2,9 +2,10 @@ package middleware import ( "errors" - "github.com/gofiber/fiber/v3/log" "nysoure/server/model" + "github.com/gofiber/fiber/v3/log" + "github.com/gofiber/fiber/v3" ) @@ -14,6 +15,7 @@ func ErrorHandler(c fiber.Ctx) error { var requestErr *model.RequestError var unauthorizedErr *model.UnAuthorizedError var notFoundErr *model.NotFoundError + var fiberErr *fiber.Error if errors.As(err, &requestErr) { log.Error("Request Error: ", err) return c.Status(fiber.StatusBadRequest).JSON(model.Response[any]{ @@ -47,6 +49,12 @@ func ErrorHandler(c fiber.Ctx) error { Data: nil, Message: "Method not allowed", }) + } else if errors.As(err, &fiberErr) && fiberErr.Message != "" { + return c.Status(fiberErr.Code).JSON(model.Response[any]{ + Success: false, + Data: nil, + Message: fiberErr.Message, + }) } else { var fiberErr *fiber.Error if errors.As(err, &fiberErr) { diff --git a/server/middleware/request_limiter.go b/server/middleware/request_limiter.go new file mode 100644 index 0000000..44d826a --- /dev/null +++ b/server/middleware/request_limiter.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "nysoure/server/utils" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" +) + +func NewRequestLimiter(maxRequests int, duration time.Duration) func(c fiber.Ctx) error { + limiter := utils.NewRequestLimiter(func() int { + return maxRequests + }, duration) + + return func(c fiber.Ctx) error { + if !limiter.AllowRequest(c.IP()) { + log.Warnf("IP %s has exceeded the request limit of %d requests in %s", c.IP(), maxRequests, duration) + return fiber.NewError(fiber.StatusTooManyRequests, "Too many requests") + } + return c.Next() + } +} + +func NewDynamicRequestLimiter(maxRequestsFunc func() int, duration time.Duration) func(c fiber.Ctx) error { + limiter := utils.NewRequestLimiter(maxRequestsFunc, duration) + + return func(c fiber.Ctx) error { + if !limiter.AllowRequest(c.IP()) { + return fiber.NewError(fiber.StatusTooManyRequests, "Too many requests") + } + return c.Next() + } +} diff --git a/server/service/comment.go b/server/service/comment.go index 4e04c25..bef505c 100644 --- a/server/service/comment.go +++ b/server/service/comment.go @@ -3,11 +3,9 @@ package service import ( "nysoure/server/dao" "nysoure/server/model" - "nysoure/server/utils" "regexp" "strconv" "strings" - "time" "github.com/gofiber/fiber/v3/log" ) @@ -19,10 +17,6 @@ const ( maxCommentBriefLength = 256 // Maximum length of a comment brief ) -var ( - commentsLimiter = utils.NewRequestLimiter(maxCommentsPerIP, 24*time.Hour) -) - type CommentRequest struct { Content string `json:"content"` // markdown // Images []uint `json:"images"` // Unrequired after new design @@ -62,11 +56,6 @@ func findImagesInContent(content string, host string) []uint { } func CreateComment(req CommentRequest, userID uint, refID uint, ip string, cType model.CommentType, host string) (*model.CommentView, error) { - if !commentsLimiter.AllowRequest(ip) { - log.Warnf("IP %s has exceeded the comment limit of %d comments per day", ip, maxCommentsPerIP) - return nil, model.NewRequestError("Too many comments from this IP address, please try again later") - } - if len(req.Content) == 0 { return nil, model.NewRequestError("Content cannot be empty") } diff --git a/server/service/file.go b/server/service/file.go index 4276b6b..25bf6c2 100644 --- a/server/service/file.go +++ b/server/service/file.go @@ -13,7 +13,6 @@ import ( "os" "path/filepath" "strconv" - "sync" "time" "github.com/gofiber/fiber/v3/log" @@ -25,23 +24,6 @@ const ( storageKeyUnavailable = "storage_key_unavailable" // Placeholder for unavailable storage key ) -var ( - ipDownloads = sync.Map{} -) - -func init() { - go func() { - for { - // Clean up old IP download records every 24 hours - time.Sleep(24 * time.Hour) - ipDownloads.Range(func(key, value interface{}) bool { - ipDownloads.Delete(key) - return true - }) - } - }() -} - func getUploadingSize() int64 { return dao.GetStatistic("uploading_size") } @@ -405,7 +387,7 @@ func GetFile(fid string) (*model.FileView, error) { } // DownloadFile handles the file download request. Return a presigned URL or a direct file path. -func DownloadFile(ip, fid, cfToken string) (string, string, error) { +func DownloadFile(fid, cfToken string) (string, string, error) { passed, err := verifyCfToken(cfToken) if err != nil { log.Error("failed to verify cf token: ", err) @@ -415,17 +397,6 @@ func DownloadFile(ip, fid, cfToken string) (string, string, error) { log.Info("cf token verification failed") return "", "", model.NewRequestError("cf token verification failed") } - log.Info("File download request from: " + ip) - downloads, _ := ipDownloads.Load(ip) - if downloads == nil { - ipDownloads.Store(ip, 1) - } else { - count := downloads.(int) - if count >= config.MaxDownloadsPerDayForSingleIP() { - return "", "", model.NewRequestError("Too many requests, please try again later") - } - ipDownloads.Store(ip, count+1) - } file, err := dao.GetFile(fid) if err != nil { log.Error("failed to get file: ", err) diff --git a/server/service/image.go b/server/service/image.go index 232d17d..bf3db3b 100644 --- a/server/service/image.go +++ b/server/service/image.go @@ -3,7 +3,6 @@ package service import ( "bytes" "errors" - "github.com/disintegration/imaging" "image" "math" "net/http" @@ -14,14 +13,17 @@ import ( "strconv" "time" + "github.com/disintegration/imaging" + "github.com/gofiber/fiber/v3/log" "github.com/google/uuid" - _ "golang.org/x/image/bmp" _ "image/gif" _ "image/jpeg" _ "image/png" + _ "golang.org/x/image/bmp" + _ "golang.org/x/image/webp" "github.com/chai2010/webp" @@ -54,24 +56,12 @@ func init() { }() } -var ( - imageLimiter = utils.NewRequestLimiter(maxUploadsPerIP, 24*time.Hour) -) - -const maxUploadsPerIP = 100 - func CreateImage(uid uint, ip string, data []byte) (uint, error) { canUpload, err := checkUserCanUpload(uid) if err != nil { log.Error("Error checking user upload permission:", err) return 0, model.NewInternalServerError("Error checking user upload permission") } - if !canUpload { - // For a normal user, check the IP upload limit - if !imageLimiter.AllowRequest(ip) { - return 0, model.NewUnAuthorizedError("You have reached the maximum upload limit") - } - } if len(data) == 0 { return 0, model.NewRequestError("Image data is empty") diff --git a/server/service/user.go b/server/service/user.go index d92eddf..ef7102e 100644 --- a/server/service/user.go +++ b/server/service/user.go @@ -3,7 +3,6 @@ package service import ( "errors" "fmt" - "github.com/gofiber/fiber/v3/log" "nysoure/server/config" "nysoure/server/dao" "nysoure/server/model" @@ -15,6 +14,8 @@ import ( "time" "unicode" + "github.com/gofiber/fiber/v3/log" + "golang.org/x/crypto/bcrypt" ) @@ -376,7 +377,7 @@ func validateUsername(username string) error { if usernameLen < 3 || usernameLen > 20 { return model.NewRequestError("Username must be between 3 and 20 characters") } - for _, r := range []rune(username) { + for _, r := range username { if r == ' ' || r == '\n' || r == '\r' || r == '\t' || r == '\v' || r == '\f' { return model.NewRequestError("Username cannot contain whitespace characters") } diff --git a/server/utils/request_limit.go b/server/utils/request_limit.go index bbab0b0..ea0a4cd 100644 --- a/server/utils/request_limit.go +++ b/server/utils/request_limit.go @@ -6,12 +6,12 @@ import ( ) type RequestLimiter struct { - limit int + limit func() int requestsByIP map[string]int mu sync.Mutex } -func NewRequestLimiter(limit int, duration time.Duration) *RequestLimiter { +func NewRequestLimiter(limit func() int, duration time.Duration) *RequestLimiter { l := &RequestLimiter{ limit: limit, requestsByIP: make(map[string]int), @@ -38,7 +38,7 @@ func (rl *RequestLimiter) AllowRequest(ip string) bool { count = 0 } - if count >= rl.limit { + if count >= rl.limit() { return false // Exceeded request limit for this IP }