diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b9602e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea +.vscode +.DS_Store +hubproxy* \ No newline at end of file diff --git a/README.md b/README.md index 6a242e9..200e943 100644 --- a/README.md +++ b/README.md @@ -138,11 +138,17 @@ blackList = [ "baduser/*" ] -# SOCKS5代理配置,支持有用户名/密码认证和无认证模式 +# 代理配置,支持有用户名/密码认证和无认证模式 # 无认证: socks5://127.0.0.1:1080 # 有认证: socks5://username:password@127.0.0.1:1080 +# HTTP 代理示例 +# http://username:password@127.0.0.1:7890 +# SOCKS5 代理示例 +# socks5://username:password@127.0.0.1:1080 +# SOCKS5H 代理示例 +# socks5h://username:password@127.0.0.1:1080 # 留空不使用代理 -socks5 = "" +proxy = "" [download] # 批量下载离线镜像数量限制 diff --git a/src/access_control.go b/src/access_control.go index e23ee52..b8c6ab1 100644 --- a/src/access_control.go +++ b/src/access_control.go @@ -32,7 +32,7 @@ var GlobalAccessController = &AccessController{} // ParseDockerImage 解析Docker镜像名称 func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { image = strings.TrimPrefix(image, "docker://") - + var tag string if idx := strings.LastIndex(image, ":"); idx != -1 { part := image[idx+1:] @@ -44,7 +44,7 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { if tag == "" { tag = "latest" } - + var namespace, repository string if strings.Contains(image, "/") { parts := strings.Split(image, "/") @@ -66,9 +66,9 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { namespace = "library" repository = image } - + fullName := namespace + "/" + repository - + return DockerImageInfo{ Namespace: namespace, Repository: repository, @@ -80,24 +80,24 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { // CheckDockerAccess 检查Docker镜像访问权限 func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { cfg := GetConfig() - + // 解析镜像名称 imageInfo := ac.ParseDockerImage(image) - + // 检查白名单(如果配置了白名单,则只允许白名单中的镜像) - if len(cfg.Proxy.WhiteList) > 0 { - if !ac.matchImageInList(imageInfo, cfg.Proxy.WhiteList) { + if len(cfg.Access.WhiteList) > 0 { + if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) { return false, "不在Docker镜像白名单内" } } - + // 检查黑名单 - if len(cfg.Proxy.BlackList) > 0 { - if ac.matchImageInList(imageInfo, cfg.Proxy.BlackList) { + if len(cfg.Access.BlackList) > 0 { + if ac.matchImageInList(imageInfo, cfg.Access.BlackList) { return false, "Docker镜像在黑名单内" } } - + return true, "" } @@ -106,19 +106,19 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r if len(matches) < 2 { return false, "无效的GitHub仓库格式" } - + cfg := GetConfig() - + // 检查白名单 - if len(cfg.Proxy.WhiteList) > 0 && !ac.checkList(matches, cfg.Proxy.WhiteList) { + if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) { return false, "不在GitHub仓库白名单内" } - + // 检查黑名单 - if len(cfg.Proxy.BlackList) > 0 && ac.checkList(matches, cfg.Proxy.BlackList) { + if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) { return false, "GitHub仓库在黑名单内" } - + return true, "" } @@ -126,28 +126,28 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool { fullName := strings.ToLower(imageInfo.FullName) namespace := strings.ToLower(imageInfo.Namespace) - + for _, item := range list { item = strings.ToLower(strings.TrimSpace(item)) if item == "" { continue } - + if fullName == item { return true } - + if item == namespace || item == namespace+"/*" { return true } - + if strings.HasSuffix(item, "*") { prefix := strings.TrimSuffix(item, "*") if strings.HasPrefix(fullName, prefix) { return true } } - + if strings.HasPrefix(item, "*/") { repoPattern := strings.TrimPrefix(item, "*/") if strings.HasSuffix(repoPattern, "*") { @@ -161,7 +161,7 @@ func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []s } } } - + if strings.HasPrefix(fullName, item+"/") { return true } @@ -174,27 +174,27 @@ func (ac *AccessController) checkList(matches, list []string) bool { if len(matches) < 2 { return false } - + username := strings.ToLower(strings.TrimSpace(matches[0])) repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git"))) fullRepo := username + "/" + repoName - + for _, item := range list { item = strings.ToLower(strings.TrimSpace(item)) if item == "" { continue } - + // 支持多种匹配模式 if fullRepo == item { return true } - + // 用户级匹配 if item == username || item == username+"/*" { return true } - + // 前缀匹配(支持通配符) if strings.HasSuffix(item, "*") { prefix := strings.TrimSuffix(item, "*") @@ -202,7 +202,7 @@ func (ac *AccessController) checkList(matches, list []string) bool { return true } } - + // 子仓库匹配(防止 user/repo 匹配到 user/repo-fork) if strings.HasPrefix(fullRepo, item+"/") { return true @@ -210,5 +210,3 @@ func (ac *AccessController) checkList(matches, list []string) bool { } return false } - - \ No newline at end of file diff --git a/src/config.go b/src/config.go index d85c217..fd22dc5 100644 --- a/src/config.go +++ b/src/config.go @@ -1,275 +1,276 @@ -package main - -import ( - "fmt" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/pelletier/go-toml/v2" -) - -// RegistryMapping Registry映射配置 -type RegistryMapping struct { - Upstream string `toml:"upstream"` // 上游Registry地址 - AuthHost string `toml:"authHost"` // 认证服务器地址 - AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic - Enabled bool `toml:"enabled"` // 是否启用 -} - -// AppConfig 应用配置结构体 -type AppConfig struct { - Server struct { - Host string `toml:"host"` // 监听地址 - Port int `toml:"port"` // 监听端口 - FileSize int64 `toml:"fileSize"` // 文件大小限制(字节) - } `toml:"server"` - - RateLimit struct { - RequestLimit int `toml:"requestLimit"` // 每小时请求限制 - PeriodHours float64 `toml:"periodHours"` // 限制周期(小时) - } `toml:"rateLimit"` - - Security struct { - WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表 - BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表 - } `toml:"security"` - - Proxy struct { - WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别) - BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别) - Socks5 string `toml:"socks5"` // SOCKS5代理地址: socks5://[user:pass@]host:port - } `toml:"proxy"` - - Download struct { - MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制 - } `toml:"download"` - - Registries map[string]RegistryMapping `toml:"registries"` - - TokenCache struct { - Enabled bool `toml:"enabled"` // 是否启用token缓存 - DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间 - } `toml:"tokenCache"` -} - -var ( - appConfig *AppConfig - appConfigLock sync.RWMutex - - cachedConfig *AppConfig - configCacheTime time.Time - configCacheTTL = 5 * time.Second - configCacheMutex sync.RWMutex -) - -// DefaultConfig 返回默认配置 -func DefaultConfig() *AppConfig { - return &AppConfig{ - Server: struct { - Host string `toml:"host"` - Port int `toml:"port"` - FileSize int64 `toml:"fileSize"` - }{ - Host: "0.0.0.0", - Port: 5000, - FileSize: 2 * 1024 * 1024 * 1024, // 2GB - }, - RateLimit: struct { - RequestLimit int `toml:"requestLimit"` - PeriodHours float64 `toml:"periodHours"` - }{ - RequestLimit: 20, - PeriodHours: 1.0, - }, - Security: struct { - WhiteList []string `toml:"whiteList"` - BlackList []string `toml:"blackList"` - }{ - WhiteList: []string{}, - BlackList: []string{}, - }, - Proxy: struct { - WhiteList []string `toml:"whiteList"` - BlackList []string `toml:"blackList"` - Socks5 string `toml:"socks5"` - }{ - WhiteList: []string{}, - BlackList: []string{}, - Socks5: "", // 默认不使用代理 - }, - Download: struct { - MaxImages int `toml:"maxImages"` - }{ - MaxImages: 10, // 默认值:最多同时下载10个镜像 - }, - Registries: map[string]RegistryMapping{ - "ghcr.io": { - Upstream: "ghcr.io", - AuthHost: "ghcr.io/token", - AuthType: "github", - Enabled: true, - }, - "gcr.io": { - Upstream: "gcr.io", - AuthHost: "gcr.io/v2/token", - AuthType: "google", - Enabled: true, - }, - "quay.io": { - Upstream: "quay.io", - AuthHost: "quay.io/v2/auth", - AuthType: "quay", - Enabled: true, - }, - "registry.k8s.io": { - Upstream: "registry.k8s.io", - AuthHost: "registry.k8s.io", - AuthType: "anonymous", - Enabled: true, - }, - }, - TokenCache: struct { - Enabled bool `toml:"enabled"` - DefaultTTL string `toml:"defaultTTL"` - }{ - Enabled: true, // docker认证的匿名Token缓存配置,用于提升性能 - DefaultTTL: "20m", - }, - } -} - -// GetConfig 安全地获取配置副本 -func GetConfig() *AppConfig { - configCacheMutex.RLock() - if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { - config := cachedConfig - configCacheMutex.RUnlock() - return config - } - configCacheMutex.RUnlock() - - // 缓存过期,重新生成配置 - configCacheMutex.Lock() - defer configCacheMutex.Unlock() - - // 双重检查,防止重复生成 - if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { - return cachedConfig - } - - appConfigLock.RLock() - if appConfig == nil { - appConfigLock.RUnlock() - defaultCfg := DefaultConfig() - cachedConfig = defaultCfg - configCacheTime = time.Now() - return defaultCfg - } - - // 生成新的配置深拷贝 - configCopy := *appConfig - configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...) - configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...) - configCopy.Proxy.WhiteList = append([]string(nil), appConfig.Proxy.WhiteList...) - configCopy.Proxy.BlackList = append([]string(nil), appConfig.Proxy.BlackList...) - appConfigLock.RUnlock() - - cachedConfig = &configCopy - configCacheTime = time.Now() - - return cachedConfig -} - -// setConfig 安全地设置配置 -func setConfig(cfg *AppConfig) { - appConfigLock.Lock() - defer appConfigLock.Unlock() - appConfig = cfg - - configCacheMutex.Lock() - cachedConfig = nil - configCacheMutex.Unlock() -} - -// LoadConfig 加载配置文件 -func LoadConfig() error { - // 首先使用默认配置 - cfg := DefaultConfig() - - // 尝试加载TOML配置文件 - if data, err := os.ReadFile("config.toml"); err == nil { - if err := toml.Unmarshal(data, cfg); err != nil { - return fmt.Errorf("解析配置文件失败: %v", err) - } - } else { - fmt.Println("未找到config.toml,使用默认配置") - } - - // 从环境变量覆盖配置 - overrideFromEnv(cfg) - - // 设置配置 - setConfig(cfg) - - return nil -} - -// overrideFromEnv 从环境变量覆盖配置 -func overrideFromEnv(cfg *AppConfig) { - // 服务器配置 - if val := os.Getenv("SERVER_HOST"); val != "" { - cfg.Server.Host = val - } - if val := os.Getenv("SERVER_PORT"); val != "" { - if port, err := strconv.Atoi(val); err == nil && port > 0 { - cfg.Server.Port = port - } - } - if val := os.Getenv("MAX_FILE_SIZE"); val != "" { - if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 { - cfg.Server.FileSize = size - } - } - - // 限流配置 - if val := os.Getenv("RATE_LIMIT"); val != "" { - if limit, err := strconv.Atoi(val); err == nil && limit > 0 { - cfg.RateLimit.RequestLimit = limit - } - } - if val := os.Getenv("RATE_PERIOD_HOURS"); val != "" { - if period, err := strconv.ParseFloat(val, 64); err == nil && period > 0 { - cfg.RateLimit.PeriodHours = period - } - } - - // IP限制配置 - if val := os.Getenv("IP_WHITELIST"); val != "" { - cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...) - } - if val := os.Getenv("IP_BLACKLIST"); val != "" { - cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...) - } - - // 下载限制配置 - if val := os.Getenv("MAX_IMAGES"); val != "" { - if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 { - cfg.Download.MaxImages = maxImages - } - } -} - -// CreateDefaultConfigFile 创建默认配置文件 -func CreateDefaultConfigFile() error { - cfg := DefaultConfig() - - data, err := toml.Marshal(cfg) - if err != nil { - return fmt.Errorf("序列化默认配置失败: %v", err) - } - - return os.WriteFile("config.toml", data, 0644) -} \ No newline at end of file +package main + +import ( + "fmt" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/pelletier/go-toml/v2" +) + +// RegistryMapping Registry映射配置 +type RegistryMapping struct { + Upstream string `toml:"upstream"` // 上游Registry地址 + AuthHost string `toml:"authHost"` // 认证服务器地址 + AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic + Enabled bool `toml:"enabled"` // 是否启用 +} + +// AppConfig 应用配置结构体 +type AppConfig struct { + Server struct { + Host string `toml:"host"` // 监听地址 + Port int `toml:"port"` // 监听端口 + FileSize int64 `toml:"fileSize"` // 文件大小限制(字节) + } `toml:"server"` + + RateLimit struct { + RequestLimit int `toml:"requestLimit"` // 每小时请求限制 + PeriodHours float64 `toml:"periodHours"` // 限制周期(小时) + } `toml:"rateLimit"` + + Security struct { + WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表 + BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表 + } `toml:"security"` + + Access struct { + WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别) + BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别) + Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h + } `toml:"proxy"` + + Download struct { + MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制 + } `toml:"download"` + + Registries map[string]RegistryMapping `toml:"registries"` + + TokenCache struct { + Enabled bool `toml:"enabled"` // 是否启用token缓存 + DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间 + } `toml:"tokenCache"` +} + +var ( + appConfig *AppConfig + appConfigLock sync.RWMutex + + cachedConfig *AppConfig + configCacheTime time.Time + configCacheTTL = 5 * time.Second + configCacheMutex sync.RWMutex +) + +// todo:Refactoring is needed +// DefaultConfig 返回默认配置 +func DefaultConfig() *AppConfig { + return &AppConfig{ + Server: struct { + Host string `toml:"host"` + Port int `toml:"port"` + FileSize int64 `toml:"fileSize"` + }{ + Host: "0.0.0.0", + Port: 5000, + FileSize: 2 * 1024 * 1024 * 1024, // 2GB + }, + RateLimit: struct { + RequestLimit int `toml:"requestLimit"` + PeriodHours float64 `toml:"periodHours"` + }{ + RequestLimit: 20, + PeriodHours: 1.0, + }, + Security: struct { + WhiteList []string `toml:"whiteList"` + BlackList []string `toml:"blackList"` + }{ + WhiteList: []string{}, + BlackList: []string{}, + }, + Access: struct { + WhiteList []string `toml:"whiteList"` + BlackList []string `toml:"blackList"` + Proxy string `toml:"proxy"` + }{ + WhiteList: []string{}, + BlackList: []string{}, + Proxy: "", // 默认不使用代理 + }, + Download: struct { + MaxImages int `toml:"maxImages"` + }{ + MaxImages: 10, // 默认值:最多同时下载10个镜像 + }, + Registries: map[string]RegistryMapping{ + "ghcr.io": { + Upstream: "ghcr.io", + AuthHost: "ghcr.io/token", + AuthType: "github", + Enabled: true, + }, + "gcr.io": { + Upstream: "gcr.io", + AuthHost: "gcr.io/v2/token", + AuthType: "google", + Enabled: true, + }, + "quay.io": { + Upstream: "quay.io", + AuthHost: "quay.io/v2/auth", + AuthType: "quay", + Enabled: true, + }, + "registry.k8s.io": { + Upstream: "registry.k8s.io", + AuthHost: "registry.k8s.io", + AuthType: "anonymous", + Enabled: true, + }, + }, + TokenCache: struct { + Enabled bool `toml:"enabled"` + DefaultTTL string `toml:"defaultTTL"` + }{ + Enabled: true, // docker认证的匿名Token缓存配置,用于提升性能 + DefaultTTL: "20m", + }, + } +} + +// GetConfig 安全地获取配置副本 +func GetConfig() *AppConfig { + configCacheMutex.RLock() + if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { + config := cachedConfig + configCacheMutex.RUnlock() + return config + } + configCacheMutex.RUnlock() + + // 缓存过期,重新生成配置 + configCacheMutex.Lock() + defer configCacheMutex.Unlock() + + // 双重检查,防止重复生成 + if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { + return cachedConfig + } + + appConfigLock.RLock() + if appConfig == nil { + appConfigLock.RUnlock() + defaultCfg := DefaultConfig() + cachedConfig = defaultCfg + configCacheTime = time.Now() + return defaultCfg + } + + // 生成新的配置深拷贝 + configCopy := *appConfig + configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...) + configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...) + configCopy.Access.WhiteList = append([]string(nil), appConfig.Access.WhiteList...) + configCopy.Access.BlackList = append([]string(nil), appConfig.Access.BlackList...) + appConfigLock.RUnlock() + + cachedConfig = &configCopy + configCacheTime = time.Now() + + return cachedConfig +} + +// setConfig 安全地设置配置 +func setConfig(cfg *AppConfig) { + appConfigLock.Lock() + defer appConfigLock.Unlock() + appConfig = cfg + + configCacheMutex.Lock() + cachedConfig = nil + configCacheMutex.Unlock() +} + +// LoadConfig 加载配置文件 +func LoadConfig() error { + // 首先使用默认配置 + cfg := DefaultConfig() + + // 尝试加载TOML配置文件 + if data, err := os.ReadFile("config.toml"); err == nil { + if err := toml.Unmarshal(data, cfg); err != nil { + return fmt.Errorf("解析配置文件失败: %v", err) + } + } else { + fmt.Println("未找到config.toml,使用默认配置") + } + + // 从环境变量覆盖配置 + overrideFromEnv(cfg) + + // 设置配置 + setConfig(cfg) + + return nil +} + +// overrideFromEnv 从环境变量覆盖配置 +func overrideFromEnv(cfg *AppConfig) { + // 服务器配置 + if val := os.Getenv("SERVER_HOST"); val != "" { + cfg.Server.Host = val + } + if val := os.Getenv("SERVER_PORT"); val != "" { + if port, err := strconv.Atoi(val); err == nil && port > 0 { + cfg.Server.Port = port + } + } + if val := os.Getenv("MAX_FILE_SIZE"); val != "" { + if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 { + cfg.Server.FileSize = size + } + } + + // 限流配置 + if val := os.Getenv("RATE_LIMIT"); val != "" { + if limit, err := strconv.Atoi(val); err == nil && limit > 0 { + cfg.RateLimit.RequestLimit = limit + } + } + if val := os.Getenv("RATE_PERIOD_HOURS"); val != "" { + if period, err := strconv.ParseFloat(val, 64); err == nil && period > 0 { + cfg.RateLimit.PeriodHours = period + } + } + + // IP限制配置 + if val := os.Getenv("IP_WHITELIST"); val != "" { + cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...) + } + if val := os.Getenv("IP_BLACKLIST"); val != "" { + cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...) + } + + // 下载限制配置 + if val := os.Getenv("MAX_IMAGES"); val != "" { + if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 { + cfg.Download.MaxImages = maxImages + } + } +} + +// CreateDefaultConfigFile 创建默认配置文件 +func CreateDefaultConfigFile() error { + cfg := DefaultConfig() + + data, err := toml.Marshal(cfg) + if err != nil { + return fmt.Errorf("序列化默认配置失败: %v", err) + } + + return os.WriteFile("config.toml", data, 0644) +} diff --git a/src/config.toml b/src/config.toml index 1411c8f..6a96c11 100644 --- a/src/config.toml +++ b/src/config.toml @@ -26,7 +26,7 @@ blackList = [ "192.168.100.0/24" ] -[proxy] +[access] # 代理服务白名单(支持GitHub仓库和Docker镜像,支持通配符) # 只允许访问白名单中的仓库/镜像,为空时不限制 whiteList = [] @@ -39,11 +39,17 @@ blackList = [ "baduser/*" ] -# SOCKS5代理配置,支持有用户名/密码认证和无认证模式 +# 代理配置,支持有用户名/密码认证和无认证模式 # 无认证: socks5://127.0.0.1:1080 # 有认证: socks5://username:password@127.0.0.1:1080 +# HTTP 代理示例 +# http://username:password@127.0.0.1:7890 +# SOCKS5 代理示例 +# socks5://username:password@127.0.0.1:1080 +# SOCKS5H 代理示例 +# socks5h://username:password@127.0.0.1:1080 # 留空不使用代理 -socks5 = "" +proxy = "" [download] # 批量下载离线镜像数量限制 diff --git a/src/go.mod b/src/go.mod index 6e8a2f8..7c2806d 100644 --- a/src/go.mod +++ b/src/go.mod @@ -6,7 +6,6 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/google/go-containerregistry v0.20.5 github.com/pelletier/go-toml/v2 v2.2.3 - golang.org/x/net v0.33.0 golang.org/x/time v0.11.0 ) @@ -44,6 +43,7 @@ require ( github.com/vbatts/tar-split v0.12.1 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.32.0 // indirect + golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/src/http_client.go b/src/http_client.go index 24d509b..93988eb 100644 --- a/src/http_client.go +++ b/src/http_client.go @@ -1,113 +1,68 @@ -package main - -import ( - "context" - "log" - "net" - "net/http" - "net/url" - "time" - - "golang.org/x/net/proxy" -) - -var ( - // 全局HTTP客户端 - 用于代理请求(长超时) - globalHTTPClient *http.Client - // 搜索HTTP客户端 - 用于API请求(短超时) - searchHTTPClient *http.Client -) - -// initHTTPClients 初始化HTTP客户端 -func initHTTPClients() { - cfg := GetConfig() - - // 创建DialContext函数,支持SOCKS5代理 - createDialContext := func(timeout time.Duration) func(ctx context.Context, network, addr string) (net.Conn, error) { - if cfg.Proxy.Socks5 == "" { - // 没有配置代理,使用直连 - dialer := &net.Dialer{ - Timeout: timeout, - KeepAlive: 30 * time.Second, - } - return dialer.DialContext - } - - // 解析SOCKS5代理URL - proxyURL, err := url.Parse(cfg.Proxy.Socks5) - if err != nil { - log.Printf("SOCKS5代理配置错误,使用直连: %v", err) - dialer := &net.Dialer{ - Timeout: timeout, - KeepAlive: 30 * time.Second, - } - return dialer.DialContext - } - - // 创建基础dialer - baseDialer := &net.Dialer{ - Timeout: timeout, - KeepAlive: 30 * time.Second, - } - - // 创建SOCKS5代理dialer - var auth *proxy.Auth - if proxyURL.User != nil { - if password, ok := proxyURL.User.Password(); ok { - auth = &proxy.Auth{ - User: proxyURL.User.Username(), - Password: password, - } - } - } - - socks5Dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, baseDialer) - if err != nil { - log.Printf("创建SOCKS5代理失败,使用直连: %v", err) - return baseDialer.DialContext - } - - log.Printf("使用SOCKS5代理: %s", proxyURL.Host) - - // 返回带上下文的dial函数 - return func(ctx context.Context, network, addr string) (net.Conn, error) { - return socks5Dialer.Dial(network, addr) - } - } - - // 代理客户端配置 - 适用于大文件传输 - globalHTTPClient = &http.Client{ - Transport: &http.Transport{ - DialContext: createDialContext(30 * time.Second), - MaxIdleConns: 1000, - MaxIdleConnsPerHost: 1000, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ResponseHeaderTimeout: 300 * time.Second, - }, - } - - // 搜索客户端配置 - 适用于API调用 - searchHTTPClient = &http.Client{ - Timeout: 10 * time.Second, - Transport: &http.Transport{ - DialContext: createDialContext(5 * time.Second), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 5 * time.Second, - DisableCompression: false, - }, - } -} - -// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理) -func GetGlobalHTTPClient() *http.Client { - return globalHTTPClient -} - -// GetSearchHTTPClient 获取搜索HTTP客户端(用于API调用) -func GetSearchHTTPClient() *http.Client { - return searchHTTPClient -} \ No newline at end of file +package main + +import ( + "net" + "net/http" + "os" + "time" +) + +var ( + // 全局HTTP客户端 - 用于代理请求(长超时) + globalHTTPClient *http.Client + // 搜索HTTP客户端 - 用于API请求(短超时) + searchHTTPClient *http.Client +) + +// initHTTPClients 初始化HTTP客户端 +func initHTTPClients() { + cfg := GetConfig() + + if p := cfg.Access.Proxy; p != "" { + os.Setenv("HTTP_PROXY", p) + os.Setenv("HTTPS_PROXY", p) + } + // 代理客户端配置 - 适用于大文件传输 + globalHTTPClient = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 1000, + MaxIdleConnsPerHost: 1000, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: 300 * time.Second, + }, + } + + // 搜索客户端配置 - 适用于API调用 + searchHTTPClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 5 * time.Second, + DisableCompression: false, + }, + } +} + +// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理) +func GetGlobalHTTPClient() *http.Client { + return globalHTTPClient +} + +// GetSearchHTTPClient 获取搜索HTTP客户端(用于API调用) +func GetSearchHTTPClient() *http.Client { + return searchHTTPClient +}