From def3e8bb0f74f77ff8a64b3fec2441ab29ceb9e4 Mon Sep 17 00:00:00 2001 From: nyne Date: Fri, 25 Jul 2025 16:45:45 +0800 Subject: [PATCH] feat: stop server download task --- server/service/file.go | 83 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 74 insertions(+), 9 deletions(-) diff --git a/server/service/file.go b/server/service/file.go index a6ab9b7..a903e1e 100644 --- a/server/service/file.go +++ b/server/service/file.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/md5" "encoding/hex" "io" @@ -13,6 +14,7 @@ import ( "os" "path/filepath" "strconv" + "sync/atomic" "time" "github.com/gofiber/fiber/v3/log" @@ -471,17 +473,22 @@ func testFileUrl(url string) (int64, error) { return contentLength, nil } -func downloadFile(url string, path string) error { +// downloadFile return nil if the download is successful or the context is cancelled +func downloadFile(ctx context.Context, url string, path string) error { if _, err := os.Stat(path); err == nil { _ = os.Remove(path) // Remove the file if it already exists } - client := http.Client{} - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return model.NewRequestError("failed to create HTTP request") } + client := http.Client{} resp, err := client.Do(req) if err != nil { + // Check if the error is due to context cancellation + if ctx.Err() != nil { + return nil + } return model.NewRequestError("failed to send HTTP request") } defer resp.Body.Close() @@ -493,10 +500,30 @@ func downloadFile(url string, path string) error { return model.NewInternalServerError("failed to open file for writing") } defer file.Close() - if _, err := io.Copy(file, resp.Body); err != nil { - return model.NewInternalServerError("failed to copy response body to file") + + buf := make([]byte, 64*1024) + for { + select { + case <-ctx.Done(): + return nil + default: + n, readErr := resp.Body.Read(buf) + if n > 0 { + if _, writeErr := file.Write(buf[:n]); writeErr != nil { + return model.NewInternalServerError("failed to write to file") + } + } + if readErr != nil { + if readErr == io.EOF { + return nil // Download completed successfully + } + if ctx.Err() != nil { + return nil // Context cancelled, return nil + } + return model.NewInternalServerError("failed to read response body") + } + } } - return nil } func CreateServerDownloadTask(uid uint, url, filename, description string, resourceID, storageID uint) (*model.FileView, error) { @@ -529,19 +556,53 @@ func CreateServerDownloadTask(uid uint, url, filename, description string, resou updateUploadingSize(contentLength) go func() { + ctx, cancel := context.WithCancel(context.Background()) + + done := atomic.Bool{} + + go func() { + for { + time.Sleep(10 * time.Second) + if done.Load() { + return + } + // Stop the task if the file is deleted + if _, err := dao.GetFileByID(file.ID); err != nil { + log.Info("File deleted by user, stopping download task: ", file.UUID) + done.Store(true) + cancel() + return + } + } + }() + + defer func() { + done.Store(true) + }() + defer func() { updateUploadingSize(-contentLength) }() - tempPath := filepath.Join(utils.GetStoragePath(), uuid.NewString()) + tempDir := filepath.Join(utils.GetStoragePath(), "temp") + if err := os.MkdirAll(tempDir, os.ModePerm); err != nil { + log.Error("failed to create temp dir: ", err) + _ = dao.DeleteFile(file.UUID) + return + } + + tempPath := filepath.Join(utils.GetStoragePath(), "temp", uuid.NewString()) defer func() { if err := os.Remove(tempPath); err != nil { log.Error("failed to remove temp file: ", err) } }() - for i := 0; i < 3; i++ { - if err := downloadFile(url, tempPath); err != nil { + for i := range 3 { + if done.Load() { + return + } + if err := downloadFile(ctx, url, tempPath); err != nil { log.Error("failed to download file: ", err) if i == 2 { _ = dao.DeleteFile(file.UUID) @@ -557,6 +618,10 @@ func CreateServerDownloadTask(uid uint, url, filename, description string, resou } } + if done.Load() { + return + } + stat, err := os.Stat(tempPath) if err != nil { log.Error("failed to get temp file info: ", err)