Files
hubproxy/src/ratelimiter.go
2025-06-17 18:49:34 +08:00

320 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
const (
// 清理间隔
CleanupInterval = 10 * time.Minute
MaxIPCacheSize = 10000
)
// IPRateLimiter IP限流器结构体
type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射
mu *sync.RWMutex // 读写锁,保证并发安全
r rate.Limit // 速率限制(每秒允许的请求数)
b int // 令牌桶容量(突发请求数)
whitelist []*net.IPNet // 白名单IP段
blacklist []*net.IPNet // 黑名单IP段
}
// rateLimiterEntry 限流器条目
type rateLimiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// initGlobalLimiter 初始化全局限流器
func initGlobalLimiter() *IPRateLimiter {
cfg := GetConfig()
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
whitelist = append(whitelist, ipnet)
} else {
fmt.Printf("警告: 无效的白名单IP格式: %s\n", item)
}
}
}
// 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
blacklist = append(blacklist, ipnet)
} else {
fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item)
}
}
}
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
burstSize := cfg.RateLimit.RequestLimit
if burstSize < 1 {
burstSize = 1
}
limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
}
// 启动定期清理goroutine
go limiter.cleanupRoutine()
return limiter
}
// initLimiter 初始化限流器
func initLimiter() {
globalLimiter = initGlobalLimiter()
}
// cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expired := make([]string, 0)
// 查找过期的条目
i.mu.RLock()
for ip, entry := range i.ips {
// 如果最后访问时间超过1小时认为过期
if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip)
}
}
i.mu.RUnlock()
// 如果有过期条目或者缓存过大,进行清理
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock()
// 删除过期条目
for _, ip := range expired {
delete(i.ips, ip)
}
// 如果缓存仍然过大,全部清理
if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry)
}
i.mu.Unlock()
}
}
}
// extractIPFromAddress 从地址中提取纯IP
func extractIPFromAddress(address string) string {
if host, _, err := net.SplitHostPort(address); err == nil {
return host
}
return address
}
// normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr // 解析失败,返回原值
}
if ip.To4() != nil {
return ipStr // IPv4保持不变
}
// IPv6标准化为 /64 网段
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0 // 清零后64位
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址
cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil {
return false
}
for _, cidr := range cidrList {
if cidr.Contains(parsedIP) {
return true
}
}
return false
}
// GetLimiter 获取指定IP的限流器同时返回是否允许访问
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 提取纯IP地址
cleanIP := extractIPFromAddress(ip)
// 检查是否在黑名单中
if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false
}
// 检查是否在白名单中
if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true
}
// 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now()
i.mu.RLock()
entry, exists := i.ips[normalizedIP]
i.mu.RUnlock()
if exists {
i.mu.Lock()
if entry, stillExists := i.ips[normalizedIP]; stillExists {
entry.lastAccess = now
i.mu.Unlock()
return entry.limiter, true
}
i.mu.Unlock()
}
i.mu.Lock()
if entry, exists := i.ips[normalizedIP]; exists {
entry.lastAccess = now
i.mu.Unlock()
return entry.limiter, true
}
entry = &rateLimiterEntry{
limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now,
}
i.ips[normalizedIP] = entry
i.mu.Unlock()
return entry.limiter, true
}
// RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端真实IP
var ip string
// 优先尝试从请求头获取真实IP
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For可能包含多个IP取第一个
ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
// 如果有X-Real-IP头
ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
// 某些代理可能使用此头
ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0])
} else {
// 回退到ClientIP方法
ip = c.ClientIP()
}
// 提取纯IP地址去除可能存在的端口
cleanIP := extractIPFromAddress(ip)
// 日志记录请求IP和头信息
normalizedIP := normalizeIPForRateLimit(cleanIP)
if cleanIP != normalizedIP {
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP, normalizedIP,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
} else {
fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
}
// 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
// 如果IP在黑名单中
if !allowed {
c.JSON(403, gin.H{
"error": "您已被限制访问",
})
c.Abort()
return
}
// 智能限流判断:检查是否应该跳过限流计数
shouldSkip := smartLimiter.ShouldSkipRateLimit(cleanIP, c.Request.URL.Path)
// 只有在不跳过的情况下才检查限流
if !shouldSkip && !ipLimiter.Allow() {
c.JSON(429, gin.H{
"error": "请求频率过快,暂时限制访问",
})
c.Abort()
return
}
c.Next()
}
}
// ApplyRateLimit 应用限流到特定路由
func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) {
// 使用全局限流器
limiter := globalLimiter
if limiter == nil {
limiter = initGlobalLimiter()
}
// 根据HTTP方法应用限流
switch method {
case "GET":
router.GET(path, RateLimitMiddleware(limiter), handler)
case "POST":
router.POST(path, RateLimitMiddleware(limiter), handler)
case "PUT":
router.PUT(path, RateLimitMiddleware(limiter), handler)
case "DELETE":
router.DELETE(path, RateLimitMiddleware(limiter), handler)
default:
router.Any(path, RateLimitMiddleware(limiter), handler)
}
}