Merge pull request #39 from sky22333/dev

优化代码结构,支持h2
This commit was merged in pull request #39.
This commit is contained in:
starry
2025-07-27 12:11:39 +08:00
committed by GitHub
16 changed files with 1024 additions and 1189 deletions

View File

@@ -37,7 +37,7 @@ docker run -d \
### 一键脚本安装
```bash
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install-service.sh | sudo bash
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install.sh | sudo bash
```
也可以直接下载二进制文件执行`./hubproxy`使用无需配置文件即可启动内置默认配置支持所有功能。初始内存占用约18M二进制文件大小约12M
@@ -109,12 +109,14 @@ host = "0.0.0.0"
port = 5000
# Github文件大小限制字节默认2GB
fileSize = 2147483648
# HTTP/2 多路复用
enableH2C = false
[rateLimit]
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层会消耗多个次数)
requestLimit = 500
# 限流周期(小时)
periodHours = 1.0
periodHours = 3.0
[security]
# IP白名单支持单个IP或IP段
@@ -132,7 +134,7 @@ blackList = [
"192.168.100.0/24"
]
[proxy]
[access]
# 代理服务白名单支持GitHub仓库和Docker镜像支持通配符
# 只允许访问白名单中的仓库/镜像,为空时不限制
whiteList = []
@@ -148,12 +150,6 @@ blackList = [
# 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080
# HTTP 代理示例
# http://username:password@127.0.0.1:7890
# SOCKS5 代理示例
# socks5://username:password@127.0.0.1:1080
# SOCKS5H 代理示例
# socks5h://username:password@127.0.0.1:1080
# 留空不使用代理
proxy = ""

View File

@@ -1,8 +1,8 @@
services:
hubproxy:
build: .
restart: always
ports:
- '5000:5000'
volumes:
- ./src/config.toml:/root/config.toml
hubproxy:
build: .
restart: always
ports:
- '5000:5000'
volumes:
- ./src/config.toml:/root/config.toml

View File

@@ -4,12 +4,14 @@ host = "0.0.0.0"
port = 5000
# Github文件大小限制字节默认2GB
fileSize = 2147483648
# HTTP/2 多路复用
enableH2C = false
[rateLimit]
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层会消耗多个次数)
requestLimit = 500
# 限流周期(小时)
periodHours = 1.0
periodHours = 3.0
[security]
# IP白名单支持单个IP或IP段
@@ -43,12 +45,6 @@ blackList = [
# 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080
# HTTP 代理示例
# http://username:password@127.0.0.1:7890
# SOCKS5 代理示例
# socks5://username:password@127.0.0.1:1080
# SOCKS5H 代理示例
# socks5h://username:password@127.0.0.1:1080
# 留空不使用代理
proxy = ""

View File

@@ -1,4 +1,4 @@
package main
package config
import (
"fmt"
@@ -13,45 +13,46 @@ import (
// RegistryMapping Registry映射配置
type RegistryMapping struct {
Upstream string `toml:"upstream"` // 上游Registry地址
AuthHost string `toml:"authHost"` // 认证服务器地址
AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic
Enabled bool `toml:"enabled"` // 是否启用
Upstream string `toml:"upstream"`
AuthHost string `toml:"authHost"`
AuthType string `toml:"authType"`
Enabled bool `toml:"enabled"`
}
// AppConfig 应用配置结构体
type AppConfig struct {
Server struct {
Host string `toml:"host"` // 监听地址
Port int `toml:"port"` // 监听端口
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节)
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
EnableH2C bool `toml:"enableH2C"`
} `toml:"server"`
RateLimit struct {
RequestLimit int `toml:"requestLimit"` // 每小时请求限制
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时)
RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"`
} `toml:"rateLimit"`
Security struct {
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表
WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"`
} `toml:"security"`
Access struct {
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别)
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别)
Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h
WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"`
Proxy string `toml:"proxy"`
} `toml:"access"`
Download struct {
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制
MaxImages int `toml:"maxImages"`
} `toml:"download"`
Registries map[string]RegistryMapping `toml:"registries"`
TokenCache struct {
Enabled bool `toml:"enabled"` // 是否启用token缓存
DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间
Enabled bool `toml:"enabled"`
DefaultTTL string `toml:"defaultTTL"`
} `toml:"tokenCache"`
}
@@ -65,24 +66,25 @@ var (
configCacheMutex sync.RWMutex
)
// todo:Refactoring is needed
// DefaultConfig 返回默认配置
func DefaultConfig() *AppConfig {
return &AppConfig{
Server: struct {
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
EnableH2C bool `toml:"enableH2C"`
}{
Host: "0.0.0.0",
Port: 5000,
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
Host: "0.0.0.0",
Port: 5000,
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
EnableH2C: false, // 默认关闭H2C
},
RateLimit: struct {
RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"`
}{
RequestLimit: 20,
RequestLimit: 200,
PeriodHours: 1.0,
},
Security: struct {
@@ -99,12 +101,12 @@ func DefaultConfig() *AppConfig {
}{
WhiteList: []string{},
BlackList: []string{},
Proxy: "", // 默认不使用代理
Proxy: "",
},
Download: struct {
MaxImages int `toml:"maxImages"`
}{
MaxImages: 10, // 默认值最多同时下载10个镜像
MaxImages: 10,
},
Registries: map[string]RegistryMapping{
"ghcr.io": {
@@ -136,7 +138,7 @@ func DefaultConfig() *AppConfig {
Enabled bool `toml:"enabled"`
DefaultTTL string `toml:"defaultTTL"`
}{
Enabled: true, // docker认证的匿名Token缓存配置用于提升性能
Enabled: true,
DefaultTTL: "20m",
},
}
@@ -152,11 +154,9 @@ func GetConfig() *AppConfig {
}
configCacheMutex.RUnlock()
// 缓存过期,重新生成配置
configCacheMutex.Lock()
defer configCacheMutex.Unlock()
// 双重检查,防止重复生成
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
return cachedConfig
}
@@ -170,7 +170,6 @@ func GetConfig() *AppConfig {
return defaultCfg
}
// 生成新的配置深拷贝
configCopy := *appConfig
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
@@ -197,10 +196,8 @@ func setConfig(cfg *AppConfig) {
// LoadConfig 加载配置文件
func LoadConfig() error {
// 首先使用默认配置
cfg := DefaultConfig()
// 尝试加载TOML配置文件
if data, err := os.ReadFile("config.toml"); err == nil {
if err := toml.Unmarshal(data, cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %v", err)
@@ -209,10 +206,7 @@ func LoadConfig() error {
fmt.Println("未找到config.toml使用默认配置")
}
// 从环境变量覆盖配置
overrideFromEnv(cfg)
// 设置配置
setConfig(cfg)
return nil
@@ -220,7 +214,6 @@ func LoadConfig() error {
// overrideFromEnv 从环境变量覆盖配置
func overrideFromEnv(cfg *AppConfig) {
// 服务器配置
if val := os.Getenv("SERVER_HOST"); val != "" {
cfg.Server.Host = val
}
@@ -229,13 +222,17 @@ func overrideFromEnv(cfg *AppConfig) {
cfg.Server.Port = port
}
}
if val := os.Getenv("ENABLE_H2C"); val != "" {
if enable, err := strconv.ParseBool(val); err == nil {
cfg.Server.EnableH2C = enable
}
}
if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
cfg.Server.FileSize = size
}
}
// 限流配置
if val := os.Getenv("RATE_LIMIT"); val != "" {
if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
cfg.RateLimit.RequestLimit = limit
@@ -247,7 +244,6 @@ func overrideFromEnv(cfg *AppConfig) {
}
}
// IP限制配置
if val := os.Getenv("IP_WHITELIST"); val != "" {
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
}
@@ -255,7 +251,6 @@ func overrideFromEnv(cfg *AppConfig) {
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
}
// 下载限制配置
if val := os.Getenv("MAX_IMAGES"); val != "" {
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
cfg.Download.MaxImages = maxImages

View File

@@ -6,6 +6,7 @@ require (
github.com/gin-gonic/gin v1.10.0
github.com/google/go-containerregistry v0.20.5
github.com/pelletier/go-toml/v2 v2.2.3
golang.org/x/net v0.33.0
golang.org/x/time v0.11.0
)
@@ -43,7 +44,6 @@ require (
github.com/vbatts/tar-split v0.12.1 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.21.0 // indirect

View File

@@ -1,4 +1,4 @@
package main
package handlers
import (
"context"
@@ -12,6 +12,8 @@ import (
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"hubproxy/config"
"hubproxy/utils"
)
// DockerProxy Docker代理配置
@@ -27,12 +29,10 @@ type RegistryDetector struct{}
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
cfg := GetConfig()
cfg := config.GetConfig()
// 检查路径是否以已知Registry域名开头
for domain := range cfg.Registries {
if strings.HasPrefix(path, domain+"/") {
// 找到匹配的域名,返回域名和剩余路径
remainingPath := strings.TrimPrefix(path, domain+"/")
return domain, remainingPath
}
@@ -43,7 +43,7 @@ func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
// isRegistryEnabled 检查Registry是否启用
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
cfg := GetConfig()
cfg := config.GetConfig()
if mapping, exists := cfg.Registries[domain]; exists {
return mapping.Enabled
}
@@ -51,28 +51,26 @@ func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
}
// getRegistryMapping 获取Registry映射配置
func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) {
cfg := GetConfig()
func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
cfg := config.GetConfig()
mapping, exists := cfg.Registries[domain]
return mapping, exists && mapping.Enabled
}
var registryDetector = &RegistryDetector{}
// 初始化Docker代理
func initDockerProxy() {
// 创建目标registry
// InitDockerProxy 初始化Docker代理
func InitDockerProxy() {
registry, err := name.NewRegistry("registry-1.docker.io")
if err != nil {
fmt.Printf("创建Docker registry失败: %v\n", err)
return
}
// 配置代理选项
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(GetGlobalHTTPClient().Transport),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
dockerProxy = &DockerProxy{
@@ -85,13 +83,11 @@ func initDockerProxy() {
func ProxyDockerRegistryGin(c *gin.Context) {
path := c.Request.URL.Path
// 处理 /v2/ API版本检查
if path == "/v2/" {
c.JSON(http.StatusOK, gin.H{})
return
}
// 处理不同的API端点
if strings.HasPrefix(path, "/v2/") {
handleRegistryRequest(c, path)
} else {
@@ -101,16 +97,13 @@ func ProxyDockerRegistryGin(c *gin.Context) {
// handleRegistryRequest 处理Registry请求
func handleRegistryRequest(c *gin.Context, path string) {
// 移除 /v2/ 前缀
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
if registryDetector.isRegistryEnabled(registryDomain) {
// 设置目标Registry信息到Context
c.Set("target_registry_domain", registryDomain)
c.Set("target_path", remainingPath)
// 处理多Registry请求
handleMultiRegistryRequest(c, registryDomain, remainingPath)
return
}
@@ -122,19 +115,16 @@ func handleRegistryRequest(c *gin.Context, path string) {
return
}
// 自动处理官方镜像的library命名空间
if !strings.Contains(imageName, "/") {
imageName = "library/" + imageName
}
// Docker镜像访问控制检查
if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed {
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
// 构建完整的镜像引用
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
switch apiType {
@@ -151,7 +141,6 @@ func handleRegistryRequest(c *gin.Context, path string) {
// parseRegistryPath 解析Registry路径
func parseRegistryPath(path string) (imageName, apiType, reference string) {
// 查找API端点关键字
if idx := strings.Index(path, "/manifests/"); idx != -1 {
imageName = path[:idx]
apiType = "manifests"
@@ -178,13 +167,11 @@ func parseRegistryPath(path string) (imageName, apiType, reference string) {
// handleManifestRequest 处理manifest请求
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
// Manifest缓存逻辑(仅对GET请求缓存)
if isCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := buildManifestCacheKey(imageRef, reference)
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
// 优先从缓存获取
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
writeCachedResponse(c, cachedItem)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
@@ -192,12 +179,9 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
var ref name.Reference
var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") {
// 是digest
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
// 是tag
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
@@ -207,9 +191,7 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return
}
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead {
// HEAD请求使用remote.Head
desc, err := remote.Head(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
@@ -217,13 +199,11 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
// GET请求使用remote.Get
desc, err := remote.Get(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
@@ -231,33 +211,28 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return
}
// 设置响应头
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
// 缓存响应
if isCacheEnabled() {
cacheKey := buildManifestCacheKey(imageRef, reference)
ttl := getManifestTTL(reference)
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
}
// 返回manifest内容
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
}
}
// handleBlobRequest 处理blob请求
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
// 构建digest引用
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
@@ -265,7 +240,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return
}
// 使用remote.Layer获取layer
layer, err := remote.Layer(digestRef, dockerProxy.options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
@@ -273,7 +247,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return
}
// 获取layer信息
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
@@ -281,7 +254,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return
}
// 获取layer内容
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
@@ -290,19 +262,16 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
}
defer reader.Close()
// 设置响应头
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
// 流式传输blob内容
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// handleTagsRequest 处理tags列表请求
func handleTagsRequest(c *gin.Context, imageRef string) {
// 解析repository
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
@@ -310,7 +279,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
return
}
// 使用remote.List获取tags
tags, err := remote.List(repo, dockerProxy.options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
@@ -318,7 +286,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
return
}
// 构建响应
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
"tags": tags,
@@ -327,10 +294,9 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
c.JSON(http.StatusOK, response)
}
// ProxyDockerAuthGin Docker认证代理(带缓存优化)
// ProxyDockerAuthGin Docker认证代理
func ProxyDockerAuthGin(c *gin.Context) {
// 检查是否启用token缓存
if isTokenCacheEnabled() {
if utils.IsTokenCacheEnabled() {
proxyDockerAuthWithCache(c)
} else {
proxyDockerAuthOriginal(c)
@@ -339,32 +305,26 @@ func ProxyDockerAuthGin(c *gin.Context) {
// proxyDockerAuthWithCache 带缓存的认证代理
func proxyDockerAuthWithCache(c *gin.Context) {
// 1. 构建缓存key基于完整的查询参数
cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery)
cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
// 2. 尝试从缓存获取token
if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" {
writeTokenResponse(c, cachedToken)
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
utils.WriteTokenResponse(c, cachedToken)
return
}
// 3. 缓存未命中,创建响应记录器
recorder := &ResponseRecorder{
ResponseWriter: c.Writer,
statusCode: 200,
}
c.Writer = recorder
// 4. 调用原有认证逻辑
proxyDockerAuthOriginal(c)
// 5. 如果认证成功,缓存响应
if recorder.statusCode == 200 && len(recorder.body) > 0 {
ttl := extractTTLFromResponse(recorder.body)
globalCache.SetToken(cacheKey, string(recorder.body), ttl)
ttl := utils.ExtractTTLFromResponse(recorder.body)
utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
}
// 6. 写入实际响应
c.Writer = recorder.ResponseWriter
c.Data(recorder.statusCode, "application/json", recorder.body)
}
@@ -389,14 +349,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
var authURL string
if targetDomain, exists := c.Get("target_registry_domain"); exists {
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
// 使用Registry特定的认证服务器
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
} else {
// fallback到默认Docker认证
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
} else {
// 构建默认Docker认证URL
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
@@ -404,13 +361,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
authURL += "?" + c.Request.URL.RawQuery
}
// 创建HTTP客户端复用全局传输配置包含代理设置
client := &http.Client{
Timeout: 30 * time.Second,
Transport: GetGlobalHTTPClient().Transport,
Transport: utils.GetGlobalHTTPClient().Transport,
}
// 创建请求
req, err := http.NewRequestWithContext(
context.Background(),
c.Request.Method,
@@ -422,14 +377,12 @@ func proxyDockerAuthOriginal(c *gin.Context) {
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
// 执行请求
resp, err := client.Do(req)
if err != nil {
c.String(http.StatusBadGateway, "Auth request failed")
@@ -437,37 +390,30 @@ func proxyDockerAuthOriginal(c *gin.Context) {
}
defer resp.Body.Close()
// 获取当前代理的Host地址
proxyHost := c.Request.Host
if proxyHost == "" {
// 使用配置中的服务器地址和端口
cfg := GetConfig()
cfg := config.GetConfig()
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" {
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
}
}
// 复制响应头并重写认证URL
for key, values := range resp.Header {
for _, value := range values {
// 重写WWW-Authenticate头中的realm URL
if key == "Www-Authenticate" {
// 支持多Registry的URL重写
value = rewriteAuthHeader(value, proxyHost)
}
c.Header(key, value)
}
}
// 返回响应
c.Status(resp.StatusCode)
io.Copy(c.Writer, resp.Body)
}
// rewriteAuthHeader 重写认证头
func rewriteAuthHeader(authHeader, proxyHost string) string {
// 重写各种Registry的认证URL
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
@@ -478,32 +424,27 @@ func rewriteAuthHeader(authHeader, proxyHost string) string {
// handleMultiRegistryRequest 处理多Registry请求
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
// 获取Registry映射配置
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
if !exists {
c.String(http.StatusBadRequest, "Registry not configured")
return
}
// 解析剩余路径
imageName, apiType, reference := parseRegistryPath(remainingPath)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
// 访问控制检查(使用完整的镜像路径)
fullImageName := registryDomain + "/" + imageName
if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
// 构建上游Registry引用
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
// 根据API类型处理请求
switch apiType {
case "manifests":
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
@@ -517,14 +458,12 @@ func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath st
}
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) {
// Manifest缓存逻辑(仅对GET请求缓存)
if isCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := buildManifestCacheKey(imageRef, reference)
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
// 优先从缓存获取
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
writeCachedResponse(c, cachedItem)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
@@ -532,7 +471,6 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
var ref name.Reference
var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
@@ -545,10 +483,8 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
return
}
// 创建针对上游Registry的选项
options := createUpstreamOptions(mapping)
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, options...)
if err != nil {
@@ -569,20 +505,17 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
return
}
// 设置响应头
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
// 缓存响应
if isCacheEnabled() {
cacheKey := buildManifestCacheKey(imageRef, reference)
ttl := getManifestTTL(reference)
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
@@ -593,7 +526,7 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
}
// handleUpstreamBlobRequest 处理上游Registry的blob请求
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) {
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
@@ -633,7 +566,7 @@ func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping
}
// handleUpstreamTagsRequest 处理上游Registry的tags请求
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) {
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
@@ -658,14 +591,13 @@ func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping Registry
}
// createUpstreamOptions 创建上游Registry选项
func createUpstreamOptions(mapping RegistryMapping) []remote.Option {
func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(GetGlobalHTTPClient().Transport),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
// 根据Registry类型添加特定的认证选项方便后续扩展
switch mapping.AuthType {
case "github":
case "google":

213
src/handlers/github.go Normal file
View File

@@ -0,0 +1,213 @@
package handlers
import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"hubproxy/config"
"hubproxy/utils"
)
var (
// GitHub URL匹配正则表达式
githubExps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`),
}
)
// GitHubProxyHandler GitHub代理处理器
func GitHubProxyHandler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
// 自动补全协议头
if !strings.HasPrefix(rawPath, "https://") {
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
rawPath = strings.Replace(rawPath, "http:/", "", 1)
rawPath = strings.Replace(rawPath, "https:/", "", 1)
} else if strings.HasPrefix(rawPath, "http://") {
rawPath = strings.TrimPrefix(rawPath, "http://")
}
rawPath = "https://" + rawPath
}
matches := CheckGitHubURL(rawPath)
if matches != nil {
if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed {
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
// 将blob链接转换为raw链接
if githubExps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
ProxyGitHubRequest(c, rawPath)
}
// CheckGitHubURL 检查URL是否匹配GitHub模式
func CheckGitHubURL(u string) []string {
for _, exp := range githubExps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// ProxyGitHubRequest 代理GitHub请求
func ProxyGitHubRequest(c *gin.Context, u string) {
proxyGitHubWithRedirect(c, u, 0)
}
// proxyGitHubWithRedirect 带重定向的GitHub代理请求
func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) {
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := utils.GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := config.GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
// 处理.sh文件的智能处理
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
io.Copy(c.Writer, resp.Body)
}
}

View File

@@ -1,4 +1,4 @@
package main
package handlers
import (
"archive/tar"
@@ -23,6 +23,8 @@ import (
"github.com/google/go-containerregistry/pkg/v1/partial"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/types"
"hubproxy/config"
"hubproxy/utils"
)
// DebounceEntry 防抖条目
@@ -58,17 +60,15 @@ func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
if entry, exists := d.entries[key]; exists {
if now.Sub(entry.LastRequest) < d.window {
return false // 在防抖窗口内,拒绝请求
return false
}
}
// 更新或创建条目
d.entries[key] = &DebounceEntry{
LastRequest: now,
UserID: userID,
}
// 清理过期条目每5分钟清理一次
if time.Since(d.lastCleanup) > 5*time.Minute {
d.cleanup(now)
d.lastCleanup = now
@@ -88,50 +88,41 @@ func (d *DownloadDebouncer) cleanup(now time.Time) {
// generateContentFingerprint 生成内容指纹
func generateContentFingerprint(images []string, platform string) string {
// 对镜像列表排序确保顺序无关
sortedImages := make([]string, len(images))
copy(sortedImages, images)
sort.Strings(sortedImages)
// 组合内容:镜像列表 + 平台信息
content := strings.Join(sortedImages, "|") + ":" + platform
// 生成MD5哈希
hash := md5.Sum([]byte(content))
return hex.EncodeToString(hash[:])
}
// getUserID 获取用户标识
func getUserID(c *gin.Context) string {
// 优先使用会话Cookie
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
return "session:" + sessionID
}
// 备用方案IP + User-Agent组合
ip := c.ClientIP()
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
userAgent = "unknown"
}
// 生成简短标识
combined := ip + ":" + userAgent
hash := md5.Sum([]byte(combined))
return "ip:" + hex.EncodeToString(hash[:8]) // 只取前8字节
return "ip:" + hex.EncodeToString(hash[:8])
}
// 全局防抖器实例
var (
singleImageDebouncer *DownloadDebouncer
batchImageDebouncer *DownloadDebouncer
)
// initDebouncer 初始化防抖器
func initDebouncer() {
// 单个镜像5秒防抖窗口
// InitDebouncer 初始化防抖器
func InitDebouncer() {
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
// 批量镜像60秒防抖窗口
batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
}
@@ -147,15 +138,15 @@ type ImageStreamerConfig struct {
}
// NewImageStreamer 创建镜像下载器
func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
if config == nil {
config = &ImageStreamerConfig{}
func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer {
if cfg == nil {
cfg = &ImageStreamerConfig{}
}
concurrency := config.Concurrency
concurrency := cfg.Concurrency
if concurrency <= 0 {
cfg := GetConfig()
concurrency = cfg.Download.MaxImages
appCfg := config.GetConfig()
concurrency = appCfg.Download.MaxImages
if concurrency <= 0 {
concurrency = 10
}
@@ -163,7 +154,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
remoteOptions := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithTransport(GetGlobalHTTPClient().Transport),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
return &ImageStreamer{
@@ -176,7 +167,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
type StreamOptions struct {
Platform string
Compression bool
UseCompressedLayers bool // 是否保存原始压缩层,默认开启
UseCompressedLayers bool
}
// StreamImageToWriter 流式下载镜像到Writer
@@ -215,7 +206,6 @@ func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
// 直接从网络获取完整的descriptor确保对象完整性
return remote.Get(ref, options...)
}
@@ -343,7 +333,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
var layerSize int64
var layerReader io.ReadCloser
// 根据配置选择使用压缩层或未压缩层
if options != nil && options.UseCompressedLayers {
layerSize, err = layer.Size()
if err != nil {
@@ -385,7 +374,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
log.Printf("已处理层 %d/%d", i+1, len(layers))
}
// 构建单个镜像的manifest信息
singleManifest := map[string]interface{}{
"Config": configDigest.String() + ".json",
"RepoTags": []string{imageRef},
@@ -398,7 +386,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
}(),
}
// 构建repositories信息
repositories := make(map[string]map[string]string)
parts := strings.Split(imageRef, ":")
if len(parts) == 2 {
@@ -407,14 +394,12 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
repositories[repoName] = map[string]string{tag: configDigest.String()}
}
// 如果是批量下载,返回信息而不写入文件
if manifestOut != nil && repositoriesOut != nil {
*manifestOut = singleManifest
*repositoriesOut = repositories
return nil
}
// 单镜像下载直接写入manifest.json
manifest := []map[string]interface{}{singleManifest}
manifestData, err := json.Marshal(manifest)
@@ -436,7 +421,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err
}
// 写入repositories文件
repositoriesData, err := json.Marshal(repositories)
if err != nil {
return err
@@ -456,7 +440,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err
}
// processImageForBatch 处理镜像的公共逻辑,用于批量下载
// processImageForBatch 处理镜像的公共逻辑
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
layers, err := img.Layers()
if err != nil {
@@ -498,7 +482,6 @@ func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWrite
switch desc.MediaType {
case types.OCIImageIndex, types.DockerManifestList:
// 处理多架构镜像
img, err = is.selectPlatformImage(desc, options)
if err != nil {
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
@@ -530,7 +513,6 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
return nil, fmt.Errorf("获取索引清单失败: %w", err)
}
// 选择合适的平台
var selectedDesc *v1.Descriptor
for _, m := range manifest.Manifests {
if m.Platform == nil {
@@ -578,8 +560,8 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
var globalImageStreamer *ImageStreamer
// initImageStreamer 初始化镜像下载器
func initImageStreamer() {
// InitImageStreamer 初始化镜像下载器
func InitImageStreamer() {
globalImageStreamer = NewImageStreamer(nil)
}
@@ -591,8 +573,8 @@ func formatPlatformText(platform string) string {
return platform
}
// initImageTarRoutes 初始化镜像下载路由
func initImageTarRoutes(router *gin.Engine) {
// InitImageTarRoutes 初始化镜像下载路由
func InitImageTarRoutes(router *gin.Engine) {
imageAPI := router.Group("/api/image")
{
imageAPI.GET("/download/:image", handleDirectImageDownload)
@@ -625,7 +607,6 @@ func handleDirectImageDownload(c *gin.Context) {
return
}
// 防抖检查
userID := getUserID(c)
contentKey := generateContentFingerprint([]string{imageRef}, platform)
@@ -677,7 +658,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
}
}
cfg := GetConfig()
cfg := config.GetConfig()
if len(req.Images) > cfg.Download.MaxImages {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
@@ -685,7 +666,6 @@ func handleSimpleBatchDownload(c *gin.Context) {
return
}
// 批量下载防抖检查
userID := getUserID(c)
contentKey := generateContentFingerprint(req.Images, req.Platform)
@@ -697,7 +677,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
return
}
useCompressed := true // 默认启用原始压缩层
useCompressed := true
if req.UseCompressedLayers != nil {
useCompressed = *req.UseCompressedLayers
}
@@ -801,7 +781,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
var allManifests []map[string]interface{}
var allRepositories = make(map[string]map[string]string)
// 流式处理每个镜像
for i, imageRef := range imageRefs {
select {
case <-ctx.Done():
@@ -811,7 +790,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
// 防止单个镜像处理时间过长
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
cancel()
@@ -825,10 +803,8 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
}
// 收集manifest信息
allManifests = append(allManifests, manifest)
// 合并repositories信息
for repo, tags := range repositories {
if allRepositories[repo] == nil {
allRepositories[repo] = make(map[string]string)
@@ -839,7 +815,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
}
}
// 写入合并的manifest.json
manifestData, err := json.Marshal(allManifests)
if err != nil {
return fmt.Errorf("序列化manifest失败: %w", err)
@@ -859,7 +834,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
return fmt.Errorf("写入manifest数据失败: %w", err)
}
// 写入合并的repositories文件
repositoriesData, err := json.Marshal(allRepositories)
if err != nil {
return fmt.Errorf("序列化repositories失败: %w", err)

File diff suppressed because it is too large Load Diff

View File

@@ -3,15 +3,17 @@ package main
import (
"embed"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"hubproxy/config"
"hubproxy/handlers"
"hubproxy/utils"
)
//go:embed public/*
@@ -32,19 +34,7 @@ func serveEmbedFile(c *gin.Context, filename string) {
}
var (
exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
}
globalLimiter *IPRateLimiter
globalLimiter *utils.IPRateLimiter
// 服务启动时间
serviceStartTime = time.Now()
@@ -52,25 +42,25 @@ var (
func main() {
// 加载配置
if err := LoadConfig(); err != nil {
if err := config.LoadConfig(); err != nil {
fmt.Printf("配置加载失败: %v\n", err)
return
}
// 初始化HTTP客户端
initHTTPClients()
utils.InitHTTPClients()
// 初始化限流器
initLimiter()
globalLimiter = utils.InitGlobalLimiter()
// 初始化Docker流式代理
initDockerProxy()
handlers.InitDockerProxy()
// 初始化镜像流式下载器
initImageStreamer()
handlers.InitImageStreamer()
// 初始化防抖器
initDebouncer()
handlers.InitDebouncer()
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
@@ -84,14 +74,14 @@ func main() {
})
}))
// 全局限流中间件 - 应用到所有路由
router.Use(RateLimitMiddleware(globalLimiter))
// 全局限流中间件
router.Use(utils.RateLimitMiddleware(globalLimiter))
// 初始化监控端点
initHealthRoutes(router)
// 初始化镜像tar下载路由
initImageTarRoutes(router)
handlers.InitImageTarRoutes(router)
// 静态文件路由
router.GET("/", func(c *gin.Context) {
@@ -113,217 +103,59 @@ func main() {
})
// 注册dockerhub搜索路由
RegisterSearchRoute(router)
handlers.RegisterSearchRoute(router)
// 注册Docker认证路由/token*
router.Any("/token", ProxyDockerAuthGin)
router.Any("/token/*path", ProxyDockerAuthGin)
// 注册Docker认证路由
router.Any("/token", handlers.ProxyDockerAuthGin)
router.Any("/token/*path", handlers.ProxyDockerAuthGin)
// 注册Docker Registry代理路由
router.Any("/v2/*path", ProxyDockerRegistryGin)
router.Any("/v2/*path", handlers.ProxyDockerRegistryGin)
// 注册NoRoute处理器
router.NoRoute(handler)
// 注册GitHub代理路由NoRoute处理器
router.NoRoute(handlers.GitHubProxyHandler)
cfg := GetConfig()
cfg := config.GetConfig()
fmt.Printf("🚀 HubProxy 启动成功\n")
fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
// 显示HTTP/2支持状态
if cfg.Server.EnableH2C {
fmt.Printf("H2c: 已启用\n")
}
fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
// 创建HTTP2服务器
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
ReadTimeout: 60 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 120 * time.Second,
}
// 根据配置决定是否启用H2C
if cfg.Server.EnableH2C {
h2cHandler := h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: 250,
IdleTimeout: 300 * time.Second,
MaxReadFrameSize: 4 << 20,
MaxUploadBufferPerConnection: 8 << 20,
MaxUploadBufferPerStream: 2 << 20,
})
server.Handler = h2cHandler
} else {
server.Handler = router
}
err := server.ListenAndServe()
if err != nil {
fmt.Printf("启动服务失败: %v\n", err)
}
}
func handler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
// 自动补全协议头
if !strings.HasPrefix(rawPath, "https://") {
// 修复 http:/ 和 https:/ 的情况
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
rawPath = strings.Replace(rawPath, "http:/", "", 1)
rawPath = strings.Replace(rawPath, "https:/", "", 1)
} else if strings.HasPrefix(rawPath, "http://") {
rawPath = strings.TrimPrefix(rawPath, "http://")
}
rawPath = "https://" + rawPath
}
matches := checkURL(rawPath)
if matches != nil {
// GitHub仓库访问控制检查
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
// 构建仓库名用于日志
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
if exps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
proxyRequest(c, rawPath)
}
func proxyRequest(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0)
}
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
// 限制最大重定向次数,防止无限递归
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
// 如果域名中没有协议前缀添加https://
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
io.Copy(c.Writer, resp.Body)
}
}
func checkURL(u string) []string {
for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// 简单的健康检查
func formatBeijingTime(t time.Time) string {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil {
loc = time.FixedZone("CST", 8*3600) // 兜底时区
}
return t.In(loc).Format("2006-01-02 15:04:05")
}
// 转换为可读时间
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%d秒", int(d.Seconds()))
@@ -338,26 +170,20 @@ func formatDuration(d time.Duration) string {
}
}
func initHealthRoutes(router *gin.Engine) {
router.GET("/health", func(c *gin.Context) {
uptime := time.Since(serviceStartTime)
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp_unix": serviceStartTime.Unix(),
"uptime_sec": uptime.Seconds(),
"service": "hubproxy",
"start_time_bj": formatBeijingTime(serviceStartTime),
"uptime_human": formatDuration(uptime),
})
})
func getUptimeInfo() (time.Duration, float64, string) {
uptime := time.Since(serviceStartTime)
return uptime, uptime.Seconds(), formatDuration(uptime)
}
func initHealthRoutes(router *gin.Engine) {
router.GET("/ready", func(c *gin.Context) {
uptime := time.Since(serviceStartTime)
_, uptimeSec, uptimeHuman := getUptimeInfo()
c.JSON(http.StatusOK, gin.H{
"ready": true,
"timestamp_unix": time.Now().Unix(),
"uptime_sec": uptime.Seconds(),
"uptime_human": formatDuration(uptime),
"ready": true,
"service": "hubproxy",
"start_time_unix": serviceStartTime.Unix(),
"uptime_sec": uptimeSec,
"uptime_human": uptimeHuman,
})
})
}

View File

@@ -1,8 +1,10 @@
package main
package utils
import (
"strings"
"sync"
"hubproxy/config"
)
// ResourceType 资源类型
@@ -26,7 +28,7 @@ type DockerImageInfo struct {
FullName string
}
// 全局访问控制器实例
// GlobalAccessController 全局访问控制器实例
var GlobalAccessController = &AccessController{}
// ParseDockerImage 解析Docker镜像名称
@@ -79,19 +81,16 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
// CheckDockerAccess 检查Docker镜像访问权限
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
cfg := GetConfig()
cfg := config.GetConfig()
// 解析镜像名称
imageInfo := ac.ParseDockerImage(image)
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像)
if len(cfg.Access.WhiteList) > 0 {
if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
return false, "不在Docker镜像白名单内"
}
}
// 检查黑名单
if len(cfg.Access.BlackList) > 0 {
if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
return false, "Docker镜像在黑名单内"
@@ -107,14 +106,12 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r
return false, "无效的GitHub仓库格式"
}
cfg := GetConfig()
cfg := config.GetConfig()
// 检查白名单
if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
return false, "不在GitHub仓库白名单内"
}
// 检查黑名单
if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
return false, "GitHub仓库在黑名单内"
}
@@ -185,17 +182,14 @@ func (ac *AccessController) checkList(matches, list []string) bool {
continue
}
// 支持多种匹配模式
if fullRepo == item {
return true
}
// 用户级匹配
if item == username || item == username+"/*" {
return true
}
// 前缀匹配(支持通配符)
if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullRepo, prefix) {
@@ -203,7 +197,6 @@ func (ac *AccessController) checkList(matches, list []string) bool {
}
}
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork
if strings.HasPrefix(fullRepo, item+"/") {
return true
}

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"crypto/md5"
@@ -9,22 +9,23 @@ import (
"time"
"github.com/gin-gonic/gin"
"hubproxy/config"
)
// CachedItem 通用缓存项支持Token和Manifest
// CachedItem 通用缓存项
type CachedItem struct {
Data []byte // 缓存数据(token字符串或manifest字节)
ContentType string // 内容类型
Headers map[string]string // 额外的响应头
ExpiresAt time.Time // 过期时间
Data []byte
ContentType string
Headers map[string]string
ExpiresAt time.Time
}
// UniversalCache 通用缓存支持Token和Manifest
// UniversalCache 通用缓存
type UniversalCache struct {
cache sync.Map
}
var globalCache = &UniversalCache{}
var GlobalCache = &UniversalCache{}
// Get 获取缓存项
func (c *UniversalCache) Get(key string) *CachedItem {
@@ -57,22 +58,22 @@ func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) {
c.Set(key, []byte(token), "application/json", nil, ttl)
}
// buildCacheKey 构建稳定的缓存key
func buildCacheKey(prefix, query string) string {
// BuildCacheKey 构建稳定的缓存key
func BuildCacheKey(prefix, query string) string {
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
}
func buildTokenCacheKey(query string) string {
return buildCacheKey("token", query)
func BuildTokenCacheKey(query string) string {
return BuildCacheKey("token", query)
}
func buildManifestCacheKey(imageRef, reference string) string {
func BuildManifestCacheKey(imageRef, reference string) string {
key := fmt.Sprintf("%s:%s", imageRef, reference)
return buildCacheKey("manifest", key)
return BuildCacheKey("manifest", key)
}
func getManifestTTL(reference string) time.Duration {
cfg := GetConfig()
func GetManifestTTL(reference string) time.Duration {
cfg := config.GetConfig()
defaultTTL := 30 * time.Minute
if cfg.TokenCache.DefaultTTL != "" {
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
@@ -84,23 +85,20 @@ func getManifestTTL(reference string) time.Duration {
return 24 * time.Hour
}
// mutable tag的智能判断
if reference == "latest" || reference == "main" || reference == "master" ||
reference == "dev" || reference == "develop" {
// 热门可变标签: 短期缓存
return 10 * time.Minute
}
return defaultTTL
}
// extractTTLFromResponse 从响应中智能提取TTL
func extractTTLFromResponse(responseBody []byte) time.Duration {
// ExtractTTLFromResponse 从响应中智能提取TTL
func ExtractTTLFromResponse(responseBody []byte) time.Duration {
var tokenResp struct {
ExpiresIn int `json:"expires_in"`
}
// 默认30分钟TTL确保稳定性
defaultTTL := 30 * time.Minute
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
@@ -113,37 +111,35 @@ func extractTTLFromResponse(responseBody []byte) time.Duration {
return defaultTTL
}
func writeTokenResponse(c *gin.Context, cachedBody string) {
func WriteTokenResponse(c *gin.Context, cachedBody string) {
c.Header("Content-Type", "application/json")
c.String(200, cachedBody)
}
func writeCachedResponse(c *gin.Context, item *CachedItem) {
func WriteCachedResponse(c *gin.Context, item *CachedItem) {
if item.ContentType != "" {
c.Header("Content-Type", item.ContentType)
}
// 设置额外的响应头
for key, value := range item.Headers {
c.Header(key, value)
}
// 返回数据
c.Data(200, item.ContentType, item.Data)
}
// isCacheEnabled 检查缓存是否启用
func isCacheEnabled() bool {
cfg := GetConfig()
// IsCacheEnabled 检查缓存是否启用
func IsCacheEnabled() bool {
cfg := config.GetConfig()
return cfg.TokenCache.Enabled
}
// isTokenCacheEnabled 检查token缓存是否启用(向后兼容)
func isTokenCacheEnabled() bool {
return isCacheEnabled()
// IsTokenCacheEnabled 检查token缓存是否启用
func IsTokenCacheEnabled() bool {
return IsCacheEnabled()
}
// 定期清理过期缓存,防止内存泄漏
// 定期清理过期缓存
func init() {
go func() {
ticker := time.NewTicker(20 * time.Minute)
@@ -153,7 +149,7 @@ func init() {
now := time.Now()
expiredKeys := make([]string, 0)
globalCache.cache.Range(func(key, value interface{}) bool {
GlobalCache.cache.Range(func(key, value interface{}) bool {
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
expiredKeys = append(expiredKeys, key.(string))
}
@@ -161,7 +157,7 @@ func init() {
})
for _, key := range expiredKeys {
globalCache.cache.Delete(key)
GlobalCache.cache.Delete(key)
}
}
}()

View File

@@ -1,28 +1,28 @@
package main
package utils
import (
"net"
"net/http"
"os"
"time"
"hubproxy/config"
)
var (
// 全局HTTP客户端 - 用于代理请求(长超时)
globalHTTPClient *http.Client
// 搜索HTTP客户端 - 用于API请求短超时
searchHTTPClient *http.Client
)
// initHTTPClients 初始化HTTP客户端
func initHTTPClients() {
cfg := GetConfig()
// InitHTTPClients 初始化HTTP客户端
func InitHTTPClients() {
cfg := config.GetConfig()
if p := cfg.Access.Proxy; p != "" {
os.Setenv("HTTP_PROXY", p)
os.Setenv("HTTPS_PROXY", p)
}
// 代理客户端配置 - 适用于大文件传输
globalHTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -39,7 +39,6 @@ func initHTTPClients() {
},
}
// 搜索客户端配置 - 适用于API调用
searchHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
@@ -57,12 +56,12 @@ func initHTTPClients() {
}
}
// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理)
// GetGlobalHTTPClient 获取全局HTTP客户端
func GetGlobalHTTPClient() *http.Client {
return globalHTTPClient
}
// GetSearchHTTPClient 获取搜索HTTP客户端用于API调用
// GetSearchHTTPClient 获取搜索HTTP客户端
func GetSearchHTTPClient() *http.Client {
return searchHTTPClient
}

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"bytes"
@@ -41,7 +41,6 @@ func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reade
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
var reader io.Reader = input
// 处理gzip压缩
if isCompressed {
peek := make([]byte, 2)
n, err := input.Read(peek)

View File

@@ -1,4 +1,4 @@
package main
package utils
import (
"fmt"
@@ -9,22 +9,23 @@ import (
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"hubproxy/config"
)
const (
// 清理间隔
CleanupInterval = 10 * time.Minute
MaxIPCacheSize = 10000
)
// IPRateLimiter IP限流器结构体
type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射
mu *sync.RWMutex // 读写锁,保证并发安全
r rate.Limit // 速率限制(每秒允许的请求数)
b int // 令牌桶容量(突发请求数)
whitelist []*net.IPNet // 白名单IP段
blacklist []*net.IPNet // 黑名单IP段
ips map[string]*rateLimiterEntry
mu *sync.RWMutex
r rate.Limit
b int
whitelist []*net.IPNet
blacklist []*net.IPNet
whitelistLimiter *rate.Limiter // 全局共享的白名单限流器
}
// rateLimiterEntry 限流器条目
@@ -33,15 +34,15 @@ type rateLimiterEntry struct {
lastAccess time.Time
}
// initGlobalLimiter 初始化全局限流器
func initGlobalLimiter() *IPRateLimiter {
cfg := GetConfig()
// InitGlobalLimiter 初始化全局限流器
func InitGlobalLimiter() *IPRateLimiter {
cfg := config.GetConfig()
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
item = item + "/32"
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
@@ -52,12 +53,11 @@ func initGlobalLimiter() *IPRateLimiter {
}
}
// 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
item = item + "/32"
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
@@ -68,7 +68,6 @@ func initGlobalLimiter() *IPRateLimiter {
}
}
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
burstSize := cfg.RateLimit.RequestLimit
@@ -77,25 +76,20 @@ func initGlobalLimiter() *IPRateLimiter {
}
limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
whitelistLimiter: rate.NewLimiter(rate.Inf, burstSize),
}
// 启动定期清理goroutine
go limiter.cleanupRoutine()
return limiter
}
// initLimiter 初始化限流器
func initLimiter() {
globalLimiter = initGlobalLimiter()
}
// cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval)
@@ -105,25 +99,20 @@ func (i *IPRateLimiter) cleanupRoutine() {
now := time.Now()
expired := make([]string, 0)
// 查找过期的条目
i.mu.RLock()
for ip, entry := range i.ips {
// 如果最后访问时间超过1小时认为过期
if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip)
}
}
i.mu.RUnlock()
// 如果有过期条目或者缓存过大,进行清理
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock()
// 删除过期条目
for _, ip := range expired {
delete(i.ips, ip)
}
// 如果缓存仍然过大,全部清理
if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry)
}
@@ -140,28 +129,26 @@ func extractIPFromAddress(address string) string {
return address
}
// normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段
// normalizeIPForRateLimit 标准化IP地址用于限流
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr // 解析失败,返回原值
return ipStr
}
if ip.To4() != nil {
return ipStr // IPv4保持不变
return ipStr
}
// IPv6标准化为 /64 网段
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0 // 清零后64位
ipv6[i] = 0
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址
cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil {
@@ -176,22 +163,18 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
return false
}
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问
// GetLimiter 获取指定IP的限流器
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 提取纯IP地址
cleanIP := extractIPFromAddress(ip)
// 检查是否在黑名单中
if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false
}
// 检查是否在白名单中
if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true
return i.whitelistLimiter, true
}
// 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now()
@@ -230,7 +213,6 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 静态文件豁免:跳过限流检查
path := c.Request.URL.Path
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
strings.HasPrefix(path, "/public/") {
@@ -238,30 +220,22 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return
}
// 获取客户端真实IP
var ip string
// 优先尝试从请求头获取真实IP
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For可能包含多个IP取第一个
ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
// 如果有X-Real-IP头
ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
// 某些代理可能使用此头
ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0])
} else {
// 回退到ClientIP方法
ip = c.ClientIP()
}
// 提取纯IP地址去除可能存在的端口
cleanIP := extractIPFromAddress(ip)
// 日志记录请求IP和头信息
normalizedIP := normalizeIPForRateLimit(cleanIP)
if cleanIP != normalizedIP {
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
@@ -275,10 +249,8 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
c.GetHeader("X-Real-IP"))
}
// 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
// 如果IP在黑名单中
if !allowed {
c.JSON(403, gin.H{
"error": "您已被限制访问",
@@ -287,7 +259,6 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return
}
// 检查限流
if !ipLimiter.Allow() {
c.JSON(429, gin.H{
"error": "请求频率过快,暂时限制访问",