diff --git a/src/main.go b/src/main.go index 66bb42b..036ea9e 100644 --- a/src/main.go +++ b/src/main.go @@ -171,11 +171,11 @@ func handler(c *gin.Context) { rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) } - proxy(c, rawPath) + proxyRequest(c, rawPath) } -func proxy(c *gin.Context, u string) { +func proxyRequest(c *gin.Context, u string) { proxyWithRedirect(c, u, 0) } diff --git a/src/ratelimiter.go b/src/ratelimiter.go index 642a495..2c2a4c1 100644 --- a/src/ratelimiter.go +++ b/src/ratelimiter.go @@ -132,23 +132,33 @@ func (i *IPRateLimiter) cleanupRoutine() { } } -// extractIPFromAddress 从地址中提取纯IP,去除端口号 +// extractIPFromAddress 从地址中提取纯IP func extractIPFromAddress(address string) string { - // 处理IPv6地址 [::1]:8080 格式 - if strings.HasPrefix(address, "[") { - if endIndex := strings.Index(address, "]"); endIndex != -1 { - return address[1:endIndex] - } + if host, _, err := net.SplitHostPort(address); err == nil { + return host } - - // 处理IPv4地址 192.168.1.1:8080 格式 - if lastColon := strings.LastIndex(address, ":"); lastColon != -1 { - return address[:lastColon] - } - return address } +// normalizeIPForRateLimit 标准化IP地址用于限流:IPv4保持不变,IPv6标准化为/64网段 +func normalizeIPForRateLimit(ipStr string) string { + ip := net.ParseIP(ipStr) + if ip == nil { + return ipStr // 解析失败,返回原值 + } + + if ip.To4() != nil { + return ipStr // IPv4保持不变 + } + + // IPv6:标准化为 /64 网段 + ipv6 := ip.To16() + for i := 8; i < 16; i++ { + ipv6[i] = 0 // 清零后64位 + } + return ipv6.String() + "/64" +} + // isIPInCIDRList 检查IP是否在CIDR列表中 func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { // 先提取纯IP地址 @@ -181,15 +191,18 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { return rate.NewLimiter(rate.Inf, i.b), true } + // 标准化IP用于限流:IPv4保持不变,IPv6标准化为/64网段 + normalizedIP := normalizeIPForRateLimit(cleanIP) + now := time.Now() i.mu.RLock() - entry, exists := i.ips[cleanIP] + entry, exists := i.ips[normalizedIP] i.mu.RUnlock() if exists { i.mu.Lock() - if entry, stillExists := i.ips[cleanIP]; stillExists { + if entry, stillExists := i.ips[normalizedIP]; stillExists { entry.lastAccess = now i.mu.Unlock() return entry.limiter, true @@ -198,7 +211,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { } i.mu.Lock() - if entry, exists := i.ips[cleanIP]; exists { + if entry, exists := i.ips[normalizedIP]; exists { entry.lastAccess = now i.mu.Unlock() return entry.limiter, true @@ -208,7 +221,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { limiter: rate.NewLimiter(i.r, i.b), lastAccess: now, } - i.ips[cleanIP] = entry + i.ips[normalizedIP] = entry i.mu.Unlock() return entry.limiter, true diff --git a/src/smart_ratelimit.go b/src/smart_ratelimit.go index cdb6074..c154cdc 100644 --- a/src/smart_ratelimit.go +++ b/src/smart_ratelimit.go @@ -41,7 +41,7 @@ func (s *SmartRateLimit) ShouldSkipRateLimit(ip, path string) bool { return false } - sessionKey := ip + sessionKey := normalizeIPForRateLimit(ip) sessionInterface, _ := s.sessions.LoadOrStore(sessionKey, &PullSession{}) session := sessionInterface.(*PullSession)