Compare commits

...

2 Commits

7 changed files with 36 additions and 12 deletions

View File

@@ -21,6 +21,8 @@ func main() {
app.Use(middleware.ErrorHandler) app.Use(middleware.ErrorHandler)
app.Use(middleware.RealUserMiddleware)
app.Use(middleware.JwtMiddleware) app.Use(middleware.JwtMiddleware)
app.Use(middleware.FrontendMiddleware) app.Use(middleware.FrontendMiddleware)

View File

@@ -205,7 +205,7 @@ func deleteFile(c fiber.Ctx) error {
func downloadFile(c fiber.Ctx) error { func downloadFile(c fiber.Ctx) error {
cfToken := c.Query("cf_token") cfToken := c.Query("cf_token")
s, filename, err := service.DownloadFile(c.Params("id"), cfToken) s, filename, err := service.DownloadFile(c.Params("id"), cfToken, c.Locals("real_user") == true)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -56,7 +56,7 @@ func handleGetResource(c fiber.Ctx) error {
return model.NewRequestError("Invalid resource ID") return model.NewRequestError("Invalid resource ID")
} }
host := c.Hostname() host := c.Hostname()
resource, err := service.GetResource(uint(id), host) resource, err := service.GetResource(uint(id), host, c)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -79,7 +79,7 @@ func serveIndexHtml(c fiber.Ctx) error {
idStr := strings.TrimPrefix(path, "/resources/") idStr := strings.TrimPrefix(path, "/resources/")
id, err := strconv.Atoi(idStr) id, err := strconv.Atoi(idStr)
if err == nil { if err == nil {
r, err := service.GetResource(uint(id), c.Hostname()) r, err := service.GetResource(uint(id), c.Hostname(), c)
if err == nil { if err == nil {
if len(r.Images) > 0 { if len(r.Images) > 0 {
preview = fmt.Sprintf("%s/api/image/%d", serverBaseURL, r.Images[0].ID) preview = fmt.Sprintf("%s/api/image/%d", serverBaseURL, r.Images[0].ID)

View File

@@ -0,0 +1,17 @@
package middleware
import (
"strings"
"github.com/gofiber/fiber/v3"
)
func RealUserMiddleware(c fiber.Ctx) error {
userAgent := c.Get("User-Agent")
if strings.Contains(userAgent, "Mozilla") || strings.Contains(userAgent, "AppleWebKit") {
c.Locals("real_user", true)
} else {
c.Locals("real_user", false)
}
return c.Next()
}

View File

@@ -389,7 +389,7 @@ func GetFile(fid string) (*model.FileView, error) {
} }
// DownloadFile handles the file download request. Return a presigned URL or a direct file path. // DownloadFile handles the file download request. Return a presigned URL or a direct file path.
func DownloadFile(fid, cfToken string) (string, string, error) { func DownloadFile(fid, cfToken string, isRealUser bool) (string, string, error) {
file, err := dao.GetFile(fid) file, err := dao.GetFile(fid)
if err != nil { if err != nil {
log.Error("failed to get file: ", err) log.Error("failed to get file: ", err)
@@ -436,9 +436,11 @@ func DownloadFile(fid, cfToken string) (string, string, error) {
return "", "", model.NewInternalServerError("failed to download file from storage") return "", "", model.NewInternalServerError("failed to download file from storage")
} }
err = dao.AddResourceDownloadCount(file.ResourceID) if isRealUser {
if err != nil { err = dao.AddResourceDownloadCount(file.ResourceID)
log.Errorf("failed to add resource download count: %v", err) if err != nil {
log.Errorf("failed to add resource download count: %v", err)
}
} }
return path, file.Filename, nil return path, file.Filename, nil

View File

@@ -8,6 +8,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/log"
"gorm.io/gorm" "gorm.io/gorm"
@@ -122,14 +123,16 @@ func parseResourceIfPresent(line string, host string) *model.ResourceView {
return &v return &v
} }
func GetResource(id uint, host string) (*model.ResourceDetailView, error) { func GetResource(id uint, host string, ctx fiber.Ctx) (*model.ResourceDetailView, error) {
r, err := dao.GetResourceByID(id) r, err := dao.GetResourceByID(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = dao.AddResourceViewCount(id) if ctx != nil && ctx.Locals("real_user") == true {
if err != nil { err = dao.AddResourceViewCount(id)
log.Error("AddResourceViewCount error: ", err) if err != nil {
log.Error("AddResourceViewCount error: ", err)
}
} }
v := r.ToDetailView() v := r.ToDetailView()
if host != "" { if host != "" {
@@ -177,7 +180,7 @@ func DeleteResource(uid, id uint) error {
return model.NewUnAuthorizedError("You have not permission to delete this resource") return model.NewUnAuthorizedError("You have not permission to delete this resource")
} }
} }
r, err := GetResource(id, "") r, err := GetResource(id, "", nil)
if err != nil { if err != nil {
return err return err
} }