修复ipv6标准化的潜在BUG

This commit is contained in:
user123456
2025-06-17 18:38:48 +08:00
parent aea36939a3
commit 182dced403
3 changed files with 32 additions and 19 deletions

View File

@@ -171,11 +171,11 @@ func handler(c *gin.Context) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
proxy(c, rawPath)
proxyRequest(c, rawPath)
}
func proxy(c *gin.Context, u string) {
func proxyRequest(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0)
}

View File

@@ -132,23 +132,33 @@ func (i *IPRateLimiter) cleanupRoutine() {
}
}
// extractIPFromAddress 从地址中提取纯IP,去除端口号
// extractIPFromAddress 从地址中提取纯IP
func extractIPFromAddress(address string) string {
// 处理IPv6地址 [::1]:8080 格式
if strings.HasPrefix(address, "[") {
if endIndex := strings.Index(address, "]"); endIndex != -1 {
return address[1:endIndex]
}
if host, _, err := net.SplitHostPort(address); err == nil {
return host
}
// 处理IPv4地址 192.168.1.1:8080 格式
if lastColon := strings.LastIndex(address, ":"); lastColon != -1 {
return address[:lastColon]
}
return address
}
// normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr // 解析失败,返回原值
}
if ip.To4() != nil {
return ipStr // IPv4保持不变
}
// IPv6标准化为 /64 网段
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0 // 清零后64位
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址
@@ -181,15 +191,18 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
return rate.NewLimiter(rate.Inf, i.b), true
}
// 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now()
i.mu.RLock()
entry, exists := i.ips[cleanIP]
entry, exists := i.ips[normalizedIP]
i.mu.RUnlock()
if exists {
i.mu.Lock()
if entry, stillExists := i.ips[cleanIP]; stillExists {
if entry, stillExists := i.ips[normalizedIP]; stillExists {
entry.lastAccess = now
i.mu.Unlock()
return entry.limiter, true
@@ -198,7 +211,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
}
i.mu.Lock()
if entry, exists := i.ips[cleanIP]; exists {
if entry, exists := i.ips[normalizedIP]; exists {
entry.lastAccess = now
i.mu.Unlock()
return entry.limiter, true
@@ -208,7 +221,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now,
}
i.ips[cleanIP] = entry
i.ips[normalizedIP] = entry
i.mu.Unlock()
return entry.limiter, true

View File

@@ -41,7 +41,7 @@ func (s *SmartRateLimit) ShouldSkipRateLimit(ip, path string) bool {
return false
}
sessionKey := ip
sessionKey := normalizeIPForRateLimit(ip)
sessionInterface, _ := s.sessions.LoadOrStore(sessionKey, &PullSession{})
session := sessionInterface.(*PullSession)