优化代码结构,支持h2 #39

Merged
sky22333 merged 7 commits from dev into main 2025-07-27 12:11:39 +08:00
16 changed files with 1024 additions and 1189 deletions

View File

@@ -37,7 +37,7 @@ docker run -d \
### 一键脚本安装
```bash
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install-service.sh | sudo bash
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install.sh | sudo bash
```
也可以直接下载二进制文件执行`./hubproxy`使用无需配置文件即可启动内置默认配置支持所有功能。初始内存占用约18M二进制文件大小约12M
@@ -109,12 +109,14 @@ host = "0.0.0.0"
port = 5000
# Github文件大小限制字节默认2GB
fileSize = 2147483648
# HTTP/2 多路复用
enableH2C = false
[rateLimit]
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层会消耗多个次数)
requestLimit = 500
# 限流周期(小时)
periodHours = 1.0
periodHours = 3.0
[security]
# IP白名单支持单个IP或IP段
@@ -132,7 +134,7 @@ blackList = [
"192.168.100.0/24"
]
[proxy]
[access]
# 代理服务白名单支持GitHub仓库和Docker镜像支持通配符
# 只允许访问白名单中的仓库/镜像,为空时不限制
whiteList = []
@@ -148,12 +150,6 @@ blackList = [
# 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080
# HTTP 代理示例
# http://username:password@127.0.0.1:7890
# SOCKS5 代理示例
# socks5://username:password@127.0.0.1:1080
# SOCKS5H 代理示例
# socks5h://username:password@127.0.0.1:1080
# 留空不使用代理
proxy = ""

View File

@@ -1,8 +1,8 @@
services:
hubproxy:
build: .
restart: always
ports:
- '5000:5000'
volumes:
- ./src/config.toml:/root/config.toml
hubproxy:
build: .
restart: always
ports:
- '5000:5000'
volumes:
- ./src/config.toml:/root/config.toml

View File

@@ -4,12 +4,14 @@ host = "0.0.0.0"
port = 5000
# Github文件大小限制字节默认2GB
fileSize = 2147483648
# HTTP/2 多路复用
enableH2C = false
[rateLimit]
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层会消耗多个次数)
requestLimit = 500
# 限流周期(小时)
periodHours = 1.0
periodHours = 3.0
[security]
# IP白名单支持单个IP或IP段
@@ -43,12 +45,6 @@ blackList = [
# 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080
# HTTP 代理示例
# http://username:password@127.0.0.1:7890
# SOCKS5 代理示例
# socks5://username:password@127.0.0.1:1080
# SOCKS5H 代理示例
# socks5h://username:password@127.0.0.1:1080
# 留空不使用代理
proxy = ""

View File

@@ -1,4 +1,4 @@
package main
package config
import (
"fmt"
@@ -13,45 +13,46 @@ import (
// RegistryMapping Registry映射配置
type RegistryMapping struct {
Upstream string `toml:"upstream"` // 上游Registry地址
AuthHost string `toml:"authHost"` // 认证服务器地址
AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic
Enabled bool `toml:"enabled"` // 是否启用
Upstream string `toml:"upstream"`
AuthHost string `toml:"authHost"`
AuthType string `toml:"authType"`
Enabled bool `toml:"enabled"`
}
// AppConfig 应用配置结构体
type AppConfig struct {
Server struct {
Host string `toml:"host"` // 监听地址
Port int `toml:"port"` // 监听端口
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节)
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
EnableH2C bool `toml:"enableH2C"`
} `toml:"server"`
RateLimit struct {
RequestLimit int `toml:"requestLimit"` // 每小时请求限制
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时)
RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"`
} `toml:"rateLimit"`
Security struct {
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表
WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"`
} `toml:"security"`
Access struct {
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别)
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别)
Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h
WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"`
Proxy string `toml:"proxy"`
} `toml:"access"`
Download struct {
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制
MaxImages int `toml:"maxImages"`
} `toml:"download"`
Registries map[string]RegistryMapping `toml:"registries"`
TokenCache struct {
Enabled bool `toml:"enabled"` // 是否启用token缓存
DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间
Enabled bool `toml:"enabled"`
DefaultTTL string `toml:"defaultTTL"`
} `toml:"tokenCache"`
}
@@ -65,24 +66,25 @@ var (
configCacheMutex sync.RWMutex
)
// todo:Refactoring is needed
// DefaultConfig 返回默认配置
func DefaultConfig() *AppConfig {
return &AppConfig{
Server: struct {
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
EnableH2C bool `toml:"enableH2C"`
}{
Host: "0.0.0.0",
Port: 5000,
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
Host: "0.0.0.0",
Port: 5000,
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
EnableH2C: false, // 默认关闭H2C
},
RateLimit: struct {
RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"`
}{
RequestLimit: 20,
RequestLimit: 200,
PeriodHours: 1.0,
},
Security: struct {
@@ -99,12 +101,12 @@ func DefaultConfig() *AppConfig {
}{
WhiteList: []string{},
BlackList: []string{},
Proxy: "", // 默认不使用代理
Proxy: "",
},
Download: struct {
MaxImages int `toml:"maxImages"`
}{
MaxImages: 10, // 默认值最多同时下载10个镜像
MaxImages: 10,
},
Registries: map[string]RegistryMapping{
"ghcr.io": {
@@ -136,7 +138,7 @@ func DefaultConfig() *AppConfig {
Enabled bool `toml:"enabled"`
DefaultTTL string `toml:"defaultTTL"`
}{
Enabled: true, // docker认证的匿名Token缓存配置用于提升性能
Enabled: true,
DefaultTTL: "20m",
},
}
@@ -152,11 +154,9 @@ func GetConfig() *AppConfig {
}
configCacheMutex.RUnlock()
// 缓存过期,重新生成配置
configCacheMutex.Lock()
defer configCacheMutex.Unlock()
// 双重检查,防止重复生成
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
return cachedConfig
}
@@ -170,7 +170,6 @@ func GetConfig() *AppConfig {
return defaultCfg
}
// 生成新的配置深拷贝
configCopy := *appConfig
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
@@ -197,10 +196,8 @@ func setConfig(cfg *AppConfig) {
// LoadConfig 加载配置文件
func LoadConfig() error {
// 首先使用默认配置
cfg := DefaultConfig()
// 尝试加载TOML配置文件
if data, err := os.ReadFile("config.toml"); err == nil {
if err := toml.Unmarshal(data, cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %v", err)
@@ -209,10 +206,7 @@ func LoadConfig() error {
fmt.Println("未找到config.toml使用默认配置")
}
// 从环境变量覆盖配置
overrideFromEnv(cfg)
// 设置配置
setConfig(cfg)
return nil
@@ -220,7 +214,6 @@ func LoadConfig() error {
// overrideFromEnv 从环境变量覆盖配置
func overrideFromEnv(cfg *AppConfig) {
// 服务器配置
if val := os.Getenv("SERVER_HOST"); val != "" {
cfg.Server.Host = val
}
@@ -229,13 +222,17 @@ func overrideFromEnv(cfg *AppConfig) {
cfg.Server.Port = port
}
}
if val := os.Getenv("ENABLE_H2C"); val != "" {
if enable, err := strconv.ParseBool(val); err == nil {
cfg.Server.EnableH2C = enable
}
}
if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
cfg.Server.FileSize = size
}
}
// 限流配置
if val := os.Getenv("RATE_LIMIT"); val != "" {
if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
cfg.RateLimit.RequestLimit = limit
@@ -247,7 +244,6 @@ func overrideFromEnv(cfg *AppConfig) {
}
}
// IP限制配置
if val := os.Getenv("IP_WHITELIST"); val != "" {
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
}
@@ -255,7 +251,6 @@ func overrideFromEnv(cfg *AppConfig) {
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
}
// 下载限制配置
if val := os.Getenv("MAX_IMAGES"); val != "" {
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
cfg.Download.MaxImages = maxImages

View File

@@ -6,6 +6,7 @@ require (
github.com/gin-gonic/gin v1.10.0
github.com/google/go-containerregistry v0.20.5
github.com/pelletier/go-toml/v2 v2.2.3
golang.org/x/net v0.33.0
golang.org/x/time v0.11.0
)
@@ -43,7 +44,6 @@ require (
github.com/vbatts/tar-split v0.12.1 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.21.0 // indirect

View File

@@ -1,4 +1,4 @@
package main
package handlers
import (
"context"
@@ -12,6 +12,8 @@ import (
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"hubproxy/config"
"hubproxy/utils"
)
// DockerProxy Docker代理配置
@@ -27,12 +29,10 @@ type RegistryDetector struct{}
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
cfg := GetConfig()
cfg := config.GetConfig()
// 检查路径是否以已知Registry域名开头
for domain := range cfg.Registries {
if strings.HasPrefix(path, domain+"/") {
// 找到匹配的域名,返回域名和剩余路径
remainingPath := strings.TrimPrefix(path, domain+"/")
return domain, remainingPath
}
@@ -43,7 +43,7 @@ func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
// isRegistryEnabled 检查Registry是否启用
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
cfg := GetConfig()
cfg := config.GetConfig()
if mapping, exists := cfg.Registries[domain]; exists {
return mapping.Enabled
}
@@ -51,28 +51,26 @@ func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
}
// getRegistryMapping 获取Registry映射配置
func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) {
cfg := GetConfig()
func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
cfg := config.GetConfig()
mapping, exists := cfg.Registries[domain]
return mapping, exists && mapping.Enabled
}
var registryDetector = &RegistryDetector{}
// 初始化Docker代理
func initDockerProxy() {
// 创建目标registry
// InitDockerProxy 初始化Docker代理
func InitDockerProxy() {
registry, err := name.NewRegistry("registry-1.docker.io")
if err != nil {
fmt.Printf("创建Docker registry失败: %v\n", err)
return
}
// 配置代理选项
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(GetGlobalHTTPClient().Transport),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
dockerProxy = &DockerProxy{
@@ -85,13 +83,11 @@ func initDockerProxy() {
func ProxyDockerRegistryGin(c *gin.Context) {
path := c.Request.URL.Path
// 处理 /v2/ API版本检查
if path == "/v2/" {
c.JSON(http.StatusOK, gin.H{})
return
}
// 处理不同的API端点
if strings.HasPrefix(path, "/v2/") {
handleRegistryRequest(c, path)
} else {
@@ -101,16 +97,13 @@ func ProxyDockerRegistryGin(c *gin.Context) {
// handleRegistryRequest 处理Registry请求
func handleRegistryRequest(c *gin.Context, path string) {
// 移除 /v2/ 前缀
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
if registryDetector.isRegistryEnabled(registryDomain) {
// 设置目标Registry信息到Context
c.Set("target_registry_domain", registryDomain)
c.Set("target_path", remainingPath)
// 处理多Registry请求
handleMultiRegistryRequest(c, registryDomain, remainingPath)
return
}
@@ -122,19 +115,16 @@ func handleRegistryRequest(c *gin.Context, path string) {
return
}
// 自动处理官方镜像的library命名空间
if !strings.Contains(imageName, "/") {
imageName = "library/" + imageName
}
// Docker镜像访问控制检查
if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed {
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
// 构建完整的镜像引用
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
switch apiType {
@@ -151,7 +141,6 @@ func handleRegistryRequest(c *gin.Context, path string) {
// parseRegistryPath 解析Registry路径
func parseRegistryPath(path string) (imageName, apiType, reference string) {
// 查找API端点关键字
if idx := strings.Index(path, "/manifests/"); idx != -1 {
imageName = path[:idx]
apiType = "manifests"
@@ -178,13 +167,11 @@ func parseRegistryPath(path string) (imageName, apiType, reference string) {
// handleManifestRequest 处理manifest请求
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
// Manifest缓存逻辑(仅对GET请求缓存)
if isCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := buildManifestCacheKey(imageRef, reference)
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
// 优先从缓存获取
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
writeCachedResponse(c, cachedItem)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
@@ -192,12 +179,9 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
var ref name.Reference
var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") {
// 是digest
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
// 是tag
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
@@ -207,9 +191,7 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return
}
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead {
// HEAD请求使用remote.Head
desc, err := remote.Head(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
@@ -217,13 +199,11 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
// GET请求使用remote.Get
desc, err := remote.Get(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
@@ -231,33 +211,28 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return
}
// 设置响应头
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
// 缓存响应
if isCacheEnabled() {
cacheKey := buildManifestCacheKey(imageRef, reference)
ttl := getManifestTTL(reference)
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
}
// 返回manifest内容
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
}
}
// handleBlobRequest 处理blob请求
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
// 构建digest引用
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
@@ -265,7 +240,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return
}
// 使用remote.Layer获取layer
layer, err := remote.Layer(digestRef, dockerProxy.options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
@@ -273,7 +247,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return
}
// 获取layer信息
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
@@ -281,7 +254,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return
}
// 获取layer内容
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
@@ -290,19 +262,16 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
}
defer reader.Close()
// 设置响应头
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
// 流式传输blob内容
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// handleTagsRequest 处理tags列表请求
func handleTagsRequest(c *gin.Context, imageRef string) {
// 解析repository
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
@@ -310,7 +279,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
return
}
// 使用remote.List获取tags
tags, err := remote.List(repo, dockerProxy.options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
@@ -318,7 +286,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
return
}
// 构建响应
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
"tags": tags,
@@ -327,10 +294,9 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
c.JSON(http.StatusOK, response)
}
// ProxyDockerAuthGin Docker认证代理(带缓存优化)
// ProxyDockerAuthGin Docker认证代理
func ProxyDockerAuthGin(c *gin.Context) {
// 检查是否启用token缓存
if isTokenCacheEnabled() {
if utils.IsTokenCacheEnabled() {
proxyDockerAuthWithCache(c)
} else {
proxyDockerAuthOriginal(c)
@@ -339,32 +305,26 @@ func ProxyDockerAuthGin(c *gin.Context) {
// proxyDockerAuthWithCache 带缓存的认证代理
func proxyDockerAuthWithCache(c *gin.Context) {
// 1. 构建缓存key基于完整的查询参数
cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery)
cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
// 2. 尝试从缓存获取token
if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" {
writeTokenResponse(c, cachedToken)
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
utils.WriteTokenResponse(c, cachedToken)
return
}
// 3. 缓存未命中,创建响应记录器
recorder := &ResponseRecorder{
ResponseWriter: c.Writer,
statusCode: 200,
}
c.Writer = recorder
// 4. 调用原有认证逻辑
proxyDockerAuthOriginal(c)
// 5. 如果认证成功,缓存响应
if recorder.statusCode == 200 && len(recorder.body) > 0 {
ttl := extractTTLFromResponse(recorder.body)
globalCache.SetToken(cacheKey, string(recorder.body), ttl)
ttl := utils.ExtractTTLFromResponse(recorder.body)
utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
}
// 6. 写入实际响应
c.Writer = recorder.ResponseWriter
c.Data(recorder.statusCode, "application/json", recorder.body)
}
@@ -389,14 +349,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
var authURL string
if targetDomain, exists := c.Get("target_registry_domain"); exists {
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
// 使用Registry特定的认证服务器
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
} else {
// fallback到默认Docker认证
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
} else {
// 构建默认Docker认证URL
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
@@ -404,13 +361,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
authURL += "?" + c.Request.URL.RawQuery
}
// 创建HTTP客户端复用全局传输配置包含代理设置
client := &http.Client{
Timeout: 30 * time.Second,
Transport: GetGlobalHTTPClient().Transport,
Transport: utils.GetGlobalHTTPClient().Transport,
}
// 创建请求
req, err := http.NewRequestWithContext(
context.Background(),
c.Request.Method,
@@ -422,14 +377,12 @@ func proxyDockerAuthOriginal(c *gin.Context) {
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
// 执行请求
resp, err := client.Do(req)
if err != nil {
c.String(http.StatusBadGateway, "Auth request failed")
@@ -437,37 +390,30 @@ func proxyDockerAuthOriginal(c *gin.Context) {
}
defer resp.Body.Close()
// 获取当前代理的Host地址
proxyHost := c.Request.Host
if proxyHost == "" {
// 使用配置中的服务器地址和端口
cfg := GetConfig()
cfg := config.GetConfig()
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" {
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
}
}
// 复制响应头并重写认证URL
for key, values := range resp.Header {
for _, value := range values {
// 重写WWW-Authenticate头中的realm URL
if key == "Www-Authenticate" {
// 支持多Registry的URL重写
value = rewriteAuthHeader(value, proxyHost)
}
c.Header(key, value)
}
}
// 返回响应
c.Status(resp.StatusCode)
io.Copy(c.Writer, resp.Body)
}
// rewriteAuthHeader 重写认证头
func rewriteAuthHeader(authHeader, proxyHost string) string {
// 重写各种Registry的认证URL
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
@@ -478,32 +424,27 @@ func rewriteAuthHeader(authHeader, proxyHost string) string {
// handleMultiRegistryRequest 处理多Registry请求
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
// 获取Registry映射配置
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
if !exists {
c.String(http.StatusBadRequest, "Registry not configured")
return
}
// 解析剩余路径
imageName, apiType, reference := parseRegistryPath(remainingPath)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
// 访问控制检查(使用完整的镜像路径)
fullImageName := registryDomain + "/" + imageName
if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
// 构建上游Registry引用
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
// 根据API类型处理请求
switch apiType {
case "manifests":
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
@@ -517,14 +458,12 @@ func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath st
}
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) {
// Manifest缓存逻辑(仅对GET请求缓存)
if isCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := buildManifestCacheKey(imageRef, reference)
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
// 优先从缓存获取
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
writeCachedResponse(c, cachedItem)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
@@ -532,7 +471,6 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
var ref name.Reference
var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
@@ -545,10 +483,8 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
return
}
// 创建针对上游Registry的选项
options := createUpstreamOptions(mapping)
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, options...)
if err != nil {
@@ -569,20 +505,17 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
return
}
// 设置响应头
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
// 缓存响应
if isCacheEnabled() {
cacheKey := buildManifestCacheKey(imageRef, reference)
ttl := getManifestTTL(reference)
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
@@ -593,7 +526,7 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
}
// handleUpstreamBlobRequest 处理上游Registry的blob请求
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) {
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
@@ -633,7 +566,7 @@ func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping
}
// handleUpstreamTagsRequest 处理上游Registry的tags请求
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) {
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
@@ -658,14 +591,13 @@ func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping Registry
}
// createUpstreamOptions 创建上游Registry选项
func createUpstreamOptions(mapping RegistryMapping) []remote.Option {
func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(GetGlobalHTTPClient().Transport),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
// 根据Registry类型添加特定的认证选项方便后续扩展
switch mapping.AuthType {
case "github":
case "google":

213
src/handlers/github.go Normal file
View File

@@ -0,0 +1,213 @@
package handlers
import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"hubproxy/config"
"hubproxy/utils"
)
var (
// GitHub URL匹配正则表达式
githubExps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`),
}
)
// GitHubProxyHandler GitHub代理处理器
func GitHubProxyHandler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
// 自动补全协议头
if !strings.HasPrefix(rawPath, "https://") {
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
rawPath = strings.Replace(rawPath, "http:/", "", 1)
rawPath = strings.Replace(rawPath, "https:/", "", 1)
} else if strings.HasPrefix(rawPath, "http://") {
rawPath = strings.TrimPrefix(rawPath, "http://")
}
rawPath = "https://" + rawPath
}
matches := CheckGitHubURL(rawPath)
if matches != nil {
if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed {
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
// 将blob链接转换为raw链接
if githubExps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
ProxyGitHubRequest(c, rawPath)
}
// CheckGitHubURL 检查URL是否匹配GitHub模式
func CheckGitHubURL(u string) []string {
for _, exp := range githubExps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// ProxyGitHubRequest 代理GitHub请求
func ProxyGitHubRequest(c *gin.Context, u string) {
proxyGitHubWithRedirect(c, u, 0)
}
// proxyGitHubWithRedirect 带重定向的GitHub代理请求
func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) {
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
coderabbitai[bot] commented 2025-07-27 12:00:55 +08:00 (Migrated from github.com)
Review

🛠️ Refactor suggestion

Consider filtering sensitive request headers.

While copying headers, consider filtering out potentially sensitive or problematic headers:

  • Authorization (unless intentionally proxying auth)
  • Cookie and Set-Cookie
  • X-Forwarded-* headers that might conflict
 // 复制请求头
 for key, values := range c.Request.Header {
+    // Skip potentially sensitive headers
+    if key == "Authorization" || key == "Cookie" || strings.HasPrefix(key, "X-Forwarded-") {
+        continue
+    }
     for _, value := range values {
         req.Header.Add(key, value)
     }
 }
 req.Header.Del("Host")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

	// 复制请求头
	for key, values := range c.Request.Header {
		// Skip potentially sensitive headers
		if key == "Authorization" || key == "Cookie" || strings.HasPrefix(key, "X-Forwarded-") {
			continue
		}
		for _, value := range values {
			req.Header.Add(key, value)
		}
	}
	req.Header.Del("Host")
🤖 Prompt for AI Agents
In src/handlers/github.go around lines 106 to 112, the code copies all request
headers without filtering, which may expose sensitive headers like
Authorization, Cookie, Set-Cookie, and X-Forwarded-* headers. Modify the header
copying loop to skip these sensitive headers by checking the header key and
excluding them from being added to req.Header.
_🛠️ Refactor suggestion_ **Consider filtering sensitive request headers.** While copying headers, consider filtering out potentially sensitive or problematic headers: - `Authorization` (unless intentionally proxying auth) - `Cookie` and `Set-Cookie` - `X-Forwarded-*` headers that might conflict ```diff // 复制请求头 for key, values := range c.Request.Header { + // Skip potentially sensitive headers + if key == "Authorization" || key == "Cookie" || strings.HasPrefix(key, "X-Forwarded-") { + continue + } for _, value := range values { req.Header.Add(key, value) } } req.Header.Del("Host") ``` <!-- suggestion_start --> <details> <summary>📝 Committable suggestion</summary> > ‼️ **IMPORTANT** > Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements. ```suggestion // 复制请求头 for key, values := range c.Request.Header { // Skip potentially sensitive headers if key == "Authorization" || key == "Cookie" || strings.HasPrefix(key, "X-Forwarded-") { continue } for _, value := range values { req.Header.Add(key, value) } } req.Header.Del("Host") ``` </details> <!-- suggestion_end --> <details> <summary>🤖 Prompt for AI Agents</summary> ``` In src/handlers/github.go around lines 106 to 112, the code copies all request headers without filtering, which may expose sensitive headers like Authorization, Cookie, Set-Cookie, and X-Forwarded-* headers. Modify the header copying loop to skip these sensitive headers by checking the header key and excluding them from being added to req.Header. ``` </details> <!-- fingerprinting:phantom:medusa:lion --> <!-- This is an auto-generated comment by CodeRabbit -->
resp, err := utils.GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := config.GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
// 处理.sh文件的智能处理
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
coderabbitai[bot] commented 2025-07-27 12:00:55 +08:00 (Migrated from github.com)
Review

💡 Verification agent

🧩 Analysis chain

Clarify the purpose of .sh file processing.

The code performs special processing on .sh files using utils.ProcessSmart. Please document:

  1. What modifications are being made to shell scripts?
  2. Are there security implications of modifying executable content?
  3. What's the use case for this feature?

🏁 Script executed:

#!/bin/bash
# Search for the ProcessSmart implementation to understand what it does
ast-grep --pattern 'func ProcessSmart($$$) {
  $$$
}'

Length of output: 55


🏁 Script executed:

#!/bin/bash
# Search for ProcessSmart references and implementation across the repository
rg "ProcessSmart" -n .

Length of output: 337


🏁 Script executed:

#!/bin/bash
# Display the full implementation of ProcessSmart in utils/proxy_shell.go
sed -n '1,200p' src/utils/proxy_shell.go

Length of output: 2472


Please document and safeguard .sh file processing logic

We’ve confirmed that ProcessSmart in utils/proxy_shell.go:

  • Handles optional gzip by peeking and unwrapping if needed
  • Reads the entire script into memory (unless >10 MB or no GitHub URLs)
  • Uses githubRegex to locate all GitHub‐hosted URLs and rewrites each via transformURL(host)
  • Returns a new io.Reader of the (potentially) modified script

To help future maintainers and ensure safety, please:

• In src/utils/proxy_shell.go (ProcessSmart, lines ~16–54):
– Add a top-level doc comment describing each step (decompression, size check, URL rewrite)
– Call out any edge cases (e.g. very large scripts, scripts without GitHub links)
– Note security considerations around modifying executable content (shebang integrity, unintended regex matches)

• In src/handlers/github.go (lines 150–190):
– Explain the .sh branch’s purpose: why and when you proxy/modify scripts
– Document the impact on downstream clients (chunked responses, URL schemes)

• Add or update unit/integration tests covering:
– A small script with GitHub URLs (verify correct rewrite and valid URL syntax)
– A gzipped script
– A large script (>10 MB) and one without GitHub URLs (ensure no modifications)

This will clarify intent, guard against inadvertent breakage, and surface any security risks.

🤖 Prompt for AI Agents
In src/handlers/github.go around lines 150 to 190, add clear comments explaining
the purpose of the .sh file processing branch, detailing why scripts are proxied
and modified, and how the response headers are adjusted (e.g., chunked transfer
encoding) to inform downstream clients. In src/utils/proxy_shell.go within the
ProcessSmart function (lines ~16–54), add a comprehensive top-level doc comment
describing each processing step: decompression if gzip encoded, size checks to
avoid large memory usage, GitHub URL detection and rewriting, and returning the
modified script reader. Also document edge cases like very large scripts or
scripts without GitHub URLs, and note security considerations such as preserving
shebang integrity and avoiding unintended regex matches. Finally, add or update
unit and integration tests to cover scenarios including small scripts with
GitHub URLs (verifying correct URL rewriting), gzipped scripts, large scripts
over 10 MB, and scripts without GitHub URLs to ensure no modifications occur.
_💡 Verification agent_ <details> <summary>🧩 Analysis chain</summary> **Clarify the purpose of .sh file processing.** The code performs special processing on `.sh` files using `utils.ProcessSmart`. Please document: 1. What modifications are being made to shell scripts? 2. Are there security implications of modifying executable content? 3. What's the use case for this feature? --- 🏁 Script executed: ```shell #!/bin/bash # Search for the ProcessSmart implementation to understand what it does ast-grep --pattern 'func ProcessSmart($$$) { $$$ }' ``` Length of output: 55 --- 🏁 Script executed: ```shell #!/bin/bash # Search for ProcessSmart references and implementation across the repository rg "ProcessSmart" -n . ``` Length of output: 337 --- 🏁 Script executed: ```shell #!/bin/bash # Display the full implementation of ProcessSmart in utils/proxy_shell.go sed -n '1,200p' src/utils/proxy_shell.go ``` Length of output: 2472 --- </details> **Please document and safeguard `.sh` file processing logic** We’ve confirmed that `ProcessSmart` in `utils/proxy_shell.go`: - Handles optional gzip by peeking and unwrapping if needed - Reads the entire script into memory (unless >10 MB or no GitHub URLs) - Uses `githubRegex` to locate all GitHub‐hosted URLs and rewrites each via `transformURL(host)` - Returns a new `io.Reader` of the (potentially) modified script To help future maintainers and ensure safety, please: • In `src/utils/proxy_shell.go` (ProcessSmart, lines ~16–54): – Add a top-level doc comment describing each step (decompression, size check, URL rewrite) – Call out any edge cases (e.g. very large scripts, scripts without GitHub links) – Note security considerations around modifying executable content (shebang integrity, unintended regex matches) • In `src/handlers/github.go` (lines 150–190): – Explain the `.sh` branch’s purpose: why and when you proxy/modify scripts – Document the impact on downstream clients (chunked responses, URL schemes) • Add or update unit/integration tests covering: – A small script with GitHub URLs (verify correct rewrite and valid URL syntax) – A gzipped script – A large script (>10 MB) and one without GitHub URLs (ensure no modifications) This will clarify intent, guard against inadvertent breakage, and surface any security risks. <details> <summary>🤖 Prompt for AI Agents</summary> ``` In src/handlers/github.go around lines 150 to 190, add clear comments explaining the purpose of the .sh file processing branch, detailing why scripts are proxied and modified, and how the response headers are adjusted (e.g., chunked transfer encoding) to inform downstream clients. In src/utils/proxy_shell.go within the ProcessSmart function (lines ~16–54), add a comprehensive top-level doc comment describing each processing step: decompression if gzip encoded, size checks to avoid large memory usage, GitHub URL detection and rewriting, and returning the modified script reader. Also document edge cases like very large scripts or scripts without GitHub URLs, and note security considerations such as preserving shebang integrity and avoiding unintended regex matches. Finally, add or update unit and integration tests to cover scenarios including small scripts with GitHub URLs (verifying correct URL rewriting), gzipped scripts, large scripts over 10 MB, and scripts without GitHub URLs to ensure no modifications occur. ``` </details> <!-- fingerprinting:phantom:medusa:lion --> <!-- This is an auto-generated comment by CodeRabbit -->
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
io.Copy(c.Writer, resp.Body)
}
}

View File

@@ -1,4 +1,4 @@
package main
package handlers
import (
"archive/tar"
@@ -23,6 +23,8 @@ import (
"github.com/google/go-containerregistry/pkg/v1/partial"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/types"
"hubproxy/config"
"hubproxy/utils"
)
// DebounceEntry 防抖条目
@@ -58,17 +60,15 @@ func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
if entry, exists := d.entries[key]; exists {
if now.Sub(entry.LastRequest) < d.window {
return false // 在防抖窗口内,拒绝请求
return false
}
}
// 更新或创建条目
d.entries[key] = &DebounceEntry{
LastRequest: now,
UserID: userID,
}
// 清理过期条目每5分钟清理一次
if time.Since(d.lastCleanup) > 5*time.Minute {
d.cleanup(now)
d.lastCleanup = now
@@ -88,50 +88,41 @@ func (d *DownloadDebouncer) cleanup(now time.Time) {
// generateContentFingerprint 生成内容指纹
func generateContentFingerprint(images []string, platform string) string {
// 对镜像列表排序确保顺序无关
sortedImages := make([]string, len(images))
copy(sortedImages, images)
sort.Strings(sortedImages)
// 组合内容:镜像列表 + 平台信息
content := strings.Join(sortedImages, "|") + ":" + platform
// 生成MD5哈希
hash := md5.Sum([]byte(content))
return hex.EncodeToString(hash[:])
}
// getUserID 获取用户标识
func getUserID(c *gin.Context) string {
// 优先使用会话Cookie
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
return "session:" + sessionID
}
// 备用方案IP + User-Agent组合
ip := c.ClientIP()
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
userAgent = "unknown"
}
// 生成简短标识
combined := ip + ":" + userAgent
hash := md5.Sum([]byte(combined))
return "ip:" + hex.EncodeToString(hash[:8]) // 只取前8字节
return "ip:" + hex.EncodeToString(hash[:8])
}
// 全局防抖器实例
var (
singleImageDebouncer *DownloadDebouncer
batchImageDebouncer *DownloadDebouncer
)
// initDebouncer 初始化防抖器
func initDebouncer() {
// 单个镜像5秒防抖窗口
// InitDebouncer 初始化防抖器
func InitDebouncer() {
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
// 批量镜像60秒防抖窗口
batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
}
@@ -147,15 +138,15 @@ type ImageStreamerConfig struct {
}
// NewImageStreamer 创建镜像下载器
func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
if config == nil {
config = &ImageStreamerConfig{}
func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer {
if cfg == nil {
cfg = &ImageStreamerConfig{}
}
concurrency := config.Concurrency
concurrency := cfg.Concurrency
if concurrency <= 0 {
cfg := GetConfig()
concurrency = cfg.Download.MaxImages
appCfg := config.GetConfig()
concurrency = appCfg.Download.MaxImages
if concurrency <= 0 {
concurrency = 10
}
@@ -163,7 +154,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
remoteOptions := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithTransport(GetGlobalHTTPClient().Transport),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
return &ImageStreamer{
@@ -176,7 +167,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
type StreamOptions struct {
Platform string
Compression bool
UseCompressedLayers bool // 是否保存原始压缩层,默认开启
UseCompressedLayers bool
}
// StreamImageToWriter 流式下载镜像到Writer
@@ -215,7 +206,6 @@ func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
// 直接从网络获取完整的descriptor确保对象完整性
return remote.Get(ref, options...)
}
@@ -343,7 +333,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
var layerSize int64
var layerReader io.ReadCloser
// 根据配置选择使用压缩层或未压缩层
if options != nil && options.UseCompressedLayers {
layerSize, err = layer.Size()
if err != nil {
@@ -385,7 +374,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
log.Printf("已处理层 %d/%d", i+1, len(layers))
}
// 构建单个镜像的manifest信息
singleManifest := map[string]interface{}{
"Config": configDigest.String() + ".json",
"RepoTags": []string{imageRef},
@@ -398,7 +386,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
}(),
}
// 构建repositories信息
repositories := make(map[string]map[string]string)
parts := strings.Split(imageRef, ":")
if len(parts) == 2 {
@@ -407,14 +394,12 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
repositories[repoName] = map[string]string{tag: configDigest.String()}
}
// 如果是批量下载,返回信息而不写入文件
if manifestOut != nil && repositoriesOut != nil {
*manifestOut = singleManifest
*repositoriesOut = repositories
return nil
}
// 单镜像下载直接写入manifest.json
manifest := []map[string]interface{}{singleManifest}
manifestData, err := json.Marshal(manifest)
@@ -436,7 +421,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err
}
// 写入repositories文件
repositoriesData, err := json.Marshal(repositories)
if err != nil {
return err
@@ -456,7 +440,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err
}
// processImageForBatch 处理镜像的公共逻辑,用于批量下载
// processImageForBatch 处理镜像的公共逻辑
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
layers, err := img.Layers()
if err != nil {
@@ -498,7 +482,6 @@ func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWrite
switch desc.MediaType {
case types.OCIImageIndex, types.DockerManifestList:
// 处理多架构镜像
img, err = is.selectPlatformImage(desc, options)
if err != nil {
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
@@ -530,7 +513,6 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
return nil, fmt.Errorf("获取索引清单失败: %w", err)
}
// 选择合适的平台
var selectedDesc *v1.Descriptor
for _, m := range manifest.Manifests {
if m.Platform == nil {
@@ -578,8 +560,8 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
var globalImageStreamer *ImageStreamer
// initImageStreamer 初始化镜像下载器
func initImageStreamer() {
// InitImageStreamer 初始化镜像下载器
func InitImageStreamer() {
globalImageStreamer = NewImageStreamer(nil)
}
@@ -591,8 +573,8 @@ func formatPlatformText(platform string) string {
return platform
}
// initImageTarRoutes 初始化镜像下载路由
func initImageTarRoutes(router *gin.Engine) {
// InitImageTarRoutes 初始化镜像下载路由
func InitImageTarRoutes(router *gin.Engine) {
imageAPI := router.Group("/api/image")
{
imageAPI.GET("/download/:image", handleDirectImageDownload)
@@ -625,7 +607,6 @@ func handleDirectImageDownload(c *gin.Context) {
return
}
// 防抖检查
userID := getUserID(c)
contentKey := generateContentFingerprint([]string{imageRef}, platform)
@@ -677,7 +658,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
}
}
cfg := GetConfig()
cfg := config.GetConfig()
if len(req.Images) > cfg.Download.MaxImages {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
@@ -685,7 +666,6 @@ func handleSimpleBatchDownload(c *gin.Context) {
return
}
// 批量下载防抖检查
userID := getUserID(c)
contentKey := generateContentFingerprint(req.Images, req.Platform)
@@ -697,7 +677,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
return
}
useCompressed := true // 默认启用原始压缩层
useCompressed := true
if req.UseCompressedLayers != nil {
useCompressed = *req.UseCompressedLayers
}
@@ -801,7 +781,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
var allManifests []map[string]interface{}
var allRepositories = make(map[string]map[string]string)
// 流式处理每个镜像
for i, imageRef := range imageRefs {
select {
case <-ctx.Done():
@@ -811,7 +790,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
// 防止单个镜像处理时间过长
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
cancel()
@@ -825,10 +803,8 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
}
// 收集manifest信息
allManifests = append(allManifests, manifest)
// 合并repositories信息
for repo, tags := range repositories {
if allRepositories[repo] == nil {
allRepositories[repo] = make(map[string]string)
@@ -839,7 +815,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
}
}
// 写入合并的manifest.json
manifestData, err := json.Marshal(allManifests)
if err != nil {
return fmt.Errorf("序列化manifest失败: %w", err)
@@ -859,7 +834,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
return fmt.Errorf("写入manifest数据失败: %w", err)
}
// 写入合并的repositories文件
repositoriesData, err := json.Marshal(allRepositories)
if err != nil {
return fmt.Errorf("序列化repositories失败: %w", err)

File diff suppressed because it is too large Load Diff

View File

@@ -3,15 +3,17 @@ package main
import (
"embed"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"hubproxy/config"
"hubproxy/handlers"
"hubproxy/utils"
)
//go:embed public/*
@@ -32,19 +34,7 @@ func serveEmbedFile(c *gin.Context, filename string) {
}
var (
exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
}
globalLimiter *IPRateLimiter
globalLimiter *utils.IPRateLimiter
// 服务启动时间
serviceStartTime = time.Now()
@@ -52,25 +42,25 @@ var (
func main() {
// 加载配置
if err := LoadConfig(); err != nil {
if err := config.LoadConfig(); err != nil {
fmt.Printf("配置加载失败: %v\n", err)
return
}
// 初始化HTTP客户端
initHTTPClients()
utils.InitHTTPClients()
// 初始化限流器
initLimiter()
globalLimiter = utils.InitGlobalLimiter()
// 初始化Docker流式代理
initDockerProxy()
handlers.InitDockerProxy()
// 初始化镜像流式下载器
initImageStreamer()
handlers.InitImageStreamer()
// 初始化防抖器
initDebouncer()
handlers.InitDebouncer()
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
@@ -84,14 +74,14 @@ func main() {
})
}))
// 全局限流中间件 - 应用到所有路由
router.Use(RateLimitMiddleware(globalLimiter))
// 全局限流中间件
router.Use(utils.RateLimitMiddleware(globalLimiter))
// 初始化监控端点
initHealthRoutes(router)
// 初始化镜像tar下载路由
initImageTarRoutes(router)
handlers.InitImageTarRoutes(router)
// 静态文件路由
router.GET("/", func(c *gin.Context) {
@@ -113,217 +103,59 @@ func main() {
})
// 注册dockerhub搜索路由
RegisterSearchRoute(router)
handlers.RegisterSearchRoute(router)
// 注册Docker认证路由/token*
router.Any("/token", ProxyDockerAuthGin)
router.Any("/token/*path", ProxyDockerAuthGin)
// 注册Docker认证路由
router.Any("/token", handlers.ProxyDockerAuthGin)
router.Any("/token/*path", handlers.ProxyDockerAuthGin)
// 注册Docker Registry代理路由
router.Any("/v2/*path", ProxyDockerRegistryGin)
router.Any("/v2/*path", handlers.ProxyDockerRegistryGin)
// 注册NoRoute处理器
router.NoRoute(handler)
// 注册GitHub代理路由NoRoute处理器
router.NoRoute(handlers.GitHubProxyHandler)
cfg := GetConfig()
cfg := config.GetConfig()
fmt.Printf("🚀 HubProxy 启动成功\n")
fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
// 显示HTTP/2支持状态
if cfg.Server.EnableH2C {
fmt.Printf("H2c: 已启用\n")
}
fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
// 创建HTTP2服务器
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
ReadTimeout: 60 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 120 * time.Second,
}
// 根据配置决定是否启用H2C
if cfg.Server.EnableH2C {
h2cHandler := h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: 250,
IdleTimeout: 300 * time.Second,
MaxReadFrameSize: 4 << 20,
MaxUploadBufferPerConnection: 8 << 20,
MaxUploadBufferPerStream: 2 << 20,
})
server.Handler = h2cHandler
} else {
server.Handler = router
}
err := server.ListenAndServe()
if err != nil {
fmt.Printf("启动服务失败: %v\n", err)
}
}
func handler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
// 自动补全协议头
if !strings.HasPrefix(rawPath, "https://") {
// 修复 http:/ 和 https:/ 的情况
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
rawPath = strings.Replace(rawPath, "http:/", "", 1)
rawPath = strings.Replace(rawPath, "https:/", "", 1)
} else if strings.HasPrefix(rawPath, "http://") {
rawPath = strings.TrimPrefix(rawPath, "http://")
}
rawPath = "https://" + rawPath
}
matches := checkURL(rawPath)
if matches != nil {
// GitHub仓库访问控制检查
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
// 构建仓库名用于日志
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
if exps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
proxyRequest(c, rawPath)
}
func proxyRequest(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0)
}
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
// 限制最大重定向次数,防止无限递归
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
// 如果域名中没有协议前缀添加https://
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
io.Copy(c.Writer, resp.Body)
}
}
func checkURL(u string) []string {
for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// 简单的健康检查
func formatBeijingTime(t time.Time) string {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil {
loc = time.FixedZone("CST", 8*3600) // 兜底时区
}
return t.In(loc).Format("2006-01-02 15:04:05")
}
// 转换为可读时间
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%d秒", int(d.Seconds()))
@@ -338,26 +170,20 @@ func formatDuration(d time.Duration) string {
}
}
func initHealthRoutes(router *gin.Engine) {
router.GET("/health", func(c *gin.Context) {
uptime := time.Since(serviceStartTime)
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp_unix": serviceStartTime.Unix(),
"uptime_sec": uptime.Seconds(),
"service": "hubproxy",
"start_time_bj": formatBeijingTime(serviceStartTime),
"uptime_human": formatDuration(uptime),
})
})
func getUptimeInfo() (time.Duration, float64, string) {
uptime := time.Since(serviceStartTime)
return uptime, uptime.Seconds(), formatDuration(uptime)
}
func initHealthRoutes(router *gin.Engine) {
router.GET("/ready", func(c *gin.Context) {
uptime := time.Since(serviceStartTime)
_, uptimeSec, uptimeHuman := getUptimeInfo()
c.JSON(http.StatusOK, gin.H{
"ready": true,
"timestamp_unix": time.Now().Unix(),
"uptime_sec": uptime.Seconds(),
"uptime_human": formatDuration(uptime),
"ready": true,
"service": "hubproxy",
"start_time_unix": serviceStartTime.Unix(),
"uptime_sec": uptimeSec,
"uptime_human": uptimeHuman,
})
})
}

View File

@@ -1,8 +1,10 @@
package main
package utils
import (
"strings"
"sync"
"hubproxy/config"
)
// ResourceType 资源类型
@@ -26,7 +28,7 @@ type DockerImageInfo struct {
FullName string
}
// 全局访问控制器实例
// GlobalAccessController 全局访问控制器实例
var GlobalAccessController = &AccessController{}
// ParseDockerImage 解析Docker镜像名称
@@ -79,19 +81,16 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
// CheckDockerAccess 检查Docker镜像访问权限
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
cfg := GetConfig()
cfg := config.GetConfig()
// 解析镜像名称
imageInfo := ac.ParseDockerImage(image)
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像)
if len(cfg.Access.WhiteList) > 0 {
if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
return false, "不在Docker镜像白名单内"
}
}
// 检查黑名单
if len(cfg.Access.BlackList) > 0 {
if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
return false, "Docker镜像在黑名单内"
@@ -107,14 +106,12 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r
return false, "无效的GitHub仓库格式"
}
cfg := GetConfig()
cfg := config.GetConfig()
// 检查白名单
if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
return false, "不在GitHub仓库白名单内"
}
// 检查黑名单
if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
return false, "GitHub仓库在黑名单内"
}
@@ -185,17 +182,14 @@ func (ac *AccessController) checkList(matches, list []string) bool {
continue
}
// 支持多种匹配模式
if fullRepo == item {
return true
}
// 用户级匹配
if item == username || item == username+"/*" {
return true
}
// 前缀匹配(支持通配符)
if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullRepo, prefix) {
@@ -203,7 +197,6 @@ func (ac *AccessController) checkList(matches, list []string) bool {
}
}
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork
if strings.HasPrefix(fullRepo, item+"/") {
return true
}

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"crypto/md5"
@@ -9,22 +9,23 @@ import (
"time"
"github.com/gin-gonic/gin"
"hubproxy/config"
)
// CachedItem 通用缓存项支持Token和Manifest
// CachedItem 通用缓存项
type CachedItem struct {
Data []byte // 缓存数据(token字符串或manifest字节)
ContentType string // 内容类型
Headers map[string]string // 额外的响应头
ExpiresAt time.Time // 过期时间
Data []byte
ContentType string
Headers map[string]string
ExpiresAt time.Time
}
// UniversalCache 通用缓存支持Token和Manifest
// UniversalCache 通用缓存
type UniversalCache struct {
cache sync.Map
}
var globalCache = &UniversalCache{}
var GlobalCache = &UniversalCache{}
// Get 获取缓存项
func (c *UniversalCache) Get(key string) *CachedItem {
@@ -57,22 +58,22 @@ func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) {
c.Set(key, []byte(token), "application/json", nil, ttl)
}
// buildCacheKey 构建稳定的缓存key
func buildCacheKey(prefix, query string) string {
// BuildCacheKey 构建稳定的缓存key
func BuildCacheKey(prefix, query string) string {
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
}
func buildTokenCacheKey(query string) string {
return buildCacheKey("token", query)
func BuildTokenCacheKey(query string) string {
return BuildCacheKey("token", query)
}
func buildManifestCacheKey(imageRef, reference string) string {
func BuildManifestCacheKey(imageRef, reference string) string {
key := fmt.Sprintf("%s:%s", imageRef, reference)
return buildCacheKey("manifest", key)
return BuildCacheKey("manifest", key)
}
func getManifestTTL(reference string) time.Duration {
cfg := GetConfig()
func GetManifestTTL(reference string) time.Duration {
cfg := config.GetConfig()
defaultTTL := 30 * time.Minute
if cfg.TokenCache.DefaultTTL != "" {
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
@@ -84,23 +85,20 @@ func getManifestTTL(reference string) time.Duration {
return 24 * time.Hour
}
// mutable tag的智能判断
if reference == "latest" || reference == "main" || reference == "master" ||
reference == "dev" || reference == "develop" {
// 热门可变标签: 短期缓存
return 10 * time.Minute
}
return defaultTTL
}
// extractTTLFromResponse 从响应中智能提取TTL
func extractTTLFromResponse(responseBody []byte) time.Duration {
// ExtractTTLFromResponse 从响应中智能提取TTL
func ExtractTTLFromResponse(responseBody []byte) time.Duration {
var tokenResp struct {
ExpiresIn int `json:"expires_in"`
}
// 默认30分钟TTL确保稳定性
defaultTTL := 30 * time.Minute
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
@@ -113,37 +111,35 @@ func extractTTLFromResponse(responseBody []byte) time.Duration {
return defaultTTL
}
func writeTokenResponse(c *gin.Context, cachedBody string) {
func WriteTokenResponse(c *gin.Context, cachedBody string) {
c.Header("Content-Type", "application/json")
c.String(200, cachedBody)
}
func writeCachedResponse(c *gin.Context, item *CachedItem) {
func WriteCachedResponse(c *gin.Context, item *CachedItem) {
if item.ContentType != "" {
c.Header("Content-Type", item.ContentType)
}
// 设置额外的响应头
for key, value := range item.Headers {
c.Header(key, value)
}
// 返回数据
c.Data(200, item.ContentType, item.Data)
}
// isCacheEnabled 检查缓存是否启用
func isCacheEnabled() bool {
cfg := GetConfig()
// IsCacheEnabled 检查缓存是否启用
func IsCacheEnabled() bool {
cfg := config.GetConfig()
return cfg.TokenCache.Enabled
}
// isTokenCacheEnabled 检查token缓存是否启用(向后兼容)
func isTokenCacheEnabled() bool {
return isCacheEnabled()
// IsTokenCacheEnabled 检查token缓存是否启用
func IsTokenCacheEnabled() bool {
return IsCacheEnabled()
}
// 定期清理过期缓存,防止内存泄漏
// 定期清理过期缓存
func init() {
go func() {
ticker := time.NewTicker(20 * time.Minute)
@@ -153,7 +149,7 @@ func init() {
now := time.Now()
expiredKeys := make([]string, 0)
globalCache.cache.Range(func(key, value interface{}) bool {
GlobalCache.cache.Range(func(key, value interface{}) bool {
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
expiredKeys = append(expiredKeys, key.(string))
}
@@ -161,7 +157,7 @@ func init() {
})
for _, key := range expiredKeys {
globalCache.cache.Delete(key)
GlobalCache.cache.Delete(key)
}
}
}()

View File

@@ -1,28 +1,28 @@
package main
package utils
import (
"net"
"net/http"
"os"
"time"
"hubproxy/config"
)
var (
// 全局HTTP客户端 - 用于代理请求(长超时)
globalHTTPClient *http.Client
// 搜索HTTP客户端 - 用于API请求短超时
searchHTTPClient *http.Client
)
// initHTTPClients 初始化HTTP客户端
func initHTTPClients() {
cfg := GetConfig()
// InitHTTPClients 初始化HTTP客户端
func InitHTTPClients() {
cfg := config.GetConfig()
if p := cfg.Access.Proxy; p != "" {
os.Setenv("HTTP_PROXY", p)
os.Setenv("HTTPS_PROXY", p)
}
// 代理客户端配置 - 适用于大文件传输
globalHTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -39,7 +39,6 @@ func initHTTPClients() {
},
}
// 搜索客户端配置 - 适用于API调用
searchHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
@@ -57,12 +56,12 @@ func initHTTPClients() {
}
}
// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理)
// GetGlobalHTTPClient 获取全局HTTP客户端
func GetGlobalHTTPClient() *http.Client {
return globalHTTPClient
}
// GetSearchHTTPClient 获取搜索HTTP客户端用于API调用
// GetSearchHTTPClient 获取搜索HTTP客户端
func GetSearchHTTPClient() *http.Client {
return searchHTTPClient
}

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"bytes"
@@ -41,7 +41,6 @@ func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reade
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
var reader io.Reader = input
// 处理gzip压缩
if isCompressed {
peek := make([]byte, 2)
n, err := input.Read(peek)

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"fmt"
@@ -9,22 +9,23 @@ import (
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"hubproxy/config"
)
const (
// 清理间隔
CleanupInterval = 10 * time.Minute
MaxIPCacheSize = 10000
)
// IPRateLimiter IP限流器结构体
type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射
mu *sync.RWMutex // 读写锁,保证并发安全
r rate.Limit // 速率限制(每秒允许的请求数)
b int // 令牌桶容量(突发请求数)
whitelist []*net.IPNet // 白名单IP段
blacklist []*net.IPNet // 黑名单IP段
ips map[string]*rateLimiterEntry
mu *sync.RWMutex
r rate.Limit
b int
whitelist []*net.IPNet
blacklist []*net.IPNet
whitelistLimiter *rate.Limiter // 全局共享的白名单限流器
}
// rateLimiterEntry 限流器条目
@@ -33,15 +34,15 @@ type rateLimiterEntry struct {
lastAccess time.Time
}
// initGlobalLimiter 初始化全局限流器
func initGlobalLimiter() *IPRateLimiter {
cfg := GetConfig()
// InitGlobalLimiter 初始化全局限流器
func InitGlobalLimiter() *IPRateLimiter {
cfg := config.GetConfig()
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
item = item + "/32"
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
@@ -52,12 +53,11 @@ func initGlobalLimiter() *IPRateLimiter {
}
}
// 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
item = item + "/32"
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
@@ -68,7 +68,6 @@ func initGlobalLimiter() *IPRateLimiter {
}
}
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
burstSize := cfg.RateLimit.RequestLimit
@@ -77,25 +76,20 @@ func initGlobalLimiter() *IPRateLimiter {
}
limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
whitelistLimiter: rate.NewLimiter(rate.Inf, burstSize),
}
// 启动定期清理goroutine
go limiter.cleanupRoutine()
return limiter
}
// initLimiter 初始化限流器
func initLimiter() {
globalLimiter = initGlobalLimiter()
}
// cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval)
@@ -105,25 +99,20 @@ func (i *IPRateLimiter) cleanupRoutine() {
now := time.Now()
expired := make([]string, 0)
// 查找过期的条目
i.mu.RLock()
for ip, entry := range i.ips {
// 如果最后访问时间超过1小时认为过期
if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip)
}
}
i.mu.RUnlock()
// 如果有过期条目或者缓存过大,进行清理
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock()
// 删除过期条目
for _, ip := range expired {
delete(i.ips, ip)
}
// 如果缓存仍然过大,全部清理
if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry)
}
@@ -140,28 +129,26 @@ func extractIPFromAddress(address string) string {
return address
}
// normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段
// normalizeIPForRateLimit 标准化IP地址用于限流
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr // 解析失败,返回原值
return ipStr
}
if ip.To4() != nil {
return ipStr // IPv4保持不变
return ipStr
}
// IPv6标准化为 /64 网段
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0 // 清零后64位
ipv6[i] = 0
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址
cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil {
@@ -176,22 +163,18 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
return false
}
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问
// GetLimiter 获取指定IP的限流器
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 提取纯IP地址
cleanIP := extractIPFromAddress(ip)
// 检查是否在黑名单中
if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false
}
// 检查是否在白名单中
if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true
return i.whitelistLimiter, true
}
// 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now()
@@ -230,7 +213,6 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 静态文件豁免:跳过限流检查
path := c.Request.URL.Path
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
strings.HasPrefix(path, "/public/") {
@@ -238,30 +220,22 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return
}
// 获取客户端真实IP
var ip string
// 优先尝试从请求头获取真实IP
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For可能包含多个IP取第一个
ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
// 如果有X-Real-IP头
ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
// 某些代理可能使用此头
ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0])
} else {
// 回退到ClientIP方法
ip = c.ClientIP()
}
// 提取纯IP地址去除可能存在的端口
cleanIP := extractIPFromAddress(ip)
// 日志记录请求IP和头信息
normalizedIP := normalizeIPForRateLimit(cleanIP)
if cleanIP != normalizedIP {
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
@@ -275,10 +249,8 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
c.GetHeader("X-Real-IP"))
}
// 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
// 如果IP在黑名单中
if !allowed {
c.JSON(403, gin.H{
"error": "您已被限制访问",
@@ -287,7 +259,6 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return
}
// 检查限流
if !ipLimiter.Allow() {
c.JSON(429, gin.H{
"error": "请求频率过快,暂时限制访问",