优化代码结构,支持h2 #39
14
README.md
@@ -37,7 +37,7 @@ docker run -d \
|
||||
### 一键脚本安装
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install-service.sh | sudo bash
|
||||
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install.sh | sudo bash
|
||||
```
|
||||
|
||||
也可以直接下载二进制文件执行`./hubproxy`使用,无需配置文件即可启动,内置默认配置,支持所有功能。初始内存占用约18M,二进制文件大小约12M
|
||||
@@ -109,12 +109,14 @@ host = "0.0.0.0"
|
||||
port = 5000
|
||||
# Github文件大小限制(字节),默认2GB
|
||||
fileSize = 2147483648
|
||||
# HTTP/2 多路复用
|
||||
enableH2C = false
|
||||
|
||||
[rateLimit]
|
||||
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层,会消耗多个次数)
|
||||
requestLimit = 500
|
||||
# 限流周期(小时)
|
||||
periodHours = 1.0
|
||||
periodHours = 3.0
|
||||
|
||||
[security]
|
||||
# IP白名单,支持单个IP或IP段
|
||||
@@ -132,7 +134,7 @@ blackList = [
|
||||
"192.168.100.0/24"
|
||||
]
|
||||
|
||||
[proxy]
|
||||
[access]
|
||||
# 代理服务白名单(支持GitHub仓库和Docker镜像,支持通配符)
|
||||
# 只允许访问白名单中的仓库/镜像,为空时不限制
|
||||
whiteList = []
|
||||
@@ -148,12 +150,6 @@ blackList = [
|
||||
# 代理配置,支持有用户名/密码认证和无认证模式
|
||||
# 无认证: socks5://127.0.0.1:1080
|
||||
# 有认证: socks5://username:password@127.0.0.1:1080
|
||||
# HTTP 代理示例
|
||||
# http://username:password@127.0.0.1:7890
|
||||
# SOCKS5 代理示例
|
||||
# socks5://username:password@127.0.0.1:1080
|
||||
# SOCKS5H 代理示例
|
||||
# socks5h://username:password@127.0.0.1:1080
|
||||
# 留空不使用代理
|
||||
proxy = ""
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
services:
|
||||
hubproxy:
|
||||
build: .
|
||||
restart: always
|
||||
ports:
|
||||
- '5000:5000'
|
||||
volumes:
|
||||
- ./src/config.toml:/root/config.toml
|
||||
hubproxy:
|
||||
build: .
|
||||
restart: always
|
||||
ports:
|
||||
- '5000:5000'
|
||||
volumes:
|
||||
- ./src/config.toml:/root/config.toml
|
||||
@@ -4,12 +4,14 @@ host = "0.0.0.0"
|
||||
port = 5000
|
||||
# Github文件大小限制(字节),默认2GB
|
||||
fileSize = 2147483648
|
||||
# HTTP/2 多路复用
|
||||
enableH2C = false
|
||||
|
||||
[rateLimit]
|
||||
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层,会消耗多个次数)
|
||||
requestLimit = 500
|
||||
# 限流周期(小时)
|
||||
periodHours = 1.0
|
||||
periodHours = 3.0
|
||||
|
||||
[security]
|
||||
# IP白名单,支持单个IP或IP段
|
||||
@@ -43,12 +45,6 @@ blackList = [
|
||||
# 代理配置,支持有用户名/密码认证和无认证模式
|
||||
# 无认证: socks5://127.0.0.1:1080
|
||||
# 有认证: socks5://username:password@127.0.0.1:1080
|
||||
# HTTP 代理示例
|
||||
# http://username:password@127.0.0.1:7890
|
||||
# SOCKS5 代理示例
|
||||
# socks5://username:password@127.0.0.1:1080
|
||||
# SOCKS5H 代理示例
|
||||
# socks5h://username:password@127.0.0.1:1080
|
||||
# 留空不使用代理
|
||||
proxy = ""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -13,45 +13,46 @@ import (
|
||||
|
||||
// RegistryMapping Registry映射配置
|
||||
type RegistryMapping struct {
|
||||
Upstream string `toml:"upstream"` // 上游Registry地址
|
||||
AuthHost string `toml:"authHost"` // 认证服务器地址
|
||||
AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic
|
||||
Enabled bool `toml:"enabled"` // 是否启用
|
||||
Upstream string `toml:"upstream"`
|
||||
AuthHost string `toml:"authHost"`
|
||||
AuthType string `toml:"authType"`
|
||||
Enabled bool `toml:"enabled"`
|
||||
}
|
||||
|
||||
// AppConfig 应用配置结构体
|
||||
type AppConfig struct {
|
||||
Server struct {
|
||||
Host string `toml:"host"` // 监听地址
|
||||
Port int `toml:"port"` // 监听端口
|
||||
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节)
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
FileSize int64 `toml:"fileSize"`
|
||||
EnableH2C bool `toml:"enableH2C"`
|
||||
} `toml:"server"`
|
||||
|
||||
RateLimit struct {
|
||||
RequestLimit int `toml:"requestLimit"` // 每小时请求限制
|
||||
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时)
|
||||
RequestLimit int `toml:"requestLimit"`
|
||||
PeriodHours float64 `toml:"periodHours"`
|
||||
} `toml:"rateLimit"`
|
||||
|
||||
Security struct {
|
||||
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表
|
||||
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表
|
||||
WhiteList []string `toml:"whiteList"`
|
||||
BlackList []string `toml:"blackList"`
|
||||
} `toml:"security"`
|
||||
|
||||
Access struct {
|
||||
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别)
|
||||
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别)
|
||||
Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h
|
||||
WhiteList []string `toml:"whiteList"`
|
||||
BlackList []string `toml:"blackList"`
|
||||
Proxy string `toml:"proxy"`
|
||||
} `toml:"access"`
|
||||
|
||||
Download struct {
|
||||
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制
|
||||
MaxImages int `toml:"maxImages"`
|
||||
} `toml:"download"`
|
||||
|
||||
Registries map[string]RegistryMapping `toml:"registries"`
|
||||
|
||||
TokenCache struct {
|
||||
Enabled bool `toml:"enabled"` // 是否启用token缓存
|
||||
DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间
|
||||
Enabled bool `toml:"enabled"`
|
||||
DefaultTTL string `toml:"defaultTTL"`
|
||||
} `toml:"tokenCache"`
|
||||
}
|
||||
|
||||
@@ -65,24 +66,25 @@ var (
|
||||
configCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// todo:Refactoring is needed
|
||||
// DefaultConfig 返回默认配置
|
||||
func DefaultConfig() *AppConfig {
|
||||
return &AppConfig{
|
||||
Server: struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
FileSize int64 `toml:"fileSize"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
FileSize int64 `toml:"fileSize"`
|
||||
EnableH2C bool `toml:"enableH2C"`
|
||||
}{
|
||||
Host: "0.0.0.0",
|
||||
Port: 5000,
|
||||
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
|
||||
Host: "0.0.0.0",
|
||||
Port: 5000,
|
||||
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
|
||||
EnableH2C: false, // 默认关闭H2C
|
||||
},
|
||||
RateLimit: struct {
|
||||
RequestLimit int `toml:"requestLimit"`
|
||||
PeriodHours float64 `toml:"periodHours"`
|
||||
}{
|
||||
RequestLimit: 20,
|
||||
RequestLimit: 200,
|
||||
PeriodHours: 1.0,
|
||||
},
|
||||
Security: struct {
|
||||
@@ -99,12 +101,12 @@ func DefaultConfig() *AppConfig {
|
||||
}{
|
||||
WhiteList: []string{},
|
||||
BlackList: []string{},
|
||||
Proxy: "", // 默认不使用代理
|
||||
Proxy: "",
|
||||
},
|
||||
Download: struct {
|
||||
MaxImages int `toml:"maxImages"`
|
||||
}{
|
||||
MaxImages: 10, // 默认值:最多同时下载10个镜像
|
||||
MaxImages: 10,
|
||||
},
|
||||
Registries: map[string]RegistryMapping{
|
||||
"ghcr.io": {
|
||||
@@ -136,7 +138,7 @@ func DefaultConfig() *AppConfig {
|
||||
Enabled bool `toml:"enabled"`
|
||||
DefaultTTL string `toml:"defaultTTL"`
|
||||
}{
|
||||
Enabled: true, // docker认证的匿名Token缓存配置,用于提升性能
|
||||
Enabled: true,
|
||||
DefaultTTL: "20m",
|
||||
},
|
||||
}
|
||||
@@ -152,11 +154,9 @@ func GetConfig() *AppConfig {
|
||||
}
|
||||
configCacheMutex.RUnlock()
|
||||
|
||||
// 缓存过期,重新生成配置
|
||||
configCacheMutex.Lock()
|
||||
defer configCacheMutex.Unlock()
|
||||
|
||||
// 双重检查,防止重复生成
|
||||
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
|
||||
return cachedConfig
|
||||
}
|
||||
@@ -170,7 +170,6 @@ func GetConfig() *AppConfig {
|
||||
return defaultCfg
|
||||
}
|
||||
|
||||
// 生成新的配置深拷贝
|
||||
configCopy := *appConfig
|
||||
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
|
||||
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
|
||||
@@ -197,10 +196,8 @@ func setConfig(cfg *AppConfig) {
|
||||
|
||||
// LoadConfig 加载配置文件
|
||||
func LoadConfig() error {
|
||||
// 首先使用默认配置
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// 尝试加载TOML配置文件
|
||||
if data, err := os.ReadFile("config.toml"); err == nil {
|
||||
if err := toml.Unmarshal(data, cfg); err != nil {
|
||||
return fmt.Errorf("解析配置文件失败: %v", err)
|
||||
@@ -209,10 +206,7 @@ func LoadConfig() error {
|
||||
fmt.Println("未找到config.toml,使用默认配置")
|
||||
}
|
||||
|
||||
// 从环境变量覆盖配置
|
||||
overrideFromEnv(cfg)
|
||||
|
||||
// 设置配置
|
||||
setConfig(cfg)
|
||||
|
||||
return nil
|
||||
@@ -220,7 +214,6 @@ func LoadConfig() error {
|
||||
|
||||
// overrideFromEnv 从环境变量覆盖配置
|
||||
func overrideFromEnv(cfg *AppConfig) {
|
||||
// 服务器配置
|
||||
if val := os.Getenv("SERVER_HOST"); val != "" {
|
||||
cfg.Server.Host = val
|
||||
}
|
||||
@@ -229,13 +222,17 @@ func overrideFromEnv(cfg *AppConfig) {
|
||||
cfg.Server.Port = port
|
||||
}
|
||||
}
|
||||
if val := os.Getenv("ENABLE_H2C"); val != "" {
|
||||
if enable, err := strconv.ParseBool(val); err == nil {
|
||||
cfg.Server.EnableH2C = enable
|
||||
}
|
||||
}
|
||||
if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
|
||||
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
|
||||
cfg.Server.FileSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// 限流配置
|
||||
if val := os.Getenv("RATE_LIMIT"); val != "" {
|
||||
if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
|
||||
cfg.RateLimit.RequestLimit = limit
|
||||
@@ -247,7 +244,6 @@ func overrideFromEnv(cfg *AppConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
// IP限制配置
|
||||
if val := os.Getenv("IP_WHITELIST"); val != "" {
|
||||
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
|
||||
}
|
||||
@@ -255,7 +251,6 @@ func overrideFromEnv(cfg *AppConfig) {
|
||||
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
|
||||
}
|
||||
|
||||
// 下载限制配置
|
||||
if val := os.Getenv("MAX_IMAGES"); val != "" {
|
||||
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
|
||||
cfg.Download.MaxImages = maxImages
|
||||
@@ -6,6 +6,7 @@ require (
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/go-containerregistry v0.20.5
|
||||
github.com/pelletier/go-toml/v2 v2.2.3
|
||||
golang.org/x/net v0.33.0
|
||||
golang.org/x/time v0.11.0
|
||||
)
|
||||
|
||||
@@ -43,7 +44,6 @@ require (
|
||||
github.com/vbatts/tar-split v0.12.1 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.32.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sync v0.14.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/google/go-containerregistry/pkg/authn"
|
||||
"github.com/google/go-containerregistry/pkg/name"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||
"hubproxy/config"
|
||||
"hubproxy/utils"
|
||||
)
|
||||
|
||||
// DockerProxy Docker代理配置
|
||||
@@ -27,12 +29,10 @@ type RegistryDetector struct{}
|
||||
|
||||
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
|
||||
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
|
||||
// 检查路径是否以已知Registry域名开头
|
||||
for domain := range cfg.Registries {
|
||||
if strings.HasPrefix(path, domain+"/") {
|
||||
// 找到匹配的域名,返回域名和剩余路径
|
||||
remainingPath := strings.TrimPrefix(path, domain+"/")
|
||||
return domain, remainingPath
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
|
||||
|
||||
// isRegistryEnabled 检查Registry是否启用
|
||||
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
if mapping, exists := cfg.Registries[domain]; exists {
|
||||
return mapping.Enabled
|
||||
}
|
||||
@@ -51,28 +51,26 @@ func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
|
||||
}
|
||||
|
||||
// getRegistryMapping 获取Registry映射配置
|
||||
func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) {
|
||||
cfg := GetConfig()
|
||||
func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
|
||||
cfg := config.GetConfig()
|
||||
mapping, exists := cfg.Registries[domain]
|
||||
return mapping, exists && mapping.Enabled
|
||||
}
|
||||
|
||||
var registryDetector = &RegistryDetector{}
|
||||
|
||||
// 初始化Docker代理
|
||||
func initDockerProxy() {
|
||||
// 创建目标registry
|
||||
// InitDockerProxy 初始化Docker代理
|
||||
func InitDockerProxy() {
|
||||
registry, err := name.NewRegistry("registry-1.docker.io")
|
||||
if err != nil {
|
||||
fmt.Printf("创建Docker registry失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 配置代理选项
|
||||
options := []remote.Option{
|
||||
remote.WithAuth(authn.Anonymous),
|
||||
remote.WithUserAgent("hubproxy/go-containerregistry"),
|
||||
remote.WithTransport(GetGlobalHTTPClient().Transport),
|
||||
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
|
||||
}
|
||||
|
||||
dockerProxy = &DockerProxy{
|
||||
@@ -85,13 +83,11 @@ func initDockerProxy() {
|
||||
func ProxyDockerRegistryGin(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 处理 /v2/ API版本检查
|
||||
if path == "/v2/" {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
return
|
||||
}
|
||||
|
||||
// 处理不同的API端点
|
||||
if strings.HasPrefix(path, "/v2/") {
|
||||
handleRegistryRequest(c, path)
|
||||
} else {
|
||||
@@ -101,16 +97,13 @@ func ProxyDockerRegistryGin(c *gin.Context) {
|
||||
|
||||
// handleRegistryRequest 处理Registry请求
|
||||
func handleRegistryRequest(c *gin.Context, path string) {
|
||||
// 移除 /v2/ 前缀
|
||||
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
|
||||
|
||||
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
|
||||
if registryDetector.isRegistryEnabled(registryDomain) {
|
||||
// 设置目标Registry信息到Context
|
||||
c.Set("target_registry_domain", registryDomain)
|
||||
c.Set("target_path", remainingPath)
|
||||
|
||||
// 处理多Registry请求
|
||||
handleMultiRegistryRequest(c, registryDomain, remainingPath)
|
||||
return
|
||||
}
|
||||
@@ -122,19 +115,16 @@ func handleRegistryRequest(c *gin.Context, path string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 自动处理官方镜像的library命名空间
|
||||
if !strings.Contains(imageName, "/") {
|
||||
imageName = "library/" + imageName
|
||||
}
|
||||
|
||||
// Docker镜像访问控制检查
|
||||
if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed {
|
||||
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
|
||||
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
|
||||
c.String(http.StatusForbidden, "镜像访问被限制")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建完整的镜像引用
|
||||
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
|
||||
|
||||
switch apiType {
|
||||
@@ -151,7 +141,6 @@ func handleRegistryRequest(c *gin.Context, path string) {
|
||||
|
||||
// parseRegistryPath 解析Registry路径
|
||||
func parseRegistryPath(path string) (imageName, apiType, reference string) {
|
||||
// 查找API端点关键字
|
||||
if idx := strings.Index(path, "/manifests/"); idx != -1 {
|
||||
imageName = path[:idx]
|
||||
apiType = "manifests"
|
||||
@@ -178,13 +167,11 @@ func parseRegistryPath(path string) (imageName, apiType, reference string) {
|
||||
|
||||
// handleManifestRequest 处理manifest请求
|
||||
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
||||
// Manifest缓存逻辑(仅对GET请求缓存)
|
||||
if isCacheEnabled() && c.Request.Method == http.MethodGet {
|
||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
||||
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
|
||||
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||
|
||||
// 优先从缓存获取
|
||||
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
|
||||
writeCachedResponse(c, cachedItem)
|
||||
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
|
||||
utils.WriteCachedResponse(c, cachedItem)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -192,12 +179,9 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
||||
var ref name.Reference
|
||||
var err error
|
||||
|
||||
// 判断reference是digest还是tag
|
||||
if strings.HasPrefix(reference, "sha256:") {
|
||||
// 是digest
|
||||
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
||||
} else {
|
||||
// 是tag
|
||||
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
|
||||
}
|
||||
|
||||
@@ -207,9 +191,7 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 根据请求方法选择操作
|
||||
if c.Request.Method == http.MethodHead {
|
||||
// HEAD请求,使用remote.Head
|
||||
desc, err := remote.Head(ref, dockerProxy.options...)
|
||||
if err != nil {
|
||||
fmt.Printf("HEAD请求失败: %v\n", err)
|
||||
@@ -217,13 +199,11 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", string(desc.MediaType))
|
||||
c.Header("Docker-Content-Digest", desc.Digest.String())
|
||||
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
|
||||
c.Status(http.StatusOK)
|
||||
} else {
|
||||
// GET请求,使用remote.Get
|
||||
desc, err := remote.Get(ref, dockerProxy.options...)
|
||||
if err != nil {
|
||||
fmt.Printf("GET请求失败: %v\n", err)
|
||||
@@ -231,33 +211,28 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
headers := map[string]string{
|
||||
"Docker-Content-Digest": desc.Digest.String(),
|
||||
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
|
||||
}
|
||||
|
||||
// 缓存响应
|
||||
if isCacheEnabled() {
|
||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
||||
ttl := getManifestTTL(reference)
|
||||
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
||||
if utils.IsCacheEnabled() {
|
||||
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||
ttl := utils.GetManifestTTL(reference)
|
||||
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", string(desc.MediaType))
|
||||
for key, value := range headers {
|
||||
c.Header(key, value)
|
||||
}
|
||||
|
||||
// 返回manifest内容
|
||||
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleBlobRequest 处理blob请求
|
||||
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
||||
// 构建digest引用
|
||||
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
||||
if err != nil {
|
||||
fmt.Printf("解析digest引用失败: %v\n", err)
|
||||
@@ -265,7 +240,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用remote.Layer获取layer
|
||||
layer, err := remote.Layer(digestRef, dockerProxy.options...)
|
||||
if err != nil {
|
||||
fmt.Printf("获取layer失败: %v\n", err)
|
||||
@@ -273,7 +247,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取layer信息
|
||||
size, err := layer.Size()
|
||||
if err != nil {
|
||||
fmt.Printf("获取layer大小失败: %v\n", err)
|
||||
@@ -281,7 +254,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取layer内容
|
||||
reader, err := layer.Compressed()
|
||||
if err != nil {
|
||||
fmt.Printf("获取layer内容失败: %v\n", err)
|
||||
@@ -290,19 +262,16 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", "application/octet-stream")
|
||||
c.Header("Content-Length", fmt.Sprintf("%d", size))
|
||||
c.Header("Docker-Content-Digest", digest)
|
||||
|
||||
// 流式传输blob内容
|
||||
c.Status(http.StatusOK)
|
||||
io.Copy(c.Writer, reader)
|
||||
}
|
||||
|
||||
// handleTagsRequest 处理tags列表请求
|
||||
func handleTagsRequest(c *gin.Context, imageRef string) {
|
||||
// 解析repository
|
||||
repo, err := name.NewRepository(imageRef)
|
||||
if err != nil {
|
||||
fmt.Printf("解析repository失败: %v\n", err)
|
||||
@@ -310,7 +279,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用remote.List获取tags
|
||||
tags, err := remote.List(repo, dockerProxy.options...)
|
||||
if err != nil {
|
||||
fmt.Printf("获取tags失败: %v\n", err)
|
||||
@@ -318,7 +286,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
response := map[string]interface{}{
|
||||
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
|
||||
"tags": tags,
|
||||
@@ -327,10 +294,9 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ProxyDockerAuthGin Docker认证代理(带缓存优化)
|
||||
// ProxyDockerAuthGin Docker认证代理
|
||||
func ProxyDockerAuthGin(c *gin.Context) {
|
||||
// 检查是否启用token缓存
|
||||
if isTokenCacheEnabled() {
|
||||
if utils.IsTokenCacheEnabled() {
|
||||
proxyDockerAuthWithCache(c)
|
||||
} else {
|
||||
proxyDockerAuthOriginal(c)
|
||||
@@ -339,32 +305,26 @@ func ProxyDockerAuthGin(c *gin.Context) {
|
||||
|
||||
// proxyDockerAuthWithCache 带缓存的认证代理
|
||||
func proxyDockerAuthWithCache(c *gin.Context) {
|
||||
// 1. 构建缓存key(基于完整的查询参数)
|
||||
cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery)
|
||||
cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
|
||||
|
||||
// 2. 尝试从缓存获取token
|
||||
if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" {
|
||||
writeTokenResponse(c, cachedToken)
|
||||
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
|
||||
utils.WriteTokenResponse(c, cachedToken)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 缓存未命中,创建响应记录器
|
||||
recorder := &ResponseRecorder{
|
||||
ResponseWriter: c.Writer,
|
||||
statusCode: 200,
|
||||
}
|
||||
c.Writer = recorder
|
||||
|
||||
// 4. 调用原有认证逻辑
|
||||
proxyDockerAuthOriginal(c)
|
||||
|
||||
// 5. 如果认证成功,缓存响应
|
||||
if recorder.statusCode == 200 && len(recorder.body) > 0 {
|
||||
ttl := extractTTLFromResponse(recorder.body)
|
||||
globalCache.SetToken(cacheKey, string(recorder.body), ttl)
|
||||
ttl := utils.ExtractTTLFromResponse(recorder.body)
|
||||
utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
|
||||
}
|
||||
|
||||
// 6. 写入实际响应
|
||||
c.Writer = recorder.ResponseWriter
|
||||
c.Data(recorder.statusCode, "application/json", recorder.body)
|
||||
}
|
||||
@@ -389,14 +349,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
||||
var authURL string
|
||||
if targetDomain, exists := c.Get("target_registry_domain"); exists {
|
||||
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
|
||||
// 使用Registry特定的认证服务器
|
||||
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
|
||||
} else {
|
||||
// fallback到默认Docker认证
|
||||
authURL = "https://auth.docker.io" + c.Request.URL.Path
|
||||
}
|
||||
} else {
|
||||
// 构建默认Docker认证URL
|
||||
authURL = "https://auth.docker.io" + c.Request.URL.Path
|
||||
}
|
||||
|
||||
@@ -404,13 +361,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
||||
authURL += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
|
||||
// 创建HTTP客户端,复用全局传输配置(包含代理设置)
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: GetGlobalHTTPClient().Transport,
|
||||
Transport: utils.GetGlobalHTTPClient().Transport,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(
|
||||
context.Background(),
|
||||
c.Request.Method,
|
||||
@@ -422,14 +377,12 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 复制请求头
|
||||
for key, values := range c.Request.Header {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadGateway, "Auth request failed")
|
||||
@@ -437,37 +390,30 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 获取当前代理的Host地址
|
||||
proxyHost := c.Request.Host
|
||||
if proxyHost == "" {
|
||||
// 使用配置中的服务器地址和端口
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||
if cfg.Server.Host == "0.0.0.0" {
|
||||
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// 复制响应头并重写认证URL
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
// 重写WWW-Authenticate头中的realm URL
|
||||
if key == "Www-Authenticate" {
|
||||
// 支持多Registry的URL重写
|
||||
value = rewriteAuthHeader(value, proxyHost)
|
||||
}
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.Status(resp.StatusCode)
|
||||
io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
|
||||
// rewriteAuthHeader 重写认证头
|
||||
func rewriteAuthHeader(authHeader, proxyHost string) string {
|
||||
// 重写各种Registry的认证URL
|
||||
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
|
||||
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
|
||||
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
|
||||
@@ -478,32 +424,27 @@ func rewriteAuthHeader(authHeader, proxyHost string) string {
|
||||
|
||||
// handleMultiRegistryRequest 处理多Registry请求
|
||||
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
|
||||
// 获取Registry映射配置
|
||||
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
|
||||
if !exists {
|
||||
c.String(http.StatusBadRequest, "Registry not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析剩余路径
|
||||
imageName, apiType, reference := parseRegistryPath(remainingPath)
|
||||
if imageName == "" || apiType == "" {
|
||||
c.String(http.StatusBadRequest, "Invalid path format")
|
||||
return
|
||||
}
|
||||
|
||||
// 访问控制检查(使用完整的镜像路径)
|
||||
fullImageName := registryDomain + "/" + imageName
|
||||
if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
|
||||
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
|
||||
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
|
||||
c.String(http.StatusForbidden, "镜像访问被限制")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建上游Registry引用
|
||||
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
|
||||
|
||||
// 根据API类型处理请求
|
||||
switch apiType {
|
||||
case "manifests":
|
||||
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
|
||||
@@ -517,14 +458,12 @@ func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath st
|
||||
}
|
||||
|
||||
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
|
||||
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) {
|
||||
// Manifest缓存逻辑(仅对GET请求缓存)
|
||||
if isCacheEnabled() && c.Request.Method == http.MethodGet {
|
||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
||||
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
|
||||
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
|
||||
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||
|
||||
// 优先从缓存获取
|
||||
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
|
||||
writeCachedResponse(c, cachedItem)
|
||||
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
|
||||
utils.WriteCachedResponse(c, cachedItem)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -532,7 +471,6 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
||||
var ref name.Reference
|
||||
var err error
|
||||
|
||||
// 判断reference是digest还是tag
|
||||
if strings.HasPrefix(reference, "sha256:") {
|
||||
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
||||
} else {
|
||||
@@ -545,10 +483,8 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
||||
return
|
||||
}
|
||||
|
||||
// 创建针对上游Registry的选项
|
||||
options := createUpstreamOptions(mapping)
|
||||
|
||||
// 根据请求方法选择操作
|
||||
if c.Request.Method == http.MethodHead {
|
||||
desc, err := remote.Head(ref, options...)
|
||||
if err != nil {
|
||||
@@ -569,20 +505,17 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
headers := map[string]string{
|
||||
"Docker-Content-Digest": desc.Digest.String(),
|
||||
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
|
||||
}
|
||||
|
||||
// 缓存响应
|
||||
if isCacheEnabled() {
|
||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
||||
ttl := getManifestTTL(reference)
|
||||
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
||||
if utils.IsCacheEnabled() {
|
||||
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||
ttl := utils.GetManifestTTL(reference)
|
||||
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", string(desc.MediaType))
|
||||
for key, value := range headers {
|
||||
c.Header(key, value)
|
||||
@@ -593,7 +526,7 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
||||
}
|
||||
|
||||
// handleUpstreamBlobRequest 处理上游Registry的blob请求
|
||||
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) {
|
||||
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
|
||||
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
||||
if err != nil {
|
||||
fmt.Printf("解析digest引用失败: %v\n", err)
|
||||
@@ -633,7 +566,7 @@ func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping
|
||||
}
|
||||
|
||||
// handleUpstreamTagsRequest 处理上游Registry的tags请求
|
||||
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) {
|
||||
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
|
||||
repo, err := name.NewRepository(imageRef)
|
||||
if err != nil {
|
||||
fmt.Printf("解析repository失败: %v\n", err)
|
||||
@@ -658,14 +591,13 @@ func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping Registry
|
||||
}
|
||||
|
||||
// createUpstreamOptions 创建上游Registry选项
|
||||
func createUpstreamOptions(mapping RegistryMapping) []remote.Option {
|
||||
func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
|
||||
options := []remote.Option{
|
||||
remote.WithAuth(authn.Anonymous),
|
||||
remote.WithUserAgent("hubproxy/go-containerregistry"),
|
||||
remote.WithTransport(GetGlobalHTTPClient().Transport),
|
||||
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
|
||||
}
|
||||
|
||||
// 根据Registry类型添加特定的认证选项(方便后续扩展)
|
||||
switch mapping.AuthType {
|
||||
case "github":
|
||||
case "google":
|
||||
213
src/handlers/github.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"hubproxy/config"
|
||||
"hubproxy/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
// GitHub URL匹配正则表达式
|
||||
githubExps = []*regexp.Regexp{
|
||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
|
||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`),
|
||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`),
|
||||
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`),
|
||||
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
|
||||
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
|
||||
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`),
|
||||
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`),
|
||||
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`),
|
||||
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`),
|
||||
}
|
||||
)
|
||||
|
||||
// GitHubProxyHandler GitHub代理处理器
|
||||
func GitHubProxyHandler(c *gin.Context) {
|
||||
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
|
||||
|
||||
for strings.HasPrefix(rawPath, "/") {
|
||||
rawPath = strings.TrimPrefix(rawPath, "/")
|
||||
}
|
||||
|
||||
// 自动补全协议头
|
||||
if !strings.HasPrefix(rawPath, "https://") {
|
||||
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
|
||||
rawPath = strings.Replace(rawPath, "http:/", "", 1)
|
||||
rawPath = strings.Replace(rawPath, "https:/", "", 1)
|
||||
} else if strings.HasPrefix(rawPath, "http://") {
|
||||
rawPath = strings.TrimPrefix(rawPath, "http://")
|
||||
}
|
||||
rawPath = "https://" + rawPath
|
||||
}
|
||||
|
||||
matches := CheckGitHubURL(rawPath)
|
||||
if matches != nil {
|
||||
if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed {
|
||||
var repoPath string
|
||||
if len(matches) >= 2 {
|
||||
username := matches[0]
|
||||
repoName := strings.TrimSuffix(matches[1], ".git")
|
||||
repoPath = username + "/" + repoName
|
||||
}
|
||||
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
|
||||
c.String(http.StatusForbidden, reason)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.String(http.StatusForbidden, "无效输入")
|
||||
return
|
||||
}
|
||||
|
||||
// 将blob链接转换为raw链接
|
||||
if githubExps[1].MatchString(rawPath) {
|
||||
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
|
||||
}
|
||||
|
||||
ProxyGitHubRequest(c, rawPath)
|
||||
}
|
||||
|
||||
// CheckGitHubURL 检查URL是否匹配GitHub模式
|
||||
func CheckGitHubURL(u string) []string {
|
||||
for _, exp := range githubExps {
|
||||
if matches := exp.FindStringSubmatch(u); matches != nil {
|
||||
return matches[1:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyGitHubRequest 代理GitHub请求
|
||||
func ProxyGitHubRequest(c *gin.Context, u string) {
|
||||
proxyGitHubWithRedirect(c, u, 0)
|
||||
}
|
||||
|
||||
// proxyGitHubWithRedirect 带重定向的GitHub代理请求
|
||||
func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) {
|
||||
const maxRedirects = 20
|
||||
if redirectCount > maxRedirects {
|
||||
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 复制请求头
|
||||
for key, values := range c.Request.Header {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
req.Header.Del("Host")
|
||||
|
|
||||
|
||||
resp, err := utils.GetGlobalHTTPClient().Do(req)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
fmt.Printf("关闭响应体失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 检查文件大小限制
|
||||
cfg := config.GetConfig()
|
||||
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
|
||||
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
|
||||
c.String(http.StatusRequestEntityTooLarge,
|
||||
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 清理安全相关的头
|
||||
resp.Header.Del("Content-Security-Policy")
|
||||
resp.Header.Del("Referrer-Policy")
|
||||
resp.Header.Del("Strict-Transport-Security")
|
||||
|
||||
// 获取真实域名
|
||||
realHost := c.Request.Header.Get("X-Forwarded-Host")
|
||||
if realHost == "" {
|
||||
realHost = c.Request.Host
|
||||
}
|
||||
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
|
||||
realHost = "https://" + realHost
|
||||
}
|
||||
|
||||
// 处理.sh文件的智能处理
|
||||
if strings.HasSuffix(strings.ToLower(u), ".sh") {
|
||||
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
|
||||
|
||||
processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost)
|
||||
if err != nil {
|
||||
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
|
||||
processedBody = resp.Body
|
||||
processedSize = 0
|
||||
}
|
||||
|
||||
// 智能设置响应头
|
||||
if processedSize > 0 {
|
||||
resp.Header.Del("Content-Length")
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
// 复制其他响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理重定向
|
||||
if location := resp.Header.Get("Location"); location != "" {
|
||||
if CheckGitHubURL(location) != nil {
|
||||
c.Header("Location", "/"+location)
|
||||
} else {
|
||||
proxyGitHubWithRedirect(c, location, redirectCount+1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 输出处理后的内容
|
||||
if _, err := io.Copy(c.Writer, processedBody); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
💡 Verification agent 🧩 Analysis chainClarify the purpose of .sh file processing. The code performs special processing on
🏁 Script executed: Length of output: 55 🏁 Script executed: Length of output: 337 🏁 Script executed: Length of output: 2472 Please document and safeguard We’ve confirmed that
To help future maintainers and ensure safety, please: • In • In • Add or update unit/integration tests covering: This will clarify intent, guard against inadvertent breakage, and surface any security risks. 🤖 Prompt for AI Agents_💡 Verification agent_
<details>
<summary>🧩 Analysis chain</summary>
**Clarify the purpose of .sh file processing.**
The code performs special processing on `.sh` files using `utils.ProcessSmart`. Please document:
1. What modifications are being made to shell scripts?
2. Are there security implications of modifying executable content?
3. What's the use case for this feature?
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for the ProcessSmart implementation to understand what it does
ast-grep --pattern 'func ProcessSmart($$$) {
$$$
}'
```
Length of output: 55
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for ProcessSmart references and implementation across the repository
rg "ProcessSmart" -n .
```
Length of output: 337
---
🏁 Script executed:
```shell
#!/bin/bash
# Display the full implementation of ProcessSmart in utils/proxy_shell.go
sed -n '1,200p' src/utils/proxy_shell.go
```
Length of output: 2472
---
</details>
**Please document and safeguard `.sh` file processing logic**
We’ve confirmed that `ProcessSmart` in `utils/proxy_shell.go`:
- Handles optional gzip by peeking and unwrapping if needed
- Reads the entire script into memory (unless >10 MB or no GitHub URLs)
- Uses `githubRegex` to locate all GitHub‐hosted URLs and rewrites each via `transformURL(host)`
- Returns a new `io.Reader` of the (potentially) modified script
To help future maintainers and ensure safety, please:
• In `src/utils/proxy_shell.go` (ProcessSmart, lines ~16–54):
– Add a top-level doc comment describing each step (decompression, size check, URL rewrite)
– Call out any edge cases (e.g. very large scripts, scripts without GitHub links)
– Note security considerations around modifying executable content (shebang integrity, unintended regex matches)
• In `src/handlers/github.go` (lines 150–190):
– Explain the `.sh` branch’s purpose: why and when you proxy/modify scripts
– Document the impact on downstream clients (chunked responses, URL schemes)
• Add or update unit/integration tests covering:
– A small script with GitHub URLs (verify correct rewrite and valid URL syntax)
– A gzipped script
– A large script (>10 MB) and one without GitHub URLs (ensure no modifications)
This will clarify intent, guard against inadvertent breakage, and surface any security risks.
<details>
<summary>🤖 Prompt for AI Agents</summary>
```
In src/handlers/github.go around lines 150 to 190, add clear comments explaining
the purpose of the .sh file processing branch, detailing why scripts are proxied
and modified, and how the response headers are adjusted (e.g., chunked transfer
encoding) to inform downstream clients. In src/utils/proxy_shell.go within the
ProcessSmart function (lines ~16–54), add a comprehensive top-level doc comment
describing each processing step: decompression if gzip encoded, size checks to
avoid large memory usage, GitHub URL detection and rewriting, and returning the
modified script reader. Also document edge cases like very large scripts or
scripts without GitHub URLs, and note security considerations such as preserving
shebang integrity and avoiding unintended regex matches. Finally, add or update
unit and integration tests to cover scenarios including small scripts with
GitHub URLs (verifying correct URL rewriting), gzipped scripts, large scripts
over 10 MB, and scripts without GitHub URLs to ensure no modifications occur.
```
</details>
<!-- fingerprinting:phantom:medusa:lion -->
<!-- This is an auto-generated comment by CodeRabbit -->
|
||||
// 复制响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理重定向
|
||||
if location := resp.Header.Get("Location"); location != "" {
|
||||
if CheckGitHubURL(location) != nil {
|
||||
c.Header("Location", "/"+location)
|
||||
} else {
|
||||
proxyGitHubWithRedirect(c, location, redirectCount+1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 直接流式转发
|
||||
io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"github.com/google/go-containerregistry/pkg/v1/partial"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||
"github.com/google/go-containerregistry/pkg/v1/types"
|
||||
"hubproxy/config"
|
||||
"hubproxy/utils"
|
||||
)
|
||||
|
||||
// DebounceEntry 防抖条目
|
||||
@@ -58,17 +60,15 @@ func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
|
||||
|
||||
if entry, exists := d.entries[key]; exists {
|
||||
if now.Sub(entry.LastRequest) < d.window {
|
||||
return false // 在防抖窗口内,拒绝请求
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 更新或创建条目
|
||||
d.entries[key] = &DebounceEntry{
|
||||
LastRequest: now,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
// 清理过期条目(每5分钟清理一次)
|
||||
if time.Since(d.lastCleanup) > 5*time.Minute {
|
||||
d.cleanup(now)
|
||||
d.lastCleanup = now
|
||||
@@ -88,50 +88,41 @@ func (d *DownloadDebouncer) cleanup(now time.Time) {
|
||||
|
||||
// generateContentFingerprint 生成内容指纹
|
||||
func generateContentFingerprint(images []string, platform string) string {
|
||||
// 对镜像列表排序确保顺序无关
|
||||
sortedImages := make([]string, len(images))
|
||||
copy(sortedImages, images)
|
||||
sort.Strings(sortedImages)
|
||||
|
||||
// 组合内容:镜像列表 + 平台信息
|
||||
content := strings.Join(sortedImages, "|") + ":" + platform
|
||||
|
||||
// 生成MD5哈希
|
||||
hash := md5.Sum([]byte(content))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// getUserID 获取用户标识
|
||||
func getUserID(c *gin.Context) string {
|
||||
// 优先使用会话Cookie
|
||||
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
|
||||
return "session:" + sessionID
|
||||
}
|
||||
|
||||
// 备用方案:IP + User-Agent组合
|
||||
ip := c.ClientIP()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
userAgent = "unknown"
|
||||
}
|
||||
|
||||
// 生成简短标识
|
||||
combined := ip + ":" + userAgent
|
||||
hash := md5.Sum([]byte(combined))
|
||||
return "ip:" + hex.EncodeToString(hash[:8]) // 只取前8字节
|
||||
return "ip:" + hex.EncodeToString(hash[:8])
|
||||
}
|
||||
|
||||
// 全局防抖器实例
|
||||
var (
|
||||
singleImageDebouncer *DownloadDebouncer
|
||||
batchImageDebouncer *DownloadDebouncer
|
||||
)
|
||||
|
||||
// initDebouncer 初始化防抖器
|
||||
func initDebouncer() {
|
||||
// 单个镜像:5秒防抖窗口
|
||||
// InitDebouncer 初始化防抖器
|
||||
func InitDebouncer() {
|
||||
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
|
||||
// 批量镜像:60秒防抖窗口
|
||||
batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
|
||||
}
|
||||
|
||||
@@ -147,15 +138,15 @@ type ImageStreamerConfig struct {
|
||||
}
|
||||
|
||||
// NewImageStreamer 创建镜像下载器
|
||||
func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
|
||||
if config == nil {
|
||||
config = &ImageStreamerConfig{}
|
||||
func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer {
|
||||
if cfg == nil {
|
||||
cfg = &ImageStreamerConfig{}
|
||||
}
|
||||
|
||||
concurrency := config.Concurrency
|
||||
concurrency := cfg.Concurrency
|
||||
if concurrency <= 0 {
|
||||
cfg := GetConfig()
|
||||
concurrency = cfg.Download.MaxImages
|
||||
appCfg := config.GetConfig()
|
||||
concurrency = appCfg.Download.MaxImages
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
@@ -163,7 +154,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
|
||||
|
||||
remoteOptions := []remote.Option{
|
||||
remote.WithAuth(authn.Anonymous),
|
||||
remote.WithTransport(GetGlobalHTTPClient().Transport),
|
||||
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
|
||||
}
|
||||
|
||||
return &ImageStreamer{
|
||||
@@ -176,7 +167,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
|
||||
type StreamOptions struct {
|
||||
Platform string
|
||||
Compression bool
|
||||
UseCompressedLayers bool // 是否保存原始压缩层,默认开启
|
||||
UseCompressedLayers bool
|
||||
}
|
||||
|
||||
// StreamImageToWriter 流式下载镜像到Writer
|
||||
@@ -215,7 +206,6 @@ func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote
|
||||
|
||||
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符
|
||||
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
|
||||
// 直接从网络获取完整的descriptor,确保对象完整性
|
||||
return remote.Get(ref, options...)
|
||||
}
|
||||
|
||||
@@ -343,7 +333,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
||||
var layerSize int64
|
||||
var layerReader io.ReadCloser
|
||||
|
||||
// 根据配置选择使用压缩层或未压缩层
|
||||
if options != nil && options.UseCompressedLayers {
|
||||
layerSize, err = layer.Size()
|
||||
if err != nil {
|
||||
@@ -385,7 +374,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
||||
log.Printf("已处理层 %d/%d", i+1, len(layers))
|
||||
}
|
||||
|
||||
// 构建单个镜像的manifest信息
|
||||
singleManifest := map[string]interface{}{
|
||||
"Config": configDigest.String() + ".json",
|
||||
"RepoTags": []string{imageRef},
|
||||
@@ -398,7 +386,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
||||
}(),
|
||||
}
|
||||
|
||||
// 构建repositories信息
|
||||
repositories := make(map[string]map[string]string)
|
||||
parts := strings.Split(imageRef, ":")
|
||||
if len(parts) == 2 {
|
||||
@@ -407,14 +394,12 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
||||
repositories[repoName] = map[string]string{tag: configDigest.String()}
|
||||
}
|
||||
|
||||
// 如果是批量下载,返回信息而不写入文件
|
||||
if manifestOut != nil && repositoriesOut != nil {
|
||||
*manifestOut = singleManifest
|
||||
*repositoriesOut = repositories
|
||||
return nil
|
||||
}
|
||||
|
||||
// 单镜像下载,直接写入manifest.json
|
||||
manifest := []map[string]interface{}{singleManifest}
|
||||
|
||||
manifestData, err := json.Marshal(manifest)
|
||||
@@ -436,7 +421,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入repositories文件
|
||||
repositoriesData, err := json.Marshal(repositories)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -456,7 +440,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
||||
return err
|
||||
}
|
||||
|
||||
// processImageForBatch 处理镜像的公共逻辑,用于批量下载
|
||||
// processImageForBatch 处理镜像的公共逻辑
|
||||
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
|
||||
layers, err := img.Layers()
|
||||
if err != nil {
|
||||
@@ -498,7 +482,6 @@ func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWrite
|
||||
|
||||
switch desc.MediaType {
|
||||
case types.OCIImageIndex, types.DockerManifestList:
|
||||
// 处理多架构镜像
|
||||
img, err = is.selectPlatformImage(desc, options)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
|
||||
@@ -530,7 +513,6 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
|
||||
return nil, fmt.Errorf("获取索引清单失败: %w", err)
|
||||
}
|
||||
|
||||
// 选择合适的平台
|
||||
var selectedDesc *v1.Descriptor
|
||||
for _, m := range manifest.Manifests {
|
||||
if m.Platform == nil {
|
||||
@@ -578,8 +560,8 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
|
||||
|
||||
var globalImageStreamer *ImageStreamer
|
||||
|
||||
// initImageStreamer 初始化镜像下载器
|
||||
func initImageStreamer() {
|
||||
// InitImageStreamer 初始化镜像下载器
|
||||
func InitImageStreamer() {
|
||||
globalImageStreamer = NewImageStreamer(nil)
|
||||
}
|
||||
|
||||
@@ -591,8 +573,8 @@ func formatPlatformText(platform string) string {
|
||||
return platform
|
||||
}
|
||||
|
||||
// initImageTarRoutes 初始化镜像下载路由
|
||||
func initImageTarRoutes(router *gin.Engine) {
|
||||
// InitImageTarRoutes 初始化镜像下载路由
|
||||
func InitImageTarRoutes(router *gin.Engine) {
|
||||
imageAPI := router.Group("/api/image")
|
||||
{
|
||||
imageAPI.GET("/download/:image", handleDirectImageDownload)
|
||||
@@ -625,7 +607,6 @@ func handleDirectImageDownload(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 防抖检查
|
||||
userID := getUserID(c)
|
||||
contentKey := generateContentFingerprint([]string{imageRef}, platform)
|
||||
|
||||
@@ -677,7 +658,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
if len(req.Images) > cfg.Download.MaxImages {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
|
||||
@@ -685,7 +666,6 @@ func handleSimpleBatchDownload(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 批量下载防抖检查
|
||||
userID := getUserID(c)
|
||||
contentKey := generateContentFingerprint(req.Images, req.Platform)
|
||||
|
||||
@@ -697,7 +677,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useCompressed := true // 默认启用原始压缩层
|
||||
useCompressed := true
|
||||
if req.UseCompressedLayers != nil {
|
||||
useCompressed = *req.UseCompressedLayers
|
||||
}
|
||||
@@ -801,7 +781,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
||||
var allManifests []map[string]interface{}
|
||||
var allRepositories = make(map[string]map[string]string)
|
||||
|
||||
// 流式处理每个镜像
|
||||
for i, imageRef := range imageRefs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -811,7 +790,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
||||
|
||||
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
|
||||
|
||||
// 防止单个镜像处理时间过长
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
|
||||
cancel()
|
||||
@@ -825,10 +803,8 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
||||
return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
|
||||
}
|
||||
|
||||
// 收集manifest信息
|
||||
allManifests = append(allManifests, manifest)
|
||||
|
||||
// 合并repositories信息
|
||||
for repo, tags := range repositories {
|
||||
if allRepositories[repo] == nil {
|
||||
allRepositories[repo] = make(map[string]string)
|
||||
@@ -839,7 +815,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
||||
}
|
||||
}
|
||||
|
||||
// 写入合并的manifest.json
|
||||
manifestData, err := json.Marshal(allManifests)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化manifest失败: %w", err)
|
||||
@@ -859,7 +834,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
||||
return fmt.Errorf("写入manifest数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入合并的repositories文件
|
||||
repositoriesData, err := json.Marshal(allRepositories)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化repositories失败: %w", err)
|
||||
300
src/main.go
@@ -3,15 +3,17 @@ package main
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
"hubproxy/config"
|
||||
"hubproxy/handlers"
|
||||
"hubproxy/utils"
|
||||
)
|
||||
|
||||
//go:embed public/*
|
||||
@@ -32,19 +34,7 @@ func serveEmbedFile(c *gin.Context, filename string) {
|
||||
}
|
||||
|
||||
var (
|
||||
exps = []*regexp.Regexp{
|
||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
|
||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
|
||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
|
||||
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
|
||||
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
|
||||
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
|
||||
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`),
|
||||
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
|
||||
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
|
||||
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
|
||||
}
|
||||
globalLimiter *IPRateLimiter
|
||||
globalLimiter *utils.IPRateLimiter
|
||||
|
||||
// 服务启动时间
|
||||
serviceStartTime = time.Now()
|
||||
@@ -52,25 +42,25 @@ var (
|
||||
|
||||
func main() {
|
||||
// 加载配置
|
||||
if err := LoadConfig(); err != nil {
|
||||
if err := config.LoadConfig(); err != nil {
|
||||
fmt.Printf("配置加载失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化HTTP客户端
|
||||
initHTTPClients()
|
||||
utils.InitHTTPClients()
|
||||
|
||||
// 初始化限流器
|
||||
initLimiter()
|
||||
globalLimiter = utils.InitGlobalLimiter()
|
||||
|
||||
// 初始化Docker流式代理
|
||||
initDockerProxy()
|
||||
handlers.InitDockerProxy()
|
||||
|
||||
// 初始化镜像流式下载器
|
||||
initImageStreamer()
|
||||
handlers.InitImageStreamer()
|
||||
|
||||
// 初始化防抖器
|
||||
initDebouncer()
|
||||
handlers.InitDebouncer()
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.Default()
|
||||
@@ -84,14 +74,14 @@ func main() {
|
||||
})
|
||||
}))
|
||||
|
||||
// 全局限流中间件 - 应用到所有路由
|
||||
router.Use(RateLimitMiddleware(globalLimiter))
|
||||
// 全局限流中间件
|
||||
router.Use(utils.RateLimitMiddleware(globalLimiter))
|
||||
|
||||
// 初始化监控端点
|
||||
initHealthRoutes(router)
|
||||
|
||||
// 初始化镜像tar下载路由
|
||||
initImageTarRoutes(router)
|
||||
handlers.InitImageTarRoutes(router)
|
||||
|
||||
// 静态文件路由
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
@@ -113,217 +103,59 @@ func main() {
|
||||
})
|
||||
|
||||
// 注册dockerhub搜索路由
|
||||
RegisterSearchRoute(router)
|
||||
handlers.RegisterSearchRoute(router)
|
||||
|
||||
// 注册Docker认证路由(/token*)
|
||||
router.Any("/token", ProxyDockerAuthGin)
|
||||
router.Any("/token/*path", ProxyDockerAuthGin)
|
||||
// 注册Docker认证路由
|
||||
router.Any("/token", handlers.ProxyDockerAuthGin)
|
||||
router.Any("/token/*path", handlers.ProxyDockerAuthGin)
|
||||
|
||||
// 注册Docker Registry代理路由
|
||||
router.Any("/v2/*path", ProxyDockerRegistryGin)
|
||||
router.Any("/v2/*path", handlers.ProxyDockerRegistryGin)
|
||||
|
||||
// 注册NoRoute处理器
|
||||
router.NoRoute(handler)
|
||||
// 注册GitHub代理路由(NoRoute处理器)
|
||||
router.NoRoute(handlers.GitHubProxyHandler)
|
||||
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
fmt.Printf("🚀 HubProxy 启动成功\n")
|
||||
fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
||||
fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
|
||||
|
||||
// 显示HTTP/2支持状态
|
||||
if cfg.Server.EnableH2C {
|
||||
fmt.Printf("H2c: 已启用\n")
|
||||
}
|
||||
|
||||
fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
|
||||
|
||||
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
|
||||
// 创建HTTP2服务器
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 300 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
// 根据配置决定是否启用H2C
|
||||
if cfg.Server.EnableH2C {
|
||||
h2cHandler := h2c.NewHandler(router, &http2.Server{
|
||||
MaxConcurrentStreams: 250,
|
||||
IdleTimeout: 300 * time.Second,
|
||||
MaxReadFrameSize: 4 << 20,
|
||||
MaxUploadBufferPerConnection: 8 << 20,
|
||||
MaxUploadBufferPerStream: 2 << 20,
|
||||
})
|
||||
server.Handler = h2cHandler
|
||||
} else {
|
||||
server.Handler = router
|
||||
}
|
||||
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
fmt.Printf("启动服务失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func handler(c *gin.Context) {
|
||||
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
|
||||
|
||||
for strings.HasPrefix(rawPath, "/") {
|
||||
rawPath = strings.TrimPrefix(rawPath, "/")
|
||||
}
|
||||
// 自动补全协议头
|
||||
if !strings.HasPrefix(rawPath, "https://") {
|
||||
// 修复 http:/ 和 https:/ 的情况
|
||||
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
|
||||
rawPath = strings.Replace(rawPath, "http:/", "", 1)
|
||||
rawPath = strings.Replace(rawPath, "https:/", "", 1)
|
||||
} else if strings.HasPrefix(rawPath, "http://") {
|
||||
rawPath = strings.TrimPrefix(rawPath, "http://")
|
||||
}
|
||||
rawPath = "https://" + rawPath
|
||||
}
|
||||
|
||||
matches := checkURL(rawPath)
|
||||
if matches != nil {
|
||||
// GitHub仓库访问控制检查
|
||||
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
|
||||
// 构建仓库名用于日志
|
||||
var repoPath string
|
||||
if len(matches) >= 2 {
|
||||
username := matches[0]
|
||||
repoName := strings.TrimSuffix(matches[1], ".git")
|
||||
repoPath = username + "/" + repoName
|
||||
}
|
||||
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
|
||||
c.String(http.StatusForbidden, reason)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.String(http.StatusForbidden, "无效输入")
|
||||
return
|
||||
}
|
||||
|
||||
if exps[1].MatchString(rawPath) {
|
||||
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
|
||||
}
|
||||
|
||||
proxyRequest(c, rawPath)
|
||||
}
|
||||
|
||||
func proxyRequest(c *gin.Context, u string) {
|
||||
proxyWithRedirect(c, u, 0)
|
||||
}
|
||||
|
||||
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
|
||||
// 限制最大重定向次数,防止无限递归
|
||||
const maxRedirects = 20
|
||||
if redirectCount > maxRedirects {
|
||||
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
|
||||
return
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range c.Request.Header {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
req.Header.Del("Host")
|
||||
|
||||
resp, err := GetGlobalHTTPClient().Do(req)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
fmt.Printf("关闭响应体失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 检查文件大小限制
|
||||
cfg := GetConfig()
|
||||
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
|
||||
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
|
||||
c.String(http.StatusRequestEntityTooLarge,
|
||||
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 清理安全相关的头
|
||||
resp.Header.Del("Content-Security-Policy")
|
||||
resp.Header.Del("Referrer-Policy")
|
||||
resp.Header.Del("Strict-Transport-Security")
|
||||
|
||||
// 获取真实域名
|
||||
realHost := c.Request.Header.Get("X-Forwarded-Host")
|
||||
if realHost == "" {
|
||||
realHost = c.Request.Host
|
||||
}
|
||||
// 如果域名中没有协议前缀,添加https://
|
||||
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
|
||||
realHost = "https://" + realHost
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(u), ".sh") {
|
||||
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
|
||||
|
||||
processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost)
|
||||
if err != nil {
|
||||
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
|
||||
processedBody = resp.Body
|
||||
processedSize = 0
|
||||
}
|
||||
|
||||
// 智能设置响应头
|
||||
if processedSize > 0 {
|
||||
resp.Header.Del("Content-Length")
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
// 复制其他响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if location := resp.Header.Get("Location"); location != "" {
|
||||
if checkURL(location) != nil {
|
||||
c.Header("Location", "/"+location)
|
||||
} else {
|
||||
proxyWithRedirect(c, location, redirectCount+1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 输出处理后的内容
|
||||
if _, err := io.Copy(c.Writer, processedBody); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理重定向
|
||||
if location := resp.Header.Get("Location"); location != "" {
|
||||
if checkURL(location) != nil {
|
||||
c.Header("Location", "/"+location)
|
||||
} else {
|
||||
proxyWithRedirect(c, location, redirectCount+1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 直接流式转发
|
||||
io.Copy(c.Writer, resp.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func checkURL(u string) []string {
|
||||
for _, exp := range exps {
|
||||
if matches := exp.FindStringSubmatch(u); matches != nil {
|
||||
return matches[1:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 简单的健康检查
|
||||
func formatBeijingTime(t time.Time) string {
|
||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||
if err != nil {
|
||||
loc = time.FixedZone("CST", 8*3600) // 兜底时区
|
||||
}
|
||||
return t.In(loc).Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// 转换为可读时间
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%d秒", int(d.Seconds()))
|
||||
@@ -338,26 +170,20 @@ func formatDuration(d time.Duration) string {
|
||||
}
|
||||
}
|
||||
|
||||
func initHealthRoutes(router *gin.Engine) {
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
uptime := time.Since(serviceStartTime)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"timestamp_unix": serviceStartTime.Unix(),
|
||||
"uptime_sec": uptime.Seconds(),
|
||||
"service": "hubproxy",
|
||||
"start_time_bj": formatBeijingTime(serviceStartTime),
|
||||
"uptime_human": formatDuration(uptime),
|
||||
})
|
||||
})
|
||||
func getUptimeInfo() (time.Duration, float64, string) {
|
||||
uptime := time.Since(serviceStartTime)
|
||||
return uptime, uptime.Seconds(), formatDuration(uptime)
|
||||
}
|
||||
|
||||
func initHealthRoutes(router *gin.Engine) {
|
||||
router.GET("/ready", func(c *gin.Context) {
|
||||
uptime := time.Since(serviceStartTime)
|
||||
_, uptimeSec, uptimeHuman := getUptimeInfo()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ready": true,
|
||||
"timestamp_unix": time.Now().Unix(),
|
||||
"uptime_sec": uptime.Seconds(),
|
||||
"uptime_human": formatDuration(uptime),
|
||||
"ready": true,
|
||||
"service": "hubproxy",
|
||||
"start_time_unix": serviceStartTime.Unix(),
|
||||
"uptime_sec": uptimeSec,
|
||||
"uptime_human": uptimeHuman,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package main
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"hubproxy/config"
|
||||
)
|
||||
|
||||
// ResourceType 资源类型
|
||||
@@ -26,7 +28,7 @@ type DockerImageInfo struct {
|
||||
FullName string
|
||||
}
|
||||
|
||||
// 全局访问控制器实例
|
||||
// GlobalAccessController 全局访问控制器实例
|
||||
var GlobalAccessController = &AccessController{}
|
||||
|
||||
// ParseDockerImage 解析Docker镜像名称
|
||||
@@ -79,19 +81,16 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
|
||||
|
||||
// CheckDockerAccess 检查Docker镜像访问权限
|
||||
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
|
||||
// 解析镜像名称
|
||||
imageInfo := ac.ParseDockerImage(image)
|
||||
|
||||
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像)
|
||||
if len(cfg.Access.WhiteList) > 0 {
|
||||
if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
|
||||
return false, "不在Docker镜像白名单内"
|
||||
}
|
||||
}
|
||||
|
||||
// 检查黑名单
|
||||
if len(cfg.Access.BlackList) > 0 {
|
||||
if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
|
||||
return false, "Docker镜像在黑名单内"
|
||||
@@ -107,14 +106,12 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r
|
||||
return false, "无效的GitHub仓库格式"
|
||||
}
|
||||
|
||||
cfg := GetConfig()
|
||||
cfg := config.GetConfig()
|
||||
|
||||
// 检查白名单
|
||||
if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
|
||||
return false, "不在GitHub仓库白名单内"
|
||||
}
|
||||
|
||||
// 检查黑名单
|
||||
if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
|
||||
return false, "GitHub仓库在黑名单内"
|
||||
}
|
||||
@@ -185,17 +182,14 @@ func (ac *AccessController) checkList(matches, list []string) bool {
|
||||
continue
|
||||
}
|
||||
|
||||
// 支持多种匹配模式
|
||||
if fullRepo == item {
|
||||
return true
|
||||
}
|
||||
|
||||
// 用户级匹配
|
||||
if item == username || item == username+"/*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 前缀匹配(支持通配符)
|
||||
if strings.HasSuffix(item, "*") {
|
||||
prefix := strings.TrimSuffix(item, "*")
|
||||
if strings.HasPrefix(fullRepo, prefix) {
|
||||
@@ -203,7 +197,6 @@ func (ac *AccessController) checkList(matches, list []string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork)
|
||||
if strings.HasPrefix(fullRepo, item+"/") {
|
||||
return true
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
@@ -9,22 +9,23 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"hubproxy/config"
|
||||
)
|
||||
|
||||
// CachedItem 通用缓存项,支持Token和Manifest
|
||||
// CachedItem 通用缓存项
|
||||
type CachedItem struct {
|
||||
Data []byte // 缓存数据(token字符串或manifest字节)
|
||||
ContentType string // 内容类型
|
||||
Headers map[string]string // 额外的响应头
|
||||
ExpiresAt time.Time // 过期时间
|
||||
Data []byte
|
||||
ContentType string
|
||||
Headers map[string]string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// UniversalCache 通用缓存,支持Token和Manifest
|
||||
// UniversalCache 通用缓存
|
||||
type UniversalCache struct {
|
||||
cache sync.Map
|
||||
}
|
||||
|
||||
var globalCache = &UniversalCache{}
|
||||
var GlobalCache = &UniversalCache{}
|
||||
|
||||
// Get 获取缓存项
|
||||
func (c *UniversalCache) Get(key string) *CachedItem {
|
||||
@@ -57,22 +58,22 @@ func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) {
|
||||
c.Set(key, []byte(token), "application/json", nil, ttl)
|
||||
}
|
||||
|
||||
// buildCacheKey 构建稳定的缓存key
|
||||
func buildCacheKey(prefix, query string) string {
|
||||
// BuildCacheKey 构建稳定的缓存key
|
||||
func BuildCacheKey(prefix, query string) string {
|
||||
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
|
||||
}
|
||||
|
||||
func buildTokenCacheKey(query string) string {
|
||||
return buildCacheKey("token", query)
|
||||
func BuildTokenCacheKey(query string) string {
|
||||
return BuildCacheKey("token", query)
|
||||
}
|
||||
|
||||
func buildManifestCacheKey(imageRef, reference string) string {
|
||||
func BuildManifestCacheKey(imageRef, reference string) string {
|
||||
key := fmt.Sprintf("%s:%s", imageRef, reference)
|
||||
return buildCacheKey("manifest", key)
|
||||
return BuildCacheKey("manifest", key)
|
||||
}
|
||||
|
||||
func getManifestTTL(reference string) time.Duration {
|
||||
cfg := GetConfig()
|
||||
func GetManifestTTL(reference string) time.Duration {
|
||||
cfg := config.GetConfig()
|
||||
defaultTTL := 30 * time.Minute
|
||||
if cfg.TokenCache.DefaultTTL != "" {
|
||||
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
|
||||
@@ -84,23 +85,20 @@ func getManifestTTL(reference string) time.Duration {
|
||||
return 24 * time.Hour
|
||||
}
|
||||
|
||||
// mutable tag的智能判断
|
||||
if reference == "latest" || reference == "main" || reference == "master" ||
|
||||
reference == "dev" || reference == "develop" {
|
||||
// 热门可变标签: 短期缓存
|
||||
return 10 * time.Minute
|
||||
}
|
||||
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
// extractTTLFromResponse 从响应中智能提取TTL
|
||||
func extractTTLFromResponse(responseBody []byte) time.Duration {
|
||||
// ExtractTTLFromResponse 从响应中智能提取TTL
|
||||
func ExtractTTLFromResponse(responseBody []byte) time.Duration {
|
||||
var tokenResp struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// 默认30分钟TTL,确保稳定性
|
||||
defaultTTL := 30 * time.Minute
|
||||
|
||||
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
|
||||
@@ -113,37 +111,35 @@ func extractTTLFromResponse(responseBody []byte) time.Duration {
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
func writeTokenResponse(c *gin.Context, cachedBody string) {
|
||||
func WriteTokenResponse(c *gin.Context, cachedBody string) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.String(200, cachedBody)
|
||||
}
|
||||
|
||||
func writeCachedResponse(c *gin.Context, item *CachedItem) {
|
||||
func WriteCachedResponse(c *gin.Context, item *CachedItem) {
|
||||
if item.ContentType != "" {
|
||||
c.Header("Content-Type", item.ContentType)
|
||||
}
|
||||
|
||||
// 设置额外的响应头
|
||||
for key, value := range item.Headers {
|
||||
c.Header(key, value)
|
||||
}
|
||||
|
||||
// 返回数据
|
||||
c.Data(200, item.ContentType, item.Data)
|
||||
}
|
||||
|
||||
// isCacheEnabled 检查缓存是否启用
|
||||
func isCacheEnabled() bool {
|
||||
cfg := GetConfig()
|
||||
// IsCacheEnabled 检查缓存是否启用
|
||||
func IsCacheEnabled() bool {
|
||||
cfg := config.GetConfig()
|
||||
return cfg.TokenCache.Enabled
|
||||
}
|
||||
|
||||
// isTokenCacheEnabled 检查token缓存是否启用(向后兼容)
|
||||
func isTokenCacheEnabled() bool {
|
||||
return isCacheEnabled()
|
||||
// IsTokenCacheEnabled 检查token缓存是否启用
|
||||
func IsTokenCacheEnabled() bool {
|
||||
return IsCacheEnabled()
|
||||
}
|
||||
|
||||
// 定期清理过期缓存,防止内存泄漏
|
||||
// 定期清理过期缓存
|
||||
func init() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(20 * time.Minute)
|
||||
@@ -153,7 +149,7 @@ func init() {
|
||||
now := time.Now()
|
||||
expiredKeys := make([]string, 0)
|
||||
|
||||
globalCache.cache.Range(func(key, value interface{}) bool {
|
||||
GlobalCache.cache.Range(func(key, value interface{}) bool {
|
||||
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
|
||||
expiredKeys = append(expiredKeys, key.(string))
|
||||
}
|
||||
@@ -161,7 +157,7 @@ func init() {
|
||||
})
|
||||
|
||||
for _, key := range expiredKeys {
|
||||
globalCache.cache.Delete(key)
|
||||
GlobalCache.cache.Delete(key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -1,28 +1,28 @@
|
||||
package main
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"hubproxy/config"
|
||||
)
|
||||
|
||||
var (
|
||||
// 全局HTTP客户端 - 用于代理请求(长超时)
|
||||
globalHTTPClient *http.Client
|
||||
// 搜索HTTP客户端 - 用于API请求(短超时)
|
||||
searchHTTPClient *http.Client
|
||||
)
|
||||
|
||||
// initHTTPClients 初始化HTTP客户端
|
||||
func initHTTPClients() {
|
||||
cfg := GetConfig()
|
||||
// InitHTTPClients 初始化HTTP客户端
|
||||
func InitHTTPClients() {
|
||||
cfg := config.GetConfig()
|
||||
|
||||
if p := cfg.Access.Proxy; p != "" {
|
||||
os.Setenv("HTTP_PROXY", p)
|
||||
os.Setenv("HTTPS_PROXY", p)
|
||||
}
|
||||
// 代理客户端配置 - 适用于大文件传输
|
||||
|
||||
globalHTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -39,7 +39,6 @@ func initHTTPClients() {
|
||||
},
|
||||
}
|
||||
|
||||
// 搜索客户端配置 - 适用于API调用
|
||||
searchHTTPClient = &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
@@ -57,12 +56,12 @@ func initHTTPClients() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理)
|
||||
// GetGlobalHTTPClient 获取全局HTTP客户端
|
||||
func GetGlobalHTTPClient() *http.Client {
|
||||
return globalHTTPClient
|
||||
}
|
||||
|
||||
// GetSearchHTTPClient 获取搜索HTTP客户端(用于API调用)
|
||||
// GetSearchHTTPClient 获取搜索HTTP客户端
|
||||
func GetSearchHTTPClient() *http.Client {
|
||||
return searchHTTPClient
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -41,7 +41,6 @@ func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reade
|
||||
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
|
||||
var reader io.Reader = input
|
||||
|
||||
// 处理gzip压缩
|
||||
if isCompressed {
|
||||
peek := make([]byte, 2)
|
||||
n, err := input.Read(peek)
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -9,22 +9,23 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
"hubproxy/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// 清理间隔
|
||||
CleanupInterval = 10 * time.Minute
|
||||
MaxIPCacheSize = 10000
|
||||
)
|
||||
|
||||
// IPRateLimiter IP限流器结构体
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rateLimiterEntry // IP到限流器的映射
|
||||
mu *sync.RWMutex // 读写锁,保证并发安全
|
||||
r rate.Limit // 速率限制(每秒允许的请求数)
|
||||
b int // 令牌桶容量(突发请求数)
|
||||
whitelist []*net.IPNet // 白名单IP段
|
||||
blacklist []*net.IPNet // 黑名单IP段
|
||||
ips map[string]*rateLimiterEntry
|
||||
mu *sync.RWMutex
|
||||
r rate.Limit
|
||||
b int
|
||||
whitelist []*net.IPNet
|
||||
blacklist []*net.IPNet
|
||||
whitelistLimiter *rate.Limiter // 全局共享的白名单限流器
|
||||
}
|
||||
|
||||
// rateLimiterEntry 限流器条目
|
||||
@@ -33,15 +34,15 @@ type rateLimiterEntry struct {
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
// initGlobalLimiter 初始化全局限流器
|
||||
func initGlobalLimiter() *IPRateLimiter {
|
||||
cfg := GetConfig()
|
||||
// InitGlobalLimiter 初始化全局限流器
|
||||
func InitGlobalLimiter() *IPRateLimiter {
|
||||
cfg := config.GetConfig()
|
||||
|
||||
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
|
||||
for _, item := range cfg.Security.WhiteList {
|
||||
if item = strings.TrimSpace(item); item != "" {
|
||||
if !strings.Contains(item, "/") {
|
||||
item = item + "/32" // 单个IP转为CIDR格式
|
||||
item = item + "/32"
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(item)
|
||||
if err == nil {
|
||||
@@ -52,12 +53,11 @@ func initGlobalLimiter() *IPRateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
// 解析黑名单IP段
|
||||
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
|
||||
for _, item := range cfg.Security.BlackList {
|
||||
if item = strings.TrimSpace(item); item != "" {
|
||||
if !strings.Contains(item, "/") {
|
||||
item = item + "/32" // 单个IP转为CIDR格式
|
||||
item = item + "/32"
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(item)
|
||||
if err == nil {
|
||||
@@ -68,7 +68,6 @@ func initGlobalLimiter() *IPRateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
|
||||
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
|
||||
|
||||
burstSize := cfg.RateLimit.RequestLimit
|
||||
@@ -77,25 +76,20 @@ func initGlobalLimiter() *IPRateLimiter {
|
||||
}
|
||||
|
||||
limiter := &IPRateLimiter{
|
||||
ips: make(map[string]*rateLimiterEntry),
|
||||
mu: &sync.RWMutex{},
|
||||
r: ratePerSecond,
|
||||
b: burstSize,
|
||||
whitelist: whitelist,
|
||||
blacklist: blacklist,
|
||||
ips: make(map[string]*rateLimiterEntry),
|
||||
mu: &sync.RWMutex{},
|
||||
r: ratePerSecond,
|
||||
b: burstSize,
|
||||
whitelist: whitelist,
|
||||
blacklist: blacklist,
|
||||
whitelistLimiter: rate.NewLimiter(rate.Inf, burstSize),
|
||||
}
|
||||
|
||||
// 启动定期清理goroutine
|
||||
go limiter.cleanupRoutine()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// initLimiter 初始化限流器
|
||||
func initLimiter() {
|
||||
globalLimiter = initGlobalLimiter()
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期清理过期的限流器
|
||||
func (i *IPRateLimiter) cleanupRoutine() {
|
||||
ticker := time.NewTicker(CleanupInterval)
|
||||
@@ -105,25 +99,20 @@ func (i *IPRateLimiter) cleanupRoutine() {
|
||||
now := time.Now()
|
||||
expired := make([]string, 0)
|
||||
|
||||
// 查找过期的条目
|
||||
i.mu.RLock()
|
||||
for ip, entry := range i.ips {
|
||||
// 如果最后访问时间超过1小时,认为过期
|
||||
if now.Sub(entry.lastAccess) > 1*time.Hour {
|
||||
expired = append(expired, ip)
|
||||
}
|
||||
}
|
||||
i.mu.RUnlock()
|
||||
|
||||
// 如果有过期条目或者缓存过大,进行清理
|
||||
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
|
||||
i.mu.Lock()
|
||||
// 删除过期条目
|
||||
for _, ip := range expired {
|
||||
delete(i.ips, ip)
|
||||
}
|
||||
|
||||
// 如果缓存仍然过大,全部清理
|
||||
if len(i.ips) > MaxIPCacheSize {
|
||||
i.ips = make(map[string]*rateLimiterEntry)
|
||||
}
|
||||
@@ -140,28 +129,26 @@ func extractIPFromAddress(address string) string {
|
||||
return address
|
||||
}
|
||||
|
||||
// normalizeIPForRateLimit 标准化IP地址用于限流:IPv4保持不变,IPv6标准化为/64网段
|
||||
// normalizeIPForRateLimit 标准化IP地址用于限流
|
||||
func normalizeIPForRateLimit(ipStr string) string {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return ipStr // 解析失败,返回原值
|
||||
return ipStr
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
return ipStr // IPv4保持不变
|
||||
return ipStr
|
||||
}
|
||||
|
||||
// IPv6:标准化为 /64 网段
|
||||
ipv6 := ip.To16()
|
||||
for i := 8; i < 16; i++ {
|
||||
ipv6[i] = 0 // 清零后64位
|
||||
ipv6[i] = 0
|
||||
}
|
||||
return ipv6.String() + "/64"
|
||||
}
|
||||
|
||||
// isIPInCIDRList 检查IP是否在CIDR列表中
|
||||
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
||||
// 先提取纯IP地址
|
||||
cleanIP := extractIPFromAddress(ip)
|
||||
parsedIP := net.ParseIP(cleanIP)
|
||||
if parsedIP == nil {
|
||||
@@ -176,22 +163,18 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问
|
||||
// GetLimiter 获取指定IP的限流器
|
||||
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
||||
// 提取纯IP地址
|
||||
cleanIP := extractIPFromAddress(ip)
|
||||
|
||||
// 检查是否在黑名单中
|
||||
if isIPInCIDRList(cleanIP, i.blacklist) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 检查是否在白名单中
|
||||
if isIPInCIDRList(cleanIP, i.whitelist) {
|
||||
return rate.NewLimiter(rate.Inf, i.b), true
|
||||
return i.whitelistLimiter, true
|
||||
}
|
||||
|
||||
// 标准化IP用于限流:IPv4保持不变,IPv6标准化为/64网段
|
||||
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
||||
|
||||
now := time.Now()
|
||||
@@ -230,7 +213,6 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
||||
// RateLimitMiddleware 速率限制中间件
|
||||
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 静态文件豁免:跳过限流检查
|
||||
path := c.Request.URL.Path
|
||||
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
|
||||
strings.HasPrefix(path, "/public/") {
|
||||
@@ -238,30 +220,22 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取客户端真实IP
|
||||
var ip string
|
||||
|
||||
// 优先尝试从请求头获取真实IP
|
||||
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
|
||||
// X-Forwarded-For可能包含多个IP,取第一个
|
||||
ips := strings.Split(forwarded, ",")
|
||||
ip = strings.TrimSpace(ips[0])
|
||||
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
|
||||
// 如果有X-Real-IP头
|
||||
ip = realIP
|
||||
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
|
||||
// 某些代理可能使用此头
|
||||
ips := strings.Split(remoteIP, ",")
|
||||
ip = strings.TrimSpace(ips[0])
|
||||
} else {
|
||||
// 回退到ClientIP方法
|
||||
ip = c.ClientIP()
|
||||
}
|
||||
|
||||
// 提取纯IP地址(去除可能存在的端口)
|
||||
cleanIP := extractIPFromAddress(ip)
|
||||
|
||||
// 日志记录请求IP和头信息
|
||||
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
||||
if cleanIP != normalizedIP {
|
||||
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
|
||||
@@ -275,10 +249,8 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
c.GetHeader("X-Real-IP"))
|
||||
}
|
||||
|
||||
// 获取限流器并检查是否允许访问
|
||||
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
|
||||
|
||||
// 如果IP在黑名单中
|
||||
if !allowed {
|
||||
c.JSON(403, gin.H{
|
||||
"error": "您已被限制访问",
|
||||
@@ -287,7 +259,6 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查限流
|
||||
if !ipLimiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"error": "请求频率过快,暂时限制访问",
|
||||
🛠️ Refactor suggestion
Consider filtering sensitive request headers.
While copying headers, consider filtering out potentially sensitive or problematic headers:
Authorization(unless intentionally proxying auth)CookieandSet-CookieX-Forwarded-*headers that might conflict📝 Committable suggestion
🤖 Prompt for AI Agents