diff --git a/src/access_control.go b/src/access_control.go index b8c6ab1..7d71a14 100644 --- a/src/access_control.go +++ b/src/access_control.go @@ -1,212 +1,212 @@ -package main - -import ( - "strings" - "sync" -) - -// ResourceType 资源类型 -type ResourceType string - -const ( - ResourceTypeGitHub ResourceType = "github" - ResourceTypeDocker ResourceType = "docker" -) - -// AccessController 统一访问控制器 -type AccessController struct { - mu sync.RWMutex -} - -// DockerImageInfo Docker镜像信息 -type DockerImageInfo struct { - Namespace string - Repository string - Tag string - FullName string -} - -// 全局访问控制器实例 -var GlobalAccessController = &AccessController{} - -// ParseDockerImage 解析Docker镜像名称 -func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { - image = strings.TrimPrefix(image, "docker://") - - var tag string - if idx := strings.LastIndex(image, ":"); idx != -1 { - part := image[idx+1:] - if !strings.Contains(part, "/") { - tag = part - image = image[:idx] - } - } - if tag == "" { - tag = "latest" - } - - var namespace, repository string - if strings.Contains(image, "/") { - parts := strings.Split(image, "/") - if len(parts) >= 2 { - if strings.Contains(parts[0], ".") { - if len(parts) >= 3 { - namespace = parts[1] - repository = parts[2] - } else { - namespace = "library" - repository = parts[1] - } - } else { - namespace = parts[0] - repository = parts[1] - } - } - } else { - namespace = "library" - repository = image - } - - fullName := namespace + "/" + repository - - return DockerImageInfo{ - Namespace: namespace, - Repository: repository, - Tag: tag, - FullName: fullName, - } -} - -// CheckDockerAccess 检查Docker镜像访问权限 -func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { - cfg := GetConfig() - - // 解析镜像名称 - imageInfo := ac.ParseDockerImage(image) - - // 检查白名单(如果配置了白名单,则只允许白名单中的镜像) - if len(cfg.Access.WhiteList) > 0 { - if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) { - return false, "不在Docker镜像白名单内" - } - } - - // 检查黑名单 - if len(cfg.Access.BlackList) > 0 { - if ac.matchImageInList(imageInfo, cfg.Access.BlackList) { - return false, "Docker镜像在黑名单内" - } - } - - return true, "" -} - -// CheckGitHubAccess 检查GitHub仓库访问权限 -func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) { - if len(matches) < 2 { - return false, "无效的GitHub仓库格式" - } - - cfg := GetConfig() - - // 检查白名单 - if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) { - return false, "不在GitHub仓库白名单内" - } - - // 检查黑名单 - if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) { - return false, "GitHub仓库在黑名单内" - } - - return true, "" -} - -// matchImageInList 检查Docker镜像是否在指定列表中 -func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool { - fullName := strings.ToLower(imageInfo.FullName) - namespace := strings.ToLower(imageInfo.Namespace) - - for _, item := range list { - item = strings.ToLower(strings.TrimSpace(item)) - if item == "" { - continue - } - - if fullName == item { - return true - } - - if item == namespace || item == namespace+"/*" { - return true - } - - if strings.HasSuffix(item, "*") { - prefix := strings.TrimSuffix(item, "*") - if strings.HasPrefix(fullName, prefix) { - return true - } - } - - if strings.HasPrefix(item, "*/") { - repoPattern := strings.TrimPrefix(item, "*/") - if strings.HasSuffix(repoPattern, "*") { - repoPrefix := strings.TrimSuffix(repoPattern, "*") - if strings.HasPrefix(imageInfo.Repository, repoPrefix) { - return true - } - } else { - if strings.ToLower(imageInfo.Repository) == repoPattern { - return true - } - } - } - - if strings.HasPrefix(fullName, item+"/") { - return true - } - } - return false -} - -// checkList GitHub仓库检查逻辑 -func (ac *AccessController) checkList(matches, list []string) bool { - if len(matches) < 2 { - return false - } - - username := strings.ToLower(strings.TrimSpace(matches[0])) - repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git"))) - fullRepo := username + "/" + repoName - - for _, item := range list { - item = strings.ToLower(strings.TrimSpace(item)) - if item == "" { - continue - } - - // 支持多种匹配模式 - if fullRepo == item { - return true - } - - // 用户级匹配 - if item == username || item == username+"/*" { - return true - } - - // 前缀匹配(支持通配符) - if strings.HasSuffix(item, "*") { - prefix := strings.TrimSuffix(item, "*") - if strings.HasPrefix(fullRepo, prefix) { - return true - } - } - - // 子仓库匹配(防止 user/repo 匹配到 user/repo-fork) - if strings.HasPrefix(fullRepo, item+"/") { - return true - } - } - return false -} +package main + +import ( + "strings" + "sync" +) + +// ResourceType 资源类型 +type ResourceType string + +const ( + ResourceTypeGitHub ResourceType = "github" + ResourceTypeDocker ResourceType = "docker" +) + +// AccessController 统一访问控制器 +type AccessController struct { + mu sync.RWMutex +} + +// DockerImageInfo Docker镜像信息 +type DockerImageInfo struct { + Namespace string + Repository string + Tag string + FullName string +} + +// 全局访问控制器实例 +var GlobalAccessController = &AccessController{} + +// ParseDockerImage 解析Docker镜像名称 +func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { + image = strings.TrimPrefix(image, "docker://") + + var tag string + if idx := strings.LastIndex(image, ":"); idx != -1 { + part := image[idx+1:] + if !strings.Contains(part, "/") { + tag = part + image = image[:idx] + } + } + if tag == "" { + tag = "latest" + } + + var namespace, repository string + if strings.Contains(image, "/") { + parts := strings.Split(image, "/") + if len(parts) >= 2 { + if strings.Contains(parts[0], ".") { + if len(parts) >= 3 { + namespace = parts[1] + repository = parts[2] + } else { + namespace = "library" + repository = parts[1] + } + } else { + namespace = parts[0] + repository = parts[1] + } + } + } else { + namespace = "library" + repository = image + } + + fullName := namespace + "/" + repository + + return DockerImageInfo{ + Namespace: namespace, + Repository: repository, + Tag: tag, + FullName: fullName, + } +} + +// CheckDockerAccess 检查Docker镜像访问权限 +func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { + cfg := GetConfig() + + // 解析镜像名称 + imageInfo := ac.ParseDockerImage(image) + + // 检查白名单(如果配置了白名单,则只允许白名单中的镜像) + if len(cfg.Access.WhiteList) > 0 { + if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) { + return false, "不在Docker镜像白名单内" + } + } + + // 检查黑名单 + if len(cfg.Access.BlackList) > 0 { + if ac.matchImageInList(imageInfo, cfg.Access.BlackList) { + return false, "Docker镜像在黑名单内" + } + } + + return true, "" +} + +// CheckGitHubAccess 检查GitHub仓库访问权限 +func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) { + if len(matches) < 2 { + return false, "无效的GitHub仓库格式" + } + + cfg := GetConfig() + + // 检查白名单 + if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) { + return false, "不在GitHub仓库白名单内" + } + + // 检查黑名单 + if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) { + return false, "GitHub仓库在黑名单内" + } + + return true, "" +} + +// matchImageInList 检查Docker镜像是否在指定列表中 +func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool { + fullName := strings.ToLower(imageInfo.FullName) + namespace := strings.ToLower(imageInfo.Namespace) + + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + + if fullName == item { + return true + } + + if item == namespace || item == namespace+"/*" { + return true + } + + if strings.HasSuffix(item, "*") { + prefix := strings.TrimSuffix(item, "*") + if strings.HasPrefix(fullName, prefix) { + return true + } + } + + if strings.HasPrefix(item, "*/") { + repoPattern := strings.TrimPrefix(item, "*/") + if strings.HasSuffix(repoPattern, "*") { + repoPrefix := strings.TrimSuffix(repoPattern, "*") + if strings.HasPrefix(imageInfo.Repository, repoPrefix) { + return true + } + } else { + if strings.ToLower(imageInfo.Repository) == repoPattern { + return true + } + } + } + + if strings.HasPrefix(fullName, item+"/") { + return true + } + } + return false +} + +// checkList GitHub仓库检查逻辑 +func (ac *AccessController) checkList(matches, list []string) bool { + if len(matches) < 2 { + return false + } + + username := strings.ToLower(strings.TrimSpace(matches[0])) + repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git"))) + fullRepo := username + "/" + repoName + + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + + // 支持多种匹配模式 + if fullRepo == item { + return true + } + + // 用户级匹配 + if item == username || item == username+"/*" { + return true + } + + // 前缀匹配(支持通配符) + if strings.HasSuffix(item, "*") { + prefix := strings.TrimSuffix(item, "*") + if strings.HasPrefix(fullRepo, prefix) { + return true + } + } + + // 子仓库匹配(防止 user/repo 匹配到 user/repo-fork) + if strings.HasPrefix(fullRepo, item+"/") { + return true + } + } + return false +} diff --git a/src/docker.go b/src/docker.go index 80ffd1f..2760fec 100644 --- a/src/docker.go +++ b/src/docker.go @@ -1,676 +1,676 @@ -package main - -import ( - "context" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/google/go-containerregistry/pkg/authn" - "github.com/google/go-containerregistry/pkg/name" - "github.com/google/go-containerregistry/pkg/v1/remote" -) - -// DockerProxy Docker代理配置 -type DockerProxy struct { - registry name.Registry - options []remote.Option -} - -var dockerProxy *DockerProxy - -// RegistryDetector Registry检测器 -type RegistryDetector struct{} - -// detectRegistryDomain 检测Registry域名并返回域名和剩余路径 -func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) { - cfg := GetConfig() - - // 检查路径是否以已知Registry域名开头 - for domain := range cfg.Registries { - if strings.HasPrefix(path, domain+"/") { - // 找到匹配的域名,返回域名和剩余路径 - remainingPath := strings.TrimPrefix(path, domain+"/") - return domain, remainingPath - } - } - - return "", path -} - -// isRegistryEnabled 检查Registry是否启用 -func (rd *RegistryDetector) isRegistryEnabled(domain string) bool { - cfg := GetConfig() - if mapping, exists := cfg.Registries[domain]; exists { - return mapping.Enabled - } - return false -} - -// getRegistryMapping 获取Registry映射配置 -func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) { - cfg := GetConfig() - mapping, exists := cfg.Registries[domain] - return mapping, exists && mapping.Enabled -} - -var registryDetector = &RegistryDetector{} - -// 初始化Docker代理 -func initDockerProxy() { - // 创建目标registry - registry, err := name.NewRegistry("registry-1.docker.io") - if err != nil { - fmt.Printf("创建Docker registry失败: %v\n", err) - return - } - - // 配置代理选项 - options := []remote.Option{ - remote.WithAuth(authn.Anonymous), - remote.WithUserAgent("hubproxy/go-containerregistry"), - remote.WithTransport(GetGlobalHTTPClient().Transport), - } - - dockerProxy = &DockerProxy{ - registry: registry, - options: options, - } -} - -// ProxyDockerRegistryGin 标准Docker Registry API v2代理 -func ProxyDockerRegistryGin(c *gin.Context) { - path := c.Request.URL.Path - - // 处理 /v2/ API版本检查 - if path == "/v2/" { - c.JSON(http.StatusOK, gin.H{}) - return - } - - // 处理不同的API端点 - if strings.HasPrefix(path, "/v2/") { - handleRegistryRequest(c, path) - } else { - c.String(http.StatusNotFound, "Docker Registry API v2 only") - } -} - -// handleRegistryRequest 处理Registry请求 -func handleRegistryRequest(c *gin.Context, path string) { - // 移除 /v2/ 前缀 - pathWithoutV2 := strings.TrimPrefix(path, "/v2/") - - if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" { - if registryDetector.isRegistryEnabled(registryDomain) { - // 设置目标Registry信息到Context - c.Set("target_registry_domain", registryDomain) - c.Set("target_path", remainingPath) - - // 处理多Registry请求 - handleMultiRegistryRequest(c, registryDomain, remainingPath) - return - } - } - - imageName, apiType, reference := parseRegistryPath(pathWithoutV2) - if imageName == "" || apiType == "" { - c.String(http.StatusBadRequest, "Invalid path format") - return - } - - // 自动处理官方镜像的library命名空间 - if !strings.Contains(imageName, "/") { - imageName = "library/" + imageName - } - - // Docker镜像访问控制检查 - if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed { - fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason) - c.String(http.StatusForbidden, "镜像访问被限制") - return - } - - // 构建完整的镜像引用 - imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName) - - switch apiType { - case "manifests": - handleManifestRequest(c, imageRef, reference) - case "blobs": - handleBlobRequest(c, imageRef, reference) - case "tags": - handleTagsRequest(c, imageRef) - default: - c.String(http.StatusNotFound, "API endpoint not found") - } -} - -// parseRegistryPath 解析Registry路径 -func parseRegistryPath(path string) (imageName, apiType, reference string) { - // 查找API端点关键字 - if idx := strings.Index(path, "/manifests/"); idx != -1 { - imageName = path[:idx] - apiType = "manifests" - reference = path[idx+len("/manifests/"):] - return - } - - if idx := strings.Index(path, "/blobs/"); idx != -1 { - imageName = path[:idx] - apiType = "blobs" - reference = path[idx+len("/blobs/"):] - return - } - - if idx := strings.Index(path, "/tags/list"); idx != -1 { - imageName = path[:idx] - apiType = "tags" - reference = "list" - return - } - - return "", "", "" -} - -// handleManifestRequest 处理manifest请求 -func handleManifestRequest(c *gin.Context, imageRef, reference string) { - // Manifest缓存逻辑(仅对GET请求缓存) - if isCacheEnabled() && c.Request.Method == http.MethodGet { - cacheKey := buildManifestCacheKey(imageRef, reference) - - // 优先从缓存获取 - if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { - writeCachedResponse(c, cachedItem) - return - } - } - - var ref name.Reference - var err error - - // 判断reference是digest还是tag - if strings.HasPrefix(reference, "sha256:") { - // 是digest - ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) - } else { - // 是tag - ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) - } - - if err != nil { - fmt.Printf("解析镜像引用失败: %v\n", err) - c.String(http.StatusBadRequest, "Invalid reference") - return - } - - // 根据请求方法选择操作 - if c.Request.Method == http.MethodHead { - // HEAD请求,使用remote.Head - desc, err := remote.Head(ref, dockerProxy.options...) - if err != nil { - fmt.Printf("HEAD请求失败: %v\n", err) - c.String(http.StatusNotFound, "Manifest not found") - return - } - - // 设置响应头 - c.Header("Content-Type", string(desc.MediaType)) - c.Header("Docker-Content-Digest", desc.Digest.String()) - c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) - c.Status(http.StatusOK) - } else { - // GET请求,使用remote.Get - desc, err := remote.Get(ref, dockerProxy.options...) - if err != nil { - fmt.Printf("GET请求失败: %v\n", err) - c.String(http.StatusNotFound, "Manifest not found") - return - } - - // 设置响应头 - headers := map[string]string{ - "Docker-Content-Digest": desc.Digest.String(), - "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), - } - - // 缓存响应 - if isCacheEnabled() { - cacheKey := buildManifestCacheKey(imageRef, reference) - ttl := getManifestTTL(reference) - globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) - } - - // 设置响应头 - c.Header("Content-Type", string(desc.MediaType)) - for key, value := range headers { - c.Header(key, value) - } - - // 返回manifest内容 - c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) - } -} - -// handleBlobRequest 处理blob请求 -func handleBlobRequest(c *gin.Context, imageRef, digest string) { - // 构建digest引用 - digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) - if err != nil { - fmt.Printf("解析digest引用失败: %v\n", err) - c.String(http.StatusBadRequest, "Invalid digest reference") - return - } - - // 使用remote.Layer获取layer - layer, err := remote.Layer(digestRef, dockerProxy.options...) - if err != nil { - fmt.Printf("获取layer失败: %v\n", err) - c.String(http.StatusNotFound, "Layer not found") - return - } - - // 获取layer信息 - size, err := layer.Size() - if err != nil { - fmt.Printf("获取layer大小失败: %v\n", err) - c.String(http.StatusInternalServerError, "Failed to get layer size") - return - } - - // 获取layer内容 - reader, err := layer.Compressed() - if err != nil { - fmt.Printf("获取layer内容失败: %v\n", err) - c.String(http.StatusInternalServerError, "Failed to get layer content") - return - } - defer reader.Close() - - // 设置响应头 - c.Header("Content-Type", "application/octet-stream") - c.Header("Content-Length", fmt.Sprintf("%d", size)) - c.Header("Docker-Content-Digest", digest) - - // 流式传输blob内容 - c.Status(http.StatusOK) - io.Copy(c.Writer, reader) -} - -// handleTagsRequest 处理tags列表请求 -func handleTagsRequest(c *gin.Context, imageRef string) { - // 解析repository - repo, err := name.NewRepository(imageRef) - if err != nil { - fmt.Printf("解析repository失败: %v\n", err) - c.String(http.StatusBadRequest, "Invalid repository") - return - } - - // 使用remote.List获取tags - tags, err := remote.List(repo, dockerProxy.options...) - if err != nil { - fmt.Printf("获取tags失败: %v\n", err) - c.String(http.StatusNotFound, "Tags not found") - return - } - - // 构建响应 - response := map[string]interface{}{ - "name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"), - "tags": tags, - } - - c.JSON(http.StatusOK, response) -} - -// ProxyDockerAuthGin Docker认证代理(带缓存优化) -func ProxyDockerAuthGin(c *gin.Context) { - // 检查是否启用token缓存 - if isTokenCacheEnabled() { - proxyDockerAuthWithCache(c) - } else { - proxyDockerAuthOriginal(c) - } -} - -// proxyDockerAuthWithCache 带缓存的认证代理 -func proxyDockerAuthWithCache(c *gin.Context) { - // 1. 构建缓存key(基于完整的查询参数) - cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery) - - // 2. 尝试从缓存获取token - if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" { - writeTokenResponse(c, cachedToken) - return - } - - // 3. 缓存未命中,创建响应记录器 - recorder := &ResponseRecorder{ - ResponseWriter: c.Writer, - statusCode: 200, - } - c.Writer = recorder - - // 4. 调用原有认证逻辑 - proxyDockerAuthOriginal(c) - - // 5. 如果认证成功,缓存响应 - if recorder.statusCode == 200 && len(recorder.body) > 0 { - ttl := extractTTLFromResponse(recorder.body) - globalCache.SetToken(cacheKey, string(recorder.body), ttl) - } - - // 6. 写入实际响应 - c.Writer = recorder.ResponseWriter - c.Data(recorder.statusCode, "application/json", recorder.body) -} - -// ResponseRecorder HTTP响应记录器 -type ResponseRecorder struct { - gin.ResponseWriter - statusCode int - body []byte -} - -func (r *ResponseRecorder) WriteHeader(code int) { - r.statusCode = code -} - -func (r *ResponseRecorder) Write(data []byte) (int, error) { - r.body = append(r.body, data...) - return len(data), nil -} - -func proxyDockerAuthOriginal(c *gin.Context) { - var authURL string - if targetDomain, exists := c.Get("target_registry_domain"); exists { - if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found { - // 使用Registry特定的认证服务器 - authURL = "https://" + mapping.AuthHost + c.Request.URL.Path - } else { - // fallback到默认Docker认证 - authURL = "https://auth.docker.io" + c.Request.URL.Path - } - } else { - // 构建默认Docker认证URL - authURL = "https://auth.docker.io" + c.Request.URL.Path - } - - if c.Request.URL.RawQuery != "" { - authURL += "?" + c.Request.URL.RawQuery - } - - // 创建HTTP客户端,复用全局传输配置(包含代理设置) - client := &http.Client{ - Timeout: 30 * time.Second, - Transport: GetGlobalHTTPClient().Transport, - } - - // 创建请求 - req, err := http.NewRequestWithContext( - context.Background(), - c.Request.Method, - authURL, - c.Request.Body, - ) - if err != nil { - c.String(http.StatusInternalServerError, "Failed to create request") - return - } - - // 复制请求头 - for key, values := range c.Request.Header { - for _, value := range values { - req.Header.Add(key, value) - } - } - - // 执行请求 - resp, err := client.Do(req) - if err != nil { - c.String(http.StatusBadGateway, "Auth request failed") - return - } - defer resp.Body.Close() - - // 获取当前代理的Host地址 - proxyHost := c.Request.Host - if proxyHost == "" { - // 使用配置中的服务器地址和端口 - cfg := GetConfig() - proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) - if cfg.Server.Host == "0.0.0.0" { - proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port) - } - } - - // 复制响应头并重写认证URL - for key, values := range resp.Header { - for _, value := range values { - // 重写WWW-Authenticate头中的realm URL - if key == "Www-Authenticate" { - // 支持多Registry的URL重写 - value = rewriteAuthHeader(value, proxyHost) - } - c.Header(key, value) - } - } - - // 返回响应 - c.Status(resp.StatusCode) - io.Copy(c.Writer, resp.Body) -} - -// rewriteAuthHeader 重写认证头 -func rewriteAuthHeader(authHeader, proxyHost string) string { - // 重写各种Registry的认证URL - authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost) - authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost) - authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost) - authHeader = strings.ReplaceAll(authHeader, "https://quay.io", "http://"+proxyHost) - - return authHeader -} - -// handleMultiRegistryRequest 处理多Registry请求 -func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) { - // 获取Registry映射配置 - mapping, exists := registryDetector.getRegistryMapping(registryDomain) - if !exists { - c.String(http.StatusBadRequest, "Registry not configured") - return - } - - // 解析剩余路径 - imageName, apiType, reference := parseRegistryPath(remainingPath) - if imageName == "" || apiType == "" { - c.String(http.StatusBadRequest, "Invalid path format") - return - } - - // 访问控制检查(使用完整的镜像路径) - fullImageName := registryDomain + "/" + imageName - if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed { - fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason) - c.String(http.StatusForbidden, "镜像访问被限制") - return - } - - // 构建上游Registry引用 - upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName) - - // 根据API类型处理请求 - switch apiType { - case "manifests": - handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping) - case "blobs": - handleUpstreamBlobRequest(c, upstreamImageRef, reference, mapping) - case "tags": - handleUpstreamTagsRequest(c, upstreamImageRef, mapping) - default: - c.String(http.StatusNotFound, "API endpoint not found") - } -} - -// handleUpstreamManifestRequest 处理上游Registry的manifest请求 -func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) { - // Manifest缓存逻辑(仅对GET请求缓存) - if isCacheEnabled() && c.Request.Method == http.MethodGet { - cacheKey := buildManifestCacheKey(imageRef, reference) - - // 优先从缓存获取 - if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { - writeCachedResponse(c, cachedItem) - return - } - } - - var ref name.Reference - var err error - - // 判断reference是digest还是tag - if strings.HasPrefix(reference, "sha256:") { - ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) - } else { - ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) - } - - if err != nil { - fmt.Printf("解析镜像引用失败: %v\n", err) - c.String(http.StatusBadRequest, "Invalid reference") - return - } - - // 创建针对上游Registry的选项 - options := createUpstreamOptions(mapping) - - // 根据请求方法选择操作 - if c.Request.Method == http.MethodHead { - desc, err := remote.Head(ref, options...) - if err != nil { - fmt.Printf("HEAD请求失败: %v\n", err) - c.String(http.StatusNotFound, "Manifest not found") - return - } - - c.Header("Content-Type", string(desc.MediaType)) - c.Header("Docker-Content-Digest", desc.Digest.String()) - c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) - c.Status(http.StatusOK) - } else { - desc, err := remote.Get(ref, options...) - if err != nil { - fmt.Printf("GET请求失败: %v\n", err) - c.String(http.StatusNotFound, "Manifest not found") - return - } - - // 设置响应头 - headers := map[string]string{ - "Docker-Content-Digest": desc.Digest.String(), - "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), - } - - // 缓存响应 - if isCacheEnabled() { - cacheKey := buildManifestCacheKey(imageRef, reference) - ttl := getManifestTTL(reference) - globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) - } - - // 设置响应头 - c.Header("Content-Type", string(desc.MediaType)) - for key, value := range headers { - c.Header(key, value) - } - - c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) - } -} - -// handleUpstreamBlobRequest 处理上游Registry的blob请求 -func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) { - digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) - if err != nil { - fmt.Printf("解析digest引用失败: %v\n", err) - c.String(http.StatusBadRequest, "Invalid digest reference") - return - } - - options := createUpstreamOptions(mapping) - layer, err := remote.Layer(digestRef, options...) - if err != nil { - fmt.Printf("获取layer失败: %v\n", err) - c.String(http.StatusNotFound, "Layer not found") - return - } - - size, err := layer.Size() - if err != nil { - fmt.Printf("获取layer大小失败: %v\n", err) - c.String(http.StatusInternalServerError, "Failed to get layer size") - return - } - - reader, err := layer.Compressed() - if err != nil { - fmt.Printf("获取layer内容失败: %v\n", err) - c.String(http.StatusInternalServerError, "Failed to get layer content") - return - } - defer reader.Close() - - c.Header("Content-Type", "application/octet-stream") - c.Header("Content-Length", fmt.Sprintf("%d", size)) - c.Header("Docker-Content-Digest", digest) - - c.Status(http.StatusOK) - io.Copy(c.Writer, reader) -} - -// handleUpstreamTagsRequest 处理上游Registry的tags请求 -func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) { - repo, err := name.NewRepository(imageRef) - if err != nil { - fmt.Printf("解析repository失败: %v\n", err) - c.String(http.StatusBadRequest, "Invalid repository") - return - } - - options := createUpstreamOptions(mapping) - tags, err := remote.List(repo, options...) - if err != nil { - fmt.Printf("获取tags失败: %v\n", err) - c.String(http.StatusNotFound, "Tags not found") - return - } - - response := map[string]interface{}{ - "name": strings.TrimPrefix(imageRef, mapping.Upstream+"/"), - "tags": tags, - } - - c.JSON(http.StatusOK, response) -} - -// createUpstreamOptions 创建上游Registry选项 -func createUpstreamOptions(mapping RegistryMapping) []remote.Option { - options := []remote.Option{ - remote.WithAuth(authn.Anonymous), - remote.WithUserAgent("hubproxy/go-containerregistry"), - remote.WithTransport(GetGlobalHTTPClient().Transport), - } - - // 根据Registry类型添加特定的认证选项(方便后续扩展) - switch mapping.AuthType { - case "github": - case "google": - case "quay": - } - - return options -} +package main + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote" +) + +// DockerProxy Docker代理配置 +type DockerProxy struct { + registry name.Registry + options []remote.Option +} + +var dockerProxy *DockerProxy + +// RegistryDetector Registry检测器 +type RegistryDetector struct{} + +// detectRegistryDomain 检测Registry域名并返回域名和剩余路径 +func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) { + cfg := GetConfig() + + // 检查路径是否以已知Registry域名开头 + for domain := range cfg.Registries { + if strings.HasPrefix(path, domain+"/") { + // 找到匹配的域名,返回域名和剩余路径 + remainingPath := strings.TrimPrefix(path, domain+"/") + return domain, remainingPath + } + } + + return "", path +} + +// isRegistryEnabled 检查Registry是否启用 +func (rd *RegistryDetector) isRegistryEnabled(domain string) bool { + cfg := GetConfig() + if mapping, exists := cfg.Registries[domain]; exists { + return mapping.Enabled + } + return false +} + +// getRegistryMapping 获取Registry映射配置 +func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) { + cfg := GetConfig() + mapping, exists := cfg.Registries[domain] + return mapping, exists && mapping.Enabled +} + +var registryDetector = &RegistryDetector{} + +// 初始化Docker代理 +func initDockerProxy() { + // 创建目标registry + registry, err := name.NewRegistry("registry-1.docker.io") + if err != nil { + fmt.Printf("创建Docker registry失败: %v\n", err) + return + } + + // 配置代理选项 + options := []remote.Option{ + remote.WithAuth(authn.Anonymous), + remote.WithUserAgent("hubproxy/go-containerregistry"), + remote.WithTransport(GetGlobalHTTPClient().Transport), + } + + dockerProxy = &DockerProxy{ + registry: registry, + options: options, + } +} + +// ProxyDockerRegistryGin 标准Docker Registry API v2代理 +func ProxyDockerRegistryGin(c *gin.Context) { + path := c.Request.URL.Path + + // 处理 /v2/ API版本检查 + if path == "/v2/" { + c.JSON(http.StatusOK, gin.H{}) + return + } + + // 处理不同的API端点 + if strings.HasPrefix(path, "/v2/") { + handleRegistryRequest(c, path) + } else { + c.String(http.StatusNotFound, "Docker Registry API v2 only") + } +} + +// handleRegistryRequest 处理Registry请求 +func handleRegistryRequest(c *gin.Context, path string) { + // 移除 /v2/ 前缀 + pathWithoutV2 := strings.TrimPrefix(path, "/v2/") + + if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" { + if registryDetector.isRegistryEnabled(registryDomain) { + // 设置目标Registry信息到Context + c.Set("target_registry_domain", registryDomain) + c.Set("target_path", remainingPath) + + // 处理多Registry请求 + handleMultiRegistryRequest(c, registryDomain, remainingPath) + return + } + } + + imageName, apiType, reference := parseRegistryPath(pathWithoutV2) + if imageName == "" || apiType == "" { + c.String(http.StatusBadRequest, "Invalid path format") + return + } + + // 自动处理官方镜像的library命名空间 + if !strings.Contains(imageName, "/") { + imageName = "library/" + imageName + } + + // Docker镜像访问控制检查 + if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed { + fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason) + c.String(http.StatusForbidden, "镜像访问被限制") + return + } + + // 构建完整的镜像引用 + imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName) + + switch apiType { + case "manifests": + handleManifestRequest(c, imageRef, reference) + case "blobs": + handleBlobRequest(c, imageRef, reference) + case "tags": + handleTagsRequest(c, imageRef) + default: + c.String(http.StatusNotFound, "API endpoint not found") + } +} + +// parseRegistryPath 解析Registry路径 +func parseRegistryPath(path string) (imageName, apiType, reference string) { + // 查找API端点关键字 + if idx := strings.Index(path, "/manifests/"); idx != -1 { + imageName = path[:idx] + apiType = "manifests" + reference = path[idx+len("/manifests/"):] + return + } + + if idx := strings.Index(path, "/blobs/"); idx != -1 { + imageName = path[:idx] + apiType = "blobs" + reference = path[idx+len("/blobs/"):] + return + } + + if idx := strings.Index(path, "/tags/list"); idx != -1 { + imageName = path[:idx] + apiType = "tags" + reference = "list" + return + } + + return "", "", "" +} + +// handleManifestRequest 处理manifest请求 +func handleManifestRequest(c *gin.Context, imageRef, reference string) { + // Manifest缓存逻辑(仅对GET请求缓存) + if isCacheEnabled() && c.Request.Method == http.MethodGet { + cacheKey := buildManifestCacheKey(imageRef, reference) + + // 优先从缓存获取 + if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { + writeCachedResponse(c, cachedItem) + return + } + } + + var ref name.Reference + var err error + + // 判断reference是digest还是tag + if strings.HasPrefix(reference, "sha256:") { + // 是digest + ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) + } else { + // 是tag + ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) + } + + if err != nil { + fmt.Printf("解析镜像引用失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid reference") + return + } + + // 根据请求方法选择操作 + if c.Request.Method == http.MethodHead { + // HEAD请求,使用remote.Head + desc, err := remote.Head(ref, dockerProxy.options...) + if err != nil { + fmt.Printf("HEAD请求失败: %v\n", err) + c.String(http.StatusNotFound, "Manifest not found") + return + } + + // 设置响应头 + c.Header("Content-Type", string(desc.MediaType)) + c.Header("Docker-Content-Digest", desc.Digest.String()) + c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) + c.Status(http.StatusOK) + } else { + // GET请求,使用remote.Get + desc, err := remote.Get(ref, dockerProxy.options...) + if err != nil { + fmt.Printf("GET请求失败: %v\n", err) + c.String(http.StatusNotFound, "Manifest not found") + return + } + + // 设置响应头 + headers := map[string]string{ + "Docker-Content-Digest": desc.Digest.String(), + "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), + } + + // 缓存响应 + if isCacheEnabled() { + cacheKey := buildManifestCacheKey(imageRef, reference) + ttl := getManifestTTL(reference) + globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) + } + + // 设置响应头 + c.Header("Content-Type", string(desc.MediaType)) + for key, value := range headers { + c.Header(key, value) + } + + // 返回manifest内容 + c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) + } +} + +// handleBlobRequest 处理blob请求 +func handleBlobRequest(c *gin.Context, imageRef, digest string) { + // 构建digest引用 + digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) + if err != nil { + fmt.Printf("解析digest引用失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid digest reference") + return + } + + // 使用remote.Layer获取layer + layer, err := remote.Layer(digestRef, dockerProxy.options...) + if err != nil { + fmt.Printf("获取layer失败: %v\n", err) + c.String(http.StatusNotFound, "Layer not found") + return + } + + // 获取layer信息 + size, err := layer.Size() + if err != nil { + fmt.Printf("获取layer大小失败: %v\n", err) + c.String(http.StatusInternalServerError, "Failed to get layer size") + return + } + + // 获取layer内容 + reader, err := layer.Compressed() + if err != nil { + fmt.Printf("获取layer内容失败: %v\n", err) + c.String(http.StatusInternalServerError, "Failed to get layer content") + return + } + defer reader.Close() + + // 设置响应头 + c.Header("Content-Type", "application/octet-stream") + c.Header("Content-Length", fmt.Sprintf("%d", size)) + c.Header("Docker-Content-Digest", digest) + + // 流式传输blob内容 + c.Status(http.StatusOK) + io.Copy(c.Writer, reader) +} + +// handleTagsRequest 处理tags列表请求 +func handleTagsRequest(c *gin.Context, imageRef string) { + // 解析repository + repo, err := name.NewRepository(imageRef) + if err != nil { + fmt.Printf("解析repository失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid repository") + return + } + + // 使用remote.List获取tags + tags, err := remote.List(repo, dockerProxy.options...) + if err != nil { + fmt.Printf("获取tags失败: %v\n", err) + c.String(http.StatusNotFound, "Tags not found") + return + } + + // 构建响应 + response := map[string]interface{}{ + "name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"), + "tags": tags, + } + + c.JSON(http.StatusOK, response) +} + +// ProxyDockerAuthGin Docker认证代理(带缓存优化) +func ProxyDockerAuthGin(c *gin.Context) { + // 检查是否启用token缓存 + if isTokenCacheEnabled() { + proxyDockerAuthWithCache(c) + } else { + proxyDockerAuthOriginal(c) + } +} + +// proxyDockerAuthWithCache 带缓存的认证代理 +func proxyDockerAuthWithCache(c *gin.Context) { + // 1. 构建缓存key(基于完整的查询参数) + cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery) + + // 2. 尝试从缓存获取token + if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" { + writeTokenResponse(c, cachedToken) + return + } + + // 3. 缓存未命中,创建响应记录器 + recorder := &ResponseRecorder{ + ResponseWriter: c.Writer, + statusCode: 200, + } + c.Writer = recorder + + // 4. 调用原有认证逻辑 + proxyDockerAuthOriginal(c) + + // 5. 如果认证成功,缓存响应 + if recorder.statusCode == 200 && len(recorder.body) > 0 { + ttl := extractTTLFromResponse(recorder.body) + globalCache.SetToken(cacheKey, string(recorder.body), ttl) + } + + // 6. 写入实际响应 + c.Writer = recorder.ResponseWriter + c.Data(recorder.statusCode, "application/json", recorder.body) +} + +// ResponseRecorder HTTP响应记录器 +type ResponseRecorder struct { + gin.ResponseWriter + statusCode int + body []byte +} + +func (r *ResponseRecorder) WriteHeader(code int) { + r.statusCode = code +} + +func (r *ResponseRecorder) Write(data []byte) (int, error) { + r.body = append(r.body, data...) + return len(data), nil +} + +func proxyDockerAuthOriginal(c *gin.Context) { + var authURL string + if targetDomain, exists := c.Get("target_registry_domain"); exists { + if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found { + // 使用Registry特定的认证服务器 + authURL = "https://" + mapping.AuthHost + c.Request.URL.Path + } else { + // fallback到默认Docker认证 + authURL = "https://auth.docker.io" + c.Request.URL.Path + } + } else { + // 构建默认Docker认证URL + authURL = "https://auth.docker.io" + c.Request.URL.Path + } + + if c.Request.URL.RawQuery != "" { + authURL += "?" + c.Request.URL.RawQuery + } + + // 创建HTTP客户端,复用全局传输配置(包含代理设置) + client := &http.Client{ + Timeout: 30 * time.Second, + Transport: GetGlobalHTTPClient().Transport, + } + + // 创建请求 + req, err := http.NewRequestWithContext( + context.Background(), + c.Request.Method, + authURL, + c.Request.Body, + ) + if err != nil { + c.String(http.StatusInternalServerError, "Failed to create request") + return + } + + // 复制请求头 + for key, values := range c.Request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + + // 执行请求 + resp, err := client.Do(req) + if err != nil { + c.String(http.StatusBadGateway, "Auth request failed") + return + } + defer resp.Body.Close() + + // 获取当前代理的Host地址 + proxyHost := c.Request.Host + if proxyHost == "" { + // 使用配置中的服务器地址和端口 + cfg := GetConfig() + proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) + if cfg.Server.Host == "0.0.0.0" { + proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port) + } + } + + // 复制响应头并重写认证URL + for key, values := range resp.Header { + for _, value := range values { + // 重写WWW-Authenticate头中的realm URL + if key == "Www-Authenticate" { + // 支持多Registry的URL重写 + value = rewriteAuthHeader(value, proxyHost) + } + c.Header(key, value) + } + } + + // 返回响应 + c.Status(resp.StatusCode) + io.Copy(c.Writer, resp.Body) +} + +// rewriteAuthHeader 重写认证头 +func rewriteAuthHeader(authHeader, proxyHost string) string { + // 重写各种Registry的认证URL + authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost) + authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost) + authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost) + authHeader = strings.ReplaceAll(authHeader, "https://quay.io", "http://"+proxyHost) + + return authHeader +} + +// handleMultiRegistryRequest 处理多Registry请求 +func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) { + // 获取Registry映射配置 + mapping, exists := registryDetector.getRegistryMapping(registryDomain) + if !exists { + c.String(http.StatusBadRequest, "Registry not configured") + return + } + + // 解析剩余路径 + imageName, apiType, reference := parseRegistryPath(remainingPath) + if imageName == "" || apiType == "" { + c.String(http.StatusBadRequest, "Invalid path format") + return + } + + // 访问控制检查(使用完整的镜像路径) + fullImageName := registryDomain + "/" + imageName + if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed { + fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason) + c.String(http.StatusForbidden, "镜像访问被限制") + return + } + + // 构建上游Registry引用 + upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName) + + // 根据API类型处理请求 + switch apiType { + case "manifests": + handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping) + case "blobs": + handleUpstreamBlobRequest(c, upstreamImageRef, reference, mapping) + case "tags": + handleUpstreamTagsRequest(c, upstreamImageRef, mapping) + default: + c.String(http.StatusNotFound, "API endpoint not found") + } +} + +// handleUpstreamManifestRequest 处理上游Registry的manifest请求 +func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) { + // Manifest缓存逻辑(仅对GET请求缓存) + if isCacheEnabled() && c.Request.Method == http.MethodGet { + cacheKey := buildManifestCacheKey(imageRef, reference) + + // 优先从缓存获取 + if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { + writeCachedResponse(c, cachedItem) + return + } + } + + var ref name.Reference + var err error + + // 判断reference是digest还是tag + if strings.HasPrefix(reference, "sha256:") { + ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) + } else { + ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) + } + + if err != nil { + fmt.Printf("解析镜像引用失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid reference") + return + } + + // 创建针对上游Registry的选项 + options := createUpstreamOptions(mapping) + + // 根据请求方法选择操作 + if c.Request.Method == http.MethodHead { + desc, err := remote.Head(ref, options...) + if err != nil { + fmt.Printf("HEAD请求失败: %v\n", err) + c.String(http.StatusNotFound, "Manifest not found") + return + } + + c.Header("Content-Type", string(desc.MediaType)) + c.Header("Docker-Content-Digest", desc.Digest.String()) + c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) + c.Status(http.StatusOK) + } else { + desc, err := remote.Get(ref, options...) + if err != nil { + fmt.Printf("GET请求失败: %v\n", err) + c.String(http.StatusNotFound, "Manifest not found") + return + } + + // 设置响应头 + headers := map[string]string{ + "Docker-Content-Digest": desc.Digest.String(), + "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), + } + + // 缓存响应 + if isCacheEnabled() { + cacheKey := buildManifestCacheKey(imageRef, reference) + ttl := getManifestTTL(reference) + globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) + } + + // 设置响应头 + c.Header("Content-Type", string(desc.MediaType)) + for key, value := range headers { + c.Header(key, value) + } + + c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) + } +} + +// handleUpstreamBlobRequest 处理上游Registry的blob请求 +func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) { + digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) + if err != nil { + fmt.Printf("解析digest引用失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid digest reference") + return + } + + options := createUpstreamOptions(mapping) + layer, err := remote.Layer(digestRef, options...) + if err != nil { + fmt.Printf("获取layer失败: %v\n", err) + c.String(http.StatusNotFound, "Layer not found") + return + } + + size, err := layer.Size() + if err != nil { + fmt.Printf("获取layer大小失败: %v\n", err) + c.String(http.StatusInternalServerError, "Failed to get layer size") + return + } + + reader, err := layer.Compressed() + if err != nil { + fmt.Printf("获取layer内容失败: %v\n", err) + c.String(http.StatusInternalServerError, "Failed to get layer content") + return + } + defer reader.Close() + + c.Header("Content-Type", "application/octet-stream") + c.Header("Content-Length", fmt.Sprintf("%d", size)) + c.Header("Docker-Content-Digest", digest) + + c.Status(http.StatusOK) + io.Copy(c.Writer, reader) +} + +// handleUpstreamTagsRequest 处理上游Registry的tags请求 +func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) { + repo, err := name.NewRepository(imageRef) + if err != nil { + fmt.Printf("解析repository失败: %v\n", err) + c.String(http.StatusBadRequest, "Invalid repository") + return + } + + options := createUpstreamOptions(mapping) + tags, err := remote.List(repo, options...) + if err != nil { + fmt.Printf("获取tags失败: %v\n", err) + c.String(http.StatusNotFound, "Tags not found") + return + } + + response := map[string]interface{}{ + "name": strings.TrimPrefix(imageRef, mapping.Upstream+"/"), + "tags": tags, + } + + c.JSON(http.StatusOK, response) +} + +// createUpstreamOptions 创建上游Registry选项 +func createUpstreamOptions(mapping RegistryMapping) []remote.Option { + options := []remote.Option{ + remote.WithAuth(authn.Anonymous), + remote.WithUserAgent("hubproxy/go-containerregistry"), + remote.WithTransport(GetGlobalHTTPClient().Transport), + } + + // 根据Registry类型添加特定的认证选项(方便后续扩展) + switch mapping.AuthType { + case "github": + case "google": + case "quay": + } + + return options +} diff --git a/src/imagetar.go b/src/imagetar.go index 6dfc585..54a18ea 100644 --- a/src/imagetar.go +++ b/src/imagetar.go @@ -52,28 +52,28 @@ func NewDownloadDebouncer(window time.Duration) *DownloadDebouncer { func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool { d.mu.Lock() defer d.mu.Unlock() - + key := userID + ":" + contentKey now := time.Now() - + if entry, exists := d.entries[key]; exists { if now.Sub(entry.LastRequest) < d.window { return false // 在防抖窗口内,拒绝请求 } } - + // 更新或创建条目 d.entries[key] = &DebounceEntry{ LastRequest: now, UserID: userID, } - + // 清理过期条目(每5分钟清理一次) if time.Since(d.lastCleanup) > 5*time.Minute { d.cleanup(now) d.lastCleanup = now } - + return true } @@ -92,10 +92,10 @@ func generateContentFingerprint(images []string, platform string) string { sortedImages := make([]string, len(images)) copy(sortedImages, images) sort.Strings(sortedImages) - + // 组合内容:镜像列表 + 平台信息 content := strings.Join(sortedImages, "|") + ":" + platform - + // 生成MD5哈希 hash := md5.Sum([]byte(content)) return hex.EncodeToString(hash[:]) @@ -107,14 +107,14 @@ func getUserID(c *gin.Context) string { if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" { return "session:" + sessionID } - + // 备用方案:IP + User-Agent组合 ip := c.ClientIP() userAgent := c.GetHeader("User-Agent") if userAgent == "" { userAgent = "unknown" } - + // 生成简短标识 combined := ip + ":" + userAgent hash := md5.Sum([]byte(combined)) @@ -228,7 +228,7 @@ func (is *ImageStreamer) StreamImageToGin(ctx context.Context, imageRef string, filename := strings.ReplaceAll(imageRef, "/", "_") + ".tar" c.Header("Content-Type", "application/octet-stream") c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) - + if options.Compression { c.Header("Content-Encoding", "gzip") } @@ -295,18 +295,18 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr if err != nil { return err } - + configData, err := json.Marshal(configFile) if err != nil { return err } - + configHeader := &tar.Header{ Name: configDigest.String() + ".json", Size: int64(len(configData)), Mode: 0644, } - + if err := tarWriter.WriteHeader(configHeader); err != nil { return err } @@ -335,14 +335,14 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr Typeflag: tar.TypeDir, Mode: 0755, } - + if err := tarWriter.WriteHeader(layerHeader); err != nil { return err } var layerSize int64 var layerReader io.ReadCloser - + // 根据配置选择使用压缩层或未压缩层 if options != nil && options.UseCompressedLayers { layerSize, err = layer.Size() @@ -357,7 +357,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr } layerReader, err = layer.Uncompressed() } - + if err != nil { return err } @@ -368,7 +368,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr Size: layerSize, Mode: 0644, } - + if err := tarWriter.WriteHeader(layerTarHeader); err != nil { return err } @@ -385,12 +385,11 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr log.Printf("已处理层 %d/%d", i+1, len(layers)) } - // 构建单个镜像的manifest信息 singleManifest := map[string]interface{}{ "Config": configDigest.String() + ".json", "RepoTags": []string{imageRef}, - "Layers": func() []string { + "Layers": func() []string { var layers []string for _, digest := range layerDigests { layers = append(layers, digest+"/layer.tar") @@ -417,22 +416,22 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr // 单镜像下载,直接写入manifest.json manifest := []map[string]interface{}{singleManifest} - + manifestData, err := json.Marshal(manifest) if err != nil { return err } - + manifestHeader := &tar.Header{ Name: "manifest.json", Size: int64(len(manifestData)), Mode: 0644, } - + if err := tarWriter.WriteHeader(manifestHeader); err != nil { return err } - + if _, err := tarWriter.Write(manifestData); err != nil { return err } @@ -442,17 +441,17 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr if err != nil { return err } - + repositoriesHeader := &tar.Header{ Name: "repositories", Size: int64(len(repositoriesData)), Mode: 0644, } - + if err := tarWriter.WriteHeader(repositoriesHeader); err != nil { return err } - + _, err = tarWriter.Write(repositoriesData) return err } @@ -473,12 +472,12 @@ func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, var manifest map[string]interface{} var repositories map[string]map[string]string - + err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories, options) if err != nil { return nil, nil, err } - + return manifest, repositories, nil } @@ -537,7 +536,7 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S if m.Platform == nil { continue } - + if options.Platform != "" { platformParts := strings.Split(options.Platform, "/") if len(platformParts) >= 2 { @@ -547,10 +546,10 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S if len(platformParts) >= 3 { targetVariant = platformParts[2] } - - if m.Platform.OS == targetOS && - m.Platform.Architecture == targetArch && - m.Platform.Variant == targetVariant { + + if m.Platform.OS == targetOS && + m.Platform.Architecture == targetArch && + m.Platform.Variant == targetVariant { selectedDesc = &m break } @@ -629,10 +628,10 @@ func handleDirectImageDownload(c *gin.Context) { // 防抖检查 userID := getUserID(c) contentKey := generateContentFingerprint([]string{imageRef}, platform) - + if !singleImageDebouncer.ShouldAllow(userID, contentKey) { c.JSON(http.StatusTooManyRequests, gin.H{ - "error": "请求过于频繁,请稍后再试", + "error": "请求过于频繁,请稍后再试", "retry_after": 5, }) return @@ -689,10 +688,10 @@ func handleSimpleBatchDownload(c *gin.Context) { // 批量下载防抖检查 userID := getUserID(c) contentKey := generateContentFingerprint(req.Images, req.Platform) - + if !batchImageDebouncer.ShouldAllow(userID, contentKey) { c.JSON(http.StatusTooManyRequests, gin.H{ - "error": "批量下载请求过于频繁,请稍后再试", + "error": "批量下载请求过于频繁,请稍后再试", "retry_after": 60, }) return @@ -713,7 +712,7 @@ func handleSimpleBatchDownload(c *gin.Context) { log.Printf("批量下载 %d 个镜像 (平台: %s)", len(req.Images), formatPlatformText(req.Platform)) filename := fmt.Sprintf("batch_%d_images.tar", len(req.Images)) - + c.Header("Content-Type", "application/octet-stream") c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) @@ -811,12 +810,12 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s } log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef) - + // 防止单个镜像处理时间过长 timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options) cancel() - + if err != nil { log.Printf("下载镜像 %s 失败: %v", imageRef, err) return fmt.Errorf("下载镜像 %s 失败: %w", imageRef, err) @@ -845,17 +844,17 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s if err != nil { return fmt.Errorf("序列化manifest失败: %w", err) } - + manifestHeader := &tar.Header{ Name: "manifest.json", Size: int64(len(manifestData)), Mode: 0644, } - + if err := tarWriter.WriteHeader(manifestHeader); err != nil { return fmt.Errorf("写入manifest header失败: %w", err) } - + if _, err := tarWriter.Write(manifestData); err != nil { return fmt.Errorf("写入manifest数据失败: %w", err) } @@ -865,21 +864,21 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s if err != nil { return fmt.Errorf("序列化repositories失败: %w", err) } - + repositoriesHeader := &tar.Header{ Name: "repositories", Size: int64(len(repositoriesData)), Mode: 0644, } - + if err := tarWriter.WriteHeader(repositoriesHeader); err != nil { return fmt.Errorf("写入repositories header失败: %w", err) } - + if _, err := tarWriter.Write(repositoriesData); err != nil { return fmt.Errorf("写入repositories数据失败: %w", err) } log.Printf("批量下载完成,共处理 %d 个镜像", len(imageRefs)) return nil -} \ No newline at end of file +} diff --git a/src/main.go b/src/main.go index f4a43a4..a3ef587 100644 --- a/src/main.go +++ b/src/main.go @@ -1,382 +1,379 @@ -package main - -import ( - "embed" - "fmt" - "io" - "log" - "net/http" - "regexp" - "strconv" - "strings" - "time" - - "github.com/gin-gonic/gin" -) - -//go:embed public/* -var staticFiles embed.FS - -// 服务嵌入的静态文件 -func serveEmbedFile(c *gin.Context, filename string) { - data, err := staticFiles.ReadFile(filename) - if err != nil { - c.Status(404) - return - } - contentType := "text/html; charset=utf-8" - if strings.HasSuffix(filename, ".ico") { - contentType = "image/x-icon" - } - c.Data(200, contentType, data) -} - -var ( - exps = []*regexp.Regexp{ - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), - regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`), - regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), - regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), - regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`), - regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), - regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), - regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`), - } - globalLimiter *IPRateLimiter - - // 服务启动时间 - serviceStartTime = time.Now() -) - -func main() { - // 加载配置 - if err := LoadConfig(); err != nil { - fmt.Printf("配置加载失败: %v\n", err) - return - } - - // 初始化HTTP客户端 - initHTTPClients() - - // 初始化限流器 - initLimiter() - - // 初始化Docker流式代理 - initDockerProxy() - - // 初始化镜像流式下载器 - initImageStreamer() - - // 初始化防抖器 - initDebouncer() - - gin.SetMode(gin.ReleaseMode) - router := gin.Default() - - // 全局Panic恢复保护 - router.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { - log.Printf("🚨 Panic recovered: %v", recovered) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Internal server error", - "code": "INTERNAL_ERROR", - }) - })) - - // 全局限流中间件 - 应用到所有路由 - router.Use(RateLimitMiddleware(globalLimiter)) - - // 初始化监控端点 - initHealthRoutes(router) - - // 初始化镜像tar下载路由 - initImageTarRoutes(router) - - // 静态文件路由 - router.GET("/", func(c *gin.Context) { - serveEmbedFile(c, "public/index.html") - }) - router.GET("/public/*filepath", func(c *gin.Context) { - filepath := strings.TrimPrefix(c.Param("filepath"), "/") - serveEmbedFile(c, "public/"+filepath) - }) - - router.GET("/images.html", func(c *gin.Context) { - serveEmbedFile(c, "public/images.html") - }) - router.GET("/search.html", func(c *gin.Context) { - serveEmbedFile(c, "public/search.html") - }) - router.GET("/favicon.ico", func(c *gin.Context) { - serveEmbedFile(c, "public/favicon.ico") - }) - - // 注册dockerhub搜索路由 - RegisterSearchRoute(router) - - // 注册Docker认证路由(/token*) - router.Any("/token", ProxyDockerAuthGin) - router.Any("/token/*path", ProxyDockerAuthGin) - - // 注册Docker Registry代理路由 - router.Any("/v2/*path", ProxyDockerRegistryGin) - - - // 注册NoRoute处理器 - router.NoRoute(handler) - - cfg := GetConfig() - fmt.Printf("🚀 HubProxy 启动成功\n") - fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port) - fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours) - fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n") - - err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) - if err != nil { - fmt.Printf("启动服务失败: %v\n", err) - } -} - -func handler(c *gin.Context) { - rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") - - for strings.HasPrefix(rawPath, "/") { - rawPath = strings.TrimPrefix(rawPath, "/") - } - - if !strings.HasPrefix(rawPath, "http") { - c.String(http.StatusForbidden, "无效输入") - return - } - - matches := checkURL(rawPath) - if matches != nil { - // GitHub仓库访问控制检查 - if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed { - // 构建仓库名用于日志 - var repoPath string - if len(matches) >= 2 { - username := matches[0] - repoName := strings.TrimSuffix(matches[1], ".git") - repoPath = username + "/" + repoName - } - fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason) - c.String(http.StatusForbidden, reason) - return - } - } else { - c.String(http.StatusForbidden, "无效输入") - return - } - - if exps[1].MatchString(rawPath) { - rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) - } - - proxyRequest(c, rawPath) -} - - -func proxyRequest(c *gin.Context, u string) { - proxyWithRedirect(c, u, 0) -} - - -func proxyWithRedirect(c *gin.Context, u string, redirectCount int) { - // 限制最大重定向次数,防止无限递归 - const maxRedirects = 20 - if redirectCount > maxRedirects { - c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向") - return - } - req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) - return - } - - for key, values := range c.Request.Header { - for _, value := range values { - req.Header.Add(key, value) - } - } - req.Header.Del("Host") - - resp, err := GetGlobalHTTPClient().Do(req) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) - return - } - defer func() { - if err := resp.Body.Close(); err != nil { - fmt.Printf("关闭响应体失败: %v\n", err) - } - }() - - // 检查文件大小限制 - cfg := GetConfig() - if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { - if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize { - c.String(http.StatusRequestEntityTooLarge, - fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024))) - return - } - } - - // 清理安全相关的头 - resp.Header.Del("Content-Security-Policy") - resp.Header.Del("Referrer-Policy") - resp.Header.Del("Strict-Transport-Security") - - // 获取真实域名 - realHost := c.Request.Header.Get("X-Forwarded-Host") - if realHost == "" { - realHost = c.Request.Host - } - // 如果域名中没有协议前缀,添加https:// - if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") { - realHost = "https://" + realHost - } - - if strings.HasSuffix(strings.ToLower(u), ".sh") { - isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip" - - processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost) - if err != nil { - fmt.Printf("智能处理失败,回退到直接代理: %v\n", err) - processedBody = resp.Body - processedSize = 0 - } - - // 智能设置响应头 - if processedSize > 0 { - resp.Header.Del("Content-Length") - resp.Header.Del("Content-Encoding") - resp.Header.Set("Transfer-Encoding", "chunked") - } - - // 复制其他响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - if location := resp.Header.Get("Location"); location != "" { - if checkURL(location) != nil { - c.Header("Location", "/"+location) - } else { - proxyWithRedirect(c, location, redirectCount+1) - return - } - } - - c.Status(resp.StatusCode) - - // 输出处理后的内容 - if _, err := io.Copy(c.Writer, processedBody); err != nil { - return - } - } else { - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - // 处理重定向 - if location := resp.Header.Get("Location"); location != "" { - if checkURL(location) != nil { - c.Header("Location", "/"+location) - } else { - proxyWithRedirect(c, location, redirectCount+1) - return - } - } - - c.Status(resp.StatusCode) - - // 直接流式转发 - io.Copy(c.Writer, resp.Body) - } -} - -func checkURL(u string) []string { - for _, exp := range exps { - if matches := exp.FindStringSubmatch(u); matches != nil { - return matches[1:] - } - } - return nil -} - -// 初始化健康监控路由 -func initHealthRoutes(router *gin.Engine) { - // 健康检查端点 - router.GET("/health", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "status": "healthy", - "timestamp": time.Now().Unix(), - "uptime": time.Since(serviceStartTime).Seconds(), - "service": "hubproxy", - }) - }) - - // 就绪检查端点 - router.GET("/ready", func(c *gin.Context) { - checks := make(map[string]string) - allReady := true - - if GetConfig() != nil { - checks["config"] = "ok" - } else { - checks["config"] = "failed" - allReady = false - } - - // 检查全局缓存状态 - if globalCache != nil { - checks["cache"] = "ok" - } else { - checks["cache"] = "failed" - allReady = false - } - - // 检查限流器状态 - if globalLimiter != nil { - checks["ratelimiter"] = "ok" - } else { - checks["ratelimiter"] = "failed" - allReady = false - } - - // 检查镜像下载器状态 - if globalImageStreamer != nil { - checks["imagestreamer"] = "ok" - } else { - checks["imagestreamer"] = "failed" - allReady = false - } - - // 检查HTTP客户端状态 - if GetGlobalHTTPClient() != nil { - checks["httpclient"] = "ok" - } else { - checks["httpclient"] = "failed" - allReady = false - } - - status := http.StatusOK - if !allReady { - status = http.StatusServiceUnavailable - } - - c.JSON(status, gin.H{ - "ready": allReady, - "checks": checks, - "timestamp": time.Now().Unix(), - "uptime": time.Since(serviceStartTime).Seconds(), - }) - }) -} +package main + +import ( + "embed" + "fmt" + "io" + "log" + "net/http" + "regexp" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +//go:embed public/* +var staticFiles embed.FS + +// 服务嵌入的静态文件 +func serveEmbedFile(c *gin.Context, filename string) { + data, err := staticFiles.ReadFile(filename) + if err != nil { + c.Status(404) + return + } + contentType := "text/html; charset=utf-8" + if strings.HasSuffix(filename, ".ico") { + contentType = "image/x-icon" + } + c.Data(200, contentType, data) +} + +var ( + exps = []*regexp.Regexp{ + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), + regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`), + regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), + regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), + regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`), + regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), + regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), + regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`), + } + globalLimiter *IPRateLimiter + + // 服务启动时间 + serviceStartTime = time.Now() +) + +func main() { + // 加载配置 + if err := LoadConfig(); err != nil { + fmt.Printf("配置加载失败: %v\n", err) + return + } + + // 初始化HTTP客户端 + initHTTPClients() + + // 初始化限流器 + initLimiter() + + // 初始化Docker流式代理 + initDockerProxy() + + // 初始化镜像流式下载器 + initImageStreamer() + + // 初始化防抖器 + initDebouncer() + + gin.SetMode(gin.ReleaseMode) + router := gin.Default() + + // 全局Panic恢复保护 + router.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { + log.Printf("🚨 Panic recovered: %v", recovered) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Internal server error", + "code": "INTERNAL_ERROR", + }) + })) + + // 全局限流中间件 - 应用到所有路由 + router.Use(RateLimitMiddleware(globalLimiter)) + + // 初始化监控端点 + initHealthRoutes(router) + + // 初始化镜像tar下载路由 + initImageTarRoutes(router) + + // 静态文件路由 + router.GET("/", func(c *gin.Context) { + serveEmbedFile(c, "public/index.html") + }) + router.GET("/public/*filepath", func(c *gin.Context) { + filepath := strings.TrimPrefix(c.Param("filepath"), "/") + serveEmbedFile(c, "public/"+filepath) + }) + + router.GET("/images.html", func(c *gin.Context) { + serveEmbedFile(c, "public/images.html") + }) + router.GET("/search.html", func(c *gin.Context) { + serveEmbedFile(c, "public/search.html") + }) + router.GET("/favicon.ico", func(c *gin.Context) { + serveEmbedFile(c, "public/favicon.ico") + }) + + // 注册dockerhub搜索路由 + RegisterSearchRoute(router) + + // 注册Docker认证路由(/token*) + router.Any("/token", ProxyDockerAuthGin) + router.Any("/token/*path", ProxyDockerAuthGin) + + // 注册Docker Registry代理路由 + router.Any("/v2/*path", ProxyDockerRegistryGin) + + // 注册NoRoute处理器 + router.NoRoute(handler) + + cfg := GetConfig() + fmt.Printf("🚀 HubProxy 启动成功\n") + fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port) + fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours) + fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n") + + err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) + if err != nil { + fmt.Printf("启动服务失败: %v\n", err) + } +} + +func handler(c *gin.Context) { + rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") + + for strings.HasPrefix(rawPath, "/") { + rawPath = strings.TrimPrefix(rawPath, "/") + } + + if !strings.HasPrefix(rawPath, "http") { + c.String(http.StatusForbidden, "无效输入") + return + } + + matches := checkURL(rawPath) + if matches != nil { + // GitHub仓库访问控制检查 + if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed { + // 构建仓库名用于日志 + var repoPath string + if len(matches) >= 2 { + username := matches[0] + repoName := strings.TrimSuffix(matches[1], ".git") + repoPath = username + "/" + repoName + } + fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason) + c.String(http.StatusForbidden, reason) + return + } + } else { + c.String(http.StatusForbidden, "无效输入") + return + } + + if exps[1].MatchString(rawPath) { + rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) + } + + proxyRequest(c, rawPath) +} + +func proxyRequest(c *gin.Context, u string) { + proxyWithRedirect(c, u, 0) +} + +func proxyWithRedirect(c *gin.Context, u string, redirectCount int) { + // 限制最大重定向次数,防止无限递归 + const maxRedirects = 20 + if redirectCount > maxRedirects { + c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向") + return + } + req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) + return + } + + for key, values := range c.Request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + req.Header.Del("Host") + + resp, err := GetGlobalHTTPClient().Do(req) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭响应体失败: %v\n", err) + } + }() + + // 检查文件大小限制 + cfg := GetConfig() + if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { + if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize { + c.String(http.StatusRequestEntityTooLarge, + fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024))) + return + } + } + + // 清理安全相关的头 + resp.Header.Del("Content-Security-Policy") + resp.Header.Del("Referrer-Policy") + resp.Header.Del("Strict-Transport-Security") + + // 获取真实域名 + realHost := c.Request.Header.Get("X-Forwarded-Host") + if realHost == "" { + realHost = c.Request.Host + } + // 如果域名中没有协议前缀,添加https:// + if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") { + realHost = "https://" + realHost + } + + if strings.HasSuffix(strings.ToLower(u), ".sh") { + isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip" + + processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost) + if err != nil { + fmt.Printf("智能处理失败,回退到直接代理: %v\n", err) + processedBody = resp.Body + processedSize = 0 + } + + // 智能设置响应头 + if processedSize > 0 { + resp.Header.Del("Content-Length") + resp.Header.Del("Content-Encoding") + resp.Header.Set("Transfer-Encoding", "chunked") + } + + // 复制其他响应头 + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + if location := resp.Header.Get("Location"); location != "" { + if checkURL(location) != nil { + c.Header("Location", "/"+location) + } else { + proxyWithRedirect(c, location, redirectCount+1) + return + } + } + + c.Status(resp.StatusCode) + + // 输出处理后的内容 + if _, err := io.Copy(c.Writer, processedBody); err != nil { + return + } + } else { + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + // 处理重定向 + if location := resp.Header.Get("Location"); location != "" { + if checkURL(location) != nil { + c.Header("Location", "/"+location) + } else { + proxyWithRedirect(c, location, redirectCount+1) + return + } + } + + c.Status(resp.StatusCode) + + // 直接流式转发 + io.Copy(c.Writer, resp.Body) + } +} + +func checkURL(u string) []string { + for _, exp := range exps { + if matches := exp.FindStringSubmatch(u); matches != nil { + return matches[1:] + } + } + return nil +} + +// 初始化健康监控路由 +func initHealthRoutes(router *gin.Engine) { + // 健康检查端点 + router.GET("/health", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "healthy", + "timestamp": time.Now().Unix(), + "uptime": time.Since(serviceStartTime).Seconds(), + "service": "hubproxy", + }) + }) + + // 就绪检查端点 + router.GET("/ready", func(c *gin.Context) { + checks := make(map[string]string) + allReady := true + + if GetConfig() != nil { + checks["config"] = "ok" + } else { + checks["config"] = "failed" + allReady = false + } + + // 检查全局缓存状态 + if globalCache != nil { + checks["cache"] = "ok" + } else { + checks["cache"] = "failed" + allReady = false + } + + // 检查限流器状态 + if globalLimiter != nil { + checks["ratelimiter"] = "ok" + } else { + checks["ratelimiter"] = "failed" + allReady = false + } + + // 检查镜像下载器状态 + if globalImageStreamer != nil { + checks["imagestreamer"] = "ok" + } else { + checks["imagestreamer"] = "failed" + allReady = false + } + + // 检查HTTP客户端状态 + if GetGlobalHTTPClient() != nil { + checks["httpclient"] = "ok" + } else { + checks["httpclient"] = "failed" + allReady = false + } + + status := http.StatusOK + if !allReady { + status = http.StatusServiceUnavailable + } + + c.JSON(status, gin.H{ + "ready": allReady, + "checks": checks, + "timestamp": time.Now().Unix(), + "uptime": time.Since(serviceStartTime).Seconds(), + }) + }) +} diff --git a/src/proxysh.go b/src/proxysh.go index f3f7f58..e49e917 100644 --- a/src/proxysh.go +++ b/src/proxysh.go @@ -1,95 +1,95 @@ -package main - -import ( - "bytes" - "compress/gzip" - "fmt" - "io" - "regexp" - "strings" -) - -// GitHub URL正则表达式 -var githubRegex = regexp.MustCompile(`https?://(?:github\.com|raw\.githubusercontent\.com|raw\.github\.com|gist\.githubusercontent\.com|gist\.github\.com|api\.github\.com)[^\s'"]+`) - -// ProcessSmart Shell脚本智能处理函数 -func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reader, int64, error) { - defer input.Close() - - content, err := readShellContent(input, isCompressed) - if err != nil { - return nil, 0, fmt.Errorf("内容读取失败: %v", err) - } - - if len(content) == 0 { - return strings.NewReader(""), 0, nil - } - - if len(content) > 10*1024*1024 { - return strings.NewReader(content), int64(len(content)), nil - } - - if !strings.Contains(content, "github.com") && !strings.Contains(content, "githubusercontent.com") { - return strings.NewReader(content), int64(len(content)), nil - } - - processed := processGitHubURLs(content, host) - - return strings.NewReader(processed), int64(len(processed)), nil -} - -func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) { - var reader io.Reader = input - - // 处理gzip压缩 - if isCompressed { - peek := make([]byte, 2) - n, err := input.Read(peek) - if err != nil && err != io.EOF { - return "", fmt.Errorf("读取数据失败: %v", err) - } - - if n >= 2 && peek[0] == 0x1f && peek[1] == 0x8b { - combinedReader := io.MultiReader(bytes.NewReader(peek[:n]), input) - gzReader, err := gzip.NewReader(combinedReader) - if err != nil { - return "", fmt.Errorf("gzip解压失败: %v", err) - } - defer gzReader.Close() - reader = gzReader - } else { - reader = io.MultiReader(bytes.NewReader(peek[:n]), input) - } - } - - data, err := io.ReadAll(reader) - if err != nil { - return "", fmt.Errorf("读取内容失败: %v", err) - } - - return string(data), nil -} - -func processGitHubURLs(content, host string) string { - return githubRegex.ReplaceAllStringFunc(content, func(url string) string { - return transformURL(url, host) - }) -} - -// transformURL URL转换函数 -func transformURL(url, host string) string { - if strings.Contains(url, host) { - return url - } - - if strings.HasPrefix(url, "http://") { - url = "https" + url[4:] - } else if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "//") { - url = "https://" + url - } - cleanHost := strings.TrimPrefix(host, "https://") - cleanHost = strings.TrimPrefix(cleanHost, "http://") - cleanHost = strings.TrimSuffix(cleanHost, "/") - - return cleanHost + "/" + url -} \ No newline at end of file +package main + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "regexp" + "strings" +) + +// GitHub URL正则表达式 +var githubRegex = regexp.MustCompile(`https?://(?:github\.com|raw\.githubusercontent\.com|raw\.github\.com|gist\.githubusercontent\.com|gist\.github\.com|api\.github\.com)[^\s'"]+`) + +// ProcessSmart Shell脚本智能处理函数 +func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reader, int64, error) { + defer input.Close() + + content, err := readShellContent(input, isCompressed) + if err != nil { + return nil, 0, fmt.Errorf("内容读取失败: %v", err) + } + + if len(content) == 0 { + return strings.NewReader(""), 0, nil + } + + if len(content) > 10*1024*1024 { + return strings.NewReader(content), int64(len(content)), nil + } + + if !strings.Contains(content, "github.com") && !strings.Contains(content, "githubusercontent.com") { + return strings.NewReader(content), int64(len(content)), nil + } + + processed := processGitHubURLs(content, host) + + return strings.NewReader(processed), int64(len(processed)), nil +} + +func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) { + var reader io.Reader = input + + // 处理gzip压缩 + if isCompressed { + peek := make([]byte, 2) + n, err := input.Read(peek) + if err != nil && err != io.EOF { + return "", fmt.Errorf("读取数据失败: %v", err) + } + + if n >= 2 && peek[0] == 0x1f && peek[1] == 0x8b { + combinedReader := io.MultiReader(bytes.NewReader(peek[:n]), input) + gzReader, err := gzip.NewReader(combinedReader) + if err != nil { + return "", fmt.Errorf("gzip解压失败: %v", err) + } + defer gzReader.Close() + reader = gzReader + } else { + reader = io.MultiReader(bytes.NewReader(peek[:n]), input) + } + } + + data, err := io.ReadAll(reader) + if err != nil { + return "", fmt.Errorf("读取内容失败: %v", err) + } + + return string(data), nil +} + +func processGitHubURLs(content, host string) string { + return githubRegex.ReplaceAllStringFunc(content, func(url string) string { + return transformURL(url, host) + }) +} + +// transformURL URL转换函数 +func transformURL(url, host string) string { + if strings.Contains(url, host) { + return url + } + + if strings.HasPrefix(url, "http://") { + url = "https" + url[4:] + } else if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "//") { + url = "https://" + url + } + cleanHost := strings.TrimPrefix(host, "https://") + cleanHost = strings.TrimPrefix(cleanHost, "http://") + cleanHost = strings.TrimSuffix(cleanHost, "/") + + return cleanHost + "/" + url +} diff --git a/src/ratelimiter.go b/src/ratelimiter.go index 671a6cd..9f7bdc8 100644 --- a/src/ratelimiter.go +++ b/src/ratelimiter.go @@ -1,303 +1,301 @@ -package main - -import ( - "fmt" - "net" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" - "golang.org/x/time/rate" -) - -const ( - // 清理间隔 - CleanupInterval = 10 * time.Minute - MaxIPCacheSize = 10000 -) - -// IPRateLimiter IP限流器结构体 -type IPRateLimiter struct { - ips map[string]*rateLimiterEntry // IP到限流器的映射 - mu *sync.RWMutex // 读写锁,保证并发安全 - r rate.Limit // 速率限制(每秒允许的请求数) - b int // 令牌桶容量(突发请求数) - whitelist []*net.IPNet // 白名单IP段 - blacklist []*net.IPNet // 黑名单IP段 -} - -// rateLimiterEntry 限流器条目 -type rateLimiterEntry struct { - limiter *rate.Limiter - lastAccess time.Time -} - -// initGlobalLimiter 初始化全局限流器 -func initGlobalLimiter() *IPRateLimiter { - cfg := GetConfig() - - whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList)) - for _, item := range cfg.Security.WhiteList { - if item = strings.TrimSpace(item); item != "" { - if !strings.Contains(item, "/") { - item = item + "/32" // 单个IP转为CIDR格式 - } - _, ipnet, err := net.ParseCIDR(item) - if err == nil { - whitelist = append(whitelist, ipnet) - } else { - fmt.Printf("警告: 无效的白名单IP格式: %s\n", item) - } - } - } - - // 解析黑名单IP段 - blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList)) - for _, item := range cfg.Security.BlackList { - if item = strings.TrimSpace(item); item != "" { - if !strings.Contains(item, "/") { - item = item + "/32" // 单个IP转为CIDR格式 - } - _, ipnet, err := net.ParseCIDR(item) - if err == nil { - blacklist = append(blacklist, ipnet) - } else { - fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item) - } - } - } - - // 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求" - ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600)) - - burstSize := cfg.RateLimit.RequestLimit - if burstSize < 1 { - burstSize = 1 - } - - limiter := &IPRateLimiter{ - ips: make(map[string]*rateLimiterEntry), - mu: &sync.RWMutex{}, - r: ratePerSecond, - b: burstSize, - whitelist: whitelist, - blacklist: blacklist, - } - - // 启动定期清理goroutine - go limiter.cleanupRoutine() - - return limiter -} - -// initLimiter 初始化限流器 -func initLimiter() { - globalLimiter = initGlobalLimiter() -} - -// cleanupRoutine 定期清理过期的限流器 -func (i *IPRateLimiter) cleanupRoutine() { - ticker := time.NewTicker(CleanupInterval) - defer ticker.Stop() - - for range ticker.C { - now := time.Now() - expired := make([]string, 0) - - // 查找过期的条目 - i.mu.RLock() - for ip, entry := range i.ips { - // 如果最后访问时间超过1小时,认为过期 - if now.Sub(entry.lastAccess) > 1*time.Hour { - expired = append(expired, ip) - } - } - i.mu.RUnlock() - - // 如果有过期条目或者缓存过大,进行清理 - if len(expired) > 0 || len(i.ips) > MaxIPCacheSize { - i.mu.Lock() - // 删除过期条目 - for _, ip := range expired { - delete(i.ips, ip) - } - - // 如果缓存仍然过大,全部清理 - if len(i.ips) > MaxIPCacheSize { - i.ips = make(map[string]*rateLimiterEntry) - } - i.mu.Unlock() - } - } -} - -// extractIPFromAddress 从地址中提取纯IP -func extractIPFromAddress(address string) string { - if host, _, err := net.SplitHostPort(address); err == nil { - return host - } - return address -} - -// normalizeIPForRateLimit 标准化IP地址用于限流:IPv4保持不变,IPv6标准化为/64网段 -func normalizeIPForRateLimit(ipStr string) string { - ip := net.ParseIP(ipStr) - if ip == nil { - return ipStr // 解析失败,返回原值 - } - - if ip.To4() != nil { - return ipStr // IPv4保持不变 - } - - // IPv6:标准化为 /64 网段 - ipv6 := ip.To16() - for i := 8; i < 16; i++ { - ipv6[i] = 0 // 清零后64位 - } - return ipv6.String() + "/64" -} - -// isIPInCIDRList 检查IP是否在CIDR列表中 -func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { - // 先提取纯IP地址 - cleanIP := extractIPFromAddress(ip) - parsedIP := net.ParseIP(cleanIP) - if parsedIP == nil { - return false - } - - for _, cidr := range cidrList { - if cidr.Contains(parsedIP) { - return true - } - } - return false -} - -// GetLimiter 获取指定IP的限流器,同时返回是否允许访问 -func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { - // 提取纯IP地址 - cleanIP := extractIPFromAddress(ip) - - // 检查是否在黑名单中 - if isIPInCIDRList(cleanIP, i.blacklist) { - return nil, false - } - - // 检查是否在白名单中 - if isIPInCIDRList(cleanIP, i.whitelist) { - return rate.NewLimiter(rate.Inf, i.b), true - } - - // 标准化IP用于限流:IPv4保持不变,IPv6标准化为/64网段 - normalizedIP := normalizeIPForRateLimit(cleanIP) - - now := time.Now() - - i.mu.RLock() - entry, exists := i.ips[normalizedIP] - i.mu.RUnlock() - - if exists { - i.mu.Lock() - if entry, stillExists := i.ips[normalizedIP]; stillExists { - entry.lastAccess = now - i.mu.Unlock() - return entry.limiter, true - } - i.mu.Unlock() - } - - i.mu.Lock() - if entry, exists := i.ips[normalizedIP]; exists { - entry.lastAccess = now - i.mu.Unlock() - return entry.limiter, true - } - - entry = &rateLimiterEntry{ - limiter: rate.NewLimiter(i.r, i.b), - lastAccess: now, - } - i.ips[normalizedIP] = entry - i.mu.Unlock() - - return entry.limiter, true -} - -// RateLimitMiddleware 速率限制中间件 -func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { - return func(c *gin.Context) { - // 静态文件豁免:跳过限流检查 - path := c.Request.URL.Path - if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" || - strings.HasPrefix(path, "/public/") { - c.Next() - return - } - - // 获取客户端真实IP - var ip string - - // 优先尝试从请求头获取真实IP - if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" { - // X-Forwarded-For可能包含多个IP,取第一个 - ips := strings.Split(forwarded, ",") - ip = strings.TrimSpace(ips[0]) - } else if realIP := c.GetHeader("X-Real-IP"); realIP != "" { - // 如果有X-Real-IP头 - ip = realIP - } else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" { - // 某些代理可能使用此头 - ips := strings.Split(remoteIP, ",") - ip = strings.TrimSpace(ips[0]) - } else { - // 回退到ClientIP方法 - ip = c.ClientIP() - } - - // 提取纯IP地址(去除可能存在的端口) - cleanIP := extractIPFromAddress(ip) - - // 日志记录请求IP和头信息 - normalizedIP := normalizeIPForRateLimit(cleanIP) - if cleanIP != normalizedIP { - fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", - ip, cleanIP, normalizedIP, - c.GetHeader("X-Forwarded-For"), - c.GetHeader("X-Real-IP")) - } else { - fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", - ip, cleanIP, - c.GetHeader("X-Forwarded-For"), - c.GetHeader("X-Real-IP")) - } - - // 获取限流器并检查是否允许访问 - ipLimiter, allowed := limiter.GetLimiter(cleanIP) - - // 如果IP在黑名单中 - if !allowed { - c.JSON(403, gin.H{ - "error": "您已被限制访问", - }) - c.Abort() - return - } - - // 检查限流 - if !ipLimiter.Allow() { - c.JSON(429, gin.H{ - "error": "请求频率过快,暂时限制访问", - }) - c.Abort() - return - } - - c.Next() - } -} - - +package main + +import ( + "fmt" + "net" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +const ( + // 清理间隔 + CleanupInterval = 10 * time.Minute + MaxIPCacheSize = 10000 +) + +// IPRateLimiter IP限流器结构体 +type IPRateLimiter struct { + ips map[string]*rateLimiterEntry // IP到限流器的映射 + mu *sync.RWMutex // 读写锁,保证并发安全 + r rate.Limit // 速率限制(每秒允许的请求数) + b int // 令牌桶容量(突发请求数) + whitelist []*net.IPNet // 白名单IP段 + blacklist []*net.IPNet // 黑名单IP段 +} + +// rateLimiterEntry 限流器条目 +type rateLimiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// initGlobalLimiter 初始化全局限流器 +func initGlobalLimiter() *IPRateLimiter { + cfg := GetConfig() + + whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList)) + for _, item := range cfg.Security.WhiteList { + if item = strings.TrimSpace(item); item != "" { + if !strings.Contains(item, "/") { + item = item + "/32" // 单个IP转为CIDR格式 + } + _, ipnet, err := net.ParseCIDR(item) + if err == nil { + whitelist = append(whitelist, ipnet) + } else { + fmt.Printf("警告: 无效的白名单IP格式: %s\n", item) + } + } + } + + // 解析黑名单IP段 + blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList)) + for _, item := range cfg.Security.BlackList { + if item = strings.TrimSpace(item); item != "" { + if !strings.Contains(item, "/") { + item = item + "/32" // 单个IP转为CIDR格式 + } + _, ipnet, err := net.ParseCIDR(item) + if err == nil { + blacklist = append(blacklist, ipnet) + } else { + fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item) + } + } + } + + // 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求" + ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600)) + + burstSize := cfg.RateLimit.RequestLimit + if burstSize < 1 { + burstSize = 1 + } + + limiter := &IPRateLimiter{ + ips: make(map[string]*rateLimiterEntry), + mu: &sync.RWMutex{}, + r: ratePerSecond, + b: burstSize, + whitelist: whitelist, + blacklist: blacklist, + } + + // 启动定期清理goroutine + go limiter.cleanupRoutine() + + return limiter +} + +// initLimiter 初始化限流器 +func initLimiter() { + globalLimiter = initGlobalLimiter() +} + +// cleanupRoutine 定期清理过期的限流器 +func (i *IPRateLimiter) cleanupRoutine() { + ticker := time.NewTicker(CleanupInterval) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + expired := make([]string, 0) + + // 查找过期的条目 + i.mu.RLock() + for ip, entry := range i.ips { + // 如果最后访问时间超过1小时,认为过期 + if now.Sub(entry.lastAccess) > 1*time.Hour { + expired = append(expired, ip) + } + } + i.mu.RUnlock() + + // 如果有过期条目或者缓存过大,进行清理 + if len(expired) > 0 || len(i.ips) > MaxIPCacheSize { + i.mu.Lock() + // 删除过期条目 + for _, ip := range expired { + delete(i.ips, ip) + } + + // 如果缓存仍然过大,全部清理 + if len(i.ips) > MaxIPCacheSize { + i.ips = make(map[string]*rateLimiterEntry) + } + i.mu.Unlock() + } + } +} + +// extractIPFromAddress 从地址中提取纯IP +func extractIPFromAddress(address string) string { + if host, _, err := net.SplitHostPort(address); err == nil { + return host + } + return address +} + +// normalizeIPForRateLimit 标准化IP地址用于限流:IPv4保持不变,IPv6标准化为/64网段 +func normalizeIPForRateLimit(ipStr string) string { + ip := net.ParseIP(ipStr) + if ip == nil { + return ipStr // 解析失败,返回原值 + } + + if ip.To4() != nil { + return ipStr // IPv4保持不变 + } + + // IPv6:标准化为 /64 网段 + ipv6 := ip.To16() + for i := 8; i < 16; i++ { + ipv6[i] = 0 // 清零后64位 + } + return ipv6.String() + "/64" +} + +// isIPInCIDRList 检查IP是否在CIDR列表中 +func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { + // 先提取纯IP地址 + cleanIP := extractIPFromAddress(ip) + parsedIP := net.ParseIP(cleanIP) + if parsedIP == nil { + return false + } + + for _, cidr := range cidrList { + if cidr.Contains(parsedIP) { + return true + } + } + return false +} + +// GetLimiter 获取指定IP的限流器,同时返回是否允许访问 +func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { + // 提取纯IP地址 + cleanIP := extractIPFromAddress(ip) + + // 检查是否在黑名单中 + if isIPInCIDRList(cleanIP, i.blacklist) { + return nil, false + } + + // 检查是否在白名单中 + if isIPInCIDRList(cleanIP, i.whitelist) { + return rate.NewLimiter(rate.Inf, i.b), true + } + + // 标准化IP用于限流:IPv4保持不变,IPv6标准化为/64网段 + normalizedIP := normalizeIPForRateLimit(cleanIP) + + now := time.Now() + + i.mu.RLock() + entry, exists := i.ips[normalizedIP] + i.mu.RUnlock() + + if exists { + i.mu.Lock() + if entry, stillExists := i.ips[normalizedIP]; stillExists { + entry.lastAccess = now + i.mu.Unlock() + return entry.limiter, true + } + i.mu.Unlock() + } + + i.mu.Lock() + if entry, exists := i.ips[normalizedIP]; exists { + entry.lastAccess = now + i.mu.Unlock() + return entry.limiter, true + } + + entry = &rateLimiterEntry{ + limiter: rate.NewLimiter(i.r, i.b), + lastAccess: now, + } + i.ips[normalizedIP] = entry + i.mu.Unlock() + + return entry.limiter, true +} + +// RateLimitMiddleware 速率限制中间件 +func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { + return func(c *gin.Context) { + // 静态文件豁免:跳过限流检查 + path := c.Request.URL.Path + if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" || + strings.HasPrefix(path, "/public/") { + c.Next() + return + } + + // 获取客户端真实IP + var ip string + + // 优先尝试从请求头获取真实IP + if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" { + // X-Forwarded-For可能包含多个IP,取第一个 + ips := strings.Split(forwarded, ",") + ip = strings.TrimSpace(ips[0]) + } else if realIP := c.GetHeader("X-Real-IP"); realIP != "" { + // 如果有X-Real-IP头 + ip = realIP + } else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" { + // 某些代理可能使用此头 + ips := strings.Split(remoteIP, ",") + ip = strings.TrimSpace(ips[0]) + } else { + // 回退到ClientIP方法 + ip = c.ClientIP() + } + + // 提取纯IP地址(去除可能存在的端口) + cleanIP := extractIPFromAddress(ip) + + // 日志记录请求IP和头信息 + normalizedIP := normalizeIPForRateLimit(cleanIP) + if cleanIP != normalizedIP { + fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", + ip, cleanIP, normalizedIP, + c.GetHeader("X-Forwarded-For"), + c.GetHeader("X-Real-IP")) + } else { + fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", + ip, cleanIP, + c.GetHeader("X-Forwarded-For"), + c.GetHeader("X-Real-IP")) + } + + // 获取限流器并检查是否允许访问 + ipLimiter, allowed := limiter.GetLimiter(cleanIP) + + // 如果IP在黑名单中 + if !allowed { + c.JSON(403, gin.H{ + "error": "您已被限制访问", + }) + c.Abort() + return + } + + // 检查限流 + if !ipLimiter.Allow() { + c.JSON(429, gin.H{ + "error": "请求频率过快,暂时限制访问", + }) + c.Abort() + return + } + + c.Next() + } +} diff --git a/src/search.go b/src/search.go index db284f4..1d06cc1 100644 --- a/src/search.go +++ b/src/search.go @@ -1,500 +1,500 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "sort" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" -) - -// SearchResult Docker Hub搜索结果 -type SearchResult struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []Repository `json:"results"` -} - -// Repository 仓库信息 -type Repository struct { - Name string `json:"repo_name"` - Description string `json:"short_description"` - IsOfficial bool `json:"is_official"` - IsAutomated bool `json:"is_automated"` - StarCount int `json:"star_count"` - PullCount int `json:"pull_count"` - RepoOwner string `json:"repo_owner"` - LastUpdated string `json:"last_updated"` - Status int `json:"status"` - Organization string `json:"affiliation"` - PullsLastWeek int `json:"pulls_last_week"` - Namespace string `json:"namespace"` -} - -// TagInfo 标签信息 -type TagInfo struct { - Name string `json:"name"` - FullSize int64 `json:"full_size"` - LastUpdated time.Time `json:"last_updated"` - LastPusher string `json:"last_pusher"` - Images []Image `json:"images"` - Vulnerabilities struct { - Critical int `json:"critical"` - High int `json:"high"` - Medium int `json:"medium"` - Low int `json:"low"` - Unknown int `json:"unknown"` - } `json:"vulnerabilities"` -} - -// Image 镜像信息 -type Image struct { - Architecture string `json:"architecture"` - Features string `json:"features"` - Variant string `json:"variant,omitempty"` - Digest string `json:"digest"` - OS string `json:"os"` - OSFeatures string `json:"os_features"` - Size int64 `json:"size"` -} - -type cacheEntry struct { - data interface{} - timestamp time.Time -} - -const ( - maxCacheSize = 1000 // 最大缓存条目数 - cacheTTL = 30 * time.Minute -) - -type Cache struct { - data map[string]cacheEntry - mu sync.RWMutex - maxSize int -} - -var ( - searchCache = &Cache{ - data: make(map[string]cacheEntry), - maxSize: maxCacheSize, - } -) - -func (c *Cache) Get(key string) (interface{}, bool) { - c.mu.RLock() - entry, exists := c.data[key] - c.mu.RUnlock() - - if !exists { - return nil, false - } - - if time.Since(entry.timestamp) > cacheTTL { - c.mu.Lock() - delete(c.data, key) - c.mu.Unlock() - return nil, false - } - - return entry.data, true -} - -func (c *Cache) Set(key string, data interface{}) { - c.mu.Lock() - defer c.mu.Unlock() - - now := time.Now() - for k, v := range c.data { - if now.Sub(v.timestamp) > cacheTTL { - delete(c.data, k) - } - } - - if len(c.data) >= c.maxSize { - toDelete := len(c.data) / 4 - for k := range c.data { - if toDelete <= 0 { - break - } - delete(c.data, k) - toDelete-- - } - } - - c.data[key] = cacheEntry{ - data: data, - timestamp: now, - } -} - -func (c *Cache) Cleanup() { - c.mu.Lock() - defer c.mu.Unlock() - - now := time.Now() - for key, entry := range c.data { - if now.Sub(entry.timestamp) > cacheTTL { - delete(c.data, key) - } - } -} - -// 定期清理过期缓存 -func init() { - go func() { - ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - searchCache.Cleanup() - } - }() -} - -func filterSearchResults(results []Repository, query string) []Repository { - searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/")) - filtered := make([]Repository, 0) - - for _, repo := range results { - // 标准化仓库名称 - repoName := strings.ToLower(repo.Name) - repoDesc := strings.ToLower(repo.Description) - - // 计算相关性得分 - score := 0 - - // 完全匹配 - if repoName == searchTerm { - score += 100 - } - - // 前缀匹配 - if strings.HasPrefix(repoName, searchTerm) { - score += 50 - } - - // 包含匹配 - if strings.Contains(repoName, searchTerm) { - score += 30 - } - - // 描述匹配 - if strings.Contains(repoDesc, searchTerm) { - score += 10 - } - - // 官方镜像加分 - if repo.IsOfficial { - score += 20 - } - - // 分数达到阈值的结果才保留 - if score > 0 { - filtered = append(filtered, repo) - } - } - - // 按相关性排序 - sort.Slice(filtered, func(i, j int) bool { - // 优先考虑官方镜像 - if filtered[i].IsOfficial != filtered[j].IsOfficial { - return filtered[i].IsOfficial - } - // 其次考虑拉取次数 - return filtered[i].PullCount > filtered[j].PullCount - }) - - return filtered -} - -// searchDockerHub 搜索镜像 -func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { - cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) - - // 尝试从缓存获取 - if cached, ok := searchCache.Get(cacheKey); ok { - return cached.(*SearchResult), nil - } - - // 判断是否是用户/仓库格式的搜索 - isUserRepo := strings.Contains(query, "/") - var namespace, repoName string - - if isUserRepo { - parts := strings.Split(query, "/") - if len(parts) == 2 { - namespace = parts[0] - repoName = parts[1] - } - } - - // 构建搜索URL - baseURL := "https://registry.hub.docker.com/v2" - var fullURL string - var params url.Values - - if isUserRepo && namespace != "" { - // 如果是用户/仓库格式,使用repositories接口 - fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) - params = url.Values{ - "page": {fmt.Sprintf("%d", page)}, - "page_size": {fmt.Sprintf("%d", pageSize)}, - } - } else { - // 普通搜索 - fullURL = baseURL + "/search/repositories/" - params = url.Values{ - "query": {query}, - "page": {fmt.Sprintf("%d", page)}, - "page_size": {fmt.Sprintf("%d", pageSize)}, - } - } - - fullURL = fullURL + "?" + params.Encode() - - // 使用统一的搜索HTTP客户端 - resp, err := GetSearchHTTPClient().Get(fullURL) - if err != nil { - return nil, fmt.Errorf("请求Docker Hub API失败: %v", err) - } - defer func() { - if err := resp.Body.Close(); err != nil { - fmt.Printf("关闭搜索响应体失败: %v\n", err) - } - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %v", err) - } - - if resp.StatusCode != http.StatusOK { - switch resp.StatusCode { - case http.StatusTooManyRequests: - return nil, fmt.Errorf("请求过于频繁,请稍后重试") - case http.StatusNotFound: - if isUserRepo && namespace != "" { - // 如果用户仓库搜索失败,尝试普通搜索 - return searchDockerHub(ctx, repoName, page, pageSize) - } - return nil, fmt.Errorf("未找到相关镜像") - case http.StatusBadGateway, http.StatusServiceUnavailable: - return nil, fmt.Errorf("Docker Hub服务暂时不可用,请稍后重试") - default: - return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) - } - } - - // 解析响应 - var result *SearchResult - if isUserRepo && namespace != "" { - // 解析用户仓库列表响应 - var userRepos struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []Repository `json:"results"` - } - if err := json.Unmarshal(body, &userRepos); err != nil { - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - // 转换为SearchResult格式 - result = &SearchResult{ - Count: userRepos.Count, - Next: userRepos.Next, - Previous: userRepos.Previous, - Results: make([]Repository, 0), - } - - // 处理结果 - for _, repo := range userRepos.Results { - // 如果指定了仓库名,只保留匹配的结果 - if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { - // 确保设置正确的命名空间和名称 - repo.Namespace = namespace - if !strings.Contains(repo.Name, "/") { - repo.Name = fmt.Sprintf("%s/%s", namespace, repo.Name) - } - result.Results = append(result.Results, repo) - } - } - - // 如果没有找到结果,尝试普通搜索 - if len(result.Results) == 0 { - return searchDockerHub(ctx, repoName, page, pageSize) - } - - result.Count = len(result.Results) - } else { - // 解析普通搜索响应 - result = &SearchResult{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - // 处理搜索结果 - for i := range result.Results { - if result.Results[i].IsOfficial { - if !strings.Contains(result.Results[i].Name, "/") { - result.Results[i].Name = "library/" + result.Results[i].Name - } - result.Results[i].Namespace = "library" - } else { - parts := strings.Split(result.Results[i].Name, "/") - if len(parts) > 1 { - result.Results[i].Namespace = parts[0] - result.Results[i].Name = parts[1] - } else if result.Results[i].RepoOwner != "" { - result.Results[i].Namespace = result.Results[i].RepoOwner - result.Results[i].Name = fmt.Sprintf("%s/%s", result.Results[i].RepoOwner, result.Results[i].Name) - } - } - } - - // 如果是用户/仓库搜索,过滤结果 - if isUserRepo && namespace != "" { - filteredResults := make([]Repository, 0) - for _, repo := range result.Results { - if strings.EqualFold(repo.Namespace, namespace) { - filteredResults = append(filteredResults, repo) - } - } - result.Results = filteredResults - result.Count = len(filteredResults) - } - } - - // 缓存结果 - searchCache.Set(cacheKey, result) - return result, nil -} - -// 判断错误是否可重试 -func isRetryableError(err error) bool { - if err == nil { - return false - } - - // 网络错误、超时等可以重试 - if strings.Contains(err.Error(), "timeout") || - strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "too many requests") { - return true - } - - return false -} - -// getRepositoryTags 获取仓库标签信息 -func getRepositoryTags(ctx context.Context, namespace, name string) ([]TagInfo, error) { - if namespace == "" || name == "" { - return nil, fmt.Errorf("无效输入:命名空间和名称不能为空") - } - - cacheKey := fmt.Sprintf("tags:%s:%s", namespace, name) - if cached, ok := searchCache.Get(cacheKey); ok { - return cached.([]TagInfo), nil - } - - // 构建API URL - baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) - params := url.Values{} - params.Set("page_size", "100") - params.Set("ordering", "last_updated") - - fullURL := baseURL + "?" + params.Encode() - - // 使用统一的搜索HTTP客户端 - resp, err := GetSearchHTTPClient().Get(fullURL) - if err != nil { - return nil, fmt.Errorf("发送请求失败: %v", err) - } - defer func() { - if err := resp.Body.Close(); err != nil { - fmt.Printf("关闭搜索响应体失败: %v\n", err) - } - }() - - // 读取响应体 - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %v", err) - } - - // 检查响应状态码 - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) - } - - // 解析响应 - var result struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []TagInfo `json:"results"` - } - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - // 缓存结果 - searchCache.Set(cacheKey, result.Results) - return result.Results, nil -} - -// RegisterSearchRoute 注册搜索相关路由 -func RegisterSearchRoute(r *gin.Engine) { - // 搜索镜像 - r.GET("/search", func(c *gin.Context) { - query := c.Query("q") - if query == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "搜索关键词不能为空"}) - return - } - - page := 1 - pageSize := 25 - if p := c.Query("page"); p != "" { - fmt.Sscanf(p, "%d", &page) - } - if ps := c.Query("page_size"); ps != "" { - fmt.Sscanf(ps, "%d", &pageSize) - } - - result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, result) - }) - - // 获取标签信息 - r.GET("/tags/:namespace/:name", func(c *gin.Context) { - namespace := c.Param("namespace") - name := c.Param("name") - - if namespace == "" || name == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "命名空间和名称不能为空"}) - return - } - - tags, err := getRepositoryTags(c.Request.Context(), namespace, name) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, tags) - }) -} +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// SearchResult Docker Hub搜索结果 +type SearchResult struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []Repository `json:"results"` +} + +// Repository 仓库信息 +type Repository struct { + Name string `json:"repo_name"` + Description string `json:"short_description"` + IsOfficial bool `json:"is_official"` + IsAutomated bool `json:"is_automated"` + StarCount int `json:"star_count"` + PullCount int `json:"pull_count"` + RepoOwner string `json:"repo_owner"` + LastUpdated string `json:"last_updated"` + Status int `json:"status"` + Organization string `json:"affiliation"` + PullsLastWeek int `json:"pulls_last_week"` + Namespace string `json:"namespace"` +} + +// TagInfo 标签信息 +type TagInfo struct { + Name string `json:"name"` + FullSize int64 `json:"full_size"` + LastUpdated time.Time `json:"last_updated"` + LastPusher string `json:"last_pusher"` + Images []Image `json:"images"` + Vulnerabilities struct { + Critical int `json:"critical"` + High int `json:"high"` + Medium int `json:"medium"` + Low int `json:"low"` + Unknown int `json:"unknown"` + } `json:"vulnerabilities"` +} + +// Image 镜像信息 +type Image struct { + Architecture string `json:"architecture"` + Features string `json:"features"` + Variant string `json:"variant,omitempty"` + Digest string `json:"digest"` + OS string `json:"os"` + OSFeatures string `json:"os_features"` + Size int64 `json:"size"` +} + +type cacheEntry struct { + data interface{} + timestamp time.Time +} + +const ( + maxCacheSize = 1000 // 最大缓存条目数 + cacheTTL = 30 * time.Minute +) + +type Cache struct { + data map[string]cacheEntry + mu sync.RWMutex + maxSize int +} + +var ( + searchCache = &Cache{ + data: make(map[string]cacheEntry), + maxSize: maxCacheSize, + } +) + +func (c *Cache) Get(key string) (interface{}, bool) { + c.mu.RLock() + entry, exists := c.data[key] + c.mu.RUnlock() + + if !exists { + return nil, false + } + + if time.Since(entry.timestamp) > cacheTTL { + c.mu.Lock() + delete(c.data, key) + c.mu.Unlock() + return nil, false + } + + return entry.data, true +} + +func (c *Cache) Set(key string, data interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for k, v := range c.data { + if now.Sub(v.timestamp) > cacheTTL { + delete(c.data, k) + } + } + + if len(c.data) >= c.maxSize { + toDelete := len(c.data) / 4 + for k := range c.data { + if toDelete <= 0 { + break + } + delete(c.data, k) + toDelete-- + } + } + + c.data[key] = cacheEntry{ + data: data, + timestamp: now, + } +} + +func (c *Cache) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + now := time.Now() + for key, entry := range c.data { + if now.Sub(entry.timestamp) > cacheTTL { + delete(c.data, key) + } + } +} + +// 定期清理过期缓存 +func init() { + go func() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + searchCache.Cleanup() + } + }() +} + +func filterSearchResults(results []Repository, query string) []Repository { + searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/")) + filtered := make([]Repository, 0) + + for _, repo := range results { + // 标准化仓库名称 + repoName := strings.ToLower(repo.Name) + repoDesc := strings.ToLower(repo.Description) + + // 计算相关性得分 + score := 0 + + // 完全匹配 + if repoName == searchTerm { + score += 100 + } + + // 前缀匹配 + if strings.HasPrefix(repoName, searchTerm) { + score += 50 + } + + // 包含匹配 + if strings.Contains(repoName, searchTerm) { + score += 30 + } + + // 描述匹配 + if strings.Contains(repoDesc, searchTerm) { + score += 10 + } + + // 官方镜像加分 + if repo.IsOfficial { + score += 20 + } + + // 分数达到阈值的结果才保留 + if score > 0 { + filtered = append(filtered, repo) + } + } + + // 按相关性排序 + sort.Slice(filtered, func(i, j int) bool { + // 优先考虑官方镜像 + if filtered[i].IsOfficial != filtered[j].IsOfficial { + return filtered[i].IsOfficial + } + // 其次考虑拉取次数 + return filtered[i].PullCount > filtered[j].PullCount + }) + + return filtered +} + +// searchDockerHub 搜索镜像 +func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { + cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) + + // 尝试从缓存获取 + if cached, ok := searchCache.Get(cacheKey); ok { + return cached.(*SearchResult), nil + } + + // 判断是否是用户/仓库格式的搜索 + isUserRepo := strings.Contains(query, "/") + var namespace, repoName string + + if isUserRepo { + parts := strings.Split(query, "/") + if len(parts) == 2 { + namespace = parts[0] + repoName = parts[1] + } + } + + // 构建搜索URL + baseURL := "https://registry.hub.docker.com/v2" + var fullURL string + var params url.Values + + if isUserRepo && namespace != "" { + // 如果是用户/仓库格式,使用repositories接口 + fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) + params = url.Values{ + "page": {fmt.Sprintf("%d", page)}, + "page_size": {fmt.Sprintf("%d", pageSize)}, + } + } else { + // 普通搜索 + fullURL = baseURL + "/search/repositories/" + params = url.Values{ + "query": {query}, + "page": {fmt.Sprintf("%d", page)}, + "page_size": {fmt.Sprintf("%d", pageSize)}, + } + } + + fullURL = fullURL + "?" + params.Encode() + + // 使用统一的搜索HTTP客户端 + resp, err := GetSearchHTTPClient().Get(fullURL) + if err != nil { + return nil, fmt.Errorf("请求Docker Hub API失败: %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭搜索响应体失败: %v\n", err) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + if resp.StatusCode != http.StatusOK { + switch resp.StatusCode { + case http.StatusTooManyRequests: + return nil, fmt.Errorf("请求过于频繁,请稍后重试") + case http.StatusNotFound: + if isUserRepo && namespace != "" { + // 如果用户仓库搜索失败,尝试普通搜索 + return searchDockerHub(ctx, repoName, page, pageSize) + } + return nil, fmt.Errorf("未找到相关镜像") + case http.StatusBadGateway, http.StatusServiceUnavailable: + return nil, fmt.Errorf("Docker Hub服务暂时不可用,请稍后重试") + default: + return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) + } + } + + // 解析响应 + var result *SearchResult + if isUserRepo && namespace != "" { + // 解析用户仓库列表响应 + var userRepos struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []Repository `json:"results"` + } + if err := json.Unmarshal(body, &userRepos); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + // 转换为SearchResult格式 + result = &SearchResult{ + Count: userRepos.Count, + Next: userRepos.Next, + Previous: userRepos.Previous, + Results: make([]Repository, 0), + } + + // 处理结果 + for _, repo := range userRepos.Results { + // 如果指定了仓库名,只保留匹配的结果 + if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { + // 确保设置正确的命名空间和名称 + repo.Namespace = namespace + if !strings.Contains(repo.Name, "/") { + repo.Name = fmt.Sprintf("%s/%s", namespace, repo.Name) + } + result.Results = append(result.Results, repo) + } + } + + // 如果没有找到结果,尝试普通搜索 + if len(result.Results) == 0 { + return searchDockerHub(ctx, repoName, page, pageSize) + } + + result.Count = len(result.Results) + } else { + // 解析普通搜索响应 + result = &SearchResult{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + // 处理搜索结果 + for i := range result.Results { + if result.Results[i].IsOfficial { + if !strings.Contains(result.Results[i].Name, "/") { + result.Results[i].Name = "library/" + result.Results[i].Name + } + result.Results[i].Namespace = "library" + } else { + parts := strings.Split(result.Results[i].Name, "/") + if len(parts) > 1 { + result.Results[i].Namespace = parts[0] + result.Results[i].Name = parts[1] + } else if result.Results[i].RepoOwner != "" { + result.Results[i].Namespace = result.Results[i].RepoOwner + result.Results[i].Name = fmt.Sprintf("%s/%s", result.Results[i].RepoOwner, result.Results[i].Name) + } + } + } + + // 如果是用户/仓库搜索,过滤结果 + if isUserRepo && namespace != "" { + filteredResults := make([]Repository, 0) + for _, repo := range result.Results { + if strings.EqualFold(repo.Namespace, namespace) { + filteredResults = append(filteredResults, repo) + } + } + result.Results = filteredResults + result.Count = len(filteredResults) + } + } + + // 缓存结果 + searchCache.Set(cacheKey, result) + return result, nil +} + +// 判断错误是否可重试 +func isRetryableError(err error) bool { + if err == nil { + return false + } + + // 网络错误、超时等可以重试 + if strings.Contains(err.Error(), "timeout") || + strings.Contains(err.Error(), "connection refused") || + strings.Contains(err.Error(), "no such host") || + strings.Contains(err.Error(), "too many requests") { + return true + } + + return false +} + +// getRepositoryTags 获取仓库标签信息 +func getRepositoryTags(ctx context.Context, namespace, name string) ([]TagInfo, error) { + if namespace == "" || name == "" { + return nil, fmt.Errorf("无效输入:命名空间和名称不能为空") + } + + cacheKey := fmt.Sprintf("tags:%s:%s", namespace, name) + if cached, ok := searchCache.Get(cacheKey); ok { + return cached.([]TagInfo), nil + } + + // 构建API URL + baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) + params := url.Values{} + params.Set("page_size", "100") + params.Set("ordering", "last_updated") + + fullURL := baseURL + "?" + params.Encode() + + // 使用统一的搜索HTTP客户端 + resp, err := GetSearchHTTPClient().Get(fullURL) + if err != nil { + return nil, fmt.Errorf("发送请求失败: %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭搜索响应体失败: %v\n", err) + } + }() + + // 读取响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) + } + + // 解析响应 + var result struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []TagInfo `json:"results"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + // 缓存结果 + searchCache.Set(cacheKey, result.Results) + return result.Results, nil +} + +// RegisterSearchRoute 注册搜索相关路由 +func RegisterSearchRoute(r *gin.Engine) { + // 搜索镜像 + r.GET("/search", func(c *gin.Context) { + query := c.Query("q") + if query == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "搜索关键词不能为空"}) + return + } + + page := 1 + pageSize := 25 + if p := c.Query("page"); p != "" { + fmt.Sscanf(p, "%d", &page) + } + if ps := c.Query("page_size"); ps != "" { + fmt.Sscanf(ps, "%d", &pageSize) + } + + result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) + }) + + // 获取标签信息 + r.GET("/tags/:namespace/:name", func(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + + if namespace == "" || name == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "命名空间和名称不能为空"}) + return + } + + tags, err := getRepositoryTags(c.Request.Context(), namespace, name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, tags) + }) +} diff --git a/src/token_cache.go b/src/token_cache.go index 77a56c7..5b4b730 100644 --- a/src/token_cache.go +++ b/src/token_cache.go @@ -13,10 +13,10 @@ import ( // CachedItem 通用缓存项,支持Token和Manifest type CachedItem struct { - Data []byte // 缓存数据(token字符串或manifest字节) - ContentType string // 内容类型 + Data []byte // 缓存数据(token字符串或manifest字节) + ContentType string // 内容类型 Headers map[string]string // 额外的响应头 - ExpiresAt time.Time // 过期时间 + ExpiresAt time.Time // 过期时间 } // UniversalCache 通用缓存,支持Token和Manifest @@ -79,18 +79,18 @@ func getManifestTTL(reference string) time.Duration { defaultTTL = parsed } } - + if strings.HasPrefix(reference, "sha256:") { return 24 * time.Hour } - + // mutable tag的智能判断 - if reference == "latest" || reference == "main" || reference == "master" || - reference == "dev" || reference == "develop" { + if reference == "latest" || reference == "main" || reference == "master" || + reference == "dev" || reference == "develop" { // 热门可变标签: 短期缓存 return 10 * time.Minute } - + return defaultTTL } @@ -99,17 +99,17 @@ func extractTTLFromResponse(responseBody []byte) time.Duration { var tokenResp struct { ExpiresIn int `json:"expires_in"` } - + // 默认30分钟TTL,确保稳定性 defaultTTL := 30 * time.Minute - + if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 { safeTTL := time.Duration(tokenResp.ExpiresIn-300) * time.Second if safeTTL > 5*time.Minute { return safeTTL } } - + return defaultTTL } @@ -122,12 +122,12 @@ func writeCachedResponse(c *gin.Context, item *CachedItem) { if item.ContentType != "" { c.Header("Content-Type", item.ContentType) } - + // 设置额外的响应头 for key, value := range item.Headers { c.Header(key, value) } - + // 返回数据 c.Data(200, item.ContentType, item.Data) } @@ -148,21 +148,21 @@ func init() { go func() { ticker := time.NewTicker(20 * time.Minute) defer ticker.Stop() - + for range ticker.C { now := time.Now() expiredKeys := make([]string, 0) - + globalCache.cache.Range(func(key, value interface{}) bool { if cached := value.(*CachedItem); now.After(cached.ExpiresAt) { expiredKeys = append(expiredKeys, key.(string)) } return true }) - + for _, key := range expiredKeys { globalCache.cache.Delete(key) } } }() -} \ No newline at end of file +}