diff --git a/README.md b/README.md index 228d951..61d62e0 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ docker run -d \ ### 一键脚本安装 ```bash -curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install-service.sh | sudo bash +curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install.sh | sudo bash ``` 也可以直接下载二进制文件执行`./hubproxy`使用,无需配置文件即可启动,内置默认配置,支持所有功能。初始内存占用约18M,二进制文件大小约12M @@ -109,12 +109,14 @@ host = "0.0.0.0" port = 5000 # Github文件大小限制(字节),默认2GB fileSize = 2147483648 +# HTTP/2 多路复用 +enableH2C = false [rateLimit] # 每个IP每小时允许的请求数(注意Docker镜像会有多个层,会消耗多个次数) requestLimit = 500 # 限流周期(小时) -periodHours = 1.0 +periodHours = 3.0 [security] # IP白名单,支持单个IP或IP段 @@ -132,7 +134,7 @@ blackList = [ "192.168.100.0/24" ] -[proxy] +[access] # 代理服务白名单(支持GitHub仓库和Docker镜像,支持通配符) # 只允许访问白名单中的仓库/镜像,为空时不限制 whiteList = [] @@ -148,12 +150,6 @@ blackList = [ # 代理配置,支持有用户名/密码认证和无认证模式 # 无认证: socks5://127.0.0.1:1080 # 有认证: socks5://username:password@127.0.0.1:1080 -# HTTP 代理示例 -# http://username:password@127.0.0.1:7890 -# SOCKS5 代理示例 -# socks5://username:password@127.0.0.1:1080 -# SOCKS5H 代理示例 -# socks5h://username:password@127.0.0.1:1080 # 留空不使用代理 proxy = "" diff --git a/docker-compose.yml b/docker-compose.yml index f744ced..41aab58 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,8 @@ services: - hubproxy: - build: . - restart: always - ports: - - '5000:5000' - volumes: - - ./src/config.toml:/root/config.toml \ No newline at end of file + hubproxy: + build: . + restart: always + ports: + - '5000:5000' + volumes: + - ./src/config.toml:/root/config.toml \ No newline at end of file diff --git a/install-service.sh b/install.sh similarity index 100% rename from install-service.sh rename to install.sh diff --git a/src/config.toml b/src/config.toml index 7526b75..f7d4c08 100644 --- a/src/config.toml +++ b/src/config.toml @@ -4,12 +4,14 @@ host = "0.0.0.0" port = 5000 # Github文件大小限制(字节),默认2GB fileSize = 2147483648 +# HTTP/2 多路复用 +enableH2C = false [rateLimit] # 每个IP每小时允许的请求数(注意Docker镜像会有多个层,会消耗多个次数) requestLimit = 500 # 限流周期(小时) -periodHours = 1.0 +periodHours = 3.0 [security] # IP白名单,支持单个IP或IP段 @@ -43,12 +45,6 @@ blackList = [ # 代理配置,支持有用户名/密码认证和无认证模式 # 无认证: socks5://127.0.0.1:1080 # 有认证: socks5://username:password@127.0.0.1:1080 -# HTTP 代理示例 -# http://username:password@127.0.0.1:7890 -# SOCKS5 代理示例 -# socks5://username:password@127.0.0.1:1080 -# SOCKS5H 代理示例 -# socks5h://username:password@127.0.0.1:1080 # 留空不使用代理 proxy = "" diff --git a/src/config.go b/src/config/config.go similarity index 73% rename from src/config.go rename to src/config/config.go index 1dbae37..700b4c9 100644 --- a/src/config.go +++ b/src/config/config.go @@ -1,4 +1,4 @@ -package main +package config import ( "fmt" @@ -13,45 +13,46 @@ import ( // RegistryMapping Registry映射配置 type RegistryMapping struct { - Upstream string `toml:"upstream"` // 上游Registry地址 - AuthHost string `toml:"authHost"` // 认证服务器地址 - AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic - Enabled bool `toml:"enabled"` // 是否启用 + Upstream string `toml:"upstream"` + AuthHost string `toml:"authHost"` + AuthType string `toml:"authType"` + Enabled bool `toml:"enabled"` } // AppConfig 应用配置结构体 type AppConfig struct { Server struct { - Host string `toml:"host"` // 监听地址 - Port int `toml:"port"` // 监听端口 - FileSize int64 `toml:"fileSize"` // 文件大小限制(字节) + Host string `toml:"host"` + Port int `toml:"port"` + FileSize int64 `toml:"fileSize"` + EnableH2C bool `toml:"enableH2C"` } `toml:"server"` RateLimit struct { - RequestLimit int `toml:"requestLimit"` // 每小时请求限制 - PeriodHours float64 `toml:"periodHours"` // 限制周期(小时) + RequestLimit int `toml:"requestLimit"` + PeriodHours float64 `toml:"periodHours"` } `toml:"rateLimit"` Security struct { - WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表 - BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表 + WhiteList []string `toml:"whiteList"` + BlackList []string `toml:"blackList"` } `toml:"security"` Access struct { - WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别) - BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别) - Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h + WhiteList []string `toml:"whiteList"` + BlackList []string `toml:"blackList"` + Proxy string `toml:"proxy"` } `toml:"access"` Download struct { - MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制 + MaxImages int `toml:"maxImages"` } `toml:"download"` Registries map[string]RegistryMapping `toml:"registries"` TokenCache struct { - Enabled bool `toml:"enabled"` // 是否启用token缓存 - DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间 + Enabled bool `toml:"enabled"` + DefaultTTL string `toml:"defaultTTL"` } `toml:"tokenCache"` } @@ -65,24 +66,25 @@ var ( configCacheMutex sync.RWMutex ) -// todo:Refactoring is needed // DefaultConfig 返回默认配置 func DefaultConfig() *AppConfig { return &AppConfig{ Server: struct { - Host string `toml:"host"` - Port int `toml:"port"` - FileSize int64 `toml:"fileSize"` + Host string `toml:"host"` + Port int `toml:"port"` + FileSize int64 `toml:"fileSize"` + EnableH2C bool `toml:"enableH2C"` }{ - Host: "0.0.0.0", - Port: 5000, - FileSize: 2 * 1024 * 1024 * 1024, // 2GB + Host: "0.0.0.0", + Port: 5000, + FileSize: 2 * 1024 * 1024 * 1024, // 2GB + EnableH2C: false, // 默认关闭H2C }, RateLimit: struct { RequestLimit int `toml:"requestLimit"` PeriodHours float64 `toml:"periodHours"` }{ - RequestLimit: 20, + RequestLimit: 200, PeriodHours: 1.0, }, Security: struct { @@ -99,12 +101,12 @@ func DefaultConfig() *AppConfig { }{ WhiteList: []string{}, BlackList: []string{}, - Proxy: "", // 默认不使用代理 + Proxy: "", }, Download: struct { MaxImages int `toml:"maxImages"` }{ - MaxImages: 10, // 默认值:最多同时下载10个镜像 + MaxImages: 10, }, Registries: map[string]RegistryMapping{ "ghcr.io": { @@ -136,7 +138,7 @@ func DefaultConfig() *AppConfig { Enabled bool `toml:"enabled"` DefaultTTL string `toml:"defaultTTL"` }{ - Enabled: true, // docker认证的匿名Token缓存配置,用于提升性能 + Enabled: true, DefaultTTL: "20m", }, } @@ -152,11 +154,9 @@ func GetConfig() *AppConfig { } configCacheMutex.RUnlock() - // 缓存过期,重新生成配置 configCacheMutex.Lock() defer configCacheMutex.Unlock() - // 双重检查,防止重复生成 if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { return cachedConfig } @@ -170,7 +170,6 @@ func GetConfig() *AppConfig { return defaultCfg } - // 生成新的配置深拷贝 configCopy := *appConfig configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...) configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...) @@ -197,10 +196,8 @@ func setConfig(cfg *AppConfig) { // LoadConfig 加载配置文件 func LoadConfig() error { - // 首先使用默认配置 cfg := DefaultConfig() - // 尝试加载TOML配置文件 if data, err := os.ReadFile("config.toml"); err == nil { if err := toml.Unmarshal(data, cfg); err != nil { return fmt.Errorf("解析配置文件失败: %v", err) @@ -209,10 +206,7 @@ func LoadConfig() error { fmt.Println("未找到config.toml,使用默认配置") } - // 从环境变量覆盖配置 overrideFromEnv(cfg) - - // 设置配置 setConfig(cfg) return nil @@ -220,7 +214,6 @@ func LoadConfig() error { // overrideFromEnv 从环境变量覆盖配置 func overrideFromEnv(cfg *AppConfig) { - // 服务器配置 if val := os.Getenv("SERVER_HOST"); val != "" { cfg.Server.Host = val } @@ -229,13 +222,17 @@ func overrideFromEnv(cfg *AppConfig) { cfg.Server.Port = port } } + if val := os.Getenv("ENABLE_H2C"); val != "" { + if enable, err := strconv.ParseBool(val); err == nil { + cfg.Server.EnableH2C = enable + } + } if val := os.Getenv("MAX_FILE_SIZE"); val != "" { if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 { cfg.Server.FileSize = size } } - // 限流配置 if val := os.Getenv("RATE_LIMIT"); val != "" { if limit, err := strconv.Atoi(val); err == nil && limit > 0 { cfg.RateLimit.RequestLimit = limit @@ -247,7 +244,6 @@ func overrideFromEnv(cfg *AppConfig) { } } - // IP限制配置 if val := os.Getenv("IP_WHITELIST"); val != "" { cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...) } @@ -255,7 +251,6 @@ func overrideFromEnv(cfg *AppConfig) { cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...) } - // 下载限制配置 if val := os.Getenv("MAX_IMAGES"); val != "" { if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 { cfg.Download.MaxImages = maxImages diff --git a/src/go.mod b/src/go.mod index 7c2806d..6e8a2f8 100644 --- a/src/go.mod +++ b/src/go.mod @@ -6,6 +6,7 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/google/go-containerregistry v0.20.5 github.com/pelletier/go-toml/v2 v2.2.3 + golang.org/x/net v0.33.0 golang.org/x/time v0.11.0 ) @@ -43,7 +44,6 @@ require ( github.com/vbatts/tar-split v0.12.1 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.32.0 // indirect - golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/src/docker.go b/src/handlers/docker.go similarity index 77% rename from src/docker.go rename to src/handlers/docker.go index 2760fec..ffccc73 100644 --- a/src/docker.go +++ b/src/handlers/docker.go @@ -1,4 +1,4 @@ -package main +package handlers import ( "context" @@ -12,6 +12,8 @@ import ( "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/v1/remote" + "hubproxy/config" + "hubproxy/utils" ) // DockerProxy Docker代理配置 @@ -27,12 +29,10 @@ type RegistryDetector struct{} // detectRegistryDomain 检测Registry域名并返回域名和剩余路径 func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) { - cfg := GetConfig() + cfg := config.GetConfig() - // 检查路径是否以已知Registry域名开头 for domain := range cfg.Registries { if strings.HasPrefix(path, domain+"/") { - // 找到匹配的域名,返回域名和剩余路径 remainingPath := strings.TrimPrefix(path, domain+"/") return domain, remainingPath } @@ -43,7 +43,7 @@ func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) { // isRegistryEnabled 检查Registry是否启用 func (rd *RegistryDetector) isRegistryEnabled(domain string) bool { - cfg := GetConfig() + cfg := config.GetConfig() if mapping, exists := cfg.Registries[domain]; exists { return mapping.Enabled } @@ -51,28 +51,26 @@ func (rd *RegistryDetector) isRegistryEnabled(domain string) bool { } // getRegistryMapping 获取Registry映射配置 -func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) { - cfg := GetConfig() +func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) { + cfg := config.GetConfig() mapping, exists := cfg.Registries[domain] return mapping, exists && mapping.Enabled } var registryDetector = &RegistryDetector{} -// 初始化Docker代理 -func initDockerProxy() { - // 创建目标registry +// InitDockerProxy 初始化Docker代理 +func InitDockerProxy() { registry, err := name.NewRegistry("registry-1.docker.io") if err != nil { fmt.Printf("创建Docker registry失败: %v\n", err) return } - // 配置代理选项 options := []remote.Option{ remote.WithAuth(authn.Anonymous), remote.WithUserAgent("hubproxy/go-containerregistry"), - remote.WithTransport(GetGlobalHTTPClient().Transport), + remote.WithTransport(utils.GetGlobalHTTPClient().Transport), } dockerProxy = &DockerProxy{ @@ -85,13 +83,11 @@ func initDockerProxy() { func ProxyDockerRegistryGin(c *gin.Context) { path := c.Request.URL.Path - // 处理 /v2/ API版本检查 if path == "/v2/" { c.JSON(http.StatusOK, gin.H{}) return } - // 处理不同的API端点 if strings.HasPrefix(path, "/v2/") { handleRegistryRequest(c, path) } else { @@ -101,16 +97,13 @@ func ProxyDockerRegistryGin(c *gin.Context) { // handleRegistryRequest 处理Registry请求 func handleRegistryRequest(c *gin.Context, path string) { - // 移除 /v2/ 前缀 pathWithoutV2 := strings.TrimPrefix(path, "/v2/") if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" { if registryDetector.isRegistryEnabled(registryDomain) { - // 设置目标Registry信息到Context c.Set("target_registry_domain", registryDomain) c.Set("target_path", remainingPath) - // 处理多Registry请求 handleMultiRegistryRequest(c, registryDomain, remainingPath) return } @@ -122,19 +115,16 @@ func handleRegistryRequest(c *gin.Context, path string) { return } - // 自动处理官方镜像的library命名空间 if !strings.Contains(imageName, "/") { imageName = "library/" + imageName } - // Docker镜像访问控制检查 - if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed { + if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed { fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason) c.String(http.StatusForbidden, "镜像访问被限制") return } - // 构建完整的镜像引用 imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName) switch apiType { @@ -151,7 +141,6 @@ func handleRegistryRequest(c *gin.Context, path string) { // parseRegistryPath 解析Registry路径 func parseRegistryPath(path string) (imageName, apiType, reference string) { - // 查找API端点关键字 if idx := strings.Index(path, "/manifests/"); idx != -1 { imageName = path[:idx] apiType = "manifests" @@ -178,13 +167,11 @@ func parseRegistryPath(path string) (imageName, apiType, reference string) { // handleManifestRequest 处理manifest请求 func handleManifestRequest(c *gin.Context, imageRef, reference string) { - // Manifest缓存逻辑(仅对GET请求缓存) - if isCacheEnabled() && c.Request.Method == http.MethodGet { - cacheKey := buildManifestCacheKey(imageRef, reference) + if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet { + cacheKey := utils.BuildManifestCacheKey(imageRef, reference) - // 优先从缓存获取 - if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { - writeCachedResponse(c, cachedItem) + if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil { + utils.WriteCachedResponse(c, cachedItem) return } } @@ -192,12 +179,9 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) { var ref name.Reference var err error - // 判断reference是digest还是tag if strings.HasPrefix(reference, "sha256:") { - // 是digest ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) } else { - // 是tag ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) } @@ -207,9 +191,7 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) { return } - // 根据请求方法选择操作 if c.Request.Method == http.MethodHead { - // HEAD请求,使用remote.Head desc, err := remote.Head(ref, dockerProxy.options...) if err != nil { fmt.Printf("HEAD请求失败: %v\n", err) @@ -217,13 +199,11 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) { return } - // 设置响应头 c.Header("Content-Type", string(desc.MediaType)) c.Header("Docker-Content-Digest", desc.Digest.String()) c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) c.Status(http.StatusOK) } else { - // GET请求,使用remote.Get desc, err := remote.Get(ref, dockerProxy.options...) if err != nil { fmt.Printf("GET请求失败: %v\n", err) @@ -231,33 +211,28 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) { return } - // 设置响应头 headers := map[string]string{ "Docker-Content-Digest": desc.Digest.String(), "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), } - // 缓存响应 - if isCacheEnabled() { - cacheKey := buildManifestCacheKey(imageRef, reference) - ttl := getManifestTTL(reference) - globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) + if utils.IsCacheEnabled() { + cacheKey := utils.BuildManifestCacheKey(imageRef, reference) + ttl := utils.GetManifestTTL(reference) + utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) } - // 设置响应头 c.Header("Content-Type", string(desc.MediaType)) for key, value := range headers { c.Header(key, value) } - // 返回manifest内容 c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) } } // handleBlobRequest 处理blob请求 func handleBlobRequest(c *gin.Context, imageRef, digest string) { - // 构建digest引用 digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) if err != nil { fmt.Printf("解析digest引用失败: %v\n", err) @@ -265,7 +240,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) { return } - // 使用remote.Layer获取layer layer, err := remote.Layer(digestRef, dockerProxy.options...) if err != nil { fmt.Printf("获取layer失败: %v\n", err) @@ -273,7 +247,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) { return } - // 获取layer信息 size, err := layer.Size() if err != nil { fmt.Printf("获取layer大小失败: %v\n", err) @@ -281,7 +254,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) { return } - // 获取layer内容 reader, err := layer.Compressed() if err != nil { fmt.Printf("获取layer内容失败: %v\n", err) @@ -290,19 +262,16 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) { } defer reader.Close() - // 设置响应头 c.Header("Content-Type", "application/octet-stream") c.Header("Content-Length", fmt.Sprintf("%d", size)) c.Header("Docker-Content-Digest", digest) - // 流式传输blob内容 c.Status(http.StatusOK) io.Copy(c.Writer, reader) } // handleTagsRequest 处理tags列表请求 func handleTagsRequest(c *gin.Context, imageRef string) { - // 解析repository repo, err := name.NewRepository(imageRef) if err != nil { fmt.Printf("解析repository失败: %v\n", err) @@ -310,7 +279,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) { return } - // 使用remote.List获取tags tags, err := remote.List(repo, dockerProxy.options...) if err != nil { fmt.Printf("获取tags失败: %v\n", err) @@ -318,7 +286,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) { return } - // 构建响应 response := map[string]interface{}{ "name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"), "tags": tags, @@ -327,10 +294,9 @@ func handleTagsRequest(c *gin.Context, imageRef string) { c.JSON(http.StatusOK, response) } -// ProxyDockerAuthGin Docker认证代理(带缓存优化) +// ProxyDockerAuthGin Docker认证代理 func ProxyDockerAuthGin(c *gin.Context) { - // 检查是否启用token缓存 - if isTokenCacheEnabled() { + if utils.IsTokenCacheEnabled() { proxyDockerAuthWithCache(c) } else { proxyDockerAuthOriginal(c) @@ -339,32 +305,26 @@ func ProxyDockerAuthGin(c *gin.Context) { // proxyDockerAuthWithCache 带缓存的认证代理 func proxyDockerAuthWithCache(c *gin.Context) { - // 1. 构建缓存key(基于完整的查询参数) - cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery) + cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery) - // 2. 尝试从缓存获取token - if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" { - writeTokenResponse(c, cachedToken) + if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" { + utils.WriteTokenResponse(c, cachedToken) return } - // 3. 缓存未命中,创建响应记录器 recorder := &ResponseRecorder{ ResponseWriter: c.Writer, statusCode: 200, } c.Writer = recorder - // 4. 调用原有认证逻辑 proxyDockerAuthOriginal(c) - // 5. 如果认证成功,缓存响应 if recorder.statusCode == 200 && len(recorder.body) > 0 { - ttl := extractTTLFromResponse(recorder.body) - globalCache.SetToken(cacheKey, string(recorder.body), ttl) + ttl := utils.ExtractTTLFromResponse(recorder.body) + utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl) } - // 6. 写入实际响应 c.Writer = recorder.ResponseWriter c.Data(recorder.statusCode, "application/json", recorder.body) } @@ -389,14 +349,11 @@ func proxyDockerAuthOriginal(c *gin.Context) { var authURL string if targetDomain, exists := c.Get("target_registry_domain"); exists { if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found { - // 使用Registry特定的认证服务器 authURL = "https://" + mapping.AuthHost + c.Request.URL.Path } else { - // fallback到默认Docker认证 authURL = "https://auth.docker.io" + c.Request.URL.Path } } else { - // 构建默认Docker认证URL authURL = "https://auth.docker.io" + c.Request.URL.Path } @@ -404,13 +361,11 @@ func proxyDockerAuthOriginal(c *gin.Context) { authURL += "?" + c.Request.URL.RawQuery } - // 创建HTTP客户端,复用全局传输配置(包含代理设置) client := &http.Client{ Timeout: 30 * time.Second, - Transport: GetGlobalHTTPClient().Transport, + Transport: utils.GetGlobalHTTPClient().Transport, } - // 创建请求 req, err := http.NewRequestWithContext( context.Background(), c.Request.Method, @@ -422,14 +377,12 @@ func proxyDockerAuthOriginal(c *gin.Context) { return } - // 复制请求头 for key, values := range c.Request.Header { for _, value := range values { req.Header.Add(key, value) } } - // 执行请求 resp, err := client.Do(req) if err != nil { c.String(http.StatusBadGateway, "Auth request failed") @@ -437,37 +390,30 @@ func proxyDockerAuthOriginal(c *gin.Context) { } defer resp.Body.Close() - // 获取当前代理的Host地址 proxyHost := c.Request.Host if proxyHost == "" { - // 使用配置中的服务器地址和端口 - cfg := GetConfig() + cfg := config.GetConfig() proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) if cfg.Server.Host == "0.0.0.0" { proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port) } } - // 复制响应头并重写认证URL for key, values := range resp.Header { for _, value := range values { - // 重写WWW-Authenticate头中的realm URL if key == "Www-Authenticate" { - // 支持多Registry的URL重写 value = rewriteAuthHeader(value, proxyHost) } c.Header(key, value) } } - // 返回响应 c.Status(resp.StatusCode) io.Copy(c.Writer, resp.Body) } // rewriteAuthHeader 重写认证头 func rewriteAuthHeader(authHeader, proxyHost string) string { - // 重写各种Registry的认证URL authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost) authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost) authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost) @@ -478,32 +424,27 @@ func rewriteAuthHeader(authHeader, proxyHost string) string { // handleMultiRegistryRequest 处理多Registry请求 func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) { - // 获取Registry映射配置 mapping, exists := registryDetector.getRegistryMapping(registryDomain) if !exists { c.String(http.StatusBadRequest, "Registry not configured") return } - // 解析剩余路径 imageName, apiType, reference := parseRegistryPath(remainingPath) if imageName == "" || apiType == "" { c.String(http.StatusBadRequest, "Invalid path format") return } - // 访问控制检查(使用完整的镜像路径) fullImageName := registryDomain + "/" + imageName - if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed { + if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed { fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason) c.String(http.StatusForbidden, "镜像访问被限制") return } - // 构建上游Registry引用 upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName) - // 根据API类型处理请求 switch apiType { case "manifests": handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping) @@ -517,14 +458,12 @@ func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath st } // handleUpstreamManifestRequest 处理上游Registry的manifest请求 -func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) { - // Manifest缓存逻辑(仅对GET请求缓存) - if isCacheEnabled() && c.Request.Method == http.MethodGet { - cacheKey := buildManifestCacheKey(imageRef, reference) +func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) { + if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet { + cacheKey := utils.BuildManifestCacheKey(imageRef, reference) - // 优先从缓存获取 - if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { - writeCachedResponse(c, cachedItem) + if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil { + utils.WriteCachedResponse(c, cachedItem) return } } @@ -532,7 +471,6 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m var ref name.Reference var err error - // 判断reference是digest还是tag if strings.HasPrefix(reference, "sha256:") { ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) } else { @@ -545,10 +483,8 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m return } - // 创建针对上游Registry的选项 options := createUpstreamOptions(mapping) - // 根据请求方法选择操作 if c.Request.Method == http.MethodHead { desc, err := remote.Head(ref, options...) if err != nil { @@ -569,20 +505,17 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m return } - // 设置响应头 headers := map[string]string{ "Docker-Content-Digest": desc.Digest.String(), "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), } - // 缓存响应 - if isCacheEnabled() { - cacheKey := buildManifestCacheKey(imageRef, reference) - ttl := getManifestTTL(reference) - globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) + if utils.IsCacheEnabled() { + cacheKey := utils.BuildManifestCacheKey(imageRef, reference) + ttl := utils.GetManifestTTL(reference) + utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl) } - // 设置响应头 c.Header("Content-Type", string(desc.MediaType)) for key, value := range headers { c.Header(key, value) @@ -593,7 +526,7 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m } // handleUpstreamBlobRequest 处理上游Registry的blob请求 -func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) { +func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) { digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) if err != nil { fmt.Printf("解析digest引用失败: %v\n", err) @@ -633,7 +566,7 @@ func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping } // handleUpstreamTagsRequest 处理上游Registry的tags请求 -func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) { +func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) { repo, err := name.NewRepository(imageRef) if err != nil { fmt.Printf("解析repository失败: %v\n", err) @@ -658,14 +591,13 @@ func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping Registry } // createUpstreamOptions 创建上游Registry选项 -func createUpstreamOptions(mapping RegistryMapping) []remote.Option { +func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option { options := []remote.Option{ remote.WithAuth(authn.Anonymous), remote.WithUserAgent("hubproxy/go-containerregistry"), - remote.WithTransport(GetGlobalHTTPClient().Transport), + remote.WithTransport(utils.GetGlobalHTTPClient().Transport), } - // 根据Registry类型添加特定的认证选项(方便后续扩展) switch mapping.AuthType { case "github": case "google": diff --git a/src/handlers/github.go b/src/handlers/github.go new file mode 100644 index 0000000..e29f174 --- /dev/null +++ b/src/handlers/github.go @@ -0,0 +1,213 @@ +package handlers + +import ( + "fmt" + "io" + "net/http" + "regexp" + "strconv" + "strings" + + "github.com/gin-gonic/gin" + "hubproxy/config" + "hubproxy/utils" +) + +var ( + // GitHub URL匹配正则表达式 + githubExps = []*regexp.Regexp{ + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`), + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`), + regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`), + regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`), + regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), + regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), + regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`), + regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`), + regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`), + regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`), + } +) + +// GitHubProxyHandler GitHub代理处理器 +func GitHubProxyHandler(c *gin.Context) { + rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") + + for strings.HasPrefix(rawPath, "/") { + rawPath = strings.TrimPrefix(rawPath, "/") + } + + // 自动补全协议头 + if !strings.HasPrefix(rawPath, "https://") { + if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") { + rawPath = strings.Replace(rawPath, "http:/", "", 1) + rawPath = strings.Replace(rawPath, "https:/", "", 1) + } else if strings.HasPrefix(rawPath, "http://") { + rawPath = strings.TrimPrefix(rawPath, "http://") + } + rawPath = "https://" + rawPath + } + + matches := CheckGitHubURL(rawPath) + if matches != nil { + if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed { + var repoPath string + if len(matches) >= 2 { + username := matches[0] + repoName := strings.TrimSuffix(matches[1], ".git") + repoPath = username + "/" + repoName + } + fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason) + c.String(http.StatusForbidden, reason) + return + } + } else { + c.String(http.StatusForbidden, "无效输入") + return + } + + // 将blob链接转换为raw链接 + if githubExps[1].MatchString(rawPath) { + rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) + } + + ProxyGitHubRequest(c, rawPath) +} + +// CheckGitHubURL 检查URL是否匹配GitHub模式 +func CheckGitHubURL(u string) []string { + for _, exp := range githubExps { + if matches := exp.FindStringSubmatch(u); matches != nil { + return matches[1:] + } + } + return nil +} + +// ProxyGitHubRequest 代理GitHub请求 +func ProxyGitHubRequest(c *gin.Context, u string) { + proxyGitHubWithRedirect(c, u, 0) +} + +// proxyGitHubWithRedirect 带重定向的GitHub代理请求 +func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) { + const maxRedirects = 20 + if redirectCount > maxRedirects { + c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向") + return + } + + req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) + return + } + + // 复制请求头 + for key, values := range c.Request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + req.Header.Del("Host") + + resp, err := utils.GetGlobalHTTPClient().Do(req) + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("关闭响应体失败: %v\n", err) + } + }() + + // 检查文件大小限制 + cfg := config.GetConfig() + if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { + if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize { + c.String(http.StatusRequestEntityTooLarge, + fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024))) + return + } + } + + // 清理安全相关的头 + resp.Header.Del("Content-Security-Policy") + resp.Header.Del("Referrer-Policy") + resp.Header.Del("Strict-Transport-Security") + + // 获取真实域名 + realHost := c.Request.Header.Get("X-Forwarded-Host") + if realHost == "" { + realHost = c.Request.Host + } + if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") { + realHost = "https://" + realHost + } + + // 处理.sh文件的智能处理 + if strings.HasSuffix(strings.ToLower(u), ".sh") { + isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip" + + processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost) + if err != nil { + fmt.Printf("智能处理失败,回退到直接代理: %v\n", err) + processedBody = resp.Body + processedSize = 0 + } + + // 智能设置响应头 + if processedSize > 0 { + resp.Header.Del("Content-Length") + resp.Header.Del("Content-Encoding") + resp.Header.Set("Transfer-Encoding", "chunked") + } + + // 复制其他响应头 + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + // 处理重定向 + if location := resp.Header.Get("Location"); location != "" { + if CheckGitHubURL(location) != nil { + c.Header("Location", "/"+location) + } else { + proxyGitHubWithRedirect(c, location, redirectCount+1) + return + } + } + + c.Status(resp.StatusCode) + + // 输出处理后的内容 + if _, err := io.Copy(c.Writer, processedBody); err != nil { + return + } + } else { + // 复制响应头 + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + // 处理重定向 + if location := resp.Header.Get("Location"); location != "" { + if CheckGitHubURL(location) != nil { + c.Header("Location", "/"+location) + } else { + proxyGitHubWithRedirect(c, location, redirectCount+1) + return + } + } + + c.Status(resp.StatusCode) + + // 直接流式转发 + io.Copy(c.Writer, resp.Body) + } +} diff --git a/src/imagetar.go b/src/handlers/imagetar.go similarity index 92% rename from src/imagetar.go rename to src/handlers/imagetar.go index 54a18ea..106d630 100644 --- a/src/imagetar.go +++ b/src/handlers/imagetar.go @@ -1,4 +1,4 @@ -package main +package handlers import ( "archive/tar" @@ -23,6 +23,8 @@ import ( "github.com/google/go-containerregistry/pkg/v1/partial" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/types" + "hubproxy/config" + "hubproxy/utils" ) // DebounceEntry 防抖条目 @@ -58,17 +60,15 @@ func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool { if entry, exists := d.entries[key]; exists { if now.Sub(entry.LastRequest) < d.window { - return false // 在防抖窗口内,拒绝请求 + return false } } - // 更新或创建条目 d.entries[key] = &DebounceEntry{ LastRequest: now, UserID: userID, } - // 清理过期条目(每5分钟清理一次) if time.Since(d.lastCleanup) > 5*time.Minute { d.cleanup(now) d.lastCleanup = now @@ -88,50 +88,41 @@ func (d *DownloadDebouncer) cleanup(now time.Time) { // generateContentFingerprint 生成内容指纹 func generateContentFingerprint(images []string, platform string) string { - // 对镜像列表排序确保顺序无关 sortedImages := make([]string, len(images)) copy(sortedImages, images) sort.Strings(sortedImages) - // 组合内容:镜像列表 + 平台信息 content := strings.Join(sortedImages, "|") + ":" + platform - // 生成MD5哈希 hash := md5.Sum([]byte(content)) return hex.EncodeToString(hash[:]) } // getUserID 获取用户标识 func getUserID(c *gin.Context) string { - // 优先使用会话Cookie if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" { return "session:" + sessionID } - // 备用方案:IP + User-Agent组合 ip := c.ClientIP() userAgent := c.GetHeader("User-Agent") if userAgent == "" { userAgent = "unknown" } - // 生成简短标识 combined := ip + ":" + userAgent hash := md5.Sum([]byte(combined)) - return "ip:" + hex.EncodeToString(hash[:8]) // 只取前8字节 + return "ip:" + hex.EncodeToString(hash[:8]) } -// 全局防抖器实例 var ( singleImageDebouncer *DownloadDebouncer batchImageDebouncer *DownloadDebouncer ) -// initDebouncer 初始化防抖器 -func initDebouncer() { - // 单个镜像:5秒防抖窗口 +// InitDebouncer 初始化防抖器 +func InitDebouncer() { singleImageDebouncer = NewDownloadDebouncer(5 * time.Second) - // 批量镜像:60秒防抖窗口 batchImageDebouncer = NewDownloadDebouncer(60 * time.Second) } @@ -147,15 +138,15 @@ type ImageStreamerConfig struct { } // NewImageStreamer 创建镜像下载器 -func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer { - if config == nil { - config = &ImageStreamerConfig{} +func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer { + if cfg == nil { + cfg = &ImageStreamerConfig{} } - concurrency := config.Concurrency + concurrency := cfg.Concurrency if concurrency <= 0 { - cfg := GetConfig() - concurrency = cfg.Download.MaxImages + appCfg := config.GetConfig() + concurrency = appCfg.Download.MaxImages if concurrency <= 0 { concurrency = 10 } @@ -163,7 +154,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer { remoteOptions := []remote.Option{ remote.WithAuth(authn.Anonymous), - remote.WithTransport(GetGlobalHTTPClient().Transport), + remote.WithTransport(utils.GetGlobalHTTPClient().Transport), } return &ImageStreamer{ @@ -176,7 +167,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer { type StreamOptions struct { Platform string Compression bool - UseCompressedLayers bool // 是否保存原始压缩层,默认开启 + UseCompressedLayers bool } // StreamImageToWriter 流式下载镜像到Writer @@ -215,7 +206,6 @@ func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote // getImageDescriptorWithPlatform 获取指定平台的镜像描述符 func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) { - // 直接从网络获取完整的descriptor,确保对象完整性 return remote.Get(ref, options...) } @@ -343,7 +333,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr var layerSize int64 var layerReader io.ReadCloser - // 根据配置选择使用压缩层或未压缩层 if options != nil && options.UseCompressedLayers { layerSize, err = layer.Size() if err != nil { @@ -385,7 +374,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr log.Printf("已处理层 %d/%d", i+1, len(layers)) } - // 构建单个镜像的manifest信息 singleManifest := map[string]interface{}{ "Config": configDigest.String() + ".json", "RepoTags": []string{imageRef}, @@ -398,7 +386,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr }(), } - // 构建repositories信息 repositories := make(map[string]map[string]string) parts := strings.Split(imageRef, ":") if len(parts) == 2 { @@ -407,14 +394,12 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr repositories[repoName] = map[string]string{tag: configDigest.String()} } - // 如果是批量下载,返回信息而不写入文件 if manifestOut != nil && repositoriesOut != nil { *manifestOut = singleManifest *repositoriesOut = repositories return nil } - // 单镜像下载,直接写入manifest.json manifest := []map[string]interface{}{singleManifest} manifestData, err := json.Marshal(manifest) @@ -436,7 +421,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr return err } - // 写入repositories文件 repositoriesData, err := json.Marshal(repositories) if err != nil { return err @@ -456,7 +440,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr return err } -// processImageForBatch 处理镜像的公共逻辑,用于批量下载 +// processImageForBatch 处理镜像的公共逻辑 func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) { layers, err := img.Layers() if err != nil { @@ -498,7 +482,6 @@ func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWrite switch desc.MediaType { case types.OCIImageIndex, types.DockerManifestList: - // 处理多架构镜像 img, err = is.selectPlatformImage(desc, options) if err != nil { return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err) @@ -530,7 +513,6 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S return nil, fmt.Errorf("获取索引清单失败: %w", err) } - // 选择合适的平台 var selectedDesc *v1.Descriptor for _, m := range manifest.Manifests { if m.Platform == nil { @@ -578,8 +560,8 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S var globalImageStreamer *ImageStreamer -// initImageStreamer 初始化镜像下载器 -func initImageStreamer() { +// InitImageStreamer 初始化镜像下载器 +func InitImageStreamer() { globalImageStreamer = NewImageStreamer(nil) } @@ -591,8 +573,8 @@ func formatPlatformText(platform string) string { return platform } -// initImageTarRoutes 初始化镜像下载路由 -func initImageTarRoutes(router *gin.Engine) { +// InitImageTarRoutes 初始化镜像下载路由 +func InitImageTarRoutes(router *gin.Engine) { imageAPI := router.Group("/api/image") { imageAPI.GET("/download/:image", handleDirectImageDownload) @@ -625,7 +607,6 @@ func handleDirectImageDownload(c *gin.Context) { return } - // 防抖检查 userID := getUserID(c) contentKey := generateContentFingerprint([]string{imageRef}, platform) @@ -677,7 +658,7 @@ func handleSimpleBatchDownload(c *gin.Context) { } } - cfg := GetConfig() + cfg := config.GetConfig() if len(req.Images) > cfg.Download.MaxImages { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages), @@ -685,7 +666,6 @@ func handleSimpleBatchDownload(c *gin.Context) { return } - // 批量下载防抖检查 userID := getUserID(c) contentKey := generateContentFingerprint(req.Images, req.Platform) @@ -697,7 +677,7 @@ func handleSimpleBatchDownload(c *gin.Context) { return } - useCompressed := true // 默认启用原始压缩层 + useCompressed := true if req.UseCompressedLayers != nil { useCompressed = *req.UseCompressedLayers } @@ -801,7 +781,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s var allManifests []map[string]interface{} var allRepositories = make(map[string]map[string]string) - // 流式处理每个镜像 for i, imageRef := range imageRefs { select { case <-ctx.Done(): @@ -811,7 +790,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef) - // 防止单个镜像处理时间过长 timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options) cancel() @@ -825,10 +803,8 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s return fmt.Errorf("镜像 %s manifest数据为空", imageRef) } - // 收集manifest信息 allManifests = append(allManifests, manifest) - // 合并repositories信息 for repo, tags := range repositories { if allRepositories[repo] == nil { allRepositories[repo] = make(map[string]string) @@ -839,7 +815,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s } } - // 写入合并的manifest.json manifestData, err := json.Marshal(allManifests) if err != nil { return fmt.Errorf("序列化manifest失败: %w", err) @@ -859,7 +834,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s return fmt.Errorf("写入manifest数据失败: %w", err) } - // 写入合并的repositories文件 repositoriesData, err := json.Marshal(allRepositories) if err != nil { return fmt.Errorf("序列化repositories失败: %w", err) diff --git a/src/search.go b/src/handlers/search.go similarity index 73% rename from src/search.go rename to src/handlers/search.go index ebf2c18..7c2ec2a 100644 --- a/src/search.go +++ b/src/handlers/search.go @@ -1,614 +1,559 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "sort" - "strings" - "sync" - "time" - - "github.com/gin-gonic/gin" -) - -// SearchResult Docker Hub搜索结果 -type SearchResult struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []Repository `json:"results"` -} - -// Repository 仓库信息 -type Repository struct { - Name string `json:"repo_name"` - Description string `json:"short_description"` - IsOfficial bool `json:"is_official"` - IsAutomated bool `json:"is_automated"` - StarCount int `json:"star_count"` - PullCount int `json:"pull_count"` - RepoOwner string `json:"repo_owner"` - LastUpdated string `json:"last_updated"` - Status int `json:"status"` - Organization string `json:"affiliation"` - PullsLastWeek int `json:"pulls_last_week"` - Namespace string `json:"namespace"` -} - -// TagInfo 标签信息 -type TagInfo struct { - Name string `json:"name"` - FullSize int64 `json:"full_size"` - LastUpdated time.Time `json:"last_updated"` - LastPusher string `json:"last_pusher"` - Images []Image `json:"images"` - Vulnerabilities struct { - Critical int `json:"critical"` - High int `json:"high"` - Medium int `json:"medium"` - Low int `json:"low"` - Unknown int `json:"unknown"` - } `json:"vulnerabilities"` -} - -// Image 镜像信息 -type Image struct { - Architecture string `json:"architecture"` - Features string `json:"features"` - Variant string `json:"variant,omitempty"` - Digest string `json:"digest"` - OS string `json:"os"` - OSFeatures string `json:"os_features"` - Size int64 `json:"size"` -} - -// TagPageResult 分页标签结果 -type TagPageResult struct { - Tags []TagInfo `json:"tags"` - HasMore bool `json:"has_more"` -} - -type cacheEntry struct { - data interface{} - expiresAt time.Time // 存储过期时间 -} - -const ( - maxCacheSize = 1000 // 最大缓存条目数 - maxPaginationCache = 200 // 分页缓存最大条目数 - cacheTTL = 30 * time.Minute -) - -type Cache struct { - data map[string]cacheEntry - mu sync.RWMutex - maxSize int -} - -var ( - searchCache = &Cache{ - data: make(map[string]cacheEntry), - maxSize: maxCacheSize, - } -) - -func (c *Cache) Get(key string) (interface{}, bool) { - c.mu.RLock() - entry, exists := c.data[key] - c.mu.RUnlock() - - if !exists { - return nil, false - } - - // 比较过期时间 - if time.Now().After(entry.expiresAt) { - c.mu.Lock() - delete(c.data, key) - c.mu.Unlock() - return nil, false - } - - return entry.data, true -} - -func (c *Cache) Set(key string, data interface{}) { - c.SetWithTTL(key, data, cacheTTL) -} - -func (c *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) { - c.mu.Lock() - defer c.mu.Unlock() - - // 惰性清理:仅在容量超限时清理过期项 - if len(c.data) >= c.maxSize { - c.cleanupExpiredLocked() - } - - // 计算过期时间 - c.data[key] = cacheEntry{ - data: data, - expiresAt: time.Now().Add(ttl), - } -} - -func (c *Cache) Cleanup() { - c.mu.Lock() - defer c.mu.Unlock() - c.cleanupExpiredLocked() -} - -// cleanupExpiredLocked 清理过期缓存(需要已持有锁) -func (c *Cache) cleanupExpiredLocked() { - now := time.Now() - for key, entry := range c.data { - if now.After(entry.expiresAt) { - delete(c.data, key) - } - } -} - -// 定期清理过期缓存 -func init() { - go func() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() // 确保ticker资源释放 - - for range ticker.C { - searchCache.Cleanup() - } - }() -} - -func filterSearchResults(results []Repository, query string) []Repository { - searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/")) - filtered := make([]Repository, 0) - - for _, repo := range results { - // 标准化仓库名称 - repoName := strings.ToLower(repo.Name) - repoDesc := strings.ToLower(repo.Description) - - // 计算相关性得分 - score := 0 - - // 完全匹配 - if repoName == searchTerm { - score += 100 - } - - // 前缀匹配 - if strings.HasPrefix(repoName, searchTerm) { - score += 50 - } - - // 包含匹配 - if strings.Contains(repoName, searchTerm) { - score += 30 - } - - // 描述匹配 - if strings.Contains(repoDesc, searchTerm) { - score += 10 - } - - // 官方镜像加分 - if repo.IsOfficial { - score += 20 - } - - // 分数达到阈值的结果才保留 - if score > 0 { - filtered = append(filtered, repo) - } - } - - // 按相关性排序 - sort.Slice(filtered, func(i, j int) bool { - // 优先考虑官方镜像 - if filtered[i].IsOfficial != filtered[j].IsOfficial { - return filtered[i].IsOfficial - } - // 其次考虑拉取次数 - return filtered[i].PullCount > filtered[j].PullCount - }) - - return filtered -} - -// normalizeRepository 统一规范化仓库信息(消除重复逻辑) -func normalizeRepository(repo *Repository) { - if repo.IsOfficial { - repo.Namespace = "library" - if !strings.Contains(repo.Name, "/") { - repo.Name = "library/" + repo.Name - } - } else { - // 处理用户仓库:设置命名空间但保持Name为纯仓库名 - if repo.Namespace == "" && repo.RepoOwner != "" { - repo.Namespace = repo.RepoOwner - } - - // 如果Name包含斜杠,提取纯仓库名 - if strings.Contains(repo.Name, "/") { - parts := strings.Split(repo.Name, "/") - if len(parts) > 1 { - if repo.Namespace == "" { - repo.Namespace = parts[0] - } - repo.Name = parts[len(parts)-1] // 取最后部分作为仓库名 - } - } - } -} - -// searchDockerHub 搜索镜像 -func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { - return searchDockerHubWithDepth(ctx, query, page, pageSize, 0) -} - -// searchDockerHubWithDepth 搜索镜像(带递归深度控制) -func searchDockerHubWithDepth(ctx context.Context, query string, page, pageSize int, depth int) (*SearchResult, error) { - // 防止无限递归:最多允许1次递归调用 - if depth > 1 { - return nil, fmt.Errorf("搜索请求过于复杂,请尝试更具体的关键词") - } - cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) - - // 尝试从缓存获取 - if cached, ok := searchCache.Get(cacheKey); ok { - return cached.(*SearchResult), nil - } - - // 判断是否是用户/仓库格式的搜索 - isUserRepo := strings.Contains(query, "/") - var namespace, repoName string - - if isUserRepo { - parts := strings.Split(query, "/") - if len(parts) == 2 { - namespace = parts[0] - repoName = parts[1] - } - } - - // 构建搜索URL - baseURL := "https://registry.hub.docker.com/v2" - var fullURL string - var params url.Values - - if isUserRepo && namespace != "" { - // 如果是用户/仓库格式,使用repositories接口 - fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) - params = url.Values{ - "page": {fmt.Sprintf("%d", page)}, - "page_size": {fmt.Sprintf("%d", pageSize)}, - } - } else { - // 普通搜索 - fullURL = baseURL + "/search/repositories/" - params = url.Values{ - "query": {query}, - "page": {fmt.Sprintf("%d", page)}, - "page_size": {fmt.Sprintf("%d", pageSize)}, - } - } - - fullURL = fullURL + "?" + params.Encode() - - // 使用统一的搜索HTTP客户端 - resp, err := GetSearchHTTPClient().Get(fullURL) - if err != nil { - return nil, fmt.Errorf("请求Docker Hub API失败: %v", err) - } - defer safeCloseResponseBody(resp.Body, "搜索响应体") - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %v", err) - } - - if resp.StatusCode != http.StatusOK { - switch resp.StatusCode { - case http.StatusTooManyRequests: - return nil, fmt.Errorf("请求过于频繁,请稍后重试") - case http.StatusNotFound: - if isUserRepo && namespace != "" { - // 如果用户仓库搜索失败,尝试普通搜索(递归调用) - return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1) - } - return nil, fmt.Errorf("未找到相关镜像") - case http.StatusBadGateway, http.StatusServiceUnavailable: - return nil, fmt.Errorf("Docker Hub服务暂时不可用,请稍后重试") - default: - return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) - } - } - - // 解析响应 - var result *SearchResult - if isUserRepo && namespace != "" { - // 解析用户仓库列表响应 - var userRepos struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []Repository `json:"results"` - } - if err := json.Unmarshal(body, &userRepos); err != nil { - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - // 转换为SearchResult格式 - result = &SearchResult{ - Count: userRepos.Count, - Next: userRepos.Next, - Previous: userRepos.Previous, - Results: make([]Repository, 0), - } - - // 处理结果 - for _, repo := range userRepos.Results { - // 如果指定了仓库名,只保留匹配的结果 - if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { - // 设置命名空间并使用统一的规范化函数 - repo.Namespace = namespace - normalizeRepository(&repo) - result.Results = append(result.Results, repo) - } - } - - // 如果没有找到结果,尝试普通搜索(递归调用) - if len(result.Results) == 0 { - return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1) - } - - result.Count = len(result.Results) - } else { - // 解析普通搜索响应 - result = &SearchResult{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - // 处理搜索结果:使用统一的规范化函数 - for i := range result.Results { - normalizeRepository(&result.Results[i]) - } - - // 如果是用户/仓库搜索,过滤结果 - if isUserRepo && namespace != "" { - filteredResults := make([]Repository, 0) - for _, repo := range result.Results { - if strings.EqualFold(repo.Namespace, namespace) { - filteredResults = append(filteredResults, repo) - } - } - result.Results = filteredResults - result.Count = len(filteredResults) - } - } - - // 缓存结果 - searchCache.Set(cacheKey, result) - return result, nil -} - -// 判断错误是否可重试 -func isRetryableError(err error) bool { - if err == nil { - return false - } - - // 网络错误、超时等可以重试 - if strings.Contains(err.Error(), "timeout") || - strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "too many requests") { - return true - } - - return false -} - -// getRepositoryTags 获取仓库标签信息(支持分页) -func getRepositoryTags(ctx context.Context, namespace, name string, page, pageSize int) ([]TagInfo, bool, error) { - if namespace == "" || name == "" { - return nil, false, fmt.Errorf("无效输入:命名空间和名称不能为空") - } - - // 默认参数 - if page <= 0 { - page = 1 - } - if pageSize <= 0 || pageSize > 100 { - pageSize = 100 - } - - // 分页缓存key - cacheKey := fmt.Sprintf("tags:%s:%s:page_%d", namespace, name, page) - if cached, ok := searchCache.Get(cacheKey); ok { - result := cached.(TagPageResult) - return result.Tags, result.HasMore, nil - } - - // 构建API URL - baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) - params := url.Values{} - params.Set("page", fmt.Sprintf("%d", page)) - params.Set("page_size", fmt.Sprintf("%d", pageSize)) - params.Set("ordering", "last_updated") - - fullURL := baseURL + "?" + params.Encode() - - // 获取当前页数据 - pageResult, err := fetchTagPage(ctx, fullURL, 3) - if err != nil { - return nil, false, fmt.Errorf("获取标签失败: %v", err) - } - - hasMore := pageResult.Next != "" - - // 缓存结果(分页缓存时间较短) - result := TagPageResult{Tags: pageResult.Results, HasMore: hasMore} - searchCache.SetWithTTL(cacheKey, result, 30*time.Minute) - - return pageResult.Results, hasMore, nil -} - -// fetchTagPage 获取单页标签数据,带重试机制 -func fetchTagPage(ctx context.Context, url string, maxRetries int) (*struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []TagInfo `json:"results"` -}, error) { - var lastErr error - - for retry := 0; retry < maxRetries; retry++ { - if retry > 0 { - // 重试前等待一段时间 - time.Sleep(time.Duration(retry) * 500 * time.Millisecond) - } - - resp, err := GetSearchHTTPClient().Get(url) - if err != nil { - lastErr = err - if isRetryableError(err) && retry < maxRetries-1 { - continue - } - return nil, fmt.Errorf("发送请求失败: %v", err) - } - - // 读取响应体(立即关闭,避免defer在循环中累积) - body, err := func() ([]byte, error) { - defer safeCloseResponseBody(resp.Body, "标签响应体") - return io.ReadAll(resp.Body) - }() - - if err != nil { - lastErr = err - if retry < maxRetries-1 { - continue - } - return nil, fmt.Errorf("读取响应失败: %v", err) - } - - // 检查响应状态码 - if resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("状态码=%d, 响应=%s", resp.StatusCode, string(body)) - // 4xx错误通常不需要重试 - if resp.StatusCode >= 400 && resp.StatusCode < 500 && resp.StatusCode != 429 { - return nil, fmt.Errorf("请求失败: %v", lastErr) - } - if retry < maxRetries-1 { - continue - } - return nil, fmt.Errorf("请求失败: %v", lastErr) - } - - // 解析响应 - var result struct { - Count int `json:"count"` - Next string `json:"next"` - Previous string `json:"previous"` - Results []TagInfo `json:"results"` - } - if err := json.Unmarshal(body, &result); err != nil { - lastErr = err - if retry < maxRetries-1 { - continue - } - return nil, fmt.Errorf("解析响应失败: %v", err) - } - - return &result, nil - } - - return nil, lastErr -} - -// parsePaginationParams 解析分页参数 -func parsePaginationParams(c *gin.Context, defaultPageSize int) (page, pageSize int) { - page = 1 - pageSize = defaultPageSize - - if p := c.Query("page"); p != "" { - fmt.Sscanf(p, "%d", &page) - } - if ps := c.Query("page_size"); ps != "" { - fmt.Sscanf(ps, "%d", &pageSize) - } - - return page, pageSize -} - -// safeCloseResponseBody 安全关闭HTTP响应体(统一资源管理) -func safeCloseResponseBody(body io.ReadCloser, context string) { - if body != nil { - if err := body.Close(); err != nil { - fmt.Printf("关闭%s失败: %v\n", context, err) - } - } -} - -// sendErrorResponse 统一错误响应处理 -func sendErrorResponse(c *gin.Context, message string) { - c.JSON(http.StatusBadRequest, gin.H{"error": message}) -} - -// RegisterSearchRoute 注册搜索相关路由 -func RegisterSearchRoute(r *gin.Engine) { - // 搜索镜像 - r.GET("/search", func(c *gin.Context) { - query := c.Query("q") - if query == "" { - sendErrorResponse(c, "搜索关键词不能为空") - return - } - - page, pageSize := parsePaginationParams(c, 25) - - result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) - if err != nil { - sendErrorResponse(c, err.Error()) - return - } - - c.JSON(http.StatusOK, result) - }) - - // 获取标签信息 - r.GET("/tags/:namespace/:name", func(c *gin.Context) { - namespace := c.Param("namespace") - name := c.Param("name") - - if namespace == "" || name == "" { - sendErrorResponse(c, "命名空间和名称不能为空") - return - } - - page, pageSize := parsePaginationParams(c, 100) - - tags, hasMore, err := getRepositoryTags(c.Request.Context(), namespace, name, page, pageSize) - if err != nil { - sendErrorResponse(c, err.Error()) - return - } - - if c.Query("page") != "" || c.Query("page_size") != "" { - c.JSON(http.StatusOK, gin.H{ - "tags": tags, - "has_more": hasMore, - "page": page, - "page_size": pageSize, - }) - } else { - c.JSON(http.StatusOK, tags) - } - }) -} +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "hubproxy/utils" +) + +// SearchResult Docker Hub搜索结果 +type SearchResult struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []Repository `json:"results"` +} + +// Repository 仓库信息 +type Repository struct { + Name string `json:"repo_name"` + Description string `json:"short_description"` + IsOfficial bool `json:"is_official"` + IsAutomated bool `json:"is_automated"` + StarCount int `json:"star_count"` + PullCount int `json:"pull_count"` + RepoOwner string `json:"repo_owner"` + LastUpdated string `json:"last_updated"` + Status int `json:"status"` + Organization string `json:"affiliation"` + PullsLastWeek int `json:"pulls_last_week"` + Namespace string `json:"namespace"` +} + +// TagInfo 标签信息 +type TagInfo struct { + Name string `json:"name"` + FullSize int64 `json:"full_size"` + LastUpdated time.Time `json:"last_updated"` + LastPusher string `json:"last_pusher"` + Images []Image `json:"images"` + Vulnerabilities struct { + Critical int `json:"critical"` + High int `json:"high"` + Medium int `json:"medium"` + Low int `json:"low"` + Unknown int `json:"unknown"` + } `json:"vulnerabilities"` +} + +// Image 镜像信息 +type Image struct { + Architecture string `json:"architecture"` + Features string `json:"features"` + Variant string `json:"variant,omitempty"` + Digest string `json:"digest"` + OS string `json:"os"` + OSFeatures string `json:"os_features"` + Size int64 `json:"size"` +} + +// TagPageResult 分页标签结果 +type TagPageResult struct { + Tags []TagInfo `json:"tags"` + HasMore bool `json:"has_more"` +} + +type cacheEntry struct { + data interface{} + expiresAt time.Time +} + +const ( + maxCacheSize = 1000 + maxPaginationCache = 200 + cacheTTL = 30 * time.Minute +) + +type Cache struct { + data map[string]cacheEntry + mu sync.RWMutex + maxSize int +} + +var ( + searchCache = &Cache{ + data: make(map[string]cacheEntry), + maxSize: maxCacheSize, + } +) + +func (c *Cache) Get(key string) (interface{}, bool) { + c.mu.RLock() + entry, exists := c.data[key] + c.mu.RUnlock() + + if !exists { + return nil, false + } + + if time.Now().After(entry.expiresAt) { + c.mu.Lock() + delete(c.data, key) + c.mu.Unlock() + return nil, false + } + + return entry.data, true +} + +func (c *Cache) Set(key string, data interface{}) { + c.SetWithTTL(key, data, cacheTTL) +} + +func (c *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + if len(c.data) >= c.maxSize { + c.cleanupExpiredLocked() + } + + c.data[key] = cacheEntry{ + data: data, + expiresAt: time.Now().Add(ttl), + } +} + +func (c *Cache) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + c.cleanupExpiredLocked() +} + +func (c *Cache) cleanupExpiredLocked() { + now := time.Now() + for key, entry := range c.data { + if now.After(entry.expiresAt) { + delete(c.data, key) + } + } +} + +func init() { + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + searchCache.Cleanup() + } + }() +} + +func filterSearchResults(results []Repository, query string) []Repository { + searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/")) + filtered := make([]Repository, 0) + + for _, repo := range results { + repoName := strings.ToLower(repo.Name) + repoDesc := strings.ToLower(repo.Description) + + score := 0 + + if repoName == searchTerm { + score += 100 + } + + if strings.HasPrefix(repoName, searchTerm) { + score += 50 + } + + if strings.Contains(repoName, searchTerm) { + score += 30 + } + + if strings.Contains(repoDesc, searchTerm) { + score += 10 + } + + if repo.IsOfficial { + score += 20 + } + + if score > 0 { + filtered = append(filtered, repo) + } + } + + sort.Slice(filtered, func(i, j int) bool { + if filtered[i].IsOfficial != filtered[j].IsOfficial { + return filtered[i].IsOfficial + } + return filtered[i].PullCount > filtered[j].PullCount + }) + + return filtered +} + +// normalizeRepository 统一规范化仓库信息 +func normalizeRepository(repo *Repository) { + if repo.IsOfficial { + repo.Namespace = "library" + if !strings.Contains(repo.Name, "/") { + repo.Name = "library/" + repo.Name + } + } else { + if repo.Namespace == "" && repo.RepoOwner != "" { + repo.Namespace = repo.RepoOwner + } + + if strings.Contains(repo.Name, "/") { + parts := strings.Split(repo.Name, "/") + if len(parts) > 1 { + if repo.Namespace == "" { + repo.Namespace = parts[0] + } + repo.Name = parts[len(parts)-1] + } + } + } +} + +// searchDockerHub 搜索镜像 +func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { + return searchDockerHubWithDepth(ctx, query, page, pageSize, 0) +} + +func searchDockerHubWithDepth(ctx context.Context, query string, page, pageSize int, depth int) (*SearchResult, error) { + if depth > 1 { + return nil, fmt.Errorf("搜索请求过于复杂,请尝试更具体的关键词") + } + cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) + + if cached, ok := searchCache.Get(cacheKey); ok { + return cached.(*SearchResult), nil + } + + isUserRepo := strings.Contains(query, "/") + var namespace, repoName string + + if isUserRepo { + parts := strings.Split(query, "/") + if len(parts) == 2 { + namespace = parts[0] + repoName = parts[1] + } + } + + baseURL := "https://registry.hub.docker.com/v2" + var fullURL string + var params url.Values + + if isUserRepo && namespace != "" { + fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) + params = url.Values{ + "page": {fmt.Sprintf("%d", page)}, + "page_size": {fmt.Sprintf("%d", pageSize)}, + } + } else { + fullURL = baseURL + "/search/repositories/" + params = url.Values{ + "query": {query}, + "page": {fmt.Sprintf("%d", page)}, + "page_size": {fmt.Sprintf("%d", pageSize)}, + } + } + + fullURL = fullURL + "?" + params.Encode() + + resp, err := utils.GetSearchHTTPClient().Get(fullURL) + if err != nil { + return nil, fmt.Errorf("请求Docker Hub API失败: %v", err) + } + defer safeCloseResponseBody(resp.Body, "搜索响应体") + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + if resp.StatusCode != http.StatusOK { + switch resp.StatusCode { + case http.StatusTooManyRequests: + return nil, fmt.Errorf("请求过于频繁,请稍后重试") + case http.StatusNotFound: + if isUserRepo && namespace != "" { + return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1) + } + return nil, fmt.Errorf("未找到相关镜像") + case http.StatusBadGateway, http.StatusServiceUnavailable: + return nil, fmt.Errorf("Docker Hub服务暂时不可用,请稍后重试") + default: + return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body)) + } + } + + var result *SearchResult + if isUserRepo && namespace != "" { + var userRepos struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []Repository `json:"results"` + } + if err := json.Unmarshal(body, &userRepos); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + result = &SearchResult{ + Count: userRepos.Count, + Next: userRepos.Next, + Previous: userRepos.Previous, + Results: make([]Repository, 0), + } + + for _, repo := range userRepos.Results { + if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { + repo.Namespace = namespace + normalizeRepository(&repo) + result.Results = append(result.Results, repo) + } + } + + if len(result.Results) == 0 { + return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1) + } + + result.Count = len(result.Results) + } else { + result = &SearchResult{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + for i := range result.Results { + normalizeRepository(&result.Results[i]) + } + + if isUserRepo && namespace != "" { + filteredResults := make([]Repository, 0) + for _, repo := range result.Results { + if strings.EqualFold(repo.Namespace, namespace) { + filteredResults = append(filteredResults, repo) + } + } + result.Results = filteredResults + result.Count = len(filteredResults) + } + } + + searchCache.Set(cacheKey, result) + return result, nil +} + +func isRetryableError(err error) bool { + if err == nil { + return false + } + + if strings.Contains(err.Error(), "timeout") || + strings.Contains(err.Error(), "connection refused") || + strings.Contains(err.Error(), "no such host") || + strings.Contains(err.Error(), "too many requests") { + return true + } + + return false +} + +// getRepositoryTags 获取仓库标签信息 +func getRepositoryTags(ctx context.Context, namespace, name string, page, pageSize int) ([]TagInfo, bool, error) { + if namespace == "" || name == "" { + return nil, false, fmt.Errorf("无效输入:命名空间和名称不能为空") + } + + if page <= 0 { + page = 1 + } + if pageSize <= 0 || pageSize > 100 { + pageSize = 100 + } + + cacheKey := fmt.Sprintf("tags:%s:%s:page_%d", namespace, name, page) + if cached, ok := searchCache.Get(cacheKey); ok { + result := cached.(TagPageResult) + return result.Tags, result.HasMore, nil + } + + baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) + params := url.Values{} + params.Set("page", fmt.Sprintf("%d", page)) + params.Set("page_size", fmt.Sprintf("%d", pageSize)) + params.Set("ordering", "last_updated") + + fullURL := baseURL + "?" + params.Encode() + + pageResult, err := fetchTagPage(ctx, fullURL, 3) + if err != nil { + return nil, false, fmt.Errorf("获取标签失败: %v", err) + } + + hasMore := pageResult.Next != "" + + result := TagPageResult{Tags: pageResult.Results, HasMore: hasMore} + searchCache.SetWithTTL(cacheKey, result, 30*time.Minute) + + return pageResult.Results, hasMore, nil +} + +func fetchTagPage(ctx context.Context, url string, maxRetries int) (*struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []TagInfo `json:"results"` +}, error) { + var lastErr error + + for retry := 0; retry < maxRetries; retry++ { + if retry > 0 { + time.Sleep(time.Duration(retry) * 500 * time.Millisecond) + } + + resp, err := utils.GetSearchHTTPClient().Get(url) + if err != nil { + lastErr = err + if isRetryableError(err) && retry < maxRetries-1 { + continue + } + return nil, fmt.Errorf("发送请求失败: %v", err) + } + + body, err := func() ([]byte, error) { + defer safeCloseResponseBody(resp.Body, "标签响应体") + return io.ReadAll(resp.Body) + }() + + if err != nil { + lastErr = err + if retry < maxRetries-1 { + continue + } + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("状态码=%d, 响应=%s", resp.StatusCode, string(body)) + if resp.StatusCode >= 400 && resp.StatusCode < 500 && resp.StatusCode != 429 { + return nil, fmt.Errorf("请求失败: %v", lastErr) + } + if retry < maxRetries-1 { + continue + } + return nil, fmt.Errorf("请求失败: %v", lastErr) + } + + var result struct { + Count int `json:"count"` + Next string `json:"next"` + Previous string `json:"previous"` + Results []TagInfo `json:"results"` + } + if err := json.Unmarshal(body, &result); err != nil { + lastErr = err + if retry < maxRetries-1 { + continue + } + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + return &result, nil + } + + return nil, lastErr +} + +func parsePaginationParams(c *gin.Context, defaultPageSize int) (page, pageSize int) { + page = 1 + pageSize = defaultPageSize + + if p := c.Query("page"); p != "" { + fmt.Sscanf(p, "%d", &page) + } + if ps := c.Query("page_size"); ps != "" { + fmt.Sscanf(ps, "%d", &pageSize) + } + + return page, pageSize +} + +func safeCloseResponseBody(body io.ReadCloser, context string) { + if body != nil { + if err := body.Close(); err != nil { + fmt.Printf("关闭%s失败: %v\n", context, err) + } + } +} + +func sendErrorResponse(c *gin.Context, message string) { + c.JSON(http.StatusBadRequest, gin.H{"error": message}) +} + +// RegisterSearchRoute 注册搜索相关路由 +func RegisterSearchRoute(r *gin.Engine) { + r.GET("/search", func(c *gin.Context) { + query := c.Query("q") + if query == "" { + sendErrorResponse(c, "搜索关键词不能为空") + return + } + + page, pageSize := parsePaginationParams(c, 25) + + result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) + if err != nil { + sendErrorResponse(c, err.Error()) + return + } + + c.JSON(http.StatusOK, result) + }) + + r.GET("/tags/:namespace/:name", func(c *gin.Context) { + namespace := c.Param("namespace") + name := c.Param("name") + + if namespace == "" || name == "" { + sendErrorResponse(c, "命名空间和名称不能为空") + return + } + + page, pageSize := parsePaginationParams(c, 100) + + tags, hasMore, err := getRepositoryTags(c.Request.Context(), namespace, name, page, pageSize) + if err != nil { + sendErrorResponse(c, err.Error()) + return + } + + if c.Query("page") != "" || c.Query("page_size") != "" { + c.JSON(http.StatusOK, gin.H{ + "tags": tags, + "has_more": hasMore, + "page": page, + "page_size": pageSize, + }) + } else { + c.JSON(http.StatusOK, tags) + } + }) +} diff --git a/src/main.go b/src/main.go index bdb33bd..e1bf45c 100644 --- a/src/main.go +++ b/src/main.go @@ -3,15 +3,17 @@ package main import ( "embed" "fmt" - "io" "log" "net/http" - "regexp" - "strconv" "strings" "time" "github.com/gin-gonic/gin" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "hubproxy/config" + "hubproxy/handlers" + "hubproxy/utils" ) //go:embed public/* @@ -32,19 +34,7 @@ func serveEmbedFile(c *gin.Context, filename string) { } var ( - exps = []*regexp.Regexp{ - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), - regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), - regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`), - regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), - regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), - regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`), - regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), - regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), - regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`), - } - globalLimiter *IPRateLimiter + globalLimiter *utils.IPRateLimiter // 服务启动时间 serviceStartTime = time.Now() @@ -52,25 +42,25 @@ var ( func main() { // 加载配置 - if err := LoadConfig(); err != nil { + if err := config.LoadConfig(); err != nil { fmt.Printf("配置加载失败: %v\n", err) return } // 初始化HTTP客户端 - initHTTPClients() + utils.InitHTTPClients() // 初始化限流器 - initLimiter() + globalLimiter = utils.InitGlobalLimiter() // 初始化Docker流式代理 - initDockerProxy() + handlers.InitDockerProxy() // 初始化镜像流式下载器 - initImageStreamer() + handlers.InitImageStreamer() // 初始化防抖器 - initDebouncer() + handlers.InitDebouncer() gin.SetMode(gin.ReleaseMode) router := gin.Default() @@ -84,14 +74,14 @@ func main() { }) })) - // 全局限流中间件 - 应用到所有路由 - router.Use(RateLimitMiddleware(globalLimiter)) + // 全局限流中间件 + router.Use(utils.RateLimitMiddleware(globalLimiter)) // 初始化监控端点 initHealthRoutes(router) // 初始化镜像tar下载路由 - initImageTarRoutes(router) + handlers.InitImageTarRoutes(router) // 静态文件路由 router.GET("/", func(c *gin.Context) { @@ -113,217 +103,59 @@ func main() { }) // 注册dockerhub搜索路由 - RegisterSearchRoute(router) + handlers.RegisterSearchRoute(router) - // 注册Docker认证路由(/token*) - router.Any("/token", ProxyDockerAuthGin) - router.Any("/token/*path", ProxyDockerAuthGin) + // 注册Docker认证路由 + router.Any("/token", handlers.ProxyDockerAuthGin) + router.Any("/token/*path", handlers.ProxyDockerAuthGin) // 注册Docker Registry代理路由 - router.Any("/v2/*path", ProxyDockerRegistryGin) + router.Any("/v2/*path", handlers.ProxyDockerRegistryGin) - // 注册NoRoute处理器 - router.NoRoute(handler) + // 注册GitHub代理路由(NoRoute处理器) + router.NoRoute(handlers.GitHubProxyHandler) - cfg := GetConfig() + cfg := config.GetConfig() fmt.Printf("🚀 HubProxy 启动成功\n") fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port) fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours) + + // 显示HTTP/2支持状态 + if cfg.Server.EnableH2C { + fmt.Printf("H2c: 已启用\n") + } + fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n") - err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) + // 创建HTTP2服务器 + server := &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), + ReadTimeout: 60 * time.Second, + WriteTimeout: 300 * time.Second, + IdleTimeout: 120 * time.Second, + } + + // 根据配置决定是否启用H2C + if cfg.Server.EnableH2C { + h2cHandler := h2c.NewHandler(router, &http2.Server{ + MaxConcurrentStreams: 250, + IdleTimeout: 300 * time.Second, + MaxReadFrameSize: 4 << 20, + MaxUploadBufferPerConnection: 8 << 20, + MaxUploadBufferPerStream: 2 << 20, + }) + server.Handler = h2cHandler + } else { + server.Handler = router + } + + err := server.ListenAndServe() if err != nil { fmt.Printf("启动服务失败: %v\n", err) } } -func handler(c *gin.Context) { - rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") - - for strings.HasPrefix(rawPath, "/") { - rawPath = strings.TrimPrefix(rawPath, "/") - } - // 自动补全协议头 - if !strings.HasPrefix(rawPath, "https://") { - // 修复 http:/ 和 https:/ 的情况 - if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") { - rawPath = strings.Replace(rawPath, "http:/", "", 1) - rawPath = strings.Replace(rawPath, "https:/", "", 1) - } else if strings.HasPrefix(rawPath, "http://") { - rawPath = strings.TrimPrefix(rawPath, "http://") - } - rawPath = "https://" + rawPath - } - - matches := checkURL(rawPath) - if matches != nil { - // GitHub仓库访问控制检查 - if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed { - // 构建仓库名用于日志 - var repoPath string - if len(matches) >= 2 { - username := matches[0] - repoName := strings.TrimSuffix(matches[1], ".git") - repoPath = username + "/" + repoName - } - fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason) - c.String(http.StatusForbidden, reason) - return - } - } else { - c.String(http.StatusForbidden, "无效输入") - return - } - - if exps[1].MatchString(rawPath) { - rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) - } - - proxyRequest(c, rawPath) -} - -func proxyRequest(c *gin.Context, u string) { - proxyWithRedirect(c, u, 0) -} - -func proxyWithRedirect(c *gin.Context, u string, redirectCount int) { - // 限制最大重定向次数,防止无限递归 - const maxRedirects = 20 - if redirectCount > maxRedirects { - c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向") - return - } - req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) - return - } - - for key, values := range c.Request.Header { - for _, value := range values { - req.Header.Add(key, value) - } - } - req.Header.Del("Host") - - resp, err := GetGlobalHTTPClient().Do(req) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) - return - } - defer func() { - if err := resp.Body.Close(); err != nil { - fmt.Printf("关闭响应体失败: %v\n", err) - } - }() - - // 检查文件大小限制 - cfg := GetConfig() - if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { - if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize { - c.String(http.StatusRequestEntityTooLarge, - fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024))) - return - } - } - - // 清理安全相关的头 - resp.Header.Del("Content-Security-Policy") - resp.Header.Del("Referrer-Policy") - resp.Header.Del("Strict-Transport-Security") - - // 获取真实域名 - realHost := c.Request.Header.Get("X-Forwarded-Host") - if realHost == "" { - realHost = c.Request.Host - } - // 如果域名中没有协议前缀,添加https:// - if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") { - realHost = "https://" + realHost - } - - if strings.HasSuffix(strings.ToLower(u), ".sh") { - isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip" - - processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost) - if err != nil { - fmt.Printf("智能处理失败,回退到直接代理: %v\n", err) - processedBody = resp.Body - processedSize = 0 - } - - // 智能设置响应头 - if processedSize > 0 { - resp.Header.Del("Content-Length") - resp.Header.Del("Content-Encoding") - resp.Header.Set("Transfer-Encoding", "chunked") - } - - // 复制其他响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - if location := resp.Header.Get("Location"); location != "" { - if checkURL(location) != nil { - c.Header("Location", "/"+location) - } else { - proxyWithRedirect(c, location, redirectCount+1) - return - } - } - - c.Status(resp.StatusCode) - - // 输出处理后的内容 - if _, err := io.Copy(c.Writer, processedBody); err != nil { - return - } - } else { - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - // 处理重定向 - if location := resp.Header.Get("Location"); location != "" { - if checkURL(location) != nil { - c.Header("Location", "/"+location) - } else { - proxyWithRedirect(c, location, redirectCount+1) - return - } - } - - c.Status(resp.StatusCode) - - // 直接流式转发 - io.Copy(c.Writer, resp.Body) - } -} - -func checkURL(u string) []string { - for _, exp := range exps { - if matches := exp.FindStringSubmatch(u); matches != nil { - return matches[1:] - } - } - return nil -} - // 简单的健康检查 -func formatBeijingTime(t time.Time) string { - loc, err := time.LoadLocation("Asia/Shanghai") - if err != nil { - loc = time.FixedZone("CST", 8*3600) // 兜底时区 - } - return t.In(loc).Format("2006-01-02 15:04:05") -} - -// 转换为可读时间 func formatDuration(d time.Duration) string { if d < time.Minute { return fmt.Sprintf("%d秒", int(d.Seconds())) @@ -338,26 +170,20 @@ func formatDuration(d time.Duration) string { } } -func initHealthRoutes(router *gin.Engine) { - router.GET("/health", func(c *gin.Context) { - uptime := time.Since(serviceStartTime) - c.JSON(http.StatusOK, gin.H{ - "status": "healthy", - "timestamp_unix": serviceStartTime.Unix(), - "uptime_sec": uptime.Seconds(), - "service": "hubproxy", - "start_time_bj": formatBeijingTime(serviceStartTime), - "uptime_human": formatDuration(uptime), - }) - }) +func getUptimeInfo() (time.Duration, float64, string) { + uptime := time.Since(serviceStartTime) + return uptime, uptime.Seconds(), formatDuration(uptime) +} +func initHealthRoutes(router *gin.Engine) { router.GET("/ready", func(c *gin.Context) { - uptime := time.Since(serviceStartTime) + _, uptimeSec, uptimeHuman := getUptimeInfo() c.JSON(http.StatusOK, gin.H{ - "ready": true, - "timestamp_unix": time.Now().Unix(), - "uptime_sec": uptime.Seconds(), - "uptime_human": formatDuration(uptime), + "ready": true, + "service": "hubproxy", + "start_time_unix": serviceStartTime.Unix(), + "uptime_sec": uptimeSec, + "uptime_human": uptimeHuman, }) }) } diff --git a/src/access_control.go b/src/utils/access_control.go similarity index 91% rename from src/access_control.go rename to src/utils/access_control.go index 7d71a14..48a685b 100644 --- a/src/access_control.go +++ b/src/utils/access_control.go @@ -1,8 +1,10 @@ -package main +package utils import ( "strings" "sync" + + "hubproxy/config" ) // ResourceType 资源类型 @@ -26,7 +28,7 @@ type DockerImageInfo struct { FullName string } -// 全局访问控制器实例 +// GlobalAccessController 全局访问控制器实例 var GlobalAccessController = &AccessController{} // ParseDockerImage 解析Docker镜像名称 @@ -79,19 +81,16 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { // CheckDockerAccess 检查Docker镜像访问权限 func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { - cfg := GetConfig() + cfg := config.GetConfig() - // 解析镜像名称 imageInfo := ac.ParseDockerImage(image) - // 检查白名单(如果配置了白名单,则只允许白名单中的镜像) if len(cfg.Access.WhiteList) > 0 { if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) { return false, "不在Docker镜像白名单内" } } - // 检查黑名单 if len(cfg.Access.BlackList) > 0 { if ac.matchImageInList(imageInfo, cfg.Access.BlackList) { return false, "Docker镜像在黑名单内" @@ -107,14 +106,12 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r return false, "无效的GitHub仓库格式" } - cfg := GetConfig() + cfg := config.GetConfig() - // 检查白名单 if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) { return false, "不在GitHub仓库白名单内" } - // 检查黑名单 if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) { return false, "GitHub仓库在黑名单内" } @@ -185,17 +182,14 @@ func (ac *AccessController) checkList(matches, list []string) bool { continue } - // 支持多种匹配模式 if fullRepo == item { return true } - // 用户级匹配 if item == username || item == username+"/*" { return true } - // 前缀匹配(支持通配符) if strings.HasSuffix(item, "*") { prefix := strings.TrimSuffix(item, "*") if strings.HasPrefix(fullRepo, prefix) { @@ -203,7 +197,6 @@ func (ac *AccessController) checkList(matches, list []string) bool { } } - // 子仓库匹配(防止 user/repo 匹配到 user/repo-fork) if strings.HasPrefix(fullRepo, item+"/") { return true } diff --git a/src/token_cache.go b/src/utils/cache.go similarity index 62% rename from src/token_cache.go rename to src/utils/cache.go index 5b4b730..488ce80 100644 --- a/src/token_cache.go +++ b/src/utils/cache.go @@ -1,4 +1,4 @@ -package main +package utils import ( "crypto/md5" @@ -9,22 +9,23 @@ import ( "time" "github.com/gin-gonic/gin" + "hubproxy/config" ) -// CachedItem 通用缓存项,支持Token和Manifest +// CachedItem 通用缓存项 type CachedItem struct { - Data []byte // 缓存数据(token字符串或manifest字节) - ContentType string // 内容类型 - Headers map[string]string // 额外的响应头 - ExpiresAt time.Time // 过期时间 + Data []byte + ContentType string + Headers map[string]string + ExpiresAt time.Time } -// UniversalCache 通用缓存,支持Token和Manifest +// UniversalCache 通用缓存 type UniversalCache struct { cache sync.Map } -var globalCache = &UniversalCache{} +var GlobalCache = &UniversalCache{} // Get 获取缓存项 func (c *UniversalCache) Get(key string) *CachedItem { @@ -57,22 +58,22 @@ func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) { c.Set(key, []byte(token), "application/json", nil, ttl) } -// buildCacheKey 构建稳定的缓存key -func buildCacheKey(prefix, query string) string { +// BuildCacheKey 构建稳定的缓存key +func BuildCacheKey(prefix, query string) string { return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query))) } -func buildTokenCacheKey(query string) string { - return buildCacheKey("token", query) +func BuildTokenCacheKey(query string) string { + return BuildCacheKey("token", query) } -func buildManifestCacheKey(imageRef, reference string) string { +func BuildManifestCacheKey(imageRef, reference string) string { key := fmt.Sprintf("%s:%s", imageRef, reference) - return buildCacheKey("manifest", key) + return BuildCacheKey("manifest", key) } -func getManifestTTL(reference string) time.Duration { - cfg := GetConfig() +func GetManifestTTL(reference string) time.Duration { + cfg := config.GetConfig() defaultTTL := 30 * time.Minute if cfg.TokenCache.DefaultTTL != "" { if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil { @@ -84,23 +85,20 @@ func getManifestTTL(reference string) time.Duration { return 24 * time.Hour } - // mutable tag的智能判断 if reference == "latest" || reference == "main" || reference == "master" || reference == "dev" || reference == "develop" { - // 热门可变标签: 短期缓存 return 10 * time.Minute } return defaultTTL } -// extractTTLFromResponse 从响应中智能提取TTL -func extractTTLFromResponse(responseBody []byte) time.Duration { +// ExtractTTLFromResponse 从响应中智能提取TTL +func ExtractTTLFromResponse(responseBody []byte) time.Duration { var tokenResp struct { ExpiresIn int `json:"expires_in"` } - // 默认30分钟TTL,确保稳定性 defaultTTL := 30 * time.Minute if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 { @@ -113,37 +111,35 @@ func extractTTLFromResponse(responseBody []byte) time.Duration { return defaultTTL } -func writeTokenResponse(c *gin.Context, cachedBody string) { +func WriteTokenResponse(c *gin.Context, cachedBody string) { c.Header("Content-Type", "application/json") c.String(200, cachedBody) } -func writeCachedResponse(c *gin.Context, item *CachedItem) { +func WriteCachedResponse(c *gin.Context, item *CachedItem) { if item.ContentType != "" { c.Header("Content-Type", item.ContentType) } - // 设置额外的响应头 for key, value := range item.Headers { c.Header(key, value) } - // 返回数据 c.Data(200, item.ContentType, item.Data) } -// isCacheEnabled 检查缓存是否启用 -func isCacheEnabled() bool { - cfg := GetConfig() +// IsCacheEnabled 检查缓存是否启用 +func IsCacheEnabled() bool { + cfg := config.GetConfig() return cfg.TokenCache.Enabled } -// isTokenCacheEnabled 检查token缓存是否启用(向后兼容) -func isTokenCacheEnabled() bool { - return isCacheEnabled() +// IsTokenCacheEnabled 检查token缓存是否启用 +func IsTokenCacheEnabled() bool { + return IsCacheEnabled() } -// 定期清理过期缓存,防止内存泄漏 +// 定期清理过期缓存 func init() { go func() { ticker := time.NewTicker(20 * time.Minute) @@ -153,7 +149,7 @@ func init() { now := time.Now() expiredKeys := make([]string, 0) - globalCache.cache.Range(func(key, value interface{}) bool { + GlobalCache.cache.Range(func(key, value interface{}) bool { if cached := value.(*CachedItem); now.After(cached.ExpiresAt) { expiredKeys = append(expiredKeys, key.(string)) } @@ -161,7 +157,7 @@ func init() { }) for _, key := range expiredKeys { - globalCache.cache.Delete(key) + GlobalCache.cache.Delete(key) } } }() diff --git a/src/http_client.go b/src/utils/http_client.go similarity index 73% rename from src/http_client.go rename to src/utils/http_client.go index 93988eb..9bb250f 100644 --- a/src/http_client.go +++ b/src/utils/http_client.go @@ -1,28 +1,28 @@ -package main +package utils import ( "net" "net/http" "os" "time" + + "hubproxy/config" ) var ( - // 全局HTTP客户端 - 用于代理请求(长超时) globalHTTPClient *http.Client - // 搜索HTTP客户端 - 用于API请求(短超时) searchHTTPClient *http.Client ) -// initHTTPClients 初始化HTTP客户端 -func initHTTPClients() { - cfg := GetConfig() +// InitHTTPClients 初始化HTTP客户端 +func InitHTTPClients() { + cfg := config.GetConfig() if p := cfg.Access.Proxy; p != "" { os.Setenv("HTTP_PROXY", p) os.Setenv("HTTPS_PROXY", p) } - // 代理客户端配置 - 适用于大文件传输 + globalHTTPClient = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -39,7 +39,6 @@ func initHTTPClients() { }, } - // 搜索客户端配置 - 适用于API调用 searchHTTPClient = &http.Client{ Timeout: 10 * time.Second, Transport: &http.Transport{ @@ -57,12 +56,12 @@ func initHTTPClients() { } } -// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理) +// GetGlobalHTTPClient 获取全局HTTP客户端 func GetGlobalHTTPClient() *http.Client { return globalHTTPClient } -// GetSearchHTTPClient 获取搜索HTTP客户端(用于API调用) +// GetSearchHTTPClient 获取搜索HTTP客户端 func GetSearchHTTPClient() *http.Client { return searchHTTPClient } diff --git a/src/proxysh.go b/src/utils/proxy_shell.go similarity index 98% rename from src/proxysh.go rename to src/utils/proxy_shell.go index e49e917..d83313f 100644 --- a/src/proxysh.go +++ b/src/utils/proxy_shell.go @@ -1,4 +1,4 @@ -package main +package utils import ( "bytes" @@ -41,7 +41,6 @@ func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reade func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) { var reader io.Reader = input - // 处理gzip压缩 if isCompressed { peek := make([]byte, 2) n, err := input.Read(peek) diff --git a/src/ratelimiter.go b/src/utils/ratelimiter.go similarity index 69% rename from src/ratelimiter.go rename to src/utils/ratelimiter.go index 9f7bdc8..bc6f2f2 100644 --- a/src/ratelimiter.go +++ b/src/utils/ratelimiter.go @@ -1,4 +1,4 @@ -package main +package utils import ( "fmt" @@ -9,22 +9,23 @@ import ( "github.com/gin-gonic/gin" "golang.org/x/time/rate" + "hubproxy/config" ) const ( - // 清理间隔 CleanupInterval = 10 * time.Minute MaxIPCacheSize = 10000 ) // IPRateLimiter IP限流器结构体 type IPRateLimiter struct { - ips map[string]*rateLimiterEntry // IP到限流器的映射 - mu *sync.RWMutex // 读写锁,保证并发安全 - r rate.Limit // 速率限制(每秒允许的请求数) - b int // 令牌桶容量(突发请求数) - whitelist []*net.IPNet // 白名单IP段 - blacklist []*net.IPNet // 黑名单IP段 + ips map[string]*rateLimiterEntry + mu *sync.RWMutex + r rate.Limit + b int + whitelist []*net.IPNet + blacklist []*net.IPNet + whitelistLimiter *rate.Limiter // 全局共享的白名单限流器 } // rateLimiterEntry 限流器条目 @@ -33,15 +34,15 @@ type rateLimiterEntry struct { lastAccess time.Time } -// initGlobalLimiter 初始化全局限流器 -func initGlobalLimiter() *IPRateLimiter { - cfg := GetConfig() +// InitGlobalLimiter 初始化全局限流器 +func InitGlobalLimiter() *IPRateLimiter { + cfg := config.GetConfig() whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList)) for _, item := range cfg.Security.WhiteList { if item = strings.TrimSpace(item); item != "" { if !strings.Contains(item, "/") { - item = item + "/32" // 单个IP转为CIDR格式 + item = item + "/32" } _, ipnet, err := net.ParseCIDR(item) if err == nil { @@ -52,12 +53,11 @@ func initGlobalLimiter() *IPRateLimiter { } } - // 解析黑名单IP段 blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList)) for _, item := range cfg.Security.BlackList { if item = strings.TrimSpace(item); item != "" { if !strings.Contains(item, "/") { - item = item + "/32" // 单个IP转为CIDR格式 + item = item + "/32" } _, ipnet, err := net.ParseCIDR(item) if err == nil { @@ -68,7 +68,6 @@ func initGlobalLimiter() *IPRateLimiter { } } - // 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求" ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600)) burstSize := cfg.RateLimit.RequestLimit @@ -77,25 +76,20 @@ func initGlobalLimiter() *IPRateLimiter { } limiter := &IPRateLimiter{ - ips: make(map[string]*rateLimiterEntry), - mu: &sync.RWMutex{}, - r: ratePerSecond, - b: burstSize, - whitelist: whitelist, - blacklist: blacklist, + ips: make(map[string]*rateLimiterEntry), + mu: &sync.RWMutex{}, + r: ratePerSecond, + b: burstSize, + whitelist: whitelist, + blacklist: blacklist, + whitelistLimiter: rate.NewLimiter(rate.Inf, burstSize), } - // 启动定期清理goroutine go limiter.cleanupRoutine() return limiter } -// initLimiter 初始化限流器 -func initLimiter() { - globalLimiter = initGlobalLimiter() -} - // cleanupRoutine 定期清理过期的限流器 func (i *IPRateLimiter) cleanupRoutine() { ticker := time.NewTicker(CleanupInterval) @@ -105,25 +99,20 @@ func (i *IPRateLimiter) cleanupRoutine() { now := time.Now() expired := make([]string, 0) - // 查找过期的条目 i.mu.RLock() for ip, entry := range i.ips { - // 如果最后访问时间超过1小时,认为过期 if now.Sub(entry.lastAccess) > 1*time.Hour { expired = append(expired, ip) } } i.mu.RUnlock() - // 如果有过期条目或者缓存过大,进行清理 if len(expired) > 0 || len(i.ips) > MaxIPCacheSize { i.mu.Lock() - // 删除过期条目 for _, ip := range expired { delete(i.ips, ip) } - // 如果缓存仍然过大,全部清理 if len(i.ips) > MaxIPCacheSize { i.ips = make(map[string]*rateLimiterEntry) } @@ -140,28 +129,26 @@ func extractIPFromAddress(address string) string { return address } -// normalizeIPForRateLimit 标准化IP地址用于限流:IPv4保持不变,IPv6标准化为/64网段 +// normalizeIPForRateLimit 标准化IP地址用于限流 func normalizeIPForRateLimit(ipStr string) string { ip := net.ParseIP(ipStr) if ip == nil { - return ipStr // 解析失败,返回原值 + return ipStr } if ip.To4() != nil { - return ipStr // IPv4保持不变 + return ipStr } - // IPv6:标准化为 /64 网段 ipv6 := ip.To16() for i := 8; i < 16; i++ { - ipv6[i] = 0 // 清零后64位 + ipv6[i] = 0 } return ipv6.String() + "/64" } // isIPInCIDRList 检查IP是否在CIDR列表中 func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { - // 先提取纯IP地址 cleanIP := extractIPFromAddress(ip) parsedIP := net.ParseIP(cleanIP) if parsedIP == nil { @@ -176,22 +163,18 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { return false } -// GetLimiter 获取指定IP的限流器,同时返回是否允许访问 +// GetLimiter 获取指定IP的限流器 func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { - // 提取纯IP地址 cleanIP := extractIPFromAddress(ip) - // 检查是否在黑名单中 if isIPInCIDRList(cleanIP, i.blacklist) { return nil, false } - // 检查是否在白名单中 if isIPInCIDRList(cleanIP, i.whitelist) { - return rate.NewLimiter(rate.Inf, i.b), true + return i.whitelistLimiter, true } - // 标准化IP用于限流:IPv4保持不变,IPv6标准化为/64网段 normalizedIP := normalizeIPForRateLimit(cleanIP) now := time.Now() @@ -230,7 +213,6 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { // RateLimitMiddleware 速率限制中间件 func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { return func(c *gin.Context) { - // 静态文件豁免:跳过限流检查 path := c.Request.URL.Path if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" || strings.HasPrefix(path, "/public/") { @@ -238,30 +220,22 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { return } - // 获取客户端真实IP var ip string - // 优先尝试从请求头获取真实IP if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" { - // X-Forwarded-For可能包含多个IP,取第一个 ips := strings.Split(forwarded, ",") ip = strings.TrimSpace(ips[0]) } else if realIP := c.GetHeader("X-Real-IP"); realIP != "" { - // 如果有X-Real-IP头 ip = realIP } else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" { - // 某些代理可能使用此头 ips := strings.Split(remoteIP, ",") ip = strings.TrimSpace(ips[0]) } else { - // 回退到ClientIP方法 ip = c.ClientIP() } - // 提取纯IP地址(去除可能存在的端口) cleanIP := extractIPFromAddress(ip) - // 日志记录请求IP和头信息 normalizedIP := normalizeIPForRateLimit(cleanIP) if cleanIP != normalizedIP { fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", @@ -275,10 +249,8 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { c.GetHeader("X-Real-IP")) } - // 获取限流器并检查是否允许访问 ipLimiter, allowed := limiter.GetLimiter(cleanIP) - // 如果IP在黑名单中 if !allowed { c.JSON(403, gin.H{ "error": "您已被限制访问", @@ -287,7 +259,6 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { return } - // 检查限流 if !ipLimiter.Allow() { c.JSON(429, gin.H{ "error": "请求频率过快,暂时限制访问",