diff --git a/.github/workflows/docker-ghcr.yml b/.github/workflows/docker-ghcr.yml index c6649d3..6607cb8 100644 --- a/.github/workflows/docker-ghcr.yml +++ b/.github/workflows/docker-ghcr.yml @@ -46,7 +46,7 @@ jobs: - name: Build and push Docker image run: | - cd ghproxy + cd src docker buildx build --push \ --platform linux/amd64,linux/arm64 \ --tag ghcr.io/${{ env.REPO_LOWER }}:${{ env.VERSION }} \ diff --git a/Caddyfile b/Caddyfile deleted file mode 100644 index 0648456..0000000 --- a/Caddyfile +++ /dev/null @@ -1,15 +0,0 @@ -hub.{$DOMAIN} { - reverse_proxy * ghproxy:5000 -} - -docker.{$DOMAIN} { - @v2_manifest_blob path_regexp v2_rewrite ^/v2/([^/]+)/(manifests|blobs)/(.*)$ - handle @v2_manifest_blob { - rewrite * /v2/library/{re.v2_rewrite.1}/{re.v2_rewrite.2}/{re.v2_rewrite.3} - } - reverse_proxy * docker:5000 -} - -ghcr.{$DOMAIN} { - reverse_proxy * ghcr:5000 -} \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 2dfca85..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,31 +0,0 @@ -services: - caddy: - image: caddy:alpine - container_name: caddy - ports: - - "80:80" - - "443:443" - volumes: - - ./Caddyfile:/etc/caddy/Caddyfile - environment: - - DOMAIN=example.com # 修改为你的根域名 - restart: always - - ghcr: - image: "registry:2.8.3" - container_name: "ghcr" - restart: "always" - volumes: - - "./ghcr/config.yml:/etc/docker/registry/config.yml" - - docker: - image: "registry:2.8.3" - container_name: "docker" - restart: "always" - volumes: - - "./docker/config.yml:/etc/docker/registry/config.yml" - - ghproxy: - image: "ghcr.io/sky22333/hubproxy" - container_name: "ghproxy" - restart: "always" \ No newline at end of file diff --git a/docker/config.yml b/docker/config.yml deleted file mode 100644 index e04b212..0000000 --- a/docker/config.yml +++ /dev/null @@ -1,16 +0,0 @@ -version: 0.1 -storage: - filesystem: - rootdirectory: /var/lib/registry - delete: - enabled: true - maintenance: - uploadpurging: - enabled: true - age: 72h - dryrun: false - interval: 1m -http: - addr: 0.0.0.0:5000 -proxy: - remoteurl: https://registry-1.docker.io \ No newline at end of file diff --git a/ghcr/config.yml b/ghcr/config.yml deleted file mode 100644 index 6a9b296..0000000 --- a/ghcr/config.yml +++ /dev/null @@ -1,16 +0,0 @@ -version: 0.1 -storage: - filesystem: - rootdirectory: /var/lib/registry - delete: - enabled: true - maintenance: - uploadpurging: - enabled: true - age: 72h - dryrun: false - interval: 1m -http: - addr: 0.0.0.0:5000 -proxy: - remoteurl: https://ghcr.io \ No newline at end of file diff --git a/ghproxy/config.json b/ghproxy/config.json deleted file mode 100644 index 19fb671..0000000 --- a/ghproxy/config.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "whiteList": [ - ], - "blackList": [ - "example1", - "login" - ] -} diff --git a/ghproxy/docker-compose.yml b/ghproxy/docker-compose.yml deleted file mode 100644 index ce1c7e7..0000000 --- a/ghproxy/docker-compose.yml +++ /dev/null @@ -1,6 +0,0 @@ -services: - ghproxy: - build: . - restart: always - ports: - - '5000:5000' \ No newline at end of file diff --git a/ghproxy/Dockerfile b/src/Dockerfile similarity index 63% rename from ghproxy/Dockerfile rename to src/Dockerfile index de6f3fa..025179f 100644 --- a/ghproxy/Dockerfile +++ b/src/Dockerfile @@ -1,11 +1,11 @@ -FROM golang:1.23-alpine AS builder +FROM golang:1.24-alpine AS builder WORKDIR /app COPY go.mod go.sum ./ RUN go mod download COPY . . -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -trimpath -o ghproxy . +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -trimpath -o hubproxy . FROM alpine @@ -14,8 +14,8 @@ WORKDIR /root/ # 安装skopeo RUN apk add --no-cache skopeo && mkdir -p temp && chmod 700 temp -COPY --from=builder /app/ghproxy . -COPY --from=builder /app/config.json . +COPY --from=builder /app/hubproxy . +COPY --from=builder /app/config.toml . COPY --from=builder /app/public ./public -CMD ["./ghproxy"] +CMD ["./hubproxy"] diff --git a/ghproxy/LICENSE b/src/LICENSE similarity index 100% rename from ghproxy/LICENSE rename to src/LICENSE diff --git a/src/access_control.go b/src/access_control.go new file mode 100644 index 0000000..412e52f --- /dev/null +++ b/src/access_control.go @@ -0,0 +1,226 @@ +package main + +import ( + "strings" + "sync" +) + +// ResourceType 资源类型 +type ResourceType string + +const ( + ResourceTypeGitHub ResourceType = "github" + ResourceTypeDocker ResourceType = "docker" +) + +// AccessController 统一访问控制器 +type AccessController struct { + mu sync.RWMutex +} + +// DockerImageInfo Docker镜像信息 +type DockerImageInfo struct { + Namespace string + Repository string + Tag string + FullName string +} + +// 全局访问控制器实例 +var GlobalAccessController = &AccessController{} + +// ParseDockerImage 解析Docker镜像名称 +func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { + // 移除可能的协议前缀 + image = strings.TrimPrefix(image, "docker://") + + // 分离标签 + var tag string + if idx := strings.LastIndex(image, ":"); idx != -1 { + // 检查是否是端口号而不是标签(包含斜杠) + part := image[idx+1:] + if !strings.Contains(part, "/") { + tag = part + image = image[:idx] + } + } + if tag == "" { + tag = "latest" + } + + // 分离命名空间和仓库名 + var namespace, repository string + if strings.Contains(image, "/") { + // 处理自定义registry的情况,如 registry.com/user/repo + parts := strings.Split(image, "/") + if len(parts) >= 2 { + // 检查第一部分是否是域名(包含.) + if strings.Contains(parts[0], ".") { + // 跳过registry域名,取用户名和仓库名 + if len(parts) >= 3 { + namespace = parts[1] + repository = parts[2] + } else { + namespace = "library" + repository = parts[1] + } + } else { + // 标准格式:user/repo + namespace = parts[0] + repository = parts[1] + } + } + } else { + // 官方镜像,如 nginx + namespace = "library" + repository = image + } + + fullName := namespace + "/" + repository + + return DockerImageInfo{ + Namespace: namespace, + Repository: repository, + Tag: tag, + FullName: fullName, + } +} + +// CheckDockerAccess 检查Docker镜像访问权限 +func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { + cfg := GetConfig() + + // 解析镜像名称 + imageInfo := ac.ParseDockerImage(image) + + // 检查白名单(如果配置了白名单,则只允许白名单中的镜像) + if len(cfg.Proxy.WhiteList) > 0 { + if !ac.matchImageInList(imageInfo, cfg.Proxy.WhiteList) { + return false, "不在Docker镜像白名单内" + } + } + + // 检查黑名单 + if len(cfg.Proxy.BlackList) > 0 { + if ac.matchImageInList(imageInfo, cfg.Proxy.BlackList) { + return false, "Docker镜像在黑名单内" + } + } + + return true, "" +} + +// CheckGitHubAccess 检查GitHub仓库访问权限 +func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) { + if len(matches) < 2 { + return false, "无效的GitHub仓库格式" + } + + cfg := GetConfig() + + // 检查白名单 + if len(cfg.Proxy.WhiteList) > 0 && !ac.checkList(matches, cfg.Proxy.WhiteList) { + return false, "不在GitHub仓库白名单内" + } + + // 检查黑名单 + if len(cfg.Proxy.BlackList) > 0 && ac.checkList(matches, cfg.Proxy.BlackList) { + return false, "GitHub仓库在黑名单内" + } + + return true, "" +} + +// matchImageInList 检查Docker镜像是否在指定列表中 +func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool { + fullName := strings.ToLower(imageInfo.FullName) + namespace := strings.ToLower(imageInfo.Namespace) + + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + + if fullName == item { + return true + } + + if item == namespace || item == namespace+"/*" { + return true + } + + if strings.HasSuffix(item, "*") { + prefix := strings.TrimSuffix(item, "*") + if strings.HasPrefix(fullName, prefix) { + return true + } + } + + if strings.HasPrefix(item, "*/") { + repoPattern := strings.TrimPrefix(item, "*/") + if strings.HasSuffix(repoPattern, "*") { + repoPrefix := strings.TrimSuffix(repoPattern, "*") + if strings.HasPrefix(imageInfo.Repository, repoPrefix) { + return true + } + } else { + if strings.ToLower(imageInfo.Repository) == repoPattern { + return true + } + } + } + + // 5. 子仓库匹配(防止 user/repo 匹配到 user/repo-fork) + if strings.HasPrefix(fullName, item+"/") { + return true + } + } + return false +} + +// checkList GitHub仓库检查逻辑 +func (ac *AccessController) checkList(matches, list []string) bool { + if len(matches) < 2 { + return false + } + + // 组合用户名和仓库名,处理.git后缀 + username := strings.ToLower(strings.TrimSpace(matches[0])) + repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git"))) + fullRepo := username + "/" + repoName + + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + + // 支持多种匹配模式: + // 1. 精确匹配: "vaxilu/x-ui" + // 2. 用户级匹配: "vaxilu/*" 或 "vaxilu" + // 3. 前缀匹配: "vaxilu/x-ui-*" + if fullRepo == item { + return true + } + + // 用户级匹配 + if item == username || item == username+"/*" { + return true + } + + // 前缀匹配(支持通配符) + if strings.HasSuffix(item, "*") { + prefix := strings.TrimSuffix(item, "*") + if strings.HasPrefix(fullRepo, prefix) { + return true + } + } + + // 子仓库匹配(防止 user/repo 匹配到 user/repo-fork) + if strings.HasPrefix(fullRepo, item+"/") { + return true + } + } + return false +} \ No newline at end of file diff --git a/src/config.go b/src/config.go new file mode 100644 index 0000000..43b7099 --- /dev/null +++ b/src/config.go @@ -0,0 +1,195 @@ +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "sync" + + "github.com/pelletier/go-toml/v2" +) + +// AppConfig 应用配置结构体 +type AppConfig struct { + Server struct { + Host string `toml:"host"` // 监听地址 + Port int `toml:"port"` // 监听端口 + FileSize int64 `toml:"fileSize"` // 文件大小限制(字节) + } `toml:"server"` + + RateLimit struct { + RequestLimit int `toml:"requestLimit"` // 每小时请求限制 + PeriodHours float64 `toml:"periodHours"` // 限制周期(小时) + } `toml:"rateLimit"` + + Security struct { + WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表 + BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表 + } `toml:"security"` + + Proxy struct { + WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别) + BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别) + } `toml:"proxy"` + + Download struct { + MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制 + } `toml:"download"` +} + +var ( + appConfig *AppConfig + appConfigLock sync.RWMutex +) + +// DefaultConfig 返回默认配置 +func DefaultConfig() *AppConfig { + return &AppConfig{ + Server: struct { + Host string `toml:"host"` + Port int `toml:"port"` + FileSize int64 `toml:"fileSize"` + }{ + Host: "0.0.0.0", + Port: 5000, + FileSize: 2 * 1024 * 1024 * 1024, // 2GB + }, + RateLimit: struct { + RequestLimit int `toml:"requestLimit"` + PeriodHours float64 `toml:"periodHours"` + }{ + RequestLimit: 20, + PeriodHours: 1.0, + }, + Security: struct { + WhiteList []string `toml:"whiteList"` + BlackList []string `toml:"blackList"` + }{ + WhiteList: []string{}, + BlackList: []string{}, + }, + Proxy: struct { + WhiteList []string `toml:"whiteList"` + BlackList []string `toml:"blackList"` + }{ + WhiteList: []string{}, + BlackList: []string{}, + }, + Download: struct { + MaxImages int `toml:"maxImages"` + }{ + MaxImages: 10, // 默认值:最多同时下载10个镜像 + }, + } +} + +// GetConfig 安全地获取配置副本 +func GetConfig() *AppConfig { + appConfigLock.RLock() + defer appConfigLock.RUnlock() + + if appConfig == nil { + return DefaultConfig() + } + + // 返回配置的深拷贝 + configCopy := *appConfig + configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...) + configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...) + configCopy.Proxy.WhiteList = append([]string(nil), appConfig.Proxy.WhiteList...) + configCopy.Proxy.BlackList = append([]string(nil), appConfig.Proxy.BlackList...) + + return &configCopy +} + +// setConfig 安全地设置配置 +func setConfig(cfg *AppConfig) { + appConfigLock.Lock() + defer appConfigLock.Unlock() + appConfig = cfg +} + +// LoadConfig 加载配置文件 +func LoadConfig() error { + // 首先使用默认配置 + cfg := DefaultConfig() + + // 尝试加载TOML配置文件 + if data, err := os.ReadFile("config.toml"); err == nil { + if err := toml.Unmarshal(data, cfg); err != nil { + return fmt.Errorf("解析配置文件失败: %v", err) + } + } else { + fmt.Println("未找到config.toml,使用默认配置") + } + + // 从环境变量覆盖配置 + overrideFromEnv(cfg) + + // 设置配置 + setConfig(cfg) + + fmt.Printf("配置加载成功: 监听 %s:%d, 文件大小限制 %d MB, 限流 %d请求/%g小时, 离线镜像并发数 %d\n", + cfg.Server.Host, cfg.Server.Port, cfg.Server.FileSize/(1024*1024), + cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours, cfg.Download.MaxImages) + + return nil +} + +// overrideFromEnv 从环境变量覆盖配置 +func overrideFromEnv(cfg *AppConfig) { + // 服务器配置 + if val := os.Getenv("SERVER_HOST"); val != "" { + cfg.Server.Host = val + } + if val := os.Getenv("SERVER_PORT"); val != "" { + if port, err := strconv.Atoi(val); err == nil && port > 0 { + cfg.Server.Port = port + } + } + if val := os.Getenv("MAX_FILE_SIZE"); val != "" { + if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 { + cfg.Server.FileSize = size + } + } + + // 限流配置 + if val := os.Getenv("RATE_LIMIT"); val != "" { + if limit, err := strconv.Atoi(val); err == nil && limit > 0 { + cfg.RateLimit.RequestLimit = limit + } + } + if val := os.Getenv("RATE_PERIOD_HOURS"); val != "" { + if period, err := strconv.ParseFloat(val, 64); err == nil && period > 0 { + cfg.RateLimit.PeriodHours = period + } + } + + // IP限制配置 + if val := os.Getenv("IP_WHITELIST"); val != "" { + cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...) + } + if val := os.Getenv("IP_BLACKLIST"); val != "" { + cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...) + } + + // 下载限制配置 + if val := os.Getenv("MAX_IMAGES"); val != "" { + if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 { + cfg.Download.MaxImages = maxImages + } + } +} + +// CreateDefaultConfigFile 创建默认配置文件 +func CreateDefaultConfigFile() error { + cfg := DefaultConfig() + + data, err := toml.Marshal(cfg) + if err != nil { + return fmt.Errorf("序列化默认配置失败: %v", err) + } + + return os.WriteFile("config.toml", data, 0644) +} \ No newline at end of file diff --git a/src/config.toml b/src/config.toml new file mode 100644 index 0000000..aac1c94 --- /dev/null +++ b/src/config.toml @@ -0,0 +1,45 @@ +[server] +# 监听地址,默认监听所有接口 +host = "0.0.0.0" +# 监听端口 +port = 5000 +# 文件大小限制(字节),默认2GB +fileSize = 2147483648 + +[rateLimit] +# 每个IP每小时允许的请求数 +requestLimit = 200 +# 限流周期(小时) +periodHours = 1.0 + +[security] +# IP白名单,支持单个IP或CIDR格式 +# 白名单中的IP不受限流限制 +whiteList = [ + "127.0.0.1", + "192.168.1.0/24" +] + +# IP黑名单,支持单个IP或CIDR格式 +# 黑名单中的IP将被直接拒绝访问 +blackList = [ + "192.168.100.1" +] + +[proxy] +# 代理服务白名单(支持GitHub仓库和Docker镜像,支持通配符) +# 只允许访问白名单中的仓库/镜像,为空时不限制 +whiteList = [] + +# 代理服务黑名单(支持GitHub仓库和Docker镜像,支持通配符) +# 禁止访问黑名单中的仓库/镜像 +blackList = [ + "baduser/malicious-repo", + "thesadboy/x-ui", + "vaxilu/x-ui", + "vaxilu/*" +] + +[download] +# 单次并发下载离线镜像数量限制 +maxImages = 10 diff --git a/src/docker-compose.yml b/src/docker-compose.yml new file mode 100644 index 0000000..6f6e8c3 --- /dev/null +++ b/src/docker-compose.yml @@ -0,0 +1,8 @@ +services: + ghproxy: + build: . + restart: always + ports: + - '5000:5000' + volumes: + - ./config.toml:/root/config.toml \ No newline at end of file diff --git a/src/docker.go b/src/docker.go new file mode 100644 index 0000000..65e700a --- /dev/null +++ b/src/docker.go @@ -0,0 +1,323 @@ +package main + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote" +) + +// DockerProxy Docker代理配置 +type DockerProxy struct { + registry name.Registry + options []remote.Option +} + +var dockerProxy *DockerProxy + +// 初始化Docker代理 +func initDockerProxy() { + // 创建目标registry + registry, err := name.NewRegistry("registry-1.docker.io") + if err != nil { + fmt.Printf("创建Docker registry失败: %v\n", err) + return + } + + // 配置代理选项 + options := []remote.Option{ + remote.WithAuth(authn.Anonymous), + remote.WithUserAgent("ghproxy/go-containerregistry"), + } + + dockerProxy = &DockerProxy{ + registry: registry, + options: options, + } + + fmt.Printf("Docker代理已初始化\n") +} + +// ProxyDockerRegistryGin 标准Docker Registry API v2代理 +func ProxyDockerRegistryGin(c *gin.Context) { + path := c.Request.URL.Path + + // 处理 /v2/ API版本检查 + if path == "/v2/" { + c.JSON(http.StatusOK, gin.H{}) + return + } + + // 处理不同的API端点 + if strings.HasPrefix(path, "/v2/") { + handleRegistryRequest(c, path) + } else { + c.String(http.StatusNotFound, "Docker Registry API v2 only") + } +} + +// handleRegistryRequest 处理Registry请求 +func handleRegistryRequest(c *gin.Context, path string) { + // 移除 /v2/ 前缀 + pathWithoutV2 := strings.TrimPrefix(path, "/v2/") + + // 解析路径 + imageName, apiType, reference := parseRegistryPath(pathWithoutV2) + if imageName == "" || apiType == "" { + c.String(http.StatusBadRequest, "Invalid path format") + return + } + + // 自动处理官方镜像的library命名空间 + if !strings.Contains(imageName, "/") { + imageName = "library/" + imageName + } + + // Docker镜像访问控制检查 + if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed { + fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason) + c.String(http.StatusForbidden, "镜像访问被限制") + return + } + + // 构建完整的镜像引用 + imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName) + + switch apiType { + case "manifests": + handleManifestRequest(c, imageRef, reference) + case "blobs": + handleBlobRequest(c, imageRef, reference) + case "tags": + handleTagsRequest(c, imageRef) + default: + c.String(http.StatusNotFound, "API endpoint not found") + } +} + +// parseRegistryPath 解析Registry路径 +func parseRegistryPath(path string) (imageName, apiType, reference string) { + // 查找API端点关键字 + if idx := strings.Index(path, "/manifests/"); idx != -1 { + imageName = path[:idx] + apiType = "manifests" + reference = path[idx+len("/manifests/"):] + return + } + + if idx := strings.Index(path, "/blobs/"); idx != -1 { + imageName = path[:idx] + apiType = "blobs" + reference = path[idx+len("/blobs/"):] + return + } + + if idx := strings.Index(path, "/tags/list"); idx != -1 { + imageName = path[:idx] + apiType = "tags" + reference = "list" + return + } + + return "", "", "" +} + +// handleManifestRequest 处理manifest请求 +func handleManifestRequest(c *gin.Context, imageRef, reference string) { + var ref name.Reference + var err error + + // 判断reference是digest还是tag + if strings.HasPrefix(reference, "sha256:") { + // 是digest + ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) + } else { + // 是tag + ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) + } + + if err != nil { + fmt.Printf("解析镜像引用失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid reference") + return + } + + // 根据请求方法选择操作 + if c.Request.Method == http.MethodHead { + // HEAD请求,使用remote.Head + desc, err := remote.Head(ref, dockerProxy.options...) + if err != nil { + fmt.Printf("HEAD请求失败: %v\n", err) + c.String(http.StatusNotFound, "Manifest not found") + return + } + + // 设置响应头 + c.Header("Content-Type", string(desc.MediaType)) + c.Header("Docker-Content-Digest", desc.Digest.String()) + c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) + c.Status(http.StatusOK) + } else { + // GET请求,使用remote.Get + desc, err := remote.Get(ref, dockerProxy.options...) + if err != nil { + fmt.Printf("GET请求失败: %v\n", err) + c.String(http.StatusNotFound, "Manifest not found") + return + } + + // 设置响应头 + c.Header("Content-Type", string(desc.MediaType)) + c.Header("Docker-Content-Digest", desc.Digest.String()) + c.Header("Content-Length", fmt.Sprintf("%d", len(desc.Manifest))) + + // 返回manifest内容 + c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) + } +} + +// handleBlobRequest 处理blob请求 +func handleBlobRequest(c *gin.Context, imageRef, digest string) { + // 构建digest引用 + digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) + if err != nil { + fmt.Printf("解析digest引用失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid digest reference") + return + } + + // 使用remote.Layer获取layer + layer, err := remote.Layer(digestRef, dockerProxy.options...) + if err != nil { + fmt.Printf("获取layer失败: %v\n", err) + c.String(http.StatusNotFound, "Layer not found") + return + } + + // 获取layer信息 + size, err := layer.Size() + if err != nil { + fmt.Printf("获取layer大小失败: %v\n", err) + c.String(http.StatusInternalServerError, "Failed to get layer size") + return + } + + // 获取layer内容 + reader, err := layer.Compressed() + if err != nil { + fmt.Printf("获取layer内容失败: %v\n", err) + c.String(http.StatusInternalServerError, "Failed to get layer content") + return + } + defer reader.Close() + + // 设置响应头 + c.Header("Content-Type", "application/octet-stream") + c.Header("Content-Length", fmt.Sprintf("%d", size)) + c.Header("Docker-Content-Digest", digest) + + // 流式传输blob内容 + c.Status(http.StatusOK) + io.Copy(c.Writer, reader) +} + +// handleTagsRequest 处理tags列表请求 +func handleTagsRequest(c *gin.Context, imageRef string) { + // 解析repository + repo, err := name.NewRepository(imageRef) + if err != nil { + fmt.Printf("解析repository失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid repository") + return + } + + // 使用remote.List获取tags + tags, err := remote.List(repo, dockerProxy.options...) + if err != nil { + fmt.Printf("获取tags失败: %v\n", err) + c.String(http.StatusNotFound, "Tags not found") + return + } + + // 构建响应 + response := map[string]interface{}{ + "name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"), + "tags": tags, + } + + c.JSON(http.StatusOK, response) +} + +// ProxyDockerAuthGin Docker认证代理 +func ProxyDockerAuthGin(c *gin.Context) { + // 构建认证URL + authURL := "https://auth.docker.io" + c.Request.URL.Path + if c.Request.URL.RawQuery != "" { + authURL += "?" + c.Request.URL.RawQuery + } + + // 创建HTTP客户端 + client := &http.Client{ + Timeout: 30 * time.Second, + } + + // 创建请求 + req, err := http.NewRequestWithContext( + context.Background(), + c.Request.Method, + authURL, + c.Request.Body, + ) + if err != nil { + c.String(http.StatusInternalServerError, "Failed to create request") + return + } + + // 复制请求头 + for key, values := range c.Request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + + // 执行请求 + resp, err := client.Do(req) + if err != nil { + c.String(http.StatusBadGateway, "Auth request failed") + return + } + defer resp.Body.Close() + + // 获取当前代理的Host地址 + proxyHost := c.Request.Host + if proxyHost == "" { + // 使用配置中的服务器地址和端口 + cfg := GetConfig() + proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) + if cfg.Server.Host == "0.0.0.0" { + proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port) + } + } + + // 复制响应头并重写认证URL + for key, values := range resp.Header { + for _, value := range values { + // 重写WWW-Authenticate头中的realm URL + if key == "Www-Authenticate" && strings.Contains(value, "auth.docker.io") { + value = strings.ReplaceAll(value, "https://auth.docker.io", "http://"+proxyHost) + } + c.Header(key, value) + } + } + + // 返回响应 + c.Status(resp.StatusCode) + io.Copy(c.Writer, resp.Body) +} diff --git a/ghproxy/go.mod b/src/go.mod similarity index 60% rename from ghproxy/go.mod rename to src/go.mod index b9b8187..8755e7a 100644 --- a/ghproxy/go.mod +++ b/src/go.mod @@ -1,41 +1,51 @@ -module ghproxy - -go 1.23.0 - -toolchain go1.24.1 - -require ( - github.com/gin-gonic/gin v1.10.0 - github.com/gorilla/websocket v1.5.1 - golang.org/x/sync v0.14.0 - golang.org/x/time v0.11.0 -) - -require ( - github.com/bytedance/sonic v1.11.6 // indirect - github.com/bytedance/sonic/loader v0.1.1 // indirect - github.com/cloudwego/base64x v0.1.4 // indirect - github.com/cloudwego/iasm v0.2.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect - github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-playground/locales v0.14.1 // indirect - github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.20.0 // indirect - github.com/goccy/go-json v0.10.2 // indirect - github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/leodido/go-urn v1.4.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.12 // indirect - golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect - google.golang.org/protobuf v1.34.1 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +module hubproxy + +go 1.24.0 + +require ( + github.com/gin-gonic/gin v1.10.0 + github.com/google/go-containerregistry v0.20.5 + github.com/gorilla/websocket v1.5.1 + github.com/pelletier/go-toml/v2 v2.2.2 + golang.org/x/sync v0.14.0 + golang.org/x/time v0.11.0 +) + +require ( + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect + github.com/docker/cli v28.1.1+incompatible // indirect + github.com/docker/distribution v2.8.3+incompatible // indirect + github.com/docker/docker-credential-helpers v0.9.3 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + github.com/vbatts/tar-split v0.12.1 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.15.0 // indirect + google.golang.org/protobuf v1.36.3 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/ghproxy/go.sum b/src/go.sum similarity index 71% rename from ghproxy/go.sum rename to src/go.sum index ca2682a..6910408 100644 --- a/ghproxy/go.sum +++ b/src/go.sum @@ -1,95 +1,120 @@ -github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= -github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= -github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= -github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= -github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= -github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= -github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= -github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= -github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= -github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= -github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= -github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= -github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= -golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= -golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= +github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/cli v28.1.1+incompatible h1:eyUemzeI45DY7eDPuwUcmDyDj1pM98oD5MdSpiItp8k= +github.com/docker/cli v28.1.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= +github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8= +github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-containerregistry v0.20.5 h1:4RnlYcDs5hoA++CeFjlbZ/U9Yp1EuWr+UhhTyYQjOP0= +github.com/google/go-containerregistry v0.20.5/go.mod h1:Q14vdOOzug02bwnhMkZKD4e30pDaD9W65qzXpyzF49E= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= +github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= +github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= +google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= +gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/src/http_client.go b/src/http_client.go new file mode 100644 index 0000000..a336521 --- /dev/null +++ b/src/http_client.go @@ -0,0 +1,59 @@ +package main + +import ( + "net" + "net/http" + "time" +) + +var ( + // 全局HTTP客户端 - 用于代理请求(长超时) + globalHTTPClient *http.Client + // 搜索HTTP客户端 - 用于API请求(短超时) + searchHTTPClient *http.Client +) + +// initHTTPClients 初始化HTTP客户端 +func initHTTPClients() { + // 代理客户端配置 - 适用于大文件传输 + globalHTTPClient = &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 1000, + MaxIdleConnsPerHost: 1000, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: 300 * time.Second, + }, + } + + // 搜索客户端配置 - 适用于API调用 + searchHTTPClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 5 * time.Second, + DisableCompression: false, + }, + } +} + +// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理) +func GetGlobalHTTPClient() *http.Client { + return globalHTTPClient +} + +// GetSearchHTTPClient 获取搜索HTTP客户端(用于API调用) +func GetSearchHTTPClient() *http.Client { + return searchHTTPClient +} \ No newline at end of file diff --git a/ghproxy/main.go b/src/main.go similarity index 64% rename from ghproxy/main.go rename to src/main.go index 1c2fb21..e76f7ea 100644 --- a/ghproxy/main.go +++ b/src/main.go @@ -1,280 +1,246 @@ -package main - -import ( - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net" - "net/http" - "os" - "regexp" - "strconv" - "strings" - "sync" - "time" -) - -const ( - sizeLimit = 1024 * 1024 * 1024 * 2 // 允许的文件大小,默认2GB - host = "0.0.0.0" // 监听地址 - port = 5000 // 监听端口 -) - -var ( - exps = []*regexp.Regexp{ - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), - regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`), - regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), - regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), - regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`), - regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), - regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), - regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`), - } - httpClient *http.Client - config *Config - configLock sync.RWMutex -) - -type Config struct { - WhiteList []string `json:"whiteList"` - BlackList []string `json:"blackList"` -} - -func main() { - gin.SetMode(gin.ReleaseMode) - router := gin.Default() - - httpClient = &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - MaxIdleConns: 1000, - MaxIdleConnsPerHost: 1000, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ResponseHeaderTimeout: 300 * time.Second, - }, - } - - loadConfig() - go func() { - for { - time.Sleep(10 * time.Minute) - loadConfig() - } - }() - - // 初始化Skopeo相关路由 - 在任何通配符路由之前注册 - initSkopeoRoutes(router) - - // 单独处理根路径请求,避免冲突 - router.GET("/", func(c *gin.Context) { - c.File("./public/index.html") - }) - - // 指定具体的静态文件路径,避免使用通配符 - router.Static("/public", "./public") - - // 对于.html等特定文件注册 - router.GET("/skopeo.html", func(c *gin.Context) { - c.File("./public/skopeo.html") - }) - router.GET("/search.html", func(c *gin.Context) { - c.File("./public/search.html") - }) - - // 图标文件 - router.GET("/favicon.ico", func(c *gin.Context) { - c.File("./public/favicon.ico") - }) - - // 注册dockerhub搜索路由 - RegisterSearchRoute(router) - // 创建GitHub文件下载专用的限流器 - githubLimiter := NewIPRateLimiter() - - // 注册NoRoute处理器,应用限流中间件 - router.NoRoute(RateLimitMiddleware(githubLimiter), handler) - - err := router.Run(fmt.Sprintf("%s:%d", host, port)) - if err != nil { - fmt.Printf("Error starting server: %v\n", err) - } -} - -func handler(c *gin.Context) { - rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") - - for strings.HasPrefix(rawPath, "/") { - rawPath = strings.TrimPrefix(rawPath, "/") - } - - if !strings.HasPrefix(rawPath, "http") { - c.String(http.StatusForbidden, "无效输入") - return - } - - matches := checkURL(rawPath) - if matches != nil { - if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) { - c.String(http.StatusForbidden, "不在白名单内,限制访问。") - return - } - if len(config.BlackList) > 0 && checkList(matches, config.BlackList) { - c.String(http.StatusForbidden, "黑名单限制访问") - return - } - } else { - c.String(http.StatusForbidden, "无效输入") - return - } - - if exps[1].MatchString(rawPath) { - rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) - } - - proxy(c, rawPath) -} - -func proxy(c *gin.Context, u string) { - req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) - return - } - - for key, values := range c.Request.Header { - for _, value := range values { - req.Header.Add(key, value) - } - } - req.Header.Del("Host") - - resp, err := httpClient.Do(req) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) - return - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - - } - }(resp.Body) - - // 检查文件大小限制 - if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { - if size, err := strconv.Atoi(contentLength); err == nil && size > sizeLimit { - c.String(http.StatusRequestEntityTooLarge, "File too large.") - return - } - } - - // 清理安全相关的头 - resp.Header.Del("Content-Security-Policy") - resp.Header.Del("Referrer-Policy") - resp.Header.Del("Strict-Transport-Security") - - // 对于需要处理的shell文件,使用chunked传输 - isShellFile := strings.HasSuffix(strings.ToLower(u), ".sh") - if isShellFile { - resp.Header.Del("Content-Length") - resp.Header.Set("Transfer-Encoding", "chunked") - } - - // 复制其他响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - if location := resp.Header.Get("Location"); location != "" { - if checkURL(location) != nil { - c.Header("Location", "/"+location) - } else { - proxy(c, location) - return - } - } - - c.Status(resp.StatusCode) - - // 处理响应体 - if isShellFile { - // 获取真实域名 - realHost := c.Request.Header.Get("X-Forwarded-Host") - if realHost == "" { - realHost = c.Request.Host - } - // 如果域名中没有协议前缀,添加https:// - if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") { - realHost = "https://" + realHost - } - // 使用ProcessGitHubURLs处理.sh文件 - processedBody, _, err := ProcessGitHubURLs(resp.Body, resp.Header.Get("Content-Encoding") == "gzip", realHost, true) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("处理shell文件时发生错误: %v", err)) - return - } - if _, err := io.Copy(c.Writer, processedBody); err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("写入响应时发生错误: %v", err)) - return - } - } else { - // 对于非.sh文件,直接复制响应体 - if _, err := io.Copy(c.Writer, resp.Body); err != nil { - return - } - } -} - -func loadConfig() { - file, err := os.Open("config.json") - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - defer func(file *os.File) { - err := file.Close() - if err != nil { - - } - }(file) - - var newConfig Config - decoder := json.NewDecoder(file) - if err := decoder.Decode(&newConfig); err != nil { - fmt.Printf("Error decoding config: %v\n", err) - return - } - - configLock.Lock() - config = &newConfig - configLock.Unlock() -} - -func checkURL(u string) []string { - for _, exp := range exps { - if matches := exp.FindStringSubmatch(u); matches != nil { - return matches[1:] - } - } - return nil -} - -func checkList(matches, list []string) bool { - for _, item := range list { - if strings.HasPrefix(matches[0], item) { - return true - } - } - return false -} +package main + +import ( + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "regexp" + "strconv" + "strings" +) + +var ( + exps = []*regexp.Regexp{ + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), + regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`), + regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), + regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), + regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`), + regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), + regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), + regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`), + } + globalLimiter *IPRateLimiter +) + +func main() { + // 加载配置 + if err := LoadConfig(); err != nil { + fmt.Printf("配置加载失败: %v\n", err) + return + } + + // 初始化HTTP客户端 + initHTTPClients() + + // 初始化限流器 + initLimiter() + + // 初始化Docker流式代理 + initDockerProxy() + + gin.SetMode(gin.ReleaseMode) + router := gin.Default() + + // 初始化skopeo路由(静态文件和API路由) + initSkopeoRoutes(router) + + // 单独处理根路径请求 + router.GET("/", func(c *gin.Context) { + c.File("./public/index.html") + }) + + // 指定具体的静态文件路径 + router.Static("/public", "./public") + router.GET("/skopeo.html", func(c *gin.Context) { + c.File("./public/skopeo.html") + }) + router.GET("/search.html", func(c *gin.Context) { + c.File("./public/search.html") + }) + router.GET("/favicon.ico", func(c *gin.Context) { + c.File("./public/favicon.ico") + }) + + // 注册dockerhub搜索路由 + RegisterSearchRoute(router) + + // 注册Docker认证路由(/token*) + router.Any("/token", RateLimitMiddleware(globalLimiter), ProxyDockerAuthGin) + router.Any("/token/*path", RateLimitMiddleware(globalLimiter), ProxyDockerAuthGin) + + // 注册Docker Registry代理路由 + router.Any("/v2/*path", RateLimitMiddleware(globalLimiter), ProxyDockerRegistryGin) + + + // 注册NoRoute处理器,应用限流中间件 + router.NoRoute(RateLimitMiddleware(globalLimiter), handler) + + cfg := GetConfig() + fmt.Printf("启动成功,项目地址:https://github.com/sky22333/hubproxy \n") + + err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) + if err != nil { + fmt.Printf("启动服务失败: %v\n", err) + } +} + +func handler(c *gin.Context) { + rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") + + for strings.HasPrefix(rawPath, "/") { + rawPath = strings.TrimPrefix(rawPath, "/") + } + + if !strings.HasPrefix(rawPath, "http") { + c.String(http.StatusForbidden, "无效输入") + return + } + + matches := checkURL(rawPath) + if matches != nil { + // GitHub仓库访问控制检查 + if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed { + // 构建仓库名用于日志 + var repoPath string + if len(matches) >= 2 { + username := matches[0] + repoName := strings.TrimSuffix(matches[1], ".git") + repoPath = username + "/" + repoName + } + fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason) + c.String(http.StatusForbidden, reason) + return + } + } else { + c.String(http.StatusForbidden, "无效输入") + return + } + + if exps[1].MatchString(rawPath) { + rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) + } + + proxy(c, rawPath) +} + + +func proxy(c *gin.Context, u string) { + proxyWithRedirect(c, u, 0) +} + + +func proxyWithRedirect(c *gin.Context, u string, redirectCount int) { + // 限制最大重定向次数,防止无限递归 + const maxRedirects = 20 + if redirectCount > maxRedirects { + c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向") + return + } + req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) + return + } + + for key, values := range c.Request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + req.Header.Del("Host") + + resp, err := GetGlobalHTTPClient().Do(req) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭响应体失败: %v\n", err) + } + }() + + // 检查文件大小限制 + cfg := GetConfig() + if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { + if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize { + c.String(http.StatusRequestEntityTooLarge, + fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024))) + return + } + } + + // 清理安全相关的头 + resp.Header.Del("Content-Security-Policy") + resp.Header.Del("Referrer-Policy") + resp.Header.Del("Strict-Transport-Security") + + // 对于需要处理的shell文件,使用chunked传输 + isShellFile := strings.HasSuffix(strings.ToLower(u), ".sh") + if isShellFile { + resp.Header.Del("Content-Length") + resp.Header.Set("Transfer-Encoding", "chunked") + } + + // 复制其他响应头 + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + if location := resp.Header.Get("Location"); location != "" { + if checkURL(location) != nil { + c.Header("Location", "/"+location) + } else { + // 递归处理重定向,增加计数防止无限循环 + proxyWithRedirect(c, location, redirectCount+1) + return + } + } + + c.Status(resp.StatusCode) + + // 处理响应体 + if isShellFile { + // 获取真实域名 + realHost := c.Request.Header.Get("X-Forwarded-Host") + if realHost == "" { + realHost = c.Request.Host + } + // 如果域名中没有协议前缀,添加https:// + if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") { + realHost = "https://" + realHost + } + // 使用ProcessGitHubURLs处理.sh文件 + processedBody, _, err := ProcessGitHubURLs(resp.Body, resp.Header.Get("Content-Encoding") == "gzip", realHost, true) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("处理shell文件时发生错误: %v", err)) + return + } + if _, err := io.Copy(c.Writer, processedBody); err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("写入响应时发生错误: %v", err)) + return + } + } else { + // 对于非.sh文件,直接复制响应体 + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return + } + } +} + +func checkURL(u string) []string { + for _, exp := range exps { + if matches := exp.FindStringSubmatch(u); matches != nil { + return matches[1:] + } + } + return nil +} + + diff --git a/ghproxy/proxysh.go b/src/proxysh.go similarity index 95% rename from ghproxy/proxysh.go rename to src/proxysh.go index d8a5661..f262134 100644 --- a/ghproxy/proxysh.go +++ b/src/proxysh.go @@ -1,192 +1,192 @@ -package main - -import ( - "bufio" - "compress/gzip" - "fmt" - "io" - "regexp" - "strings" -) - -var ( - // gitHubDomains 定义所有支持的GitHub相关域名 - gitHubDomains = []string{ - "github.com", - "raw.githubusercontent.com", - "raw.github.com", - "gist.githubusercontent.com", - "gist.github.com", - "api.github.com", - } - - // urlPattern 使用gitHubDomains构建正则表达式 - urlPattern = regexp.MustCompile(`https?://(?:` + strings.Join(gitHubDomains, "|") + `)[^\s'"]+`) - - // 是否启用脚本嵌套代理的调试日志 - DebugLog = true -) - -// 打印调试日志的辅助函数 -func debugPrintf(format string, args ...interface{}) { - if DebugLog { - fmt.Printf(format, args...) - } -} - -// ProcessGitHubURLs 处理数据流中的GitHub URL,将其替换为代理URL。 -// 此处思路借鉴了 https://github.com/WJQSERVER-STUDIO/ghproxy/blob/main/proxy/nest.go - -func ProcessGitHubURLs(input io.ReadCloser, isCompressed bool, host string, isShellFile bool) (io.Reader, int64, error) { - debugPrintf("开始处理文件: isCompressed=%v, host=%s, isShellFile=%v\n", isCompressed, host, isShellFile) - - if !isShellFile { - debugPrintf("非shell文件,跳过处理\n") - return input, 0, nil - } - - // 使用更大的缓冲区以提高性能 - pipeReader, pipeWriter := io.Pipe() - var written int64 - - go func() { - var err error - defer func() { - if err != nil { - debugPrintf("处理过程中发生错误: %v\n", err) - _ = pipeWriter.CloseWithError(err) - } else { - _ = pipeWriter.Close() - } - }() - - defer input.Close() - - var reader io.Reader = input - if isCompressed { - debugPrintf("检测到压缩文件,进行解压处理\n") - gzipReader, gzipErr := gzip.NewReader(input) - if gzipErr != nil { - err = gzipErr - return - } - defer gzipReader.Close() - reader = gzipReader - } - - // 使用更大的缓冲区 - bufReader := bufio.NewReaderSize(reader, 32*1024) // 32KB buffer - var writer io.Writer = pipeWriter - - if isCompressed { - gzipWriter := gzip.NewWriter(writer) - defer gzipWriter.Close() - writer = gzipWriter - } - - bufWriter := bufio.NewWriterSize(writer, 32*1024) // 32KB buffer - defer bufWriter.Flush() - - written, err = processContent(bufReader, bufWriter, host) - if err != nil { - debugPrintf("处理内容时发生错误: %v\n", err) - return - } - - debugPrintf("文件处理完成,共处理 %d 字节\n", written) - }() - - return pipeReader, written, nil -} - -// processContent 优化处理文件内容的函数 -func processContent(reader *bufio.Reader, writer *bufio.Writer, host string) (int64, error) { - var written int64 - lineNum := 0 - - for { - lineNum++ - line, err := reader.ReadString('\n') - if err != nil && err != io.EOF { - return written, fmt.Errorf("读取行时发生错误: %w", err) - } - - if line != "" { - // 在处理前先检查是否包含GitHub URL - if strings.Contains(line, "github.com") || - strings.Contains(line, "raw.githubusercontent.com") { - matches := urlPattern.FindAllString(line, -1) - if len(matches) > 0 { - debugPrintf("\n在第 %d 行发现 %d 个GitHub URL:\n", lineNum, len(matches)) - for _, match := range matches { - debugPrintf("原始URL: %s\n", match) - } - } - - modifiedLine := processLine(line, host, lineNum) - n, writeErr := writer.WriteString(modifiedLine) - if writeErr != nil { - return written, fmt.Errorf("写入修改后的行时发生错误: %w", writeErr) - } - written += int64(n) - } else { - // 如果行中没有GitHub URL,直接写入 - n, writeErr := writer.WriteString(line) - if writeErr != nil { - return written, fmt.Errorf("写入原始行时发生错误: %w", writeErr) - } - written += int64(n) - } - } - - if err == io.EOF { - break - } - } - - // 确保所有数据都被写入 - if err := writer.Flush(); err != nil { - return written, fmt.Errorf("刷新缓冲区时发生错误: %w", err) - } - - return written, nil -} - -// processLine 处理单行文本,替换所有匹配的GitHub URL -func processLine(line string, host string, lineNum int) string { - return urlPattern.ReplaceAllStringFunc(line, func(url string) string { - newURL := modifyGitHubURL(url, host) - if newURL != url { - debugPrintf("第 %d 行URL替换:\n 原始: %s\n 替换后: %s\n", lineNum, url, newURL) - } - return newURL - }) -} - -// modifyGitHubURL 修改GitHub URL,添加代理域名前缀 -func modifyGitHubURL(url string, host string) string { - for _, domain := range gitHubDomains { - hasHttps := strings.HasPrefix(url, "https://"+domain) - hasHttp := strings.HasPrefix(url, "http://"+domain) - - if hasHttps || hasHttp || strings.HasPrefix(url, domain) { - if !hasHttps && !hasHttp { - url = "https://" + url - } - if hasHttp { - url = "https://" + strings.TrimPrefix(url, "http://") - } - // 移除host开头的协议头(如果有) - host = strings.TrimPrefix(host, "https://") - host = strings.TrimPrefix(host, "http://") - // 返回组合后的URL - return host + "/" + url - } - } - return url -} - -// IsShellFile 检查文件是否为shell文件(基于文件名) -func IsShellFile(filename string) bool { - return strings.HasSuffix(filename, ".sh") +package main + +import ( + "bufio" + "compress/gzip" + "fmt" + "io" + "regexp" + "strings" +) + +var ( + // gitHubDomains 定义所有支持的GitHub相关域名 + gitHubDomains = []string{ + "github.com", + "raw.githubusercontent.com", + "raw.github.com", + "gist.githubusercontent.com", + "gist.github.com", + "api.github.com", + } + + // urlPattern 使用gitHubDomains构建正则表达式 + urlPattern = regexp.MustCompile(`https?://(?:` + strings.Join(gitHubDomains, "|") + `)[^\s'"]+`) + + // 是否启用脚本嵌套代理的调试日志 + DebugLog = true +) + +// 打印调试日志的辅助函数 +func debugPrintf(format string, args ...interface{}) { + if DebugLog { + fmt.Printf(format, args...) + } +} + +// ProcessGitHubURLs 处理数据流中的GitHub URL,将其替换为代理URL。 +// 此处思路借鉴了 https://github.com/WJQSERVER-STUDIO/ghproxy/blob/main/proxy/nest.go + +func ProcessGitHubURLs(input io.ReadCloser, isCompressed bool, host string, isShellFile bool) (io.Reader, int64, error) { + debugPrintf("开始处理文件: isCompressed=%v, host=%s, isShellFile=%v\n", isCompressed, host, isShellFile) + + if !isShellFile { + debugPrintf("非shell文件,跳过处理\n") + return input, 0, nil + } + + // 使用更大的缓冲区以提高性能 + pipeReader, pipeWriter := io.Pipe() + var written int64 + + go func() { + var err error + defer func() { + if err != nil { + debugPrintf("处理过程中发生错误: %v\n", err) + _ = pipeWriter.CloseWithError(err) + } else { + _ = pipeWriter.Close() + } + }() + + defer input.Close() + + var reader io.Reader = input + if isCompressed { + debugPrintf("检测到压缩文件,进行解压处理\n") + gzipReader, gzipErr := gzip.NewReader(input) + if gzipErr != nil { + err = gzipErr + return + } + defer gzipReader.Close() + reader = gzipReader + } + + // 使用更大的缓冲区 + bufReader := bufio.NewReaderSize(reader, 32*1024) // 32KB buffer + var writer io.Writer = pipeWriter + + if isCompressed { + gzipWriter := gzip.NewWriter(writer) + defer gzipWriter.Close() + writer = gzipWriter + } + + bufWriter := bufio.NewWriterSize(writer, 32*1024) // 32KB buffer + defer bufWriter.Flush() + + written, err = processContent(bufReader, bufWriter, host) + if err != nil { + debugPrintf("处理内容时发生错误: %v\n", err) + return + } + + debugPrintf("文件处理完成,共处理 %d 字节\n", written) + }() + + return pipeReader, written, nil +} + +// processContent 优化处理文件内容的函数 +func processContent(reader *bufio.Reader, writer *bufio.Writer, host string) (int64, error) { + var written int64 + lineNum := 0 + + for { + lineNum++ + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + return written, fmt.Errorf("读取行时发生错误: %w", err) + } + + if line != "" { + // 在处理前先检查是否包含GitHub URL + if strings.Contains(line, "github.com") || + strings.Contains(line, "raw.githubusercontent.com") { + matches := urlPattern.FindAllString(line, -1) + if len(matches) > 0 { + debugPrintf("\n在第 %d 行发现 %d 个GitHub URL:\n", lineNum, len(matches)) + for _, match := range matches { + debugPrintf("原始URL: %s\n", match) + } + } + + modifiedLine := processLine(line, host, lineNum) + n, writeErr := writer.WriteString(modifiedLine) + if writeErr != nil { + return written, fmt.Errorf("写入修改后的行时发生错误: %w", writeErr) + } + written += int64(n) + } else { + // 如果行中没有GitHub URL,直接写入 + n, writeErr := writer.WriteString(line) + if writeErr != nil { + return written, fmt.Errorf("写入原始行时发生错误: %w", writeErr) + } + written += int64(n) + } + } + + if err == io.EOF { + break + } + } + + // 确保所有数据都被写入 + if err := writer.Flush(); err != nil { + return written, fmt.Errorf("刷新缓冲区时发生错误: %w", err) + } + + return written, nil +} + +// processLine 处理单行文本,替换所有匹配的GitHub URL +func processLine(line string, host string, lineNum int) string { + return urlPattern.ReplaceAllStringFunc(line, func(url string) string { + newURL := modifyGitHubURL(url, host) + if newURL != url { + debugPrintf("第 %d 行URL替换:\n 原始: %s\n 替换后: %s\n", lineNum, url, newURL) + } + return newURL + }) +} + +// 判断代理域名前缀 +func modifyGitHubURL(url string, host string) string { + for _, domain := range gitHubDomains { + hasHttps := strings.HasPrefix(url, "https://"+domain) + hasHttp := strings.HasPrefix(url, "http://"+domain) + + if hasHttps || hasHttp || strings.HasPrefix(url, domain) { + if !hasHttps && !hasHttp { + url = "https://" + url + } + if hasHttp { + url = "https://" + strings.TrimPrefix(url, "http://") + } + // 移除host开头的协议头(如果有) + host = strings.TrimPrefix(host, "https://") + host = strings.TrimPrefix(host, "http://") + // 返回组合后的URL + return host + "/" + url + } + } + return url +} + +// IsShellFile 检查文件是否为shell文件(基于文件名) +func IsShellFile(filename string) bool { + return strings.HasSuffix(filename, ".sh") } \ No newline at end of file diff --git a/ghproxy/public/favicon.ico b/src/public/favicon.ico similarity index 100% rename from ghproxy/public/favicon.ico rename to src/public/favicon.ico diff --git a/ghproxy/public/index.html b/src/public/index.html similarity index 96% rename from ghproxy/public/index.html rename to src/public/index.html index 9f5bc15..4708e0e 100644 --- a/ghproxy/public/index.html +++ b/src/public/index.html @@ -1,542 +1,542 @@ - - - - - - - - - - Github文件加速 - - - - - - - - hosts - 镜像包下载 - 镜像搜索 -
-

Github文件加速

-
- -
- - -
- - -

-        
-
-
-

支持release、archive文件,支持git clone、wget、curl等等操作
支持Al模型库Hugging Face


-
-
- - -
- - - - - - - - - + + + + + + + + + + Github文件加速 + + + + + + + + hosts + 镜像包下载 + 镜像搜索 +
+

Github文件加速

+
+ +
+ + +
+ + +

+        
+
+
+

支持release、archive文件,支持git clone、wget、curl等等操作
支持Al模型库Hugging Face


+
+
+ + +
+ + + + + + + + + diff --git a/ghproxy/public/search.html b/src/public/search.html similarity index 97% rename from ghproxy/public/search.html rename to src/public/search.html index 443a471..91612bd 100644 --- a/ghproxy/public/search.html +++ b/src/public/search.html @@ -1,1071 +1,1071 @@ - - - - - - - - - Docker镜像搜索 - - - - - - - 返回 -
-

Docker镜像搜索

- -
-
- -
- -
-
-
- -
-
-

正在加载...

-
- - - -
-
- - -
- -
-
- -
- - - + + + + + + + + + Docker镜像搜索 + + + + + + + 返回 +
+

Docker镜像搜索

+ +
+
+ +
+ +
+
+
+ +
+
+

正在加载...

+
+ + + +
+
+ + +
+ +
+
+ +
+ + + \ No newline at end of file diff --git a/ghproxy/public/skopeo.html b/src/public/skopeo.html similarity index 97% rename from ghproxy/public/skopeo.html rename to src/public/skopeo.html index ffab573..df5e135 100644 --- a/ghproxy/public/skopeo.html +++ b/src/public/skopeo.html @@ -1,505 +1,505 @@ - - - - - - - - - Docker镜像批量下载 - - - - - - - - 返回 -
-

Docker离线镜像包下载

- -
-
每行输入一个镜像,跟docker pull的格式一样,多个镜像会自动打包到一起为zip包,单个镜像为tar包。导入镜像后需要手动为镜像添加名称和标签,例如:docker tag 1856948a5aa7 镜像名称:标签
- -
- -
-
镜像架构,默认为 amd64
- -
- - - -
-
0/0 - 0%
- -
- -
- - -
-
- - - - - - + + + + + + + + + Docker镜像批量下载 + + + + + + + + 返回 +
+

Docker离线镜像包下载

+ +
+
每行输入一个镜像,跟docker pull的格式一样,多个镜像会自动打包到一起为zip包,单个镜像为tar包。导入镜像后需要手动为镜像添加名称和标签,例如:docker tag 1856948a5aa7 镜像名称:标签
+ +
+ +
+
镜像架构,默认为 amd64
+ +
+ + + +
+
0/0 - 0%
+ +
+ +
+ + +
+
+ + + + + + diff --git a/ghproxy/ratelimiter.go b/src/ratelimiter.go similarity index 64% rename from ghproxy/ratelimiter.go rename to src/ratelimiter.go index df1b58d..620c41f 100644 --- a/ghproxy/ratelimiter.go +++ b/src/ratelimiter.go @@ -3,8 +3,6 @@ package main import ( "fmt" "net" - "os" - "strconv" "strings" "sync" "time" @@ -13,33 +11,14 @@ import ( "golang.org/x/time/rate" ) -// IP限流配置 -var ( - // 默认限流:每个IP每1小时允许20个请求 - DefaultRateLimit = 20.0 // 默认限制请求数 - DefaultRatePeriodHours = 1.0 // 默认时间周期(小时) - - // 白名单列表,支持IP和CIDR格式,如:"192.168.1.1", "10.0.0.0/8" - WhitelistIPs = []string{ - "127.0.0.1", // 本地回环地址 - "10.0.0.0/8", // 内网地址段 - "172.16.0.0/12", // 内网地址段 - "192.168.0.0/16", // 内网地址段 - } - - // 黑名单列表,支持IP和CIDR格式 - BlacklistIPs = []string{ - // 示例: "1.2.3.4", "5.6.7.0/24" - } - - // 清理间隔:多久清理一次过期的限流器 - CleanupInterval = 1 * time.Hour - - // IP限流器缓存上限,超过此数量将触发清理 +const ( + // 清理间隔 + CleanupInterval = 10 * time.Minute + // 最大IP缓存数量,防止内存过度占用 MaxIPCacheSize = 10000 ) -// IPRateLimiter 定义IP限流器结构 +// IPRateLimiter IP限流器结构体 type IPRateLimiter struct { ips map[string]*rateLimiterEntry // IP到限流器的映射 mu *sync.RWMutex // 读写锁,保证并发安全 @@ -49,45 +28,20 @@ type IPRateLimiter struct { blacklist []*net.IPNet // 黑名单IP段 } -// rateLimiterEntry 限流器条目,包含限流器和最后访问时间 +// rateLimiterEntry 限流器条目 type rateLimiterEntry struct { limiter *rate.Limiter // 限流器 lastAccess time.Time // 最后访问时间 } -// NewIPRateLimiter 创建新的IP限流器 -func NewIPRateLimiter() *IPRateLimiter { - // 从环境变量读取限流配置(如果有) - rateLimit := DefaultRateLimit - ratePeriod := DefaultRatePeriodHours - - if val, exists := os.LookupEnv("RATE_LIMIT"); exists { - if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 { - rateLimit = parsed - } - } - - if val, exists := os.LookupEnv("RATE_PERIOD_HOURS"); exists { - if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 { - ratePeriod = parsed - } - } - - // 从环境变量读取白名单(如果有) - whitelistIPs := WhitelistIPs - if val, exists := os.LookupEnv("IP_WHITELIST"); exists && val != "" { - whitelistIPs = append(whitelistIPs, strings.Split(val, ",")...) - } - - // 从环境变量读取黑名单(如果有) - blacklistIPs := BlacklistIPs - if val, exists := os.LookupEnv("IP_BLACKLIST"); exists && val != "" { - blacklistIPs = append(blacklistIPs, strings.Split(val, ",")...) - } +// initGlobalLimiter 初始化全局限流器 +func initGlobalLimiter() *IPRateLimiter { + // 获取配置 + cfg := GetConfig() // 解析白名单IP段 - whitelist := make([]*net.IPNet, 0, len(whitelistIPs)) - for _, item := range whitelistIPs { + whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList)) + for _, item := range cfg.Security.WhiteList { if item = strings.TrimSpace(item); item != "" { if !strings.Contains(item, "/") { item = item + "/32" // 单个IP转为CIDR格式 @@ -95,13 +49,15 @@ func NewIPRateLimiter() *IPRateLimiter { _, ipnet, err := net.ParseCIDR(item) if err == nil { whitelist = append(whitelist, ipnet) + } else { + fmt.Printf("警告: 无效的白名单IP格式: %s\n", item) } } } // 解析黑名单IP段 - blacklist := make([]*net.IPNet, 0, len(blacklistIPs)) - for _, item := range blacklistIPs { + blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList)) + for _, item := range cfg.Security.BlackList { if item = strings.TrimSpace(item); item != "" { if !strings.Contains(item, "/") { item = item + "/32" // 单个IP转为CIDR格式 @@ -109,19 +65,26 @@ func NewIPRateLimiter() *IPRateLimiter { _, ipnet, err := net.ParseCIDR(item) if err == nil { blacklist = append(blacklist, ipnet) + } else { + fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item) } } } // 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求" - // rate.Limit的单位是每秒允许的请求数 - ratePerSecond := rate.Limit(rateLimit / (ratePeriod * 3600)) + ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600)) + + // 令牌桶容量设置为最大突发请求数,建议设为限制值的一半以允许合理突发 + burstSize := cfg.RateLimit.RequestLimit + if burstSize < 1 { + burstSize = 1 // 至少允许1个请求 + } limiter := &IPRateLimiter{ ips: make(map[string]*rateLimiterEntry), mu: &sync.RWMutex{}, r: ratePerSecond, - b: int(rateLimit), // 令牌桶容量设为允许的请求总数 + b: burstSize, whitelist: whitelist, blacklist: blacklist, } @@ -129,9 +92,17 @@ func NewIPRateLimiter() *IPRateLimiter { // 启动定期清理goroutine go limiter.cleanupRoutine() + fmt.Printf("限流器初始化: %d请求/%g小时, 白名单 %d个, 黑名单 %d个\n", + cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours, len(whitelist), len(blacklist)) + return limiter } +// initLimiter 初始化限流器(保持向后兼容) +func initLimiter() { + globalLimiter = initGlobalLimiter() +} + // cleanupRoutine 定期清理过期的限流器 func (i *IPRateLimiter) cleanupRoutine() { ticker := time.NewTicker(CleanupInterval) @@ -168,9 +139,29 @@ func (i *IPRateLimiter) cleanupRoutine() { } } +// extractIPFromAddress 从地址中提取纯IP,去除端口号 +func extractIPFromAddress(address string) string { + // 处理IPv6地址 [::1]:8080 格式 + if strings.HasPrefix(address, "[") { + if endIndex := strings.Index(address, "]"); endIndex != -1 { + return address[1:endIndex] + } + } + + // 处理IPv4地址 192.168.1.1:8080 格式 + if lastColon := strings.LastIndex(address, ":"); lastColon != -1 { + return address[:lastColon] + } + + // 如果没有端口号,直接返回 + return address +} + // isIPInCIDRList 检查IP是否在CIDR列表中 func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { - parsedIP := net.ParseIP(ip) + // 先提取纯IP地址 + cleanIP := extractIPFromAddress(ip) + parsedIP := net.ParseIP(cleanIP) if parsedIP == nil { return false } @@ -185,19 +176,22 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { // GetLimiter 获取指定IP的限流器,同时返回是否允许访问 func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { + // 提取纯IP地址 + cleanIP := extractIPFromAddress(ip) + // 检查是否在黑名单中 - if isIPInCIDRList(ip, i.blacklist) { + if isIPInCIDRList(cleanIP, i.blacklist) { return nil, false // 黑名单中的IP不允许访问 } // 检查是否在白名单中 - if isIPInCIDRList(ip, i.whitelist) { + if isIPInCIDRList(cleanIP, i.whitelist) { return rate.NewLimiter(rate.Inf, i.b), true // 白名单中的IP不受限制 } - // 从缓存获取限流器 + // 使用纯IP作为缓存键 i.mu.RLock() - entry, exists := i.ips[ip] + entry, exists := i.ips[cleanIP] i.mu.RUnlock() now := time.Now() @@ -209,7 +203,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { limiter: rate.NewLimiter(i.r, i.b), lastAccess: now, } - i.ips[ip] = entry + i.ips[cleanIP] = entry i.mu.Unlock() } else { // 更新最后访问时间 @@ -244,14 +238,18 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { ip = c.ClientIP() } - // 日志记录请求IP和头信息(调试用) - fmt.Printf("请求IP: %s, X-Forwarded-For: %s, X-Real-IP: %s\n", + // 提取纯IP地址(去除端口号) + cleanIP := extractIPFromAddress(ip) + + // 日志记录请求IP和头信息 + fmt.Printf("请求IP: %s (去除端口后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", ip, + cleanIP, c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Real-IP")) // 获取限流器并检查是否允许访问 - ipLimiter, allowed := limiter.GetLimiter(ip) + ipLimiter, allowed := limiter.GetLimiter(cleanIP) // 如果IP在黑名单中 if !allowed { @@ -278,8 +276,11 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { // ApplyRateLimit 应用限流到特定路由 func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) { - // 创建限流器(如果未创建) - limiter := NewIPRateLimiter() + // 使用全局限流器 + limiter := globalLimiter + if limiter == nil { + limiter = initGlobalLimiter() + } // 根据HTTP方法应用限流 switch method { diff --git a/ghproxy/search.go b/src/search.go similarity index 64% rename from ghproxy/search.go rename to src/search.go index d5f2478..4d7651d 100644 --- a/ghproxy/search.go +++ b/src/search.go @@ -1,554 +1,498 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "sort" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" -) - -// SearchResult Docker Hub搜索结果 -type SearchResult struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []Repository `json:"results"` -} - -// Repository 仓库信息 -type Repository struct { - Name string `json:"repo_name"` - Description string `json:"short_description"` - IsOfficial bool `json:"is_official"` - IsAutomated bool `json:"is_automated"` - StarCount int `json:"star_count"` - PullCount int `json:"pull_count"` - RepoOwner string `json:"repo_owner"` - LastUpdated string `json:"last_updated"` - Status int `json:"status"` - Organization string `json:"affiliation"` - PullsLastWeek int `json:"pulls_last_week"` - Namespace string `json:"namespace"` -} - -// TagInfo 标签信息 -type TagInfo struct { - Name string `json:"name"` - FullSize int64 `json:"full_size"` - LastUpdated time.Time `json:"last_updated"` - LastPusher string `json:"last_pusher"` - Images []Image `json:"images"` - Vulnerabilities struct { - Critical int `json:"critical"` - High int `json:"high"` - Medium int `json:"medium"` - Low int `json:"low"` - Unknown int `json:"unknown"` - } `json:"vulnerabilities"` -} - -// Image 镜像信息 -type Image struct { - Architecture string `json:"architecture"` - Features string `json:"features"` - Variant string `json:"variant,omitempty"` - Digest string `json:"digest"` - OS string `json:"os"` - OSFeatures string `json:"os_features"` - Size int64 `json:"size"` -} - -type cacheEntry struct { - data interface{} - timestamp time.Time -} - -const ( - maxCacheSize = 1000 // 最大缓存条目数 - cacheTTL = 30 * time.Minute -) - -type Cache struct { - data map[string]cacheEntry - mu sync.RWMutex - maxSize int -} - -var ( - searchCache = &Cache{ - data: make(map[string]cacheEntry), - maxSize: maxCacheSize, - } -) - -// 添加全局HTTP客户端配置 -var defaultHTTPClient = &http.Client{ - Timeout: 10 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - DisableCompression: true, - DisableKeepAlives: false, - MaxIdleConnsPerHost: 10, - }, -} - -func (c *Cache) Get(key string) (interface{}, bool) { - c.mu.RLock() - entry, exists := c.data[key] - c.mu.RUnlock() - - if !exists { - return nil, false - } - - if time.Since(entry.timestamp) > cacheTTL { - c.mu.Lock() - delete(c.data, key) - c.mu.Unlock() - return nil, false - } - - return entry.data, true -} - -func (c *Cache) Set(key string, data interface{}) { - c.mu.Lock() - defer c.mu.Unlock() - - // 如果缓存已满,删除最旧的条目 - if len(c.data) >= c.maxSize { - oldest := time.Now() - var oldestKey string - for k, v := range c.data { - if v.timestamp.Before(oldest) { - oldest = v.timestamp - oldestKey = k - } - } - delete(c.data, oldestKey) - } - - c.data[key] = cacheEntry{ - data: data, - timestamp: time.Now(), - } -} - -func (c *Cache) Cleanup() { - c.mu.Lock() - defer c.mu.Unlock() - - now := time.Now() - for key, entry := range c.data { - if now.Sub(entry.timestamp) > cacheTTL { - delete(c.data, key) - } - } -} - -// 定期清理过期缓存 -func init() { - go func() { - ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - searchCache.Cleanup() - } - }() -} - -// 改进的搜索结果过滤函数 -func filterSearchResults(results []Repository, query string) []Repository { - searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/")) - filtered := make([]Repository, 0) - - for _, repo := range results { - // 标准化仓库名称 - repoName := strings.ToLower(repo.Name) - repoDesc := strings.ToLower(repo.Description) - - // 计算相关性得分 - score := 0 - - // 完全匹配 - if repoName == searchTerm { - score += 100 - } - - // 前缀匹配 - if strings.HasPrefix(repoName, searchTerm) { - score += 50 - } - - // 包含匹配 - if strings.Contains(repoName, searchTerm) { - score += 30 - } - - // 描述匹配 - if strings.Contains(repoDesc, searchTerm) { - score += 10 - } - - // 官方镜像加分 - if repo.IsOfficial { - score += 20 - } - - // 分数达到阈值的结果才保留 - if score > 0 { - filtered = append(filtered, repo) - } - } - - // 按相关性排序 - sort.Slice(filtered, func(i, j int) bool { - // 优先考虑官方镜像 - if filtered[i].IsOfficial != filtered[j].IsOfficial { - return filtered[i].IsOfficial - } - // 其次考虑拉取次数 - return filtered[i].PullCount > filtered[j].PullCount - }) - - return filtered -} - -// searchDockerHub 搜索镜像 -func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { - cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) - - // 尝试从缓存获取 - if cached, ok := searchCache.Get(cacheKey); ok { - return cached.(*SearchResult), nil - } - - // 判断是否是用户/仓库格式的搜索 - isUserRepo := strings.Contains(query, "/") - var namespace, repoName string - - if isUserRepo { - parts := strings.Split(query, "/") - if len(parts) == 2 { - namespace = parts[0] - repoName = parts[1] - } - } - - // 构建搜索URL - baseURL := "https://registry.hub.docker.com/v2" - var fullURL string - var params url.Values - - if isUserRepo && namespace != "" { - // 如果是用户/仓库格式,使用repositories接口 - fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) - params = url.Values{ - "page": {fmt.Sprintf("%d", page)}, - "page_size": {fmt.Sprintf("%d", pageSize)}, - } - } else { - // 普通搜索 - fullURL = baseURL + "/search/repositories/" - params = url.Values{ - "query": {query}, - "page": {fmt.Sprintf("%d", page)}, - "page_size": {fmt.Sprintf("%d", pageSize)}, - } - } - - fullURL = fullURL + "?" + params.Encode() - - // 创建请求 - req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil) - if err != nil { - return nil, fmt.Errorf("创建请求失败: %v", err) - } - - // 设置请求头 - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") - - // 使用全局HTTP客户端 - client := defaultHTTPClient - - var result *SearchResult - var lastErr error - - // 重试逻辑 - for retries := 3; retries > 0; retries-- { - resp, err := client.Do(req) - if err != nil { - lastErr = fmt.Errorf("发送请求失败: %v", err) - if !isRetryableError(err) { - break - } - time.Sleep(time.Second * time.Duration(4-retries)) - continue - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - lastErr = fmt.Errorf("读取响应失败: %v", err) - if !isRetryableError(err) { - break - } - time.Sleep(time.Second * time.Duration(4-retries)) - continue - } - - if resp.StatusCode != http.StatusOK { - switch resp.StatusCode { - case http.StatusTooManyRequests: - lastErr = fmt.Errorf("请求过于频繁,请稍后重试") - case http.StatusNotFound: - if isUserRepo && namespace != "" { - // 如果用户仓库搜索失败,尝试普通搜索 - return searchDockerHub(ctx, repoName, page, pageSize) - } - lastErr = fmt.Errorf("未找到相关镜像") - case http.StatusBadGateway, http.StatusServiceUnavailable: - lastErr = fmt.Errorf("Docker Hub服务暂时不可用,请稍后重试") - default: - lastErr = fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) - } - if !isRetryableError(lastErr) { - break - } - time.Sleep(time.Second * time.Duration(4-retries)) - continue - } - - // 解析响应 - if isUserRepo && namespace != "" { - // 解析用户仓库列表响应 - var userRepos struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []Repository `json:"results"` - } - if err := json.Unmarshal(body, &userRepos); err != nil { - lastErr = fmt.Errorf("解析响应失败: %v", err) - break - } - - // 转换为SearchResult格式 - result = &SearchResult{ - Count: userRepos.Count, - Next: userRepos.Next, - Previous: userRepos.Previous, - Results: make([]Repository, 0), - } - - // 处理结果 - for _, repo := range userRepos.Results { - // 如果指定了仓库名,只保留匹配的结果 - if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { - // 确保设置正确的命名空间和名称 - repo.Namespace = namespace - if !strings.Contains(repo.Name, "/") { - repo.Name = fmt.Sprintf("%s/%s", namespace, repo.Name) - } - result.Results = append(result.Results, repo) - } - } - - // 如果没有找到结果,尝试普通搜索 - if len(result.Results) == 0 { - return searchDockerHub(ctx, repoName, page, pageSize) - } - - result.Count = len(result.Results) - } else { - // 解析普通搜索响应 - result = &SearchResult{} - if err := json.Unmarshal(body, &result); err != nil { - lastErr = fmt.Errorf("解析响应失败: %v", err) - break - } - - // 处理搜索结果 - for i := range result.Results { - if result.Results[i].IsOfficial { - if !strings.Contains(result.Results[i].Name, "/") { - result.Results[i].Name = "library/" + result.Results[i].Name - } - result.Results[i].Namespace = "library" - } else { - parts := strings.Split(result.Results[i].Name, "/") - if len(parts) > 1 { - result.Results[i].Namespace = parts[0] - result.Results[i].Name = parts[1] - } else if result.Results[i].RepoOwner != "" { - result.Results[i].Namespace = result.Results[i].RepoOwner - result.Results[i].Name = fmt.Sprintf("%s/%s", result.Results[i].RepoOwner, result.Results[i].Name) - } - } - } - - // 如果是用户/仓库搜索,过滤结果 - if isUserRepo && namespace != "" { - filteredResults := make([]Repository, 0) - for _, repo := range result.Results { - if strings.EqualFold(repo.Namespace, namespace) { - filteredResults = append(filteredResults, repo) - } - } - result.Results = filteredResults - result.Count = len(filteredResults) - } - } - - // 成功获取结果,跳出重试循环 - lastErr = nil - break - } - - if lastErr != nil { - return nil, fmt.Errorf("搜索失败: %v", lastErr) - } - - // 缓存结果 - searchCache.Set(cacheKey, result) - return result, nil -} - -// 判断错误是否可重试 -func isRetryableError(err error) bool { - if err == nil { - return false - } - - // 网络错误、超时等可以重试 - if strings.Contains(err.Error(), "timeout") || - strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "too many requests") { - return true - } - - return false -} - -// getRepositoryTags 获取仓库标签信息 -func getRepositoryTags(ctx context.Context, namespace, name string) ([]TagInfo, error) { - if namespace == "" || name == "" { - return nil, fmt.Errorf("无效输入:命名空间和名称不能为空") - } - - cacheKey := fmt.Sprintf("tags:%s:%s", namespace, name) - if cached, ok := searchCache.Get(cacheKey); ok { - return cached.([]TagInfo), nil - } - - // 构建API URL - baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) - params := url.Values{} - params.Set("page_size", "100") - params.Set("ordering", "last_updated") - - fullURL := baseURL + "?" + params.Encode() - - // 使用全局HTTP客户端 - client := defaultHTTPClient - - req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil) - if err != nil { - return nil, fmt.Errorf("创建请求失败: %v", err) - } - - // 添加必要的请求头 - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") - - // 发送请求 - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("发送请求失败: %v", err) - } - defer resp.Body.Close() - - // 读取响应体 - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %v", err) - } - - // 检查响应状态码 - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) - } - - // 解析响应 - var result struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []TagInfo `json:"results"` - } - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - // 缓存结果 - searchCache.Set(cacheKey, result.Results) - return result.Results, nil -} - -// RegisterSearchRoute 注册搜索相关路由 -func RegisterSearchRoute(r *gin.Engine) { - // 搜索镜像 - r.GET("/search", func(c *gin.Context) { - query := c.Query("q") - if query == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "搜索关键词不能为空"}) - return - } - - page := 1 - pageSize := 25 - if p := c.Query("page"); p != "" { - fmt.Sscanf(p, "%d", &page) - } - if ps := c.Query("page_size"); ps != "" { - fmt.Sscanf(ps, "%d", &pageSize) - } - - result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, result) - }) - - // 获取标签信息 - r.GET("/tags/:namespace/:name", func(c *gin.Context) { - namespace := c.Param("namespace") - name := c.Param("name") - - if namespace == "" || name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "命名空间和名称不能为空"}) - return - } - - tags, err := getRepositoryTags(c.Request.Context(), namespace, name) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, tags) - }) -} +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// SearchResult Docker Hub搜索结果 +type SearchResult struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []Repository `json:"results"` +} + +// Repository 仓库信息 +type Repository struct { + Name string `json:"repo_name"` + Description string `json:"short_description"` + IsOfficial bool `json:"is_official"` + IsAutomated bool `json:"is_automated"` + StarCount int `json:"star_count"` + PullCount int `json:"pull_count"` + RepoOwner string `json:"repo_owner"` + LastUpdated string `json:"last_updated"` + Status int `json:"status"` + Organization string `json:"affiliation"` + PullsLastWeek int `json:"pulls_last_week"` + Namespace string `json:"namespace"` +} + +// TagInfo 标签信息 +type TagInfo struct { + Name string `json:"name"` + FullSize int64 `json:"full_size"` + LastUpdated time.Time `json:"last_updated"` + LastPusher string `json:"last_pusher"` + Images []Image `json:"images"` + Vulnerabilities struct { + Critical int `json:"critical"` + High int `json:"high"` + Medium int `json:"medium"` + Low int `json:"low"` + Unknown int `json:"unknown"` + } `json:"vulnerabilities"` +} + +// Image 镜像信息 +type Image struct { + Architecture string `json:"architecture"` + Features string `json:"features"` + Variant string `json:"variant,omitempty"` + Digest string `json:"digest"` + OS string `json:"os"` + OSFeatures string `json:"os_features"` + Size int64 `json:"size"` +} + +type cacheEntry struct { + data interface{} + timestamp time.Time +} + +const ( + maxCacheSize = 1000 // 最大缓存条目数 + cacheTTL = 30 * time.Minute +) + +type Cache struct { + data map[string]cacheEntry + mu sync.RWMutex + maxSize int +} + +var ( + searchCache = &Cache{ + data: make(map[string]cacheEntry), + maxSize: maxCacheSize, + } +) + +// HTTP客户端配置在 http_client.go 中统一管理 + +func (c *Cache) Get(key string) (interface{}, bool) { + c.mu.RLock() + entry, exists := c.data[key] + c.mu.RUnlock() + + if !exists { + return nil, false + } + + if time.Since(entry.timestamp) > cacheTTL { + c.mu.Lock() + delete(c.data, key) + c.mu.Unlock() + return nil, false + } + + return entry.data, true +} + +func (c *Cache) Set(key string, data interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + // 如果缓存已满,删除最旧的条目 + if len(c.data) >= c.maxSize { + oldest := time.Now() + var oldestKey string + for k, v := range c.data { + if v.timestamp.Before(oldest) { + oldest = v.timestamp + oldestKey = k + } + } + delete(c.data, oldestKey) + } + + c.data[key] = cacheEntry{ + data: data, + timestamp: time.Now(), + } +} + +func (c *Cache) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for key, entry := range c.data { + if now.Sub(entry.timestamp) > cacheTTL { + delete(c.data, key) + } + } +} + +// 定期清理过期缓存 +func init() { + go func() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + searchCache.Cleanup() + } + }() +} + +// 改进的搜索结果过滤函数 +func filterSearchResults(results []Repository, query string) []Repository { + searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/")) + filtered := make([]Repository, 0) + + for _, repo := range results { + // 标准化仓库名称 + repoName := strings.ToLower(repo.Name) + repoDesc := strings.ToLower(repo.Description) + + // 计算相关性得分 + score := 0 + + // 完全匹配 + if repoName == searchTerm { + score += 100 + } + + // 前缀匹配 + if strings.HasPrefix(repoName, searchTerm) { + score += 50 + } + + // 包含匹配 + if strings.Contains(repoName, searchTerm) { + score += 30 + } + + // 描述匹配 + if strings.Contains(repoDesc, searchTerm) { + score += 10 + } + + // 官方镜像加分 + if repo.IsOfficial { + score += 20 + } + + // 分数达到阈值的结果才保留 + if score > 0 { + filtered = append(filtered, repo) + } + } + + // 按相关性排序 + sort.Slice(filtered, func(i, j int) bool { + // 优先考虑官方镜像 + if filtered[i].IsOfficial != filtered[j].IsOfficial { + return filtered[i].IsOfficial + } + // 其次考虑拉取次数 + return filtered[i].PullCount > filtered[j].PullCount + }) + + return filtered +} + +// searchDockerHub 搜索镜像 +func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { + cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) + + // 尝试从缓存获取 + if cached, ok := searchCache.Get(cacheKey); ok { + return cached.(*SearchResult), nil + } + + // 判断是否是用户/仓库格式的搜索 + isUserRepo := strings.Contains(query, "/") + var namespace, repoName string + + if isUserRepo { + parts := strings.Split(query, "/") + if len(parts) == 2 { + namespace = parts[0] + repoName = parts[1] + } + } + + // 构建搜索URL + baseURL := "https://registry.hub.docker.com/v2" + var fullURL string + var params url.Values + + if isUserRepo && namespace != "" { + // 如果是用户/仓库格式,使用repositories接口 + fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) + params = url.Values{ + "page": {fmt.Sprintf("%d", page)}, + "page_size": {fmt.Sprintf("%d", pageSize)}, + } + } else { + // 普通搜索 + fullURL = baseURL + "/search/repositories/" + params = url.Values{ + "query": {query}, + "page": {fmt.Sprintf("%d", page)}, + "page_size": {fmt.Sprintf("%d", pageSize)}, + } + } + + fullURL = fullURL + "?" + params.Encode() + + // 使用统一的搜索HTTP客户端 + resp, err := GetSearchHTTPClient().Get(fullURL) + if err != nil { + return nil, fmt.Errorf("请求Docker Hub API失败: %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭搜索响应体失败: %v\n", err) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + if resp.StatusCode != http.StatusOK { + switch resp.StatusCode { + case http.StatusTooManyRequests: + return nil, fmt.Errorf("请求过于频繁,请稍后重试") + case http.StatusNotFound: + if isUserRepo && namespace != "" { + // 如果用户仓库搜索失败,尝试普通搜索 + return searchDockerHub(ctx, repoName, page, pageSize) + } + return nil, fmt.Errorf("未找到相关镜像") + case http.StatusBadGateway, http.StatusServiceUnavailable: + return nil, fmt.Errorf("Docker Hub服务暂时不可用,请稍后重试") + default: + return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) + } + } + + // 解析响应 + var result *SearchResult + if isUserRepo && namespace != "" { + // 解析用户仓库列表响应 + var userRepos struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []Repository `json:"results"` + } + if err := json.Unmarshal(body, &userRepos); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + // 转换为SearchResult格式 + result = &SearchResult{ + Count: userRepos.Count, + Next: userRepos.Next, + Previous: userRepos.Previous, + Results: make([]Repository, 0), + } + + // 处理结果 + for _, repo := range userRepos.Results { + // 如果指定了仓库名,只保留匹配的结果 + if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { + // 确保设置正确的命名空间和名称 + repo.Namespace = namespace + if !strings.Contains(repo.Name, "/") { + repo.Name = fmt.Sprintf("%s/%s", namespace, repo.Name) + } + result.Results = append(result.Results, repo) + } + } + + // 如果没有找到结果,尝试普通搜索 + if len(result.Results) == 0 { + return searchDockerHub(ctx, repoName, page, pageSize) + } + + result.Count = len(result.Results) + } else { + // 解析普通搜索响应 + result = &SearchResult{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + // 处理搜索结果 + for i := range result.Results { + if result.Results[i].IsOfficial { + if !strings.Contains(result.Results[i].Name, "/") { + result.Results[i].Name = "library/" + result.Results[i].Name + } + result.Results[i].Namespace = "library" + } else { + parts := strings.Split(result.Results[i].Name, "/") + if len(parts) > 1 { + result.Results[i].Namespace = parts[0] + result.Results[i].Name = parts[1] + } else if result.Results[i].RepoOwner != "" { + result.Results[i].Namespace = result.Results[i].RepoOwner + result.Results[i].Name = fmt.Sprintf("%s/%s", result.Results[i].RepoOwner, result.Results[i].Name) + } + } + } + + // 如果是用户/仓库搜索,过滤结果 + if isUserRepo && namespace != "" { + filteredResults := make([]Repository, 0) + for _, repo := range result.Results { + if strings.EqualFold(repo.Namespace, namespace) { + filteredResults = append(filteredResults, repo) + } + } + result.Results = filteredResults + result.Count = len(filteredResults) + } + } + + // 缓存结果 + searchCache.Set(cacheKey, result) + return result, nil +} + +// 判断错误是否可重试 +func isRetryableError(err error) bool { + if err == nil { + return false + } + + // 网络错误、超时等可以重试 + if strings.Contains(err.Error(), "timeout") || + strings.Contains(err.Error(), "connection refused") || + strings.Contains(err.Error(), "no such host") || + strings.Contains(err.Error(), "too many requests") { + return true + } + + return false +} + +// getRepositoryTags 获取仓库标签信息 +func getRepositoryTags(ctx context.Context, namespace, name string) ([]TagInfo, error) { + if namespace == "" || name == "" { + return nil, fmt.Errorf("无效输入:命名空间和名称不能为空") + } + + cacheKey := fmt.Sprintf("tags:%s:%s", namespace, name) + if cached, ok := searchCache.Get(cacheKey); ok { + return cached.([]TagInfo), nil + } + + // 构建API URL + baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) + params := url.Values{} + params.Set("page_size", "100") + params.Set("ordering", "last_updated") + + fullURL := baseURL + "?" + params.Encode() + + // 使用统一的搜索HTTP客户端 + resp, err := GetSearchHTTPClient().Get(fullURL) + if err != nil { + return nil, fmt.Errorf("发送请求失败: %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭搜索响应体失败: %v\n", err) + } + }() + + // 读取响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) + } + + // 解析响应 + var result struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []TagInfo `json:"results"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + // 缓存结果 + searchCache.Set(cacheKey, result.Results) + return result.Results, nil +} + +// RegisterSearchRoute 注册搜索相关路由 +func RegisterSearchRoute(r *gin.Engine) { + // 搜索镜像 + r.GET("/search", func(c *gin.Context) { + query := c.Query("q") + if query == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "搜索关键词不能为空"}) + return + } + + page := 1 + pageSize := 25 + if p := c.Query("page"); p != "" { + fmt.Sscanf(p, "%d", &page) + } + if ps := c.Query("page_size"); ps != "" { + fmt.Sscanf(ps, "%d", &pageSize) + } + + result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) + }) + + // 获取标签信息 + r.GET("/tags/:namespace/:name", func(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + + if namespace == "" || name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "命名空间和名称不能为空"}) + return + } + + tags, err := getRepositoryTags(c.Request.Context(), namespace, name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, tags) + }) +} diff --git a/ghproxy/skopeo_service.go b/src/skopeo_service.go similarity index 77% rename from ghproxy/skopeo_service.go rename to src/skopeo_service.go index 011800b..d048cc9 100644 --- a/ghproxy/skopeo_service.go +++ b/src/skopeo_service.go @@ -1,1120 +1,1371 @@ -package main - -import ( - "archive/zip" - "bufio" - "context" - "crypto/rand" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "golang.org/x/sync/errgroup" -) - -// 任务状态 -type TaskStatus string - -const ( - StatusPending TaskStatus = "pending" - StatusRunning TaskStatus = "running" - StatusCompleted TaskStatus = "completed" - StatusFailed TaskStatus = "failed" -) - -// 镜像下载任务 -type ImageTask struct { - Image string `json:"image"` - Progress float64 `json:"progress"` - Status string `json:"status"` - Error string `json:"error,omitempty"` - OutputPath string `json:"-"` // 输出文件路径,不发送给客户端 - lock sync.Mutex `json:"-"` // 镜像任务自己的锁 -} - -// 下载任务 -type DownloadTask struct { - ID string `json:"id"` - Images []*ImageTask `json:"images"` - CompletedCount int `json:"completedCount"` // 已完成任务数 - TotalCount int `json:"totalCount"` // 总任务数 - Status TaskStatus `json:"status"` - OutputFile string `json:"-"` // 最终输出文件 - TempDir string `json:"-"` // 临时目录 - StatusLock sync.RWMutex `json:"-"` // 状态锁,使用读写锁提高并发性 - ProgressLock sync.RWMutex `json:"-"` // 进度锁 - ImageLock sync.RWMutex `json:"-"` // 镜像列表锁 - updateChan chan *ProgressUpdate `json:"-"` // 进度更新通道 -} - -// 进度更新消息 -type ProgressUpdate struct { - TaskID string - ImageIndex int - Progress float64 - Status string - Error string -} - -// WebSocket客户端 -type Client struct { - Conn *websocket.Conn - TaskID string - Send chan []byte - CloseOnce sync.Once -} - -var ( - // WebSocket升级器 - upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true // 允许所有源 - }, - } - - // 活跃任务映射 - tasks = make(map[string]*DownloadTask) - tasksLock sync.Mutex - clients = make(map[string]*Client) - clientLock sync.Mutex -) - -// 初始化Skopeo相关路由 -func initSkopeoRoutes(router *gin.Engine) { - // 创建临时目录 - os.MkdirAll("./temp", 0755) - - // WebSocket路由 - 用于实时获取进度 - router.GET("/ws/:taskId", handleWebSocket) - - // 创建下载任务,应用限流中间件 - ApplyRateLimit(router, "/api/download", "POST", handleDownload) - - // 获取任务状态 - router.GET("/api/task/:taskId", getTaskStatus) - - // 下载文件 - router.GET("/api/files/:filename", serveFile) - - // 启动清理过期文件的goroutine - go cleanupTempFiles() -} - -// 处理WebSocket连接 -func handleWebSocket(c *gin.Context) { - taskID := c.Param("taskId") - - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - fmt.Printf("WebSocket升级失败: %v\n", err) - return - } - - client := &Client{ - Conn: conn, - TaskID: taskID, - Send: make(chan []byte, 256), - } - - // 注册客户端 - clientLock.Lock() - clients[taskID] = client - clientLock.Unlock() - - // 启动goroutine处理消息发送 - go client.writePump() - - // 如果任务已存在,立即发送当前状态 - tasksLock.Lock() - if task, exists := tasks[taskID]; exists { - tasksLock.Unlock() - taskJSON, _ := json.Marshal(task) - client.Send <- taskJSON - } else { - tasksLock.Unlock() - } - - // 处理WebSocket关闭 - conn.SetCloseHandler(func(code int, text string) error { - client.CloseOnce.Do(func() { - close(client.Send) - clientLock.Lock() - delete(clients, taskID) - clientLock.Unlock() - }) - return nil - }) -} - -// 客户端消息发送loop -func (c *Client) writePump() { - defer func() { - c.Conn.Close() - }() - - for message := range c.Send { - err := c.Conn.WriteMessage(websocket.TextMessage, message) - if err != nil { - fmt.Printf("发送WS消息失败: %v\n", err) - break - } - } -} - -// 获取任务状态 -func getTaskStatus(c *gin.Context) { - taskID := c.Param("taskId") - - tasksLock.Lock() - task, exists := tasks[taskID] - tasksLock.Unlock() - - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"}) - return - } - - // 创建任务状态副本以避免序列化过程中的锁 - taskCopy := &DownloadTask{ - ID: task.ID, - CompletedCount: 0, - TotalCount: len(task.Images), - Status: TaskStatus(""), - Images: nil, - } - - // 复制状态信息 - task.StatusLock.RLock() - taskCopy.Status = task.Status - task.StatusLock.RUnlock() - - task.ProgressLock.RLock() - taskCopy.CompletedCount = task.CompletedCount - task.ProgressLock.RUnlock() - - // 复制镜像信息 - task.ImageLock.RLock() - taskCopy.Images = make([]*ImageTask, len(task.Images)) - for i, img := range task.Images { - img.lock.Lock() - taskCopy.Images[i] = &ImageTask{ - Image: img.Image, - Progress: img.Progress, - Status: img.Status, - Error: img.Error, - } - img.lock.Unlock() - } - task.ImageLock.RUnlock() - - c.JSON(http.StatusOK, taskCopy) -} - -// 生成随机任务ID -func generateTaskID() string { - b := make([]byte, 16) - rand.Read(b) - return hex.EncodeToString(b) -} - -// 初始化任务并启动进度处理器 -func initTask(task *DownloadTask) { - // 创建进度更新通道 - task.updateChan = make(chan *ProgressUpdate, 100) - - // 启动进度处理goroutine - go func() { - for update := range task.updateChan { - if update == nil { - // 通道关闭信号 - break - } - - // 获取更新的镜像 - task.ImageLock.RLock() - if update.ImageIndex < 0 || update.ImageIndex >= len(task.Images) { - task.ImageLock.RUnlock() - continue - } - imgTask := task.Images[update.ImageIndex] - task.ImageLock.RUnlock() - - statusChanged := false - prevStatus := "" - - // 更新镜像进度和状态 - imgTask.lock.Lock() - if update.Progress > 0 { - imgTask.Progress = update.Progress - } - if update.Status != "" && update.Status != imgTask.Status { - prevStatus = imgTask.Status - imgTask.Status = update.Status - statusChanged = true - } - if update.Error != "" { - imgTask.Error = update.Error - } - imgTask.lock.Unlock() - - // 检查状态变化并更新完成计数 - if statusChanged { - task.ProgressLock.Lock() - // 如果之前不是Completed,现在是Completed,增加计数 - if prevStatus != string(StatusCompleted) && update.Status == string(StatusCompleted) { - task.CompletedCount++ - fmt.Printf("任务 %s: 镜像 %d 完成,当前完成数: %d/%d\n", - task.ID, update.ImageIndex, task.CompletedCount, task.TotalCount) - } - // 如果之前是Completed,现在不是,减少计数 - if prevStatus == string(StatusCompleted) && update.Status != string(StatusCompleted) { - task.CompletedCount-- - if task.CompletedCount < 0 { - task.CompletedCount = 0 - } - } - task.ProgressLock.Unlock() - } - - // 发送更新到客户端 - sendTaskUpdate(task) - } - }() -} - -// 发送进度更新 -func sendProgressUpdate(task *DownloadTask, index int, progress float64, status string, errorMsg string) { - select { - case task.updateChan <- &ProgressUpdate{ - TaskID: task.ID, - ImageIndex: index, - Progress: progress, - Status: status, - Error: errorMsg, - }: - // 成功发送 - default: - // 通道已满,丢弃更新 - fmt.Printf("Warning: Update channel for task %s is full\n", task.ID) - } -} - -// 更新总进度 - 重新计算已完成任务数 -func updateTaskTotalProgress(task *DownloadTask) { - task.ProgressLock.Lock() - defer task.ProgressLock.Unlock() - - completedCount := 0 - - task.ImageLock.RLock() - totalCount := len(task.Images) - task.ImageLock.RUnlock() - - if totalCount == 0 { - return - } - - task.ImageLock.RLock() - for _, img := range task.Images { - img.lock.Lock() - if img.Status == string(StatusCompleted) { - completedCount++ - } - img.lock.Unlock() - } - task.ImageLock.RUnlock() - - task.CompletedCount = completedCount - task.TotalCount = totalCount - - fmt.Printf("任务 %s: 进度更新 %d/%d 已完成\n", task.ID, completedCount, totalCount) -} - -// 处理下载请求 -func handleDownload(c *gin.Context) { - type DownloadRequest struct { - Images []string `json:"images"` - Platform string `json:"platform"` // 平台: amd64, arm64等 - } - - var req DownloadRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数"}) - return - } - - if len(req.Images) == 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": "请提供至少一个镜像"}) - return - } - - // 添加镜像数量限制10个,防止恶意刷流量 - if len(req.Images) > 10 { - c.JSON(http.StatusBadRequest, gin.H{"error": "您下载的数量太多,宝宝承受不住"}) - return - } - - // 创建新任务 - taskID := generateTaskID() - tempDir := filepath.Join("./temp", taskID) - os.MkdirAll(tempDir, 0755) - - // 初始化任务 - imageTasks := make([]*ImageTask, len(req.Images)) - for i, image := range req.Images { - imageTasks[i] = &ImageTask{ - Image: image, - Progress: 0, - Status: string(StatusPending), - } - } - - task := &DownloadTask{ - ID: taskID, - Images: imageTasks, - CompletedCount: 0, - TotalCount: len(imageTasks), - Status: StatusPending, - TempDir: tempDir, - } - - // 初始化任务通道和处理器 - initTask(task) - - // 保存任务 - tasksLock.Lock() - tasks[taskID] = task - tasksLock.Unlock() - - // 异步处理下载 - go func() { - processDownloadTask(task, req.Platform) - // 任务完成后关闭更新通道 - close(task.updateChan) - }() - - c.JSON(http.StatusOK, gin.H{ - "taskId": taskID, - "status": "started", - "totalCount": len(imageTasks), - }) -} - -// 处理下载任务 -func processDownloadTask(task *DownloadTask, platform string) { - // 设置任务状态为运行中 - task.StatusLock.Lock() - task.Status = StatusRunning - task.StatusLock.Unlock() - - // 初始化总任务数 - task.ImageLock.RLock() - imageCount := len(task.Images) - task.ImageLock.RUnlock() - - task.ProgressLock.Lock() - task.TotalCount = imageCount - task.CompletedCount = 0 - task.ProgressLock.Unlock() - - // 通知客户端任务已开始 - sendTaskUpdate(task) - - // 创建错误组,用于管理所有下载goroutine - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // 确保资源被释放 - - g, ctx := errgroup.WithContext(ctx) - - // 启动并发下载 - task.ImageLock.RLock() - imageCount = len(task.Images) - task.ImageLock.RUnlock() - - // 创建工作池限制并发数 - const maxConcurrent = 5 - semaphore := make(chan struct{}, maxConcurrent) - - // 添加下载任务 - for i := 0; i < imageCount; i++ { - index := i // 捕获循环变量 - - g.Go(func() error { - // 获取信号量,限制并发 - semaphore <- struct{}{} - defer func() { <-semaphore }() - - task.ImageLock.RLock() - imgTask := task.Images[index] - task.ImageLock.RUnlock() - - // 下载镜像 - err := downloadImageWithContext(ctx, task, index, imgTask, platform) - if err != nil { - fmt.Printf("镜像 %s 下载失败: %v\n", imgTask.Image, err) - return err - } - return nil - }) - } - - // 等待所有下载完成 - err := g.Wait() - - // 再次计算已完成任务数,确保正确 - updateTaskTotalProgress(task) - - // 检查是否有错误发生 - if err != nil { - task.StatusLock.Lock() - task.Status = StatusFailed - task.StatusLock.Unlock() - sendTaskUpdate(task) - return - } - - // 判断是单个tar还是需要打包 - var finalFilePath string - - task.StatusLock.Lock() - - // 检查是否所有镜像都下载成功 - allSuccess := true - task.ImageLock.RLock() - for _, img := range task.Images { - img.lock.Lock() - if img.Status != string(StatusCompleted) { - allSuccess = false - } - img.lock.Unlock() - } - task.ImageLock.RUnlock() - - if !allSuccess { - task.Status = StatusFailed - task.StatusLock.Unlock() - sendTaskUpdate(task) - return - } - - // 如果只有一个文件,直接使用它 - task.ImageLock.RLock() - if imageCount == 1 { - imgTask := task.Images[0] - imgTask.lock.Lock() - if imgTask.Status == string(StatusCompleted) { - finalFilePath = imgTask.OutputPath - // 重命名为更友好的名称 - imageName := strings.ReplaceAll(imgTask.Image, "/", "_") - imageName = strings.ReplaceAll(imageName, ":", "_") - newPath := filepath.Join(task.TempDir, imageName+".tar") - os.Rename(finalFilePath, newPath) - finalFilePath = newPath - } - imgTask.lock.Unlock() - } else { - // 多个文件打包成zip - task.ImageLock.RUnlock() - var zipErr error - finalFilePath, zipErr = createZipArchive(task) - if zipErr != nil { - task.Status = StatusFailed - task.StatusLock.Unlock() - sendTaskUpdate(task) - return - } - } - - if imageCount == 1 { - task.ImageLock.RUnlock() - } - - task.OutputFile = finalFilePath - task.Status = StatusCompleted - - // 设置完成计数为总任务数 - task.ProgressLock.Lock() - task.CompletedCount = task.TotalCount - task.ProgressLock.Unlock() - - task.StatusLock.Unlock() - - // 发送最终状态更新 - sendTaskUpdate(task) - - // 确保所有进度都达到100% - ensureTaskCompletion(task) - - fmt.Printf("任务 %s 全部完成: %d/%d\n", task.ID, task.CompletedCount, task.TotalCount) -} - -// 下载单个镜像(带上下文控制) -func downloadImageWithContext(ctx context.Context, task *DownloadTask, index int, imgTask *ImageTask, platform string) error { - // 更新状态为运行中 - sendProgressUpdate(task, index, 0, string(StatusRunning), "") - - // 创建输出文件名 - outputFileName := fmt.Sprintf("image_%d.tar", index) - outputPath := filepath.Join(task.TempDir, outputFileName) - - imgTask.lock.Lock() - imgTask.OutputPath = outputPath - imgTask.lock.Unlock() - - // 创建skopeo命令 - platformArg := "" - if platform != "" { - // 支持手动输入完整的平台参数 - if strings.Contains(platform, "--") { - platformArg = platform - } else { - // 处理特殊架构格式,如 arm/v7 - if strings.Contains(platform, "/") { - parts := strings.Split(platform, "/") - if len(parts) == 2 { - // 适用于arm/v7这样的格式 - platformArg = fmt.Sprintf("--override-os linux --override-arch %s --override-variant %s", parts[0], parts[1]) - } else { - // 对于其他带/的格式,直接按原格式处理 - platformArg = fmt.Sprintf("--override-os linux --override-arch %s", platform) - } - } else { - // 仅指定架构名称的情况 - platformArg = fmt.Sprintf("--override-os linux --override-arch %s", platform) - } - } - } - - // 构建命令 - cmdStr := fmt.Sprintf("skopeo copy %s docker://%s docker-archive:%s", - platformArg, imgTask.Image, outputPath) - - fmt.Printf("执行命令: %s\n", cmdStr) - - // 创建可取消的命令 - cmd := exec.CommandContext(ctx, "sh", "-c", cmdStr) - - // 获取命令输出 - stderr, err := cmd.StderrPipe() - if err != nil { - errMsg := fmt.Sprintf("无法创建输出管道: %v", err) - sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) - return fmt.Errorf(errMsg) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - errMsg := fmt.Sprintf("无法创建标准输出管道: %v", err) - sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) - return fmt.Errorf(errMsg) - } - - if err := cmd.Start(); err != nil { - errMsg := fmt.Sprintf("启动命令失败: %v", err) - sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) - return fmt.Errorf(errMsg) - } - - // 使用进度通道传递进度信息 - outputChan := make(chan string, 20) - done := make(chan struct{}) - - // 初始进度 - sendProgressUpdate(task, index, 5, "", "") - - // 进度聚合器 - go func() { - // 镜像获取阶段的进度标记 - downloadStages := map[string]float64{ - "Getting image source signatures": 10, - "Copying blob": 30, - "Copying config": 70, - "Writing manifest": 90, - } - - // 进度增长的定时器 - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - lastProgress := 5.0 - stagnantTime := 0 - - for { - select { - case <-ctx.Done(): - // 上下文取消 - return - - case <-done: - // 命令完成,强制更新到100% - if lastProgress < 100 { - fmt.Printf("镜像 %s 下载完成,强制更新进度到100%%\n", imgTask.Image) - sendProgressUpdate(task, index, 100, string(StatusCompleted), "") - } - return - - case output := <-outputChan: - // 解析输出更新进度 - for marker, progress := range downloadStages { - if strings.Contains(output, marker) && progress > lastProgress { - lastProgress = progress - sendProgressUpdate(task, index, progress, "", "") - stagnantTime = 0 - break - } - } - - // 解析百分比 - if strings.Contains(output, "%") { - parts := strings.Split(output, "%") - if len(parts) > 0 { - numStr := strings.TrimSpace(parts[0]) - fields := strings.Fields(numStr) - if len(fields) > 0 { - lastField := fields[len(fields)-1] - parsedProgress := 0.0 - _, err := fmt.Sscanf(lastField, "%f", &parsedProgress) - if err == nil && parsedProgress > 0 && parsedProgress <= 100 { - // 根据当前阶段调整进度值 - var adjustedProgress float64 - if lastProgress < 30 { - // Copying blob阶段,进度在10-30%之间 - adjustedProgress = 10 + (parsedProgress / 100) * 20 - } else if lastProgress < 70 { - // Copying config阶段,进度在30-70%之间 - adjustedProgress = 30 + (parsedProgress / 100) * 40 - } else if lastProgress < 90 { - // Writing manifest阶段,进度在70-90%之间 - adjustedProgress = 70 + (parsedProgress / 100) * 20 - } - - if adjustedProgress > lastProgress { - lastProgress = adjustedProgress - sendProgressUpdate(task, index, adjustedProgress, "", "") - stagnantTime = 0 - } - } - } - } - } - - // 如果发现完成标记,立即更新到100% - if checkForCompletionMarkers(output) { - fmt.Printf("镜像 %s 检测到完成标记\n", imgTask.Image) - lastProgress = 100 - sendProgressUpdate(task, index, 100, string(StatusCompleted), "") - stagnantTime = 0 - } - - case <-ticker.C: - // 如果进度长时间无变化,缓慢增加 - stagnantTime += 100 // 100ms - if stagnantTime >= 10000 && lastProgress < 95 { // 10秒无变化 - // 每10秒增加5%进度,确保不超过95% - newProgress := lastProgress + 5 - if newProgress > 95 { - newProgress = 95 - } - lastProgress = newProgress - sendProgressUpdate(task, index, newProgress, "", "") - stagnantTime = 0 - } - } - } - }() - - // 读取标准输出 - go func() { - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - output := scanner.Text() - fmt.Printf("镜像 %s 标准输出: %s\n", imgTask.Image, output) - select { - case outputChan <- output: - default: - // 通道已满,丢弃 - } - } - }() - - // 读取错误输出 - go func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - output := scanner.Text() - fmt.Printf("镜像 %s 错误输出: %s\n", imgTask.Image, output) - select { - case outputChan <- output: - default: - // 通道已满,丢弃 - } - } - }() - - // 等待命令完成 - cmdErr := cmd.Wait() - close(done) // 通知进度聚合器退出 - - if cmdErr != nil { - errMsg := fmt.Sprintf("命令执行失败: %v", cmdErr) - sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) - return fmt.Errorf(errMsg) - } - - // 检查文件是否成功创建 - if _, err := os.Stat(outputPath); os.IsNotExist(err) { - errMsg := "文件未成功创建" - sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) - return fmt.Errorf(errMsg) - } - - // 确保更新状态为已完成,进度为100% - sendProgressUpdate(task, index, 100, string(StatusCompleted), "") - return nil -} - -// 创建ZIP归档 -func createZipArchive(task *DownloadTask) (string, error) { - zipFilePath := filepath.Join(task.TempDir, "images.zip") - zipFile, err := os.Create(zipFilePath) - if err != nil { - return "", fmt.Errorf("创建ZIP文件失败: %w", err) - } - defer zipFile.Close() - - zipWriter := zip.NewWriter(zipFile) - defer zipWriter.Close() - - task.ImageLock.RLock() - images := make([]*ImageTask, len(task.Images)) - copy(images, task.Images) // 创建副本避免长时间持有锁 - task.ImageLock.RUnlock() - - for _, img := range images { - img.lock.Lock() - status := img.Status - outputPath := img.OutputPath - image := img.Image - img.lock.Unlock() - - if status != string(StatusCompleted) || outputPath == "" { - continue - } - - // 创建ZIP条目 - imgFile, err := os.Open(outputPath) - if err != nil { - return "", fmt.Errorf("无法打开镜像文件 %s: %w", outputPath, err) - } - - // 使用镜像名作为文件名 - imageName := strings.ReplaceAll(image, "/", "_") - imageName = strings.ReplaceAll(imageName, ":", "_") - fileName := imageName + ".tar" - - fileInfo, err := imgFile.Stat() - if err != nil { - imgFile.Close() - return "", fmt.Errorf("无法获取文件信息: %w", err) - } - - header, err := zip.FileInfoHeader(fileInfo) - if err != nil { - imgFile.Close() - return "", fmt.Errorf("创建ZIP头信息失败: %w", err) - } - - header.Name = fileName - header.Method = zip.Deflate - - writer, err := zipWriter.CreateHeader(header) - if err != nil { - imgFile.Close() - return "", fmt.Errorf("添加文件到ZIP失败: %w", err) - } - - _, err = io.Copy(writer, imgFile) - imgFile.Close() - if err != nil { - return "", fmt.Errorf("写入ZIP文件失败: %w", err) - } - } - - return zipFilePath, nil -} - -// 发送任务更新到WebSocket -func sendTaskUpdate(task *DownloadTask) { - // 复制任务状态避免序列化时锁定 - taskCopy := &DownloadTask{ - ID: task.ID, - CompletedCount: 0, - TotalCount: len(task.Images), - Status: TaskStatus(""), - Images: nil, - } - - // 复制状态信息 - task.StatusLock.RLock() - taskCopy.Status = task.Status - task.StatusLock.RUnlock() - - task.ProgressLock.RLock() - taskCopy.CompletedCount = task.CompletedCount - task.ProgressLock.RUnlock() - - // 复制镜像信息 - task.ImageLock.RLock() - taskCopy.Images = make([]*ImageTask, len(task.Images)) - for i, img := range task.Images { - img.lock.Lock() - taskCopy.Images[i] = &ImageTask{ - Image: img.Image, - Progress: img.Progress, - Status: img.Status, - Error: img.Error, - } - img.lock.Unlock() - } - task.ImageLock.RUnlock() - - // 序列化并发送 - taskJSON, err := json.Marshal(taskCopy) - if err != nil { - fmt.Printf("序列化任务失败: %v\n", err) - return - } - - clientLock.Lock() - client, exists := clients[task.ID] - clientLock.Unlock() - - if exists { - select { - case client.Send <- taskJSON: - // 成功发送 - default: - // 通道已满或关闭,忽略 - } - } -} - -// 发送单个镜像更新 - 保持兼容性 -func sendImageUpdate(task *DownloadTask, imageIndex int) { - sendTaskUpdate(task) -} - -// 提供文件下载 -func serveFile(c *gin.Context) { - filename := c.Param("filename") - - // 安全检查,防止任意文件访问 - if strings.Contains(filename, "..") { - c.JSON(http.StatusForbidden, gin.H{"error": "无效的文件名"}) - return - } - - // 根据任务ID和文件名查找文件 - parts := strings.Split(filename, "_") - if len(parts) < 2 { - c.JSON(http.StatusBadRequest, gin.H{"error": "无效的文件名格式"}) - return - } - - taskID := parts[0] - - tasksLock.Lock() - task, exists := tasks[taskID] - tasksLock.Unlock() - - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"}) - return - } - - // 确保任务状态为已完成,并且所有进度都是100% - task.StatusLock.RLock() - isCompleted := task.Status == StatusCompleted - task.StatusLock.RUnlock() - - if isCompleted { - // 确保所有进度达到100% - ensureTaskCompletion(task) - } - - // 检查文件是否存在 - filePath := task.OutputFile - if filePath == "" || !fileExists(filePath) { - c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) - return - } - - // 获取文件信息 - fileInfo, err := os.Stat(filePath) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "无法获取文件信息"}) - return - } - - // 设置文件名 - 提取有意义的文件名 - downloadName := filepath.Base(filePath) - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", downloadName)) - c.Header("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) - - // 返回文件 - c.File(filePath) -} - -// 检查文件是否存在 -func fileExists(path string) bool { - _, err := os.Stat(path) - return !os.IsNotExist(err) -} - -// 清理过期临时文件 -func cleanupTempFiles() { - // 创建两个定时器 - hourlyTicker := time.NewTicker(1 * time.Hour) - fiveMinTicker := time.NewTicker(5 * time.Minute) - - // 清理所有文件的函数 - cleanAll := func() { - fmt.Printf("执行清理所有临时文件\n") - entries, err := os.ReadDir("./temp") - if err == nil { - for _, entry := range entries { - entryPath := filepath.Join("./temp", entry.Name()) - info, err := entry.Info() - if err == nil { - if info.IsDir() { - os.RemoveAll(entryPath) - } else { - os.Remove(entryPath) - } - } - } - } else { - fmt.Printf("清理临时文件失败: %v\n", err) - } - } - - // 检查文件大小并在需要时清理 - checkSizeAndClean := func() { - var totalSize int64 = 0 - err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // 跳过根目录 - if path == "./temp" { - return nil - } - - if !info.IsDir() { - totalSize += info.Size() - } - - return nil - }) - - if err != nil { - fmt.Printf("计算临时文件总大小失败: %v\n", err) - return - } - - // 如果总大小超过10GB,清理所有文件,防止恶意下载导致磁盘爆满 - if totalSize > 10*1024*1024*1024 { - fmt.Printf("临时文件总大小超过10GB (当前: %.2f GB),清理所有文件\n", float64(totalSize)/(1024*1024*1024)) - cleanAll() - } else { - fmt.Printf("临时文件总大小: %.2f GB\n", float64(totalSize)/(1024*1024*1024)) - } - } - - // 主循环 - for { - select { - case <-hourlyTicker.C: - // 每小时清理所有文件 - cleanAll() - case <-fiveMinTicker.C: - // 每5分钟检查一次总文件大小 - checkSizeAndClean() - } - } -} - -// 完成任务处理函数,确保进度是100% -func ensureTaskCompletion(task *DownloadTask) { - // 重新检查一遍所有镜像的进度 - task.ImageLock.RLock() - completedCount := 0 - totalCount := len(task.Images) - - for i, img := range task.Images { - img.lock.Lock() - if img.Status == string(StatusCompleted) { - // 确保进度为100% - if img.Progress < 100 { - img.Progress = 100 - fmt.Printf("确保镜像 %d 进度为100%%\n", i) - } - completedCount++ - } - img.lock.Unlock() - } - task.ImageLock.RUnlock() - - // 更新完成计数 - task.ProgressLock.Lock() - task.CompletedCount = completedCount - task.TotalCount = totalCount - task.ProgressLock.Unlock() - - // 如果任务状态为已完成,但计数不匹配,修正计数 - task.StatusLock.RLock() - isCompleted := task.Status == StatusCompleted - task.StatusLock.RUnlock() - - if isCompleted && completedCount != totalCount { - task.ProgressLock.Lock() - task.CompletedCount = totalCount - task.ProgressLock.Unlock() - fmt.Printf("任务 %s 状态已完成,强制设置计数为 %d/%d\n", task.ID, totalCount, totalCount) - } - - // 发送最终更新 - sendTaskUpdate(task) -} - -// 处理下载单个镜像的输出中的完成标记 -func checkForCompletionMarkers(output string) bool { - // 已知的完成标记 - completionMarkers := []string{ - "Writing manifest to image destination", - "Copying config complete", - "Storing signatures", - "Writing manifest complete", - } - - for _, marker := range completionMarkers { - if strings.Contains(output, marker) { - return true - } - } - - return false +package main + +import ( + "archive/zip" + "bufio" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" +) + +// 任务状态 +type TaskStatus string + +const ( + StatusPending TaskStatus = "pending" + StatusRunning TaskStatus = "running" + StatusCompleted TaskStatus = "completed" + StatusFailed TaskStatus = "failed" +) + +// 镜像下载任务 +type ImageTask struct { + Image string `json:"image"` + Progress float64 `json:"progress"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + OutputPath string `json:"-"` // 输出文件路径,不发送给客户端 + lock sync.Mutex `json:"-"` // 镜像任务自己的锁 +} + +// 下载任务 +type DownloadTask struct { + ID string `json:"id"` + Images []*ImageTask `json:"images"` + CompletedCount int `json:"completedCount"` // 已完成任务数 + TotalCount int `json:"totalCount"` // 总任务数 + Status TaskStatus `json:"status"` + OutputFile string `json:"-"` // 最终输出文件 + TempDir string `json:"-"` // 临时目录 + StatusLock sync.RWMutex `json:"-"` // 状态锁,使用读写锁提高并发性 + ProgressLock sync.RWMutex `json:"-"` // 进度锁 + ImageLock sync.RWMutex `json:"-"` // 镜像列表锁 + updateChan chan *ProgressUpdate `json:"-"` // 进度更新通道 + done chan struct{} `json:"-"` // 用于安全关闭goroutine + once sync.Once `json:"-"` // 确保只关闭一次 + createTime time.Time `json:"-"` // 创建时间,用于清理 +} + +// 进度更新消息 +type ProgressUpdate struct { + TaskID string + ImageIndex int + Progress float64 + Status string + Error string +} + +// WebSocket客户端 +type Client struct { + Conn *websocket.Conn + TaskID string + Send chan []byte + CloseOnce sync.Once +} + +var ( + // WebSocket升级器 + upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // 允许所有源 + }, + } + + // 活跃任务映射 + tasks = make(map[string]*DownloadTask) + tasksLock sync.Mutex + clients = make(map[string]*Client) + clientLock sync.Mutex +) + +// 初始化Skopeo相关路由 +func initSkopeoRoutes(router *gin.Engine) { + // 创建临时目录 + os.MkdirAll("./temp", 0755) + + // WebSocket路由 - 用于实时获取进度 + router.GET("/ws/:taskId", handleWebSocket) + + // 创建下载任务,应用限流中间件 + ApplyRateLimit(router, "/api/download", "POST", handleDownload) + + // 获取任务状态 + router.GET("/api/task/:taskId", getTaskStatus) + + // 下载文件 + router.GET("/api/files/:filename", serveFile) + + // 启动清理过期文件的goroutine + go cleanupTempFiles() + + // 启动WebSocket连接清理goroutine + go cleanupWebSocketConnections() + + // 启动过期任务清理goroutine + go cleanupExpiredTasks() +} + +// 处理WebSocket连接 +func handleWebSocket(c *gin.Context) { + taskID := c.Param("taskId") + + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + fmt.Printf("WebSocket升级失败: %v\n", err) + return + } + + client := &Client{ + Conn: conn, + TaskID: taskID, + Send: make(chan []byte, 256), + } + + // 注册客户端 + clientLock.Lock() + clients[taskID] = client + clientLock.Unlock() + + // 启动goroutine处理消息发送 + go client.writePump() + + // 如果任务已存在,立即发送当前状态 + tasksLock.Lock() + if task, exists := tasks[taskID]; exists { + tasksLock.Unlock() + taskJSON, _ := json.Marshal(task) + client.Send <- taskJSON + } else { + tasksLock.Unlock() + } + + // 设置WebSocket超时 + conn.SetReadDeadline(time.Now().Add(120 * time.Second)) + conn.SetWriteDeadline(time.Now().Add(60 * time.Second)) + + // 处理WebSocket关闭 + conn.SetCloseHandler(func(code int, text string) error { + client.CloseOnce.Do(func() { + close(client.Send) + clientLock.Lock() + delete(clients, taskID) + clientLock.Unlock() + }) + return nil + }) +} + +// 客户端消息发送loop +func (c *Client) writePump() { + defer func() { + c.Conn.Close() + }() + + for message := range c.Send { + err := c.Conn.WriteMessage(websocket.TextMessage, message) + if err != nil { + fmt.Printf("发送WS消息失败: %v\n", err) + break + } + } +} + +// 获取任务状态 +func getTaskStatus(c *gin.Context) { + taskID := c.Param("taskId") + + tasksLock.Lock() + task, exists := tasks[taskID] + tasksLock.Unlock() + + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"}) + return + } + + // 创建任务状态副本以避免序列化过程中的锁 + taskCopy := &DownloadTask{ + ID: task.ID, + CompletedCount: 0, + TotalCount: len(task.Images), + Status: TaskStatus(""), + Images: nil, + } + + // 复制状态信息 + task.StatusLock.RLock() + taskCopy.Status = task.Status + task.StatusLock.RUnlock() + + task.ProgressLock.RLock() + taskCopy.CompletedCount = task.CompletedCount + task.ProgressLock.RUnlock() + + // 复制镜像信息 + task.ImageLock.RLock() + taskCopy.Images = make([]*ImageTask, len(task.Images)) + for i, img := range task.Images { + img.lock.Lock() + taskCopy.Images[i] = &ImageTask{ + Image: img.Image, + Progress: img.Progress, + Status: img.Status, + Error: img.Error, + } + img.lock.Unlock() + } + task.ImageLock.RUnlock() + + c.JSON(http.StatusOK, taskCopy) +} + +// 生成随机任务ID +func generateTaskID() string { + b := make([]byte, 16) + rand.Read(b) + return hex.EncodeToString(b) +} + +// 初始化任务并启动进度处理器 +func initTask(task *DownloadTask) { + // 创建进度更新通道和控制通道 + task.updateChan = make(chan *ProgressUpdate, 100) + task.done = make(chan struct{}) + task.createTime = time.Now() + + // 启动进度处理goroutine + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("任务 %s 进度处理goroutine异常: %v\n", task.ID, r) + } + }() + + // 处理消息的函数 + processUpdate := func(update *ProgressUpdate) { + if update == nil { + return + } + + // 获取更新的镜像 + task.ImageLock.RLock() + if update.ImageIndex < 0 || update.ImageIndex >= len(task.Images) { + task.ImageLock.RUnlock() + return + } + imgTask := task.Images[update.ImageIndex] + task.ImageLock.RUnlock() + + statusChanged := false + prevStatus := "" + + // 更新镜像进度和状态 + imgTask.lock.Lock() + if update.Progress > 0 { + imgTask.Progress = update.Progress + } + if update.Status != "" && update.Status != imgTask.Status { + prevStatus = imgTask.Status + imgTask.Status = update.Status + statusChanged = true + } + if update.Error != "" { + imgTask.Error = update.Error + } + imgTask.lock.Unlock() + + // 检查状态变化并更新完成计数 + if statusChanged { + task.ProgressLock.Lock() + // 如果之前不是Completed,现在是Completed,增加计数 + if prevStatus != string(StatusCompleted) && update.Status == string(StatusCompleted) { + task.CompletedCount++ + fmt.Printf("任务 %s: 镜像 %d 完成,当前完成数: %d/%d\n", + task.ID, update.ImageIndex, task.CompletedCount, task.TotalCount) + } + // 如果之前是Completed,现在不是,减少计数 + if prevStatus == string(StatusCompleted) && update.Status != string(StatusCompleted) { + task.CompletedCount-- + if task.CompletedCount < 0 { + task.CompletedCount = 0 + } + } + task.ProgressLock.Unlock() + } + + // 发送更新到客户端 + sendTaskUpdate(task) + } + + // 主处理循环 + for { + select { + case update := <-task.updateChan: + if update == nil { + // 通道关闭信号,直接退出 + return + } + processUpdate(update) + + case <-task.done: + // 收到关闭信号,进入drain模式处理剩余消息 + goto drainMode + } + } + + drainMode: + // 处理通道中剩余的所有消息,确保不丢失任何更新 + for { + select { + case update := <-task.updateChan: + if update == nil { + // 通道关闭,安全退出 + return + } + processUpdate(update) + default: + // 没有更多待处理的消息,安全退出 + return + } + } + }() +} + +// 安全关闭任务的goroutine和通道 +func (task *DownloadTask) Close() { + task.once.Do(func() { + close(task.done) + // 给一点时间让goroutine退出,然后安全关闭updateChan + time.AfterFunc(100*time.Millisecond, func() { + task.safeCloseUpdateChan() + }) + }) +} + +// 安全关闭updateChan,防止重复关闭 +func (task *DownloadTask) safeCloseUpdateChan() { + defer func() { + if r := recover(); r != nil { + // 捕获关闭已关闭channel的panic,忽略它 + fmt.Printf("任务 %s: updateChan已经关闭\n", task.ID) + } + }() + close(task.updateChan) +} + +// 发送进度更新 +func sendProgressUpdate(task *DownloadTask, index int, progress float64, status string, errorMsg string) { + // 检查任务是否已经关闭 + select { + case <-task.done: + // 任务已关闭,不发送更新 + return + default: + } + + // 安全发送进度更新 + select { + case task.updateChan <- &ProgressUpdate{ + TaskID: task.ID, + ImageIndex: index, + Progress: progress, + Status: status, + Error: errorMsg, + }: + // 成功发送 + case <-task.done: + // 在发送过程中任务被关闭 + return + default: + // 通道已满,丢弃更新 + fmt.Printf("Warning: Update channel for task %s is full\n", task.ID) + } +} + +// 更新总进度 - 重新计算已完成任务数 +func updateTaskTotalProgress(task *DownloadTask) { + task.ProgressLock.Lock() + defer task.ProgressLock.Unlock() + + completedCount := 0 + + task.ImageLock.RLock() + totalCount := len(task.Images) + task.ImageLock.RUnlock() + + if totalCount == 0 { + return + } + + task.ImageLock.RLock() + for _, img := range task.Images { + img.lock.Lock() + if img.Status == string(StatusCompleted) { + completedCount++ + } + img.lock.Unlock() + } + task.ImageLock.RUnlock() + + task.CompletedCount = completedCount + task.TotalCount = totalCount + + fmt.Printf("任务 %s: 进度更新 %d/%d 已完成\n", task.ID, completedCount, totalCount) +} + +// 处理下载请求 +func handleDownload(c *gin.Context) { + type DownloadRequest struct { + Images []string `json:"images"` + Platform string `json:"platform"` // 平台: amd64, arm64等 + } + + var req DownloadRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数"}) + return + } + + if len(req.Images) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "请提供至少一个镜像"}) + return + } + + // Docker镜像访问控制检查 + for _, image := range req.Images { + if allowed, reason := GlobalAccessController.CheckDockerAccess(image); !allowed { + fmt.Printf("Docker镜像 %s 下载被拒绝: %s\n", image, reason) + c.JSON(http.StatusForbidden, gin.H{ + "error": fmt.Sprintf("镜像 %s 访问被限制: %s", image, reason), + }) + return + } + } + + // 获取配置中的镜像数量限制 + cfg := GetConfig() + maxImages := cfg.Download.MaxImages + if maxImages <= 0 { + maxImages = 10 // 安全默认值,防止配置错误 + } + + // 检查镜像数量限制,防止恶意刷流量 + if len(req.Images) > maxImages { + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("单次下载镜像数量超过限制,最多允许 %d 个镜像", maxImages), + }) + return + } + + // 创建新任务 + taskID := generateTaskID() + tempDir := filepath.Join("./temp", taskID) + os.MkdirAll(tempDir, 0755) + + // 初始化任务 + imageTasks := make([]*ImageTask, len(req.Images)) + for i, image := range req.Images { + imageTasks[i] = &ImageTask{ + Image: image, + Progress: 0, + Status: string(StatusPending), + } + } + + task := &DownloadTask{ + ID: taskID, + Images: imageTasks, + CompletedCount: 0, + TotalCount: len(imageTasks), + Status: StatusPending, + TempDir: tempDir, + } + + // 初始化任务通道和处理器 + initTask(task) + + // 保存任务 + tasksLock.Lock() + tasks[taskID] = task + tasksLock.Unlock() + + // 异步处理下载 + go func() { + defer func() { + // 任务完成后安全关闭更新通道 + task.safeCloseUpdateChan() + }() + processDownloadTask(task, req.Platform) + }() + + c.JSON(http.StatusOK, gin.H{ + "taskId": taskID, + "status": "started", + "totalCount": len(imageTasks), + }) +} + +// 处理下载任务 +func processDownloadTask(task *DownloadTask, platform string) { + // 设置任务状态为运行中 + task.StatusLock.Lock() + task.Status = StatusRunning + task.StatusLock.Unlock() + + // 初始化总任务数 + task.ImageLock.RLock() + imageCount := len(task.Images) + task.ImageLock.RUnlock() + + task.ProgressLock.Lock() + task.TotalCount = imageCount + task.CompletedCount = 0 + task.ProgressLock.Unlock() + + // 通知客户端任务已开始 + sendTaskUpdate(task) + + // 创建错误组,用于管理所有下载goroutine + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // 确保资源被释放 + + g, ctx := errgroup.WithContext(ctx) + + // 启动并发下载 + task.ImageLock.RLock() + imageCount = len(task.Images) + task.ImageLock.RUnlock() + + // 创建工作池限制并发数 + const maxConcurrent = 5 + semaphore := make(chan struct{}, maxConcurrent) + + // 添加下载任务 + for i := 0; i < imageCount; i++ { + index := i // 捕获循环变量 + + g.Go(func() error { + // 获取信号量,限制并发 + semaphore <- struct{}{} + defer func() { <-semaphore }() + + task.ImageLock.RLock() + imgTask := task.Images[index] + task.ImageLock.RUnlock() + + // 下载镜像 + err := downloadImageWithContext(ctx, task, index, imgTask, platform) + if err != nil { + fmt.Printf("镜像 %s 下载失败: %v\n", imgTask.Image, err) + return err + } + return nil + }) + } + + // 等待所有下载完成 + err := g.Wait() + + // 再次计算已完成任务数,确保正确 + updateTaskTotalProgress(task) + + // 检查是否有错误发生 + if err != nil { + task.StatusLock.Lock() + task.Status = StatusFailed + task.StatusLock.Unlock() + sendTaskUpdate(task) + // 任务失败时关闭goroutine + task.Close() + return + } + + // 判断是单个tar还是需要打包 + var finalFilePath string + + task.StatusLock.Lock() + + // 检查是否所有镜像都下载成功 + allSuccess := true + task.ImageLock.RLock() + for _, img := range task.Images { + img.lock.Lock() + if img.Status != string(StatusCompleted) { + allSuccess = false + } + img.lock.Unlock() + } + task.ImageLock.RUnlock() + + if !allSuccess { + task.Status = StatusFailed + task.StatusLock.Unlock() + sendTaskUpdate(task) + return + } + + // 如果只有一个文件,直接使用它 + task.ImageLock.RLock() + if imageCount == 1 { + imgTask := task.Images[0] + imgTask.lock.Lock() + if imgTask.Status == string(StatusCompleted) { + finalFilePath = imgTask.OutputPath + // 重命名为更友好的名称 + imageName := strings.ReplaceAll(imgTask.Image, "/", "_") + imageName = strings.ReplaceAll(imageName, ":", "_") + newPath := filepath.Join(task.TempDir, imageName+".tar") + os.Rename(finalFilePath, newPath) + finalFilePath = newPath + } + imgTask.lock.Unlock() + } else { + // 多个文件打包成zip + task.ImageLock.RUnlock() + var zipErr error + finalFilePath, zipErr = createZipArchive(task) + if zipErr != nil { + task.Status = StatusFailed + task.StatusLock.Unlock() + sendTaskUpdate(task) + return + } + } + + if imageCount == 1 { + task.ImageLock.RUnlock() + } + + task.OutputFile = finalFilePath + task.Status = StatusCompleted + + // 设置完成计数为总任务数 + task.ProgressLock.Lock() + task.CompletedCount = task.TotalCount + task.ProgressLock.Unlock() + + task.StatusLock.Unlock() + + // 发送最终状态更新 + sendTaskUpdate(task) + + // 确保所有进度都达到100% + ensureTaskCompletion(task) + + // 任务完成时关闭goroutine + task.Close() + + fmt.Printf("任务 %s 全部完成: %d/%d\n", task.ID, task.CompletedCount, task.TotalCount) +} + +// 下载单个镜像(带上下文控制) +func downloadImageWithContext(ctx context.Context, task *DownloadTask, index int, imgTask *ImageTask, platform string) error { + // 更新状态为运行中 + sendProgressUpdate(task, index, 0, string(StatusRunning), "") + + // 创建输出文件名 + outputFileName := fmt.Sprintf("image_%d.tar", index) + outputPath := filepath.Join(task.TempDir, outputFileName) + + imgTask.lock.Lock() + imgTask.OutputPath = outputPath + imgTask.lock.Unlock() + + // 创建skopeo命令 + platformArg := "" + if platform != "" { + // 支持手动输入完整的平台参数 + if strings.Contains(platform, "--") { + platformArg = platform + } else { + // 处理特殊架构格式,如 arm/v7 + if strings.Contains(platform, "/") { + parts := strings.Split(platform, "/") + if len(parts) == 2 { + // 适用于arm/v7这样的格式 + platformArg = fmt.Sprintf("--override-os linux --override-arch %s --override-variant %s", parts[0], parts[1]) + } else { + // 对于其他带/的格式,直接按原格式处理 + platformArg = fmt.Sprintf("--override-os linux --override-arch %s", platform) + } + } else { + // 仅指定架构名称的情况 + platformArg = fmt.Sprintf("--override-os linux --override-arch %s", platform) + } + } + } + + // 构建命令 + cmdStr := fmt.Sprintf("skopeo copy %s docker://%s docker-archive:%s", + platformArg, imgTask.Image, outputPath) + + fmt.Printf("执行命令: %s\n", cmdStr) + + // 创建可取消的命令 + cmd := exec.CommandContext(ctx, "sh", "-c", cmdStr) + + // 获取命令输出 + stderr, err := cmd.StderrPipe() + if err != nil { + errMsg := fmt.Sprintf("无法创建输出管道: %v", err) + sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) + return fmt.Errorf(errMsg) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + errMsg := fmt.Sprintf("无法创建标准输出管道: %v", err) + sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) + return fmt.Errorf(errMsg) + } + + if err := cmd.Start(); err != nil { + errMsg := fmt.Sprintf("启动命令失败: %v", err) + sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) + return fmt.Errorf(errMsg) + } + + // 使用进度通道传递进度信息 + outputChan := make(chan string, 20) + done := make(chan struct{}) + + // 初始进度 + sendProgressUpdate(task, index, 5, "", "") + + // 进度聚合器 + go func() { + // 镜像获取阶段的进度标记 + downloadStages := map[string]float64{ + "Getting image source signatures": 10, + "Copying blob": 30, + "Copying config": 70, + "Writing manifest": 90, + } + + // 进度增长的定时器 + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + lastProgress := 5.0 + stagnantTime := 0 + + for { + select { + case <-ctx.Done(): + // 上下文取消 + return + + case <-done: + // 命令完成,强制更新到100% + if lastProgress < 100 { + fmt.Printf("镜像 %s 下载完成,强制更新进度到100%%\n", imgTask.Image) + sendProgressUpdate(task, index, 100, string(StatusCompleted), "") + } + return + + case output := <-outputChan: + // 解析输出更新进度 + for marker, progress := range downloadStages { + if strings.Contains(output, marker) && progress > lastProgress { + lastProgress = progress + sendProgressUpdate(task, index, progress, "", "") + stagnantTime = 0 + break + } + } + + // 解析百分比 + if strings.Contains(output, "%") { + parts := strings.Split(output, "%") + if len(parts) > 0 { + numStr := strings.TrimSpace(parts[0]) + fields := strings.Fields(numStr) + if len(fields) > 0 { + lastField := fields[len(fields)-1] + parsedProgress := 0.0 + _, err := fmt.Sscanf(lastField, "%f", &parsedProgress) + if err == nil && parsedProgress > 0 && parsedProgress <= 100 { + // 根据当前阶段调整进度值 + var adjustedProgress float64 + if lastProgress < 30 { + // Copying blob阶段,进度在10-30%之间 + adjustedProgress = 10 + (parsedProgress / 100) * 20 + } else if lastProgress < 70 { + // Copying config阶段,进度在30-70%之间 + adjustedProgress = 30 + (parsedProgress / 100) * 40 + } else if lastProgress < 90 { + // Writing manifest阶段,进度在70-90%之间 + adjustedProgress = 70 + (parsedProgress / 100) * 20 + } + + if adjustedProgress > lastProgress { + lastProgress = adjustedProgress + sendProgressUpdate(task, index, adjustedProgress, "", "") + stagnantTime = 0 + } + } + } + } + } + + // 如果发现完成标记,立即更新到100% + if checkForCompletionMarkers(output) { + fmt.Printf("镜像 %s 检测到完成标记\n", imgTask.Image) + lastProgress = 100 + sendProgressUpdate(task, index, 100, string(StatusCompleted), "") + stagnantTime = 0 + } + + case <-ticker.C: + // 如果进度长时间无变化,缓慢增加 + stagnantTime += 100 // 100ms + if stagnantTime >= 10000 && lastProgress < 95 { // 10秒无变化 + // 每10秒增加5%进度,确保不超过95% + newProgress := lastProgress + 5 + if newProgress > 95 { + newProgress = 95 + } + lastProgress = newProgress + sendProgressUpdate(task, index, newProgress, "", "") + stagnantTime = 0 + } + } + } + }() + + // 读取标准输出 + go func() { + defer func() { + // 确保pipe在goroutine退出时关闭 + stdout.Close() + }() + scanner := bufio.NewScanner(stdout) + for { + // 检查context是否已取消 + select { + case <-ctx.Done(): + return + default: + } + + if !scanner.Scan() { + break // EOF或错误,正常退出 + } + + output := scanner.Text() + fmt.Printf("镜像 %s 标准输出: %s\n", imgTask.Image, output) + select { + case outputChan <- output: + case <-ctx.Done(): + return + default: + // 通道已满,丢弃 + } + } + }() + + // 读取错误输出 + go func() { + defer func() { + // 确保pipe在goroutine退出时关闭 + stderr.Close() + }() + scanner := bufio.NewScanner(stderr) + for { + // 检查context是否已取消 + select { + case <-ctx.Done(): + return + default: + } + + if !scanner.Scan() { + break // EOF或错误,正常退出 + } + + output := scanner.Text() + fmt.Printf("镜像 %s 错误输出: %s\n", imgTask.Image, output) + select { + case outputChan <- output: + case <-ctx.Done(): + return + default: + // 通道已满,丢弃 + } + } + }() + + // 等待命令完成 + cmdErr := cmd.Wait() + close(done) // 通知进度聚合器退出 + + if cmdErr != nil { + errMsg := fmt.Sprintf("命令执行失败: %v", cmdErr) + sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) + return fmt.Errorf(errMsg) + } + + // 检查文件是否成功创建 + if _, err := os.Stat(outputPath); os.IsNotExist(err) { + errMsg := "文件未成功创建" + sendProgressUpdate(task, index, 0, string(StatusFailed), errMsg) + return fmt.Errorf(errMsg) + } + + // 确保更新状态为已完成,进度为100% + sendProgressUpdate(task, index, 100, string(StatusCompleted), "") + return nil +} + +// 创建ZIP归档 +func createZipArchive(task *DownloadTask) (string, error) { + zipFilePath := filepath.Join(task.TempDir, "images.zip") + zipFile, err := os.Create(zipFilePath) + if err != nil { + return "", fmt.Errorf("创建ZIP文件失败: %w", err) + } + defer zipFile.Close() + + zipWriter := zip.NewWriter(zipFile) + defer zipWriter.Close() + + task.ImageLock.RLock() + images := make([]*ImageTask, len(task.Images)) + copy(images, task.Images) // 创建副本避免长时间持有锁 + task.ImageLock.RUnlock() + + for _, img := range images { + img.lock.Lock() + status := img.Status + outputPath := img.OutputPath + image := img.Image + img.lock.Unlock() + + if status != string(StatusCompleted) || outputPath == "" { + continue + } + + // 创建ZIP条目 + imgFile, err := os.Open(outputPath) + if err != nil { + return "", fmt.Errorf("无法打开镜像文件 %s: %w", outputPath, err) + } + + // 使用镜像名作为文件名 + imageName := strings.ReplaceAll(image, "/", "_") + imageName = strings.ReplaceAll(imageName, ":", "_") + fileName := imageName + ".tar" + + fileInfo, err := imgFile.Stat() + if err != nil { + imgFile.Close() + return "", fmt.Errorf("无法获取文件信息: %w", err) + } + + header, err := zip.FileInfoHeader(fileInfo) + if err != nil { + imgFile.Close() + return "", fmt.Errorf("创建ZIP头信息失败: %w", err) + } + + header.Name = fileName + header.Method = zip.Deflate + + writer, err := zipWriter.CreateHeader(header) + if err != nil { + imgFile.Close() + return "", fmt.Errorf("添加文件到ZIP失败: %w", err) + } + + _, err = io.Copy(writer, imgFile) + imgFile.Close() + if err != nil { + return "", fmt.Errorf("写入ZIP文件失败: %w", err) + } + } + + return zipFilePath, nil +} + +// 发送任务更新到WebSocket +func sendTaskUpdate(task *DownloadTask) { + // 复制任务状态避免序列化时锁定 + taskCopy := &DownloadTask{ + ID: task.ID, + CompletedCount: 0, + TotalCount: len(task.Images), + Status: TaskStatus(""), + Images: nil, + } + + // 复制状态信息 + task.StatusLock.RLock() + taskCopy.Status = task.Status + task.StatusLock.RUnlock() + + task.ProgressLock.RLock() + taskCopy.CompletedCount = task.CompletedCount + task.ProgressLock.RUnlock() + + // 复制镜像信息 + task.ImageLock.RLock() + taskCopy.Images = make([]*ImageTask, len(task.Images)) + for i, img := range task.Images { + img.lock.Lock() + taskCopy.Images[i] = &ImageTask{ + Image: img.Image, + Progress: img.Progress, + Status: img.Status, + Error: img.Error, + } + img.lock.Unlock() + } + task.ImageLock.RUnlock() + + // 序列化并发送 + taskJSON, err := json.Marshal(taskCopy) + if err != nil { + fmt.Printf("序列化任务失败: %v\n", err) + return + } + + clientLock.Lock() + client, exists := clients[task.ID] + clientLock.Unlock() + + if exists { + select { + case client.Send <- taskJSON: + // 成功发送 + default: + // 通道已满或关闭,忽略 + } + } +} + +// 提供文件下载 +func serveFile(c *gin.Context) { + filename := c.Param("filename") + + // 增强安全检查,防止路径遍历攻击 + if strings.Contains(filename, "..") || + strings.Contains(filename, "/") || + strings.Contains(filename, "\\") || + strings.Contains(filename, "\x00") { + c.JSON(http.StatusForbidden, gin.H{"error": "无效的文件名"}) + return + } + + // 根据任务ID和文件名查找文件 + parts := strings.Split(filename, "_") + if len(parts) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的文件名格式"}) + return + } + + taskID := parts[0] + + tasksLock.Lock() + task, exists := tasks[taskID] + tasksLock.Unlock() + + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"}) + return + } + + // 确保任务状态为已完成,并且所有进度都是100% + task.StatusLock.RLock() + isCompleted := task.Status == StatusCompleted + task.StatusLock.RUnlock() + + if isCompleted { + // 确保所有进度达到100% + ensureTaskCompletion(task) + } + + // 检查文件是否存在 + filePath := task.OutputFile + if filePath == "" || !fileExists(filePath) { + c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"}) + return + } + + // 获取文件信息 + fileInfo, err := os.Stat(filePath) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "无法获取文件信息"}) + return + } + + // 设置文件名 - 提取有意义的文件名 + downloadName := filepath.Base(filePath) + c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", downloadName)) + c.Header("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) + + // 返回文件 + c.File(filePath) +} + +// 检查文件是否存在 +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) +} + +// 清理过期临时文件 +func cleanupTempFiles() { + // 创建两个定时器 + hourlyTicker := time.NewTicker(1 * time.Hour) + fiveMinTicker := time.NewTicker(5 * time.Minute) + + // 清理所有文件的函数 + cleanAll := func() { + fmt.Printf("执行清理所有临时文件\n") + entries, err := os.ReadDir("./temp") + if err == nil { + for _, entry := range entries { + entryPath := filepath.Join("./temp", entry.Name()) + info, err := entry.Info() + if err == nil { + if info.IsDir() { + os.RemoveAll(entryPath) + } else { + os.Remove(entryPath) + } + } + } + } else { + fmt.Printf("清理临时文件失败: %v\n", err) + } + } + + // 检查文件大小并在需要时清理 + checkSizeAndClean := func() { + var totalSize int64 = 0 + err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // 跳过根目录 + if path == "./temp" { + return nil + } + + if !info.IsDir() { + totalSize += info.Size() + } + + return nil + }) + + if err != nil { + fmt.Printf("计算临时文件总大小失败: %v\n", err) + return + } + + // 如果总大小超过10GB,清理所有文件,防止恶意下载导致磁盘爆满 + if totalSize > 10*1024*1024*1024 { + fmt.Printf("临时文件总大小超过10GB (当前: %.2f GB),清理所有文件\n", float64(totalSize)/(1024*1024*1024)) + cleanAll() + } else { + fmt.Printf("临时文件总大小: %.2f GB\n", float64(totalSize)/(1024*1024*1024)) + } + } + + // 主循环 + for { + select { + case <-hourlyTicker.C: + // 每小时清理所有文件 + cleanAll() + case <-fiveMinTicker.C: + // 每5分钟检查一次总文件大小 + checkSizeAndClean() + } + } +} + +// 完成任务处理函数,确保进度是100% +func ensureTaskCompletion(task *DownloadTask) { + // 重新检查一遍所有镜像的进度 + task.ImageLock.RLock() + completedCount := 0 + totalCount := len(task.Images) + + for i, img := range task.Images { + img.lock.Lock() + if img.Status == string(StatusCompleted) { + // 确保进度为100% + if img.Progress < 100 { + img.Progress = 100 + fmt.Printf("确保镜像 %d 进度为100%%\n", i) + } + completedCount++ + } + img.lock.Unlock() + } + task.ImageLock.RUnlock() + + // 更新完成计数 + task.ProgressLock.Lock() + task.CompletedCount = completedCount + task.TotalCount = totalCount + task.ProgressLock.Unlock() + + // 如果任务状态为已完成,但计数不匹配,修正计数 + task.StatusLock.RLock() + isCompleted := task.Status == StatusCompleted + task.StatusLock.RUnlock() + + if isCompleted && completedCount != totalCount { + task.ProgressLock.Lock() + task.CompletedCount = totalCount + task.ProgressLock.Unlock() + fmt.Printf("任务 %s 状态已完成,强制设置计数为 %d/%d\n", task.ID, totalCount, totalCount) + } + + // 发送最终更新 + sendTaskUpdate(task) +} + +// 处理下载单个镜像的输出中的完成标记 +func checkForCompletionMarkers(output string) bool { + // 已知的完成标记 + completionMarkers := []string{ + "Writing manifest to image destination", + "Copying config complete", + "Storing signatures", + "Writing manifest complete", + } + + for _, marker := range completionMarkers { + if strings.Contains(output, marker) { + return true + } + } + + return false +} + +// cleanupWebSocketConnections 定期清理无效的WebSocket连接 +func cleanupWebSocketConnections() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + clientLock.Lock() + disconnectedClients := make([]string, 0) + + for taskID, client := range clients { + // 检查连接是否仍然活跃 + if err := client.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + // 连接已断开,标记待清理 + disconnectedClients = append(disconnectedClients, taskID) + } + } + + // 清理断开的连接 + for _, taskID := range disconnectedClients { + if client, exists := clients[taskID]; exists { + client.CloseOnce.Do(func() { + close(client.Send) + client.Conn.Close() + }) + delete(clients, taskID) + } + } + + clientLock.Unlock() + + if len(disconnectedClients) > 0 { + fmt.Printf("清理了 %d 个断开的WebSocket连接\n", len(disconnectedClients)) + } + } +} + +// cleanupExpiredTasks 清理过期任务 +func cleanupExpiredTasks() { + ticker := time.NewTicker(30 * time.Minute) // 每30分钟清理一次 + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + expiredTasks := make([]string, 0) + + tasksLock.Lock() + for taskID, task := range tasks { + // 清理超过2小时的已完成任务,或超过6小时的任何任务 + isExpired := false + + task.StatusLock.RLock() + taskStatus := task.Status + task.StatusLock.RUnlock() + + // 已完成或失败的任务:2小时后清理 + if (taskStatus == StatusCompleted || taskStatus == StatusFailed) && + now.Sub(task.createTime) > 2*time.Hour { + isExpired = true + } + // 任何任务:6小时后强制清理 + if now.Sub(task.createTime) > 6*time.Hour { + isExpired = true + } + + if isExpired { + expiredTasks = append(expiredTasks, taskID) + } + } + + // 清理过期任务 + for _, taskID := range expiredTasks { + if task, exists := tasks[taskID]; exists { + // 安全关闭任务的goroutine + task.Close() + + // 清理临时文件 + if task.TempDir != "" { + os.RemoveAll(task.TempDir) + } + if task.OutputFile != "" && fileExists(task.OutputFile) { + os.Remove(task.OutputFile) + } + + delete(tasks, taskID) + } + } + tasksLock.Unlock() + + if len(expiredTasks) > 0 { + fmt.Printf("清理了 %d 个过期任务\n", len(expiredTasks)) + } + + // 输出统计信息 + tasksLock.Lock() + activeTaskCount := len(tasks) + tasksLock.Unlock() + + clientLock.Lock() + activeClientCount := len(clients) + clientLock.Unlock() + + fmt.Printf("当前活跃任务: %d, 活跃WebSocket连接: %d\n", activeTaskCount, activeClientCount) + } } \ No newline at end of file