feat: stop server download task

This commit is contained in:
2025-07-25 16:45:45 +08:00
parent 99c69d3b7d
commit def3e8bb0f

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"io" "io"
@@ -13,6 +14,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync/atomic"
"time" "time"
"github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/log"
@@ -471,17 +473,22 @@ func testFileUrl(url string) (int64, error) {
return contentLength, nil return contentLength, nil
} }
func downloadFile(url string, path string) error { // downloadFile return nil if the download is successful or the context is cancelled
func downloadFile(ctx context.Context, url string, path string) error {
if _, err := os.Stat(path); err == nil { if _, err := os.Stat(path); err == nil {
_ = os.Remove(path) // Remove the file if it already exists _ = os.Remove(path) // Remove the file if it already exists
} }
client := http.Client{} req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return model.NewRequestError("failed to create HTTP request") return model.NewRequestError("failed to create HTTP request")
} }
client := http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
// Check if the error is due to context cancellation
if ctx.Err() != nil {
return nil
}
return model.NewRequestError("failed to send HTTP request") return model.NewRequestError("failed to send HTTP request")
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -493,10 +500,30 @@ func downloadFile(url string, path string) error {
return model.NewInternalServerError("failed to open file for writing") return model.NewInternalServerError("failed to open file for writing")
} }
defer file.Close() defer file.Close()
if _, err := io.Copy(file, resp.Body); err != nil {
return model.NewInternalServerError("failed to copy response body to file") buf := make([]byte, 64*1024)
for {
select {
case <-ctx.Done():
return nil
default:
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := file.Write(buf[:n]); writeErr != nil {
return model.NewInternalServerError("failed to write to file")
}
}
if readErr != nil {
if readErr == io.EOF {
return nil // Download completed successfully
}
if ctx.Err() != nil {
return nil // Context cancelled, return nil
}
return model.NewInternalServerError("failed to read response body")
}
}
} }
return nil
} }
func CreateServerDownloadTask(uid uint, url, filename, description string, resourceID, storageID uint) (*model.FileView, error) { func CreateServerDownloadTask(uid uint, url, filename, description string, resourceID, storageID uint) (*model.FileView, error) {
@@ -529,19 +556,53 @@ func CreateServerDownloadTask(uid uint, url, filename, description string, resou
updateUploadingSize(contentLength) updateUploadingSize(contentLength)
go func() { go func() {
ctx, cancel := context.WithCancel(context.Background())
done := atomic.Bool{}
go func() {
for {
time.Sleep(10 * time.Second)
if done.Load() {
return
}
// Stop the task if the file is deleted
if _, err := dao.GetFileByID(file.ID); err != nil {
log.Info("File deleted by user, stopping download task: ", file.UUID)
done.Store(true)
cancel()
return
}
}
}()
defer func() {
done.Store(true)
}()
defer func() { defer func() {
updateUploadingSize(-contentLength) updateUploadingSize(-contentLength)
}() }()
tempPath := filepath.Join(utils.GetStoragePath(), uuid.NewString())
tempDir := filepath.Join(utils.GetStoragePath(), "temp")
if err := os.MkdirAll(tempDir, os.ModePerm); err != nil {
log.Error("failed to create temp dir: ", err)
_ = dao.DeleteFile(file.UUID)
return
}
tempPath := filepath.Join(utils.GetStoragePath(), "temp", uuid.NewString())
defer func() { defer func() {
if err := os.Remove(tempPath); err != nil { if err := os.Remove(tempPath); err != nil {
log.Error("failed to remove temp file: ", err) log.Error("failed to remove temp file: ", err)
} }
}() }()
for i := 0; i < 3; i++ { for i := range 3 {
if err := downloadFile(url, tempPath); err != nil { if done.Load() {
return
}
if err := downloadFile(ctx, url, tempPath); err != nil {
log.Error("failed to download file: ", err) log.Error("failed to download file: ", err)
if i == 2 { if i == 2 {
_ = dao.DeleteFile(file.UUID) _ = dao.DeleteFile(file.UUID)
@@ -557,6 +618,10 @@ func CreateServerDownloadTask(uid uint, url, filename, description string, resou
} }
} }
if done.Load() {
return
}
stat, err := os.Stat(tempPath) stat, err := os.Stat(tempPath)
if err != nil { if err != nil {
log.Error("failed to get temp file info: ", err) log.Error("failed to get temp file info: ", err)