修复进度和潜在的锁问题

This commit is contained in:
NewName
2025-05-17 13:30:47 +08:00
parent 3a669a2db8
commit b37a0af3d4

View File

@@ -3,6 +3,7 @@ package main
import (
"archive/zip"
"bufio"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
@@ -14,10 +15,12 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
)
// 任务状态
@@ -37,6 +40,7 @@ type ImageTask struct {
Status string `json:"status"`
Error string `json:"error,omitempty"`
OutputPath string `json:"-"` // 输出文件路径,不发送给客户端
lock sync.Mutex `json:"-"` // 镜像任务自己的锁
}
// 下载任务
@@ -47,7 +51,19 @@ type DownloadTask struct {
Status TaskStatus `json:"status"`
OutputFile string `json:"-"` // 最终输出文件
TempDir string `json:"-"` // 临时目录
Lock sync.Mutex `json:"-"` // 锁,防止并发冲突
StatusLock sync.RWMutex `json:"-"` // 状态锁,使用读写锁提高并发性
ProgressLock sync.RWMutex `json:"-"` // 进度锁
ImageLock sync.RWMutex `json:"-"` // 镜像列表锁
updateChan chan *ProgressUpdate `json:"-"` // 进度更新通道
}
// 进度更新消息
type ProgressUpdate struct {
TaskID string
ImageIndex int
Progress float64
Status string
Error string
}
// WebSocket客户端
@@ -170,7 +186,39 @@ func getTaskStatus(c *gin.Context) {
return
}
c.JSON(http.StatusOK, task)
// 创建任务状态副本以避免序列化过程中的锁
taskCopy := &DownloadTask{
ID: task.ID,
TotalProgress: 0,
Status: TaskStatus(""),
Images: nil,
}
// 复制状态信息
task.StatusLock.RLock()
taskCopy.Status = task.Status
task.StatusLock.RUnlock()
task.ProgressLock.RLock()
taskCopy.TotalProgress = task.TotalProgress
task.ProgressLock.RUnlock()
// 复制镜像信息
task.ImageLock.RLock()
taskCopy.Images = make([]*ImageTask, len(task.Images))
for i, img := range task.Images {
img.lock.Lock()
taskCopy.Images[i] = &ImageTask{
Image: img.Image,
Progress: img.Progress,
Status: img.Status,
Error: img.Error,
}
img.lock.Unlock()
}
task.ImageLock.RUnlock()
c.JSON(http.StatusOK, taskCopy)
}
// 生成随机任务ID
@@ -180,6 +228,93 @@ func generateTaskID() string {
return hex.EncodeToString(b)
}
// 初始化任务并启动进度处理器
func initTask(task *DownloadTask) {
// 创建进度更新通道
task.updateChan = make(chan *ProgressUpdate, 100)
// 启动进度处理goroutine
go func() {
for update := range task.updateChan {
if update == nil {
// 通道关闭信号
break
}
// 获取更新的镜像
task.ImageLock.RLock()
if update.ImageIndex < 0 || update.ImageIndex >= len(task.Images) {
task.ImageLock.RUnlock()
continue
}
imgTask := task.Images[update.ImageIndex]
task.ImageLock.RUnlock()
// 更新镜像进度和状态
imgTask.lock.Lock()
if update.Progress > 0 {
imgTask.Progress = update.Progress
}
if update.Status != "" {
imgTask.Status = update.Status
}
if update.Error != "" {
imgTask.Error = update.Error
}
imgTask.lock.Unlock()
// 更新总进度
updateTaskTotalProgress(task)
// 发送更新到客户端
sendTaskUpdate(task)
}
}()
}
// 发送进度更新
func sendProgressUpdate(task *DownloadTask, index int, progress float64, status string, errorMsg string) {
select {
case task.updateChan <- &ProgressUpdate{
TaskID: task.ID,
ImageIndex: index,
Progress: progress,
Status: status,
Error: errorMsg,
}:
// 成功发送
default:
// 通道已满,丢弃更新
fmt.Printf("Warning: Update channel for task %s is full\n", task.ID)
}
}
// 更新总进度 - 不需要外部锁
func updateTaskTotalProgress(task *DownloadTask) {
task.ProgressLock.Lock()
defer task.ProgressLock.Unlock()
totalProgress := 0.0
task.ImageLock.RLock()
imageCount := len(task.Images)
task.ImageLock.RUnlock()
if imageCount == 0 {
return
}
task.ImageLock.RLock()
for _, img := range task.Images {
img.lock.Lock()
totalProgress += img.Progress
img.lock.Unlock()
}
task.ImageLock.RUnlock()
task.TotalProgress = totalProgress / float64(imageCount)
}
// 处理下载请求
func handleDownload(c *gin.Context) {
type DownloadRequest struct {
@@ -220,6 +355,9 @@ func handleDownload(c *gin.Context) {
Status: StatusPending,
TempDir: tempDir,
}
// 初始化任务通道和处理器
initTask(task)
// 保存任务
tasksLock.Lock()
@@ -229,6 +367,8 @@ func handleDownload(c *gin.Context) {
// 异步处理下载
go func() {
processDownloadTask(task, req.Platform)
// 任务完成后关闭更新通道
close(task.updateChan)
}()
c.JSON(http.StatusOK, gin.H{
@@ -239,86 +379,146 @@ func handleDownload(c *gin.Context) {
// 处理下载任务
func processDownloadTask(task *DownloadTask, platform string) {
task.Lock.Lock()
// 设置任务状态为运行中
task.StatusLock.Lock()
task.Status = StatusRunning
task.Lock.Unlock()
task.StatusLock.Unlock()
// 通知客户端任务已开始
sendTaskUpdate(task)
// 使用WaitGroup等待所有镜像下载完成
var wg sync.WaitGroup
wg.Add(len(task.Images))
// 使用并发下载镜像
for i, imgTask := range task.Images {
go func(idx int, imgTask *ImageTask) {
defer wg.Done()
downloadImage(task, idx, imgTask, platform)
}(i, imgTask)
// 创建错误组用于管理所有下载goroutine
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // 确保资源被释放
g, ctx := errgroup.WithContext(ctx)
// 启动并发下载
task.ImageLock.RLock()
imageCount := len(task.Images)
task.ImageLock.RUnlock()
// 创建工作池限制并发数
const maxConcurrent = 5
semaphore := make(chan struct{}, maxConcurrent)
// 添加下载任务
for i := 0; i < imageCount; i++ {
index := i // 捕获循环变量
g.Go(func() error {
// 获取信号量,限制并发
semaphore <- struct{}{}
defer func() { <-semaphore }()
task.ImageLock.RLock()
imgTask := task.Images[index]
task.ImageLock.RUnlock()
// 下载镜像
err := downloadImageWithContext(ctx, task, index, imgTask, platform)
if err != nil {
fmt.Printf("镜像 %s 下载失败: %v\n", imgTask.Image, err)
return err
}
return nil
})
}
// 等待所有下载完成
wg.Wait()
err := g.Wait()
// 检查是否有错误发生
if err != nil {
task.StatusLock.Lock()
task.Status = StatusFailed
task.StatusLock.Unlock()
sendTaskUpdate(task)
return
}
// 判断是单个tar还是需要打包
var finalFilePath string
var err error
task.Lock.Lock()
task.StatusLock.Lock()
// 检查是否所有镜像都下载成功
allSuccess := true
task.ImageLock.RLock()
for _, img := range task.Images {
if img.Status == string(StatusFailed) {
img.lock.Lock()
if img.Status != string(StatusCompleted) {
allSuccess = false
break
}
img.lock.Unlock()
}
task.ImageLock.RUnlock()
if !allSuccess {
task.Status = StatusFailed
task.Lock.Unlock()
task.StatusLock.Unlock()
sendTaskUpdate(task)
return
}
// 如果只有一个文件,直接使用它
if len(task.Images) == 1 && task.Images[0].Status == string(StatusCompleted) {
finalFilePath = task.Images[0].OutputPath
// 重命名为更友好的名称
imageName := strings.ReplaceAll(task.Images[0].Image, "/", "_")
imageName = strings.ReplaceAll(imageName, ":", "_")
newPath := filepath.Join(task.TempDir, imageName+".tar")
os.Rename(finalFilePath, newPath)
finalFilePath = newPath
task.ImageLock.RLock()
if imageCount == 1 {
imgTask := task.Images[0]
imgTask.lock.Lock()
if imgTask.Status == string(StatusCompleted) {
finalFilePath = imgTask.OutputPath
// 重命名为更友好的名称
imageName := strings.ReplaceAll(imgTask.Image, "/", "_")
imageName = strings.ReplaceAll(imageName, ":", "_")
newPath := filepath.Join(task.TempDir, imageName+".tar")
os.Rename(finalFilePath, newPath)
finalFilePath = newPath
}
imgTask.lock.Unlock()
} else {
// 多个文件打包成zip
finalFilePath, err = createZipArchive(task)
if err != nil {
task.ImageLock.RUnlock()
var zipErr error
finalFilePath, zipErr = createZipArchive(task)
if zipErr != nil {
task.Status = StatusFailed
task.Lock.Unlock()
task.StatusLock.Unlock()
sendTaskUpdate(task)
return
}
}
if imageCount == 1 {
task.ImageLock.RUnlock()
}
task.OutputFile = finalFilePath
task.Status = StatusCompleted
// 设置总进度为100%
task.ProgressLock.Lock()
task.TotalProgress = 100
task.Lock.Unlock()
task.ProgressLock.Unlock()
task.StatusLock.Unlock()
// 发送最终状态更新
sendTaskUpdate(task)
}
// 下载单个镜像
func downloadImage(task *DownloadTask, index int, imgTask *ImageTask, platform string) {
imgTask.Status = string(StatusRunning)
sendImageUpdate(task, index)
// 下载单个镜像(带上下文控制)
func downloadImageWithContext(ctx context.Context, task *DownloadTask, index int, imgTask *ImageTask, platform string) error {
// 更新状态为运行中
sendProgressUpdate(task, index, 0, string(StatusRunning), "")
// 创建输出文件名
outputFileName := fmt.Sprintf("image_%d.tar", index)
outputPath := filepath.Join(task.TempDir, outputFileName)
imgTask.lock.Lock()
imgTask.OutputPath = outputPath
imgTask.lock.Unlock()
// 创建skopeo命令
platformArg := ""
@@ -345,158 +545,180 @@ func downloadImage(task *DownloadTask, index int, imgTask *ImageTask, platform s
}
// 构建命令
cmd := fmt.Sprintf("skopeo copy %s docker://%s docker-archive:%s",
cmdStr := fmt.Sprintf("skopeo copy %s docker://%s docker-archive:%s",
platformArg, imgTask.Image, outputPath)
fmt.Printf("执行命令: %s\n", cmd)
fmt.Printf("执行命令: %s\n", cmdStr)
// 执行命令
command := exec.Command("sh", "-c", cmd)
// 创建可取消的命令
cmd := exec.CommandContext(ctx, "sh", "-c", cmdStr)
// 获取命令输出
stderr, err := command.StderrPipe()
stderr, err := cmd.StderrPipe()
if err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("无法创建输出管道: %v", err)
sendImageUpdate(task, index)
return
errMsg := fmt.Sprintf("无法创建输出管道: %v", err)
sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg)
return fmt.Errorf(errMsg)
}
stdout, err := command.StdoutPipe()
stdout, err := cmd.StdoutPipe()
if err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("无法创建标准输出管道: %v", err)
sendImageUpdate(task, index)
return
errMsg := fmt.Sprintf("无法创建标准输出管道: %v", err)
sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg)
return fmt.Errorf(errMsg)
}
if err := command.Start(); err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("启动命令失败: %v", err)
sendImageUpdate(task, index)
return
if err := cmd.Start(); err != nil {
errMsg := fmt.Sprintf("启动命令失败: %v", err)
sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg)
return fmt.Errorf(errMsg)
}
// 模拟逐步进度增加,确保用户体验更好
// 使用进度通道传递进度信息
progressChan := make(chan float64, 10)
outputChan := make(chan string, 20)
done := make(chan struct{})
// 初始进度
sendProgressUpdate(task, index, 5, "", "")
// 进度聚合器
go func() {
// 每500ms检查一次进度如果进度没有变化则稍微增加一点
ticker := time.NewTicker(500 * time.Millisecond)
// 镜像获取阶段的进度标记
downloadStages := map[string]float64{
"Getting image source signatures": 10,
"Copying blob": 30,
"Copying config": 70,
"Writing manifest": 90,
}
// 进度增长的定时器
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
lastProgress := 0.0
stagnantCount := 0
lastProgress := 5.0
stagnantTime := 0
for {
select {
case <-ticker.C:
// 检查命令是否还在运行
if command.ProcessState != nil && command.ProcessState.Exited() {
return
case <-ctx.Done():
// 上下文取消
return
case <-done:
// 命令完成
return
case output := <-outputChan:
// 解析输出更新进度
for marker, progress := range downloadStages {
if strings.Contains(output, marker) && progress > lastProgress {
lastProgress = progress
sendProgressUpdate(task, index, progress, "", "")
stagnantTime = 0
break
}
}
// 如果进度停滞,小幅增加进度,提高用户体验
task.Lock.Lock()
currentProgress := imgTask.Progress
if currentProgress == lastProgress {
stagnantCount++
if stagnantCount > 5 && currentProgress < 90 { // 连续5次无变化且未接近完成
// 缓慢增加进度但不超过95%
newProgress := currentProgress + 0.5
if newProgress > 95 {
newProgress = 95
// 解析百分比
if strings.Contains(output, "%") {
parts := strings.Split(output, "%")
if len(parts) > 0 {
numStr := strings.TrimSpace(parts[0])
fields := strings.Fields(numStr)
if len(fields) > 0 {
lastField := fields[len(fields)-1]
parsedProgress := 0.0
_, err := fmt.Sscanf(lastField, "%f", &parsedProgress)
if err == nil && parsedProgress > 0 && parsedProgress <= 100 {
// 根据当前阶段调整进度值
var adjustedProgress float64
if lastProgress < 30 {
// Copying blob阶段进度在10-30%之间
adjustedProgress = 10 + (parsedProgress / 100) * 20
} else if lastProgress < 70 {
// Copying config阶段进度在30-70%之间
adjustedProgress = 30 + (parsedProgress / 100) * 40
} else if lastProgress < 90 {
// Writing manifest阶段进度在70-90%之间
adjustedProgress = 70 + (parsedProgress / 100) * 20
}
if adjustedProgress > lastProgress {
lastProgress = adjustedProgress
sendProgressUpdate(task, index, adjustedProgress, "", "")
stagnantTime = 0
}
}
}
imgTask.Progress = newProgress
updateTaskProgress(task)
sendImageUpdate(task, index)
}
} else {
stagnantCount = 0
lastProgress = currentProgress
}
task.Lock.Unlock()
}
}
}()
// 读取stderr以获取进度信息
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
output := scanner.Text()
fmt.Printf("镜像 %s 进度输出: %s\n", imgTask.Image, output)
// 解析进度信息
if strings.Contains(output, "%") {
parts := strings.Split(output, "%")
if len(parts) > 0 {
numStr := strings.TrimSpace(parts[0])
numStr = strings.TrimLeft(numStr, "Copying blob ")
numStr = strings.TrimLeft(numStr, "Copying config ")
numStr = strings.TrimRight(numStr, " / ")
numStr = strings.TrimSpace(numStr)
// 尝试提取最后一个数字作为进度
fields := strings.Fields(numStr)
if len(fields) > 0 {
lastField := fields[len(fields)-1]
progress := 0.0
fmt.Sscanf(lastField, "%f", &progress)
if progress > 0 && progress <= 100 {
task.Lock.Lock()
imgTask.Progress = progress
task.Lock.Unlock()
updateTaskProgress(task)
sendImageUpdate(task, index)
}
case <-ticker.C:
// 如果进度长时间无变化,缓慢增加
stagnantTime += 100 // 100ms
if stagnantTime >= 10000 && lastProgress < 95 { // 10秒无变化
// 每10秒增加5%进度确保不超过95%
newProgress := lastProgress + 5
if newProgress > 95 {
newProgress = 95
}
lastProgress = newProgress
sendProgressUpdate(task, index, newProgress, "", "")
stagnantTime = 0
}
}
}
}()
// 读取stdout
// 读取标准输出
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
output := scanner.Text()
fmt.Printf("镜像 %s 标准输出: %s\n", imgTask.Image, output)
select {
case outputChan <- output:
default:
// 通道已满,丢弃
}
}
}()
if err := command.Wait(); err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("命令执行失败: %v", err)
sendImageUpdate(task, index)
return
// 读取错误输出
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
output := scanner.Text()
fmt.Printf("镜像 %s 错误输出: %s\n", imgTask.Image, output)
select {
case outputChan <- output:
default:
// 通道已满,丢弃
}
}
}()
// 等待命令完成
cmdErr := cmd.Wait()
close(done) // 通知进度聚合器退出
if cmdErr != nil {
errMsg := fmt.Sprintf("命令执行失败: %v", cmdErr)
sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg)
return fmt.Errorf(errMsg)
}
// 检查文件是否成功创建
if _, err := os.Stat(outputPath); os.IsNotExist(err) {
imgTask.Status = string(StatusFailed)
imgTask.Error = "文件未成功创建"
sendImageUpdate(task, index)
return
errMsg := "文件未成功创建"
sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg)
return fmt.Errorf(errMsg)
}
// 更新状态为已完成
task.Lock.Lock()
imgTask.Status = string(StatusCompleted)
imgTask.Progress = 100
task.Lock.Unlock()
updateTaskProgress(task)
sendImageUpdate(task, index)
}
// 更新任务总进度
func updateTaskProgress(task *DownloadTask) {
task.Lock.Lock()
defer task.Lock.Unlock()
totalProgress := 0.0
for _, img := range task.Images {
totalProgress += img.Progress
}
task.TotalProgress = totalProgress / float64(len(task.Images))
sendProgressUpdate(task, index, 100, string(StatusCompleted), "")
return nil
}
// 创建ZIP归档
@@ -504,39 +726,50 @@ func createZipArchive(task *DownloadTask) (string, error) {
zipFilePath := filepath.Join(task.TempDir, "images.zip")
zipFile, err := os.Create(zipFilePath)
if err != nil {
return "", err
return "", fmt.Errorf("创建ZIP文件失败: %w", err)
}
defer zipFile.Close()
zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close()
for _, img := range task.Images {
if img.Status != string(StatusCompleted) || img.OutputPath == "" {
task.ImageLock.RLock()
images := make([]*ImageTask, len(task.Images))
copy(images, task.Images) // 创建副本避免长时间持有锁
task.ImageLock.RUnlock()
for _, img := range images {
img.lock.Lock()
status := img.Status
outputPath := img.OutputPath
image := img.Image
img.lock.Unlock()
if status != string(StatusCompleted) || outputPath == "" {
continue
}
// 创建ZIP条目
imgFile, err := os.Open(img.OutputPath)
imgFile, err := os.Open(outputPath)
if err != nil {
return "", err
return "", fmt.Errorf("无法打开镜像文件 %s: %w", outputPath, err)
}
// 使用镜像名作为文件名
imageName := strings.ReplaceAll(img.Image, "/", "_")
imageName := strings.ReplaceAll(image, "/", "_")
imageName = strings.ReplaceAll(imageName, ":", "_")
fileName := imageName + ".tar"
fileInfo, err := imgFile.Stat()
if err != nil {
imgFile.Close()
return "", err
return "", fmt.Errorf("无法获取文件信息: %w", err)
}
header, err := zip.FileInfoHeader(fileInfo)
if err != nil {
imgFile.Close()
return "", err
return "", fmt.Errorf("创建ZIP头信息失败: %w", err)
}
header.Name = fileName
@@ -545,13 +778,13 @@ func createZipArchive(task *DownloadTask) (string, error) {
writer, err := zipWriter.CreateHeader(header)
if err != nil {
imgFile.Close()
return "", err
return "", fmt.Errorf("添加文件到ZIP失败: %w", err)
}
_, err = io.Copy(writer, imgFile)
imgFile.Close()
if err != nil {
return "", err
return "", fmt.Errorf("写入ZIP文件失败: %w", err)
}
}
@@ -560,7 +793,40 @@ func createZipArchive(task *DownloadTask) (string, error) {
// 发送任务更新到WebSocket
func sendTaskUpdate(task *DownloadTask) {
taskJSON, err := json.Marshal(task)
// 复制任务状态避免序列化时锁定
taskCopy := &DownloadTask{
ID: task.ID,
TotalProgress: 0,
Status: TaskStatus(""),
Images: nil,
}
// 复制状态信息
task.StatusLock.RLock()
taskCopy.Status = task.Status
task.StatusLock.RUnlock()
task.ProgressLock.RLock()
taskCopy.TotalProgress = task.TotalProgress
task.ProgressLock.RUnlock()
// 复制镜像信息
task.ImageLock.RLock()
taskCopy.Images = make([]*ImageTask, len(task.Images))
for i, img := range task.Images {
img.lock.Lock()
taskCopy.Images[i] = &ImageTask{
Image: img.Image,
Progress: img.Progress,
Status: img.Status,
Error: img.Error,
}
img.lock.Unlock()
}
task.ImageLock.RUnlock()
// 序列化并发送
taskJSON, err := json.Marshal(taskCopy)
if err != nil {
fmt.Printf("序列化任务失败: %v\n", err)
return
@@ -573,13 +839,14 @@ func sendTaskUpdate(task *DownloadTask) {
if exists {
select {
case client.Send <- taskJSON:
// 成功发送
default:
// 通道已满或关闭,忽略
}
}
}
// 发送单个镜像更新
// 发送单个镜像更新 - 保持兼容性
func sendImageUpdate(task *DownloadTask, imageIndex int) {
sendTaskUpdate(task)
}
@@ -657,8 +924,8 @@ func cleanupTempFiles() {
return nil
}
// 如果文件或目录超过24小时未修改,则删除
if time.Since(info.ModTime()) > 24*time.Hour {
// 如果文件或目录超过2小时未修改则删除
if time.Since(info.ModTime()) > 2*time.Hour {
if info.IsDir() {
os.RemoveAll(path)
return filepath.SkipDir