diff --git a/ghproxy/main.go b/ghproxy/main.go index b288563..e9f4508 100644 --- a/ghproxy/main.go +++ b/ghproxy/main.go @@ -4,38 +4,23 @@ import ( "encoding/json" "fmt" "io" - "log" "net" "net/http" "os" - "path/filepath" "regexp" "strconv" "strings" "sync" "time" - "crypto/sha256" - "encoding/hex" - "github.com/gin-gonic/gin" ) +// 常量定义 const ( - MaxFileSize = 10 * 1024 * 1024 * 1024 // 允许的文件大小,默认10GB - ListenHost = "0.0.0.0" // 监听地址 - ListenPort = 5000 // 监听端口 - CacheDir = "cache" - // 是否开启缓存 - CacheExpiry = 0 * time.Minute // 默认不缓存 -) - -var ( - cache = sync.Map{} - exps = initRegexps() - httpClient = initHTTPClient() - config *Config - configLock sync.RWMutex + sizeLimit = 1024 * 1024 * 1024 * 10 // 允许的文件大小,默认10GB + host = "0.0.0.0" // 监听地址 + port = 5000 // 监听端口 ) type Config struct { @@ -43,28 +28,8 @@ type Config struct { BlackList []string `json:"blackList"` } -type CachedResponse struct { - Header http.Header - StatusCode int - Body []byte - Timestamp time.Time -} - -func init() { - if err := os.MkdirAll(CacheDir, 0755); err != nil { - log.Fatalf("Failed to create cache directory: %v", err) - } - go func() { - for { - time.Sleep(10 * time.Minute) - loadConfig() - } - }() - loadConfig() -} - -func initRegexps() []*regexp.Regexp { - return []*regexp.Regexp{ +var ( + exps = []*regexp.Regexp{ regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), @@ -75,10 +40,16 @@ func initRegexps() []*regexp.Regexp { regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), } -} + httpClient *http.Client + config *Config + configLock sync.RWMutex +) -func initHTTPClient() *http.Client { - return &http.Client{ +func main() { + gin.SetMode(gin.ReleaseMode) + router := gin.Default() + + httpClient = &http.Client{ Transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, @@ -92,23 +63,40 @@ func initHTTPClient() *http.Client { ResponseHeaderTimeout: 300 * time.Second, }, } -} -func main() { - gin.SetMode(gin.ReleaseMode) - router := gin.Default() + loadConfig() + + // 每60分钟热重载黑白名单 + go func() { + ticker := time.NewTicker(60 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + loadConfig() + } + }() + + // 前端访问路径,默认根路径 router.Static("/", "./public") router.NoRoute(handler) - addr := fmt.Sprintf("%s:%d", ListenHost, ListenPort) - if err := router.Run(addr); err != nil { - log.Fatalf("Error starting server: %v", err) + serverAddr := fmt.Sprintf("%s:%d", host, port) + if err := router.Run(serverAddr); err != nil { + fmt.Printf("Error starting server: %v\n", err) } } func handler(c *gin.Context) { rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") - rawPath = strings.TrimPrefix(rawPath, "/") + + for strings.HasPrefix(rawPath, "/") { + rawPath = strings.TrimPrefix(rawPath, "/") + } + // 脚本嵌套路径处理 + if rawPath == "perl-pe-para" { + handlePerlPePara(c) + return + } if !strings.HasPrefix(rawPath, "http") { c.String(http.StatusForbidden, "无效输入") @@ -116,52 +104,56 @@ func handler(c *gin.Context) { } matches := checkURL(rawPath) - if matches == nil { + if matches != nil { + configLock.RLock() + defer configLock.RUnlock() + + if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) { + c.String(http.StatusForbidden, "不在白名单内,限制访问。") + return + } + if len(config.BlackList) > 0 && checkList(matches, config.BlackList) { + c.String(http.StatusForbidden, "黑名单限制访问") + return + } + } else { c.String(http.StatusForbidden, "无效输入") return } - if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) { - c.String(http.StatusForbidden, "不在白名单内,限制访问。") - return - } - if len(config.BlackList) > 0 && checkList(matches, config.BlackList) { - c.String(http.StatusForbidden, "黑名单限制访问") - return - } - if exps[1].MatchString(rawPath) { rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) } proxy(c, rawPath) } +// 处理脚本嵌套相关函数 +func handlePerlPePara(c *gin.Context) { + perlstr := "perl -pe" + responseText := fmt.Sprintf(`s#(bash.*?\.sh)([^/\w\d])#\1 | %s "$(curl -L %s/perl-pe-para)" \2#g; s# (git)# https://\1#g; s#(http.*?git[^/]*?/)#%s/\1#g`, perlstr, c.Request.URL.String(), c.Request.URL.String()) + c.Header("Content-Type", "text/plain") + c.Header("Cache-Control", "max-age=300") + c.String(http.StatusOK, responseText) +} func proxy(c *gin.Context, u string) { - cacheKey := generateCacheKey(u) - // 当 CacheExpiry 为 0 时,不使用缓存 - if CacheExpiry != 0 { - if cachedData, ok := cache.Load(cacheKey); ok { - log.Printf("Using cached response for %s", u) - cached := cachedData.(*CachedResponse) - if time.Since(cached.Timestamp) < CacheExpiry { - setHeaders(c, cached.Header) - c.Status(cached.StatusCode) - c.Writer.Write(cached.Body) - return - } - } + // 检查是否脚本嵌套路径 + if strings.HasSuffix(u, "perl-pe-para") { + handlePerlPePara(c) + return } - log.Printf("use proxy response for %s", u) - req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) if err != nil { c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) return } - copyHeaders(req.Header, c.Request.Header) + for key, values := range c.Request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } req.Header.Del("Host") resp, err := httpClient.Do(req) @@ -169,17 +161,28 @@ func proxy(c *gin.Context, u string) { c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) return } - defer closeWithLog(resp.Body) + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("Error closing response body: %v\n", err) + } + }() if contentLength, ok := resp.Header["Content-Length"]; ok { - if size, err := strconv.Atoi(contentLength[0]); err == nil && size > MaxFileSize { + if size, err := strconv.Atoi(contentLength[0]); err == nil && size > sizeLimit { c.String(http.StatusRequestEntityTooLarge, "File too large.") return } } - removeHeaders(resp.Header, "Content-Security-Policy", "Referrer-Policy", "Strict-Transport-Security") - setHeaders(c, resp.Header) + resp.Header.Del("Content-Security-Policy") + resp.Header.Del("Referrer-Policy") + resp.Header.Del("Strict-Transport-Security") + + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } if location := resp.Header.Get("Location"); location != "" { if checkURL(location) != nil { @@ -191,51 +194,28 @@ func proxy(c *gin.Context, u string) { } c.Status(resp.StatusCode) - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("Failed to read response body: %v", err) + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + fmt.Printf("Error copying response body: %v\n", err) return } - if _, err := c.Writer.Write(body); err != nil { - log.Printf("Failed to write response body: %v", err) - return - } - - // 当 CacheExpiry 不为 0 时,保存到缓存 - if CacheExpiry != 0 { - // Save to cache - cached := &CachedResponse{ - Header: resp.Header, - StatusCode: resp.StatusCode, - Body: body, - Timestamp: time.Now(), - } - cache.Store(cacheKey, cached) - cacheFilePath := filepath.Join(CacheDir, cacheKey) - // 修改 ioutil.WriteFile 为 os.WriteFile - if err := os.WriteFile(cacheFilePath, body, 0644); err != nil { - log.Printf("Failed to write cache file: %v", err) - } - } -} - -func generateCacheKey(u string) string { - hash := sha256.Sum256([]byte(u)) - return hex.EncodeToString(hash[:]) } func loadConfig() { file, err := os.Open("config.json") if err != nil { - log.Printf("Error loading config: %v", err) + fmt.Printf("Error loading config: %v\n", err) return } - defer closeWithLog(file) + defer func() { + if err := file.Close(); err != nil { + fmt.Printf("Error closing config file: %v\n", err) + } + }() var newConfig Config decoder := json.NewDecoder(file) if err := decoder.Decode(&newConfig); err != nil { - log.Printf("Error decoding config: %v", err) + fmt.Printf("Error decoding config: %v\n", err) return } @@ -261,31 +241,3 @@ func checkList(matches, list []string) bool { } return false } - -func setHeaders(c *gin.Context, headers http.Header) { - for key, values := range headers { - for _, value := range values { - c.Header(key, value) - } - } -} - -func copyHeaders(dst, src http.Header) { - for key, values := range src { - for _, value := range values { - dst.Add(key, value) - } - } -} - -func removeHeaders(headers http.Header, keys ...string) { - for _, key := range keys { - headers.Del(key) - } -} - -func closeWithLog(c io.Closer) { - if err := c.Close(); err != nil { - log.Printf("Failed to close: %v", err) - } -}