mirror of
https://github.com/wgh136/nysoure.git
synced 2025-09-27 04:17:23 +00:00
Move request limiter to middleware.
This commit is contained in:
@@ -2,17 +2,19 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"nysoure/server/middleware"
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
"nysoure/server/service"
|
"nysoure/server/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3"
|
"github.com/gofiber/fiber/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddCommentRoutes(router fiber.Router) {
|
func AddCommentRoutes(router fiber.Router) {
|
||||||
api := router.Group("/comments")
|
api := router.Group("/comments")
|
||||||
api.Post("/resource/:resourceID", createResourceComment)
|
api.Use(middleware.NewRequestLimiter(500, 24*time.Hour)).Post("/resource/:resourceID", createResourceComment)
|
||||||
api.Post("/reply/:commentID", createReplyComment)
|
api.Use(middleware.NewRequestLimiter(500, 24*time.Hour)).Post("/reply/:commentID", createReplyComment)
|
||||||
api.Get("/resource/:resourceID", listResourceComments)
|
api.Get("/resource/:resourceID", listResourceComments)
|
||||||
api.Get("/reply/:commentID", listReplyComments)
|
api.Get("/reply/:commentID", listReplyComments)
|
||||||
api.Get("/user/:username", listCommentsByUser)
|
api.Get("/user/:username", listCommentsByUser)
|
||||||
|
@@ -4,11 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"nysoure/server/config"
|
||||||
|
"nysoure/server/middleware"
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
"nysoure/server/service"
|
"nysoure/server/service"
|
||||||
"nysoure/server/utils"
|
"nysoure/server/utils"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3"
|
"github.com/gofiber/fiber/v3"
|
||||||
)
|
)
|
||||||
@@ -16,17 +19,17 @@ import (
|
|||||||
func AddFileRoutes(router fiber.Router) {
|
func AddFileRoutes(router fiber.Router) {
|
||||||
fileGroup := router.Group("/files")
|
fileGroup := router.Group("/files")
|
||||||
{
|
{
|
||||||
fileGroup.Post("/upload/init", initUpload)
|
fileGroup.Use(middleware.NewRequestLimiter(10, time.Hour)).Post("/upload/init", initUpload)
|
||||||
fileGroup.Post("/upload/block/:id/:index", uploadBlock)
|
fileGroup.Post("/upload/block/:id/:index", uploadBlock)
|
||||||
fileGroup.Post("/upload/finish/:id", finishUpload)
|
fileGroup.Post("/upload/finish/:id", finishUpload)
|
||||||
fileGroup.Post("/upload/cancel/:id", cancelUpload)
|
fileGroup.Post("/upload/cancel/:id", cancelUpload)
|
||||||
fileGroup.Post("/redirect", createRedirectFile)
|
fileGroup.Use(middleware.NewRequestLimiter(50, time.Hour)).Post("/redirect", createRedirectFile)
|
||||||
fileGroup.Post("/upload/url", createServerDownloadTask)
|
fileGroup.Post("/upload/url", createServerDownloadTask)
|
||||||
fileGroup.Get("/:id", getFile)
|
fileGroup.Get("/:id", getFile)
|
||||||
fileGroup.Put("/:id", updateFile)
|
fileGroup.Put("/:id", updateFile)
|
||||||
fileGroup.Delete("/:id", deleteFile)
|
fileGroup.Delete("/:id", deleteFile)
|
||||||
fileGroup.Get("/download/local", downloadLocalFile)
|
fileGroup.Get("/download/local", downloadLocalFile)
|
||||||
fileGroup.Get("/download/:id", downloadFile)
|
fileGroup.Use(middleware.NewDynamicRequestLimiter(config.MaxDownloadsPerDayForSingleIP, 24*time.Hour)).Get("/download/:id", downloadFile)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,8 +204,7 @@ func deleteFile(c fiber.Ctx) error {
|
|||||||
|
|
||||||
func downloadFile(c fiber.Ctx) error {
|
func downloadFile(c fiber.Ctx) error {
|
||||||
cfToken := c.Query("cf_token")
|
cfToken := c.Query("cf_token")
|
||||||
ip := c.IP()
|
s, filename, err := service.DownloadFile(c.Params("id"), cfToken)
|
||||||
s, filename, err := service.DownloadFile(ip, c.Params("id"), cfToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -1,12 +1,15 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gofiber/fiber/v3"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"nysoure/server/middleware"
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
"nysoure/server/service"
|
"nysoure/server/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleUploadImage(c fiber.Ctx) error {
|
func handleUploadImage(c fiber.Ctx) error {
|
||||||
@@ -94,7 +97,7 @@ func handleGetResampledImage(c fiber.Ctx) error {
|
|||||||
func AddImageRoutes(api fiber.Router) {
|
func AddImageRoutes(api fiber.Router) {
|
||||||
image := api.Group("/image")
|
image := api.Group("/image")
|
||||||
{
|
{
|
||||||
image.Put("/", handleUploadImage)
|
image.Use(middleware.NewRequestLimiter(50, time.Hour)).Put("/", handleUploadImage)
|
||||||
image.Get("/resampled/:id", handleGetResampledImage)
|
image.Get("/resampled/:id", handleGetResampledImage)
|
||||||
image.Get("/:id", handleGetImage)
|
image.Get("/:id", handleGetImage)
|
||||||
image.Delete("/:id", handleDeleteImage)
|
image.Delete("/:id", handleDeleteImage)
|
||||||
|
@@ -4,9 +4,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"nysoure/server/middleware"
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
"nysoure/server/service"
|
"nysoure/server/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3"
|
"github.com/gofiber/fiber/v3"
|
||||||
)
|
)
|
||||||
@@ -342,7 +344,7 @@ func handleGetMe(c fiber.Ctx) error {
|
|||||||
|
|
||||||
func AddUserRoutes(r fiber.Router) {
|
func AddUserRoutes(r fiber.Router) {
|
||||||
u := r.Group("user")
|
u := r.Group("user")
|
||||||
u.Post("/register", handleUserRegister)
|
u.Use(middleware.NewRequestLimiter(5, time.Hour)).Post("/register", handleUserRegister)
|
||||||
u.Post("/login", handleUserLogin)
|
u.Post("/login", handleUserLogin)
|
||||||
u.Put("/avatar", handleUserChangeAvatar)
|
u.Put("/avatar", handleUserChangeAvatar)
|
||||||
u.Post("/password", handleUserChangePassword)
|
u.Post("/password", handleUserChangePassword)
|
||||||
|
@@ -2,9 +2,10 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/gofiber/fiber/v3/log"
|
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v3/log"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3"
|
"github.com/gofiber/fiber/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ func ErrorHandler(c fiber.Ctx) error {
|
|||||||
var requestErr *model.RequestError
|
var requestErr *model.RequestError
|
||||||
var unauthorizedErr *model.UnAuthorizedError
|
var unauthorizedErr *model.UnAuthorizedError
|
||||||
var notFoundErr *model.NotFoundError
|
var notFoundErr *model.NotFoundError
|
||||||
|
var fiberErr *fiber.Error
|
||||||
if errors.As(err, &requestErr) {
|
if errors.As(err, &requestErr) {
|
||||||
log.Error("Request Error: ", err)
|
log.Error("Request Error: ", err)
|
||||||
return c.Status(fiber.StatusBadRequest).JSON(model.Response[any]{
|
return c.Status(fiber.StatusBadRequest).JSON(model.Response[any]{
|
||||||
@@ -47,6 +49,12 @@ func ErrorHandler(c fiber.Ctx) error {
|
|||||||
Data: nil,
|
Data: nil,
|
||||||
Message: "Method not allowed",
|
Message: "Method not allowed",
|
||||||
})
|
})
|
||||||
|
} else if errors.As(err, &fiberErr) && fiberErr.Message != "" {
|
||||||
|
return c.Status(fiberErr.Code).JSON(model.Response[any]{
|
||||||
|
Success: false,
|
||||||
|
Data: nil,
|
||||||
|
Message: fiberErr.Message,
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
var fiberErr *fiber.Error
|
var fiberErr *fiber.Error
|
||||||
if errors.As(err, &fiberErr) {
|
if errors.As(err, &fiberErr) {
|
||||||
|
34
server/middleware/request_limiter.go
Normal file
34
server/middleware/request_limiter.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"nysoure/server/utils"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v3"
|
||||||
|
"github.com/gofiber/fiber/v3/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewRequestLimiter(maxRequests int, duration time.Duration) func(c fiber.Ctx) error {
|
||||||
|
limiter := utils.NewRequestLimiter(func() int {
|
||||||
|
return maxRequests
|
||||||
|
}, duration)
|
||||||
|
|
||||||
|
return func(c fiber.Ctx) error {
|
||||||
|
if !limiter.AllowRequest(c.IP()) {
|
||||||
|
log.Warnf("IP %s has exceeded the request limit of %d requests in %s", c.IP(), maxRequests, duration)
|
||||||
|
return fiber.NewError(fiber.StatusTooManyRequests, "Too many requests")
|
||||||
|
}
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDynamicRequestLimiter(maxRequestsFunc func() int, duration time.Duration) func(c fiber.Ctx) error {
|
||||||
|
limiter := utils.NewRequestLimiter(maxRequestsFunc, duration)
|
||||||
|
|
||||||
|
return func(c fiber.Ctx) error {
|
||||||
|
if !limiter.AllowRequest(c.IP()) {
|
||||||
|
return fiber.NewError(fiber.StatusTooManyRequests, "Too many requests")
|
||||||
|
}
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
}
|
@@ -3,11 +3,9 @@ package service
|
|||||||
import (
|
import (
|
||||||
"nysoure/server/dao"
|
"nysoure/server/dao"
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
"nysoure/server/utils"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3/log"
|
"github.com/gofiber/fiber/v3/log"
|
||||||
)
|
)
|
||||||
@@ -19,10 +17,6 @@ const (
|
|||||||
maxCommentBriefLength = 256 // Maximum length of a comment brief
|
maxCommentBriefLength = 256 // Maximum length of a comment brief
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
commentsLimiter = utils.NewRequestLimiter(maxCommentsPerIP, 24*time.Hour)
|
|
||||||
)
|
|
||||||
|
|
||||||
type CommentRequest struct {
|
type CommentRequest struct {
|
||||||
Content string `json:"content"` // markdown
|
Content string `json:"content"` // markdown
|
||||||
// Images []uint `json:"images"` // Unrequired after new design
|
// Images []uint `json:"images"` // Unrequired after new design
|
||||||
@@ -62,11 +56,6 @@ func findImagesInContent(content string, host string) []uint {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateComment(req CommentRequest, userID uint, refID uint, ip string, cType model.CommentType, host string) (*model.CommentView, error) {
|
func CreateComment(req CommentRequest, userID uint, refID uint, ip string, cType model.CommentType, host string) (*model.CommentView, error) {
|
||||||
if !commentsLimiter.AllowRequest(ip) {
|
|
||||||
log.Warnf("IP %s has exceeded the comment limit of %d comments per day", ip, maxCommentsPerIP)
|
|
||||||
return nil, model.NewRequestError("Too many comments from this IP address, please try again later")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(req.Content) == 0 {
|
if len(req.Content) == 0 {
|
||||||
return nil, model.NewRequestError("Content cannot be empty")
|
return nil, model.NewRequestError("Content cannot be empty")
|
||||||
}
|
}
|
||||||
|
@@ -13,7 +13,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3/log"
|
"github.com/gofiber/fiber/v3/log"
|
||||||
@@ -25,23 +24,6 @@ const (
|
|||||||
storageKeyUnavailable = "storage_key_unavailable" // Placeholder for unavailable storage key
|
storageKeyUnavailable = "storage_key_unavailable" // Placeholder for unavailable storage key
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ipDownloads = sync.Map{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
// Clean up old IP download records every 24 hours
|
|
||||||
time.Sleep(24 * time.Hour)
|
|
||||||
ipDownloads.Range(func(key, value interface{}) bool {
|
|
||||||
ipDownloads.Delete(key)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUploadingSize() int64 {
|
func getUploadingSize() int64 {
|
||||||
return dao.GetStatistic("uploading_size")
|
return dao.GetStatistic("uploading_size")
|
||||||
}
|
}
|
||||||
@@ -405,7 +387,7 @@ func GetFile(fid string) (*model.FileView, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DownloadFile handles the file download request. Return a presigned URL or a direct file path.
|
// DownloadFile handles the file download request. Return a presigned URL or a direct file path.
|
||||||
func DownloadFile(ip, fid, cfToken string) (string, string, error) {
|
func DownloadFile(fid, cfToken string) (string, string, error) {
|
||||||
passed, err := verifyCfToken(cfToken)
|
passed, err := verifyCfToken(cfToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("failed to verify cf token: ", err)
|
log.Error("failed to verify cf token: ", err)
|
||||||
@@ -415,17 +397,6 @@ func DownloadFile(ip, fid, cfToken string) (string, string, error) {
|
|||||||
log.Info("cf token verification failed")
|
log.Info("cf token verification failed")
|
||||||
return "", "", model.NewRequestError("cf token verification failed")
|
return "", "", model.NewRequestError("cf token verification failed")
|
||||||
}
|
}
|
||||||
log.Info("File download request from: " + ip)
|
|
||||||
downloads, _ := ipDownloads.Load(ip)
|
|
||||||
if downloads == nil {
|
|
||||||
ipDownloads.Store(ip, 1)
|
|
||||||
} else {
|
|
||||||
count := downloads.(int)
|
|
||||||
if count >= config.MaxDownloadsPerDayForSingleIP() {
|
|
||||||
return "", "", model.NewRequestError("Too many requests, please try again later")
|
|
||||||
}
|
|
||||||
ipDownloads.Store(ip, count+1)
|
|
||||||
}
|
|
||||||
file, err := dao.GetFile(fid)
|
file, err := dao.GetFile(fid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("failed to get file: ", err)
|
log.Error("failed to get file: ", err)
|
||||||
|
@@ -3,7 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/disintegration/imaging"
|
|
||||||
"image"
|
"image"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -14,14 +13,17 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/disintegration/imaging"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3/log"
|
"github.com/gofiber/fiber/v3/log"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
_ "golang.org/x/image/bmp"
|
|
||||||
_ "image/gif"
|
_ "image/gif"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
|
|
||||||
|
_ "golang.org/x/image/bmp"
|
||||||
|
|
||||||
_ "golang.org/x/image/webp"
|
_ "golang.org/x/image/webp"
|
||||||
|
|
||||||
"github.com/chai2010/webp"
|
"github.com/chai2010/webp"
|
||||||
@@ -54,24 +56,12 @@ func init() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
imageLimiter = utils.NewRequestLimiter(maxUploadsPerIP, 24*time.Hour)
|
|
||||||
)
|
|
||||||
|
|
||||||
const maxUploadsPerIP = 100
|
|
||||||
|
|
||||||
func CreateImage(uid uint, ip string, data []byte) (uint, error) {
|
func CreateImage(uid uint, ip string, data []byte) (uint, error) {
|
||||||
canUpload, err := checkUserCanUpload(uid)
|
canUpload, err := checkUserCanUpload(uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error checking user upload permission:", err)
|
log.Error("Error checking user upload permission:", err)
|
||||||
return 0, model.NewInternalServerError("Error checking user upload permission")
|
return 0, model.NewInternalServerError("Error checking user upload permission")
|
||||||
}
|
}
|
||||||
if !canUpload {
|
|
||||||
// For a normal user, check the IP upload limit
|
|
||||||
if !imageLimiter.AllowRequest(ip) {
|
|
||||||
return 0, model.NewUnAuthorizedError("You have reached the maximum upload limit")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return 0, model.NewRequestError("Image data is empty")
|
return 0, model.NewRequestError("Image data is empty")
|
||||||
|
@@ -3,7 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gofiber/fiber/v3/log"
|
|
||||||
"nysoure/server/config"
|
"nysoure/server/config"
|
||||||
"nysoure/server/dao"
|
"nysoure/server/dao"
|
||||||
"nysoure/server/model"
|
"nysoure/server/model"
|
||||||
@@ -15,6 +14,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v3/log"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -376,7 +377,7 @@ func validateUsername(username string) error {
|
|||||||
if usernameLen < 3 || usernameLen > 20 {
|
if usernameLen < 3 || usernameLen > 20 {
|
||||||
return model.NewRequestError("Username must be between 3 and 20 characters")
|
return model.NewRequestError("Username must be between 3 and 20 characters")
|
||||||
}
|
}
|
||||||
for _, r := range []rune(username) {
|
for _, r := range username {
|
||||||
if r == ' ' || r == '\n' || r == '\r' || r == '\t' || r == '\v' || r == '\f' {
|
if r == ' ' || r == '\n' || r == '\r' || r == '\t' || r == '\v' || r == '\f' {
|
||||||
return model.NewRequestError("Username cannot contain whitespace characters")
|
return model.NewRequestError("Username cannot contain whitespace characters")
|
||||||
}
|
}
|
||||||
|
@@ -6,12 +6,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RequestLimiter struct {
|
type RequestLimiter struct {
|
||||||
limit int
|
limit func() int
|
||||||
requestsByIP map[string]int
|
requestsByIP map[string]int
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRequestLimiter(limit int, duration time.Duration) *RequestLimiter {
|
func NewRequestLimiter(limit func() int, duration time.Duration) *RequestLimiter {
|
||||||
l := &RequestLimiter{
|
l := &RequestLimiter{
|
||||||
limit: limit,
|
limit: limit,
|
||||||
requestsByIP: make(map[string]int),
|
requestsByIP: make(map[string]int),
|
||||||
@@ -38,7 +38,7 @@ func (rl *RequestLimiter) AllowRequest(ip string) bool {
|
|||||||
count = 0
|
count = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if count >= rl.limit {
|
if count >= rl.limit() {
|
||||||
return false // Exceeded request limit for this IP
|
return false // Exceeded request limit for this IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user