Files
hubproxy/ghproxy/main.go
2024-12-25 06:29:19 +08:00

234 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"sync"
"time"
)
const (
sizeLimit = 1024 * 1024 * 1024 * 10 // 允许的文件大小默认10GB
host = "0.0.0.0" // 监听地址
port = 5000 // 监听端口
)
var (
exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co/spaces/([^/]+)/([^/]+).*`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co/repos/.+`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
}
httpClient *http.Client
config *Config
configLock sync.RWMutex
)
type Config struct {
WhiteList []string `json:"whiteList"`
BlackList []string `json:"blackList"`
}
func main() {
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
httpClient = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 300 * time.Second,
},
}
loadConfig()
go func() {
for {
time.Sleep(10 * time.Minute)
loadConfig()
}
}()
// 前端访问路径,默认根路径
router.Static("/", "./public")
router.NoRoute(handler)
err := router.Run(fmt.Sprintf("%s:%d", host, port))
if err != nil {
fmt.Printf("Error starting server: %v\n", err)
}
}
func handler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
if !strings.HasPrefix(rawPath, "http") {
c.String(http.StatusForbidden, "无效输入")
return
}
matches := checkURL(rawPath)
if matches != nil {
if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) {
c.String(http.StatusForbidden, "不在白名单内,限制访问。")
return
}
if len(config.BlackList) > 0 && checkList(matches, config.BlackList) {
c.String(http.StatusForbidden, "黑名单限制访问")
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
if exps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
proxy(c, rawPath)
}
func proxy(c *gin.Context, u string) {
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := httpClient.Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(resp.Body)
if contentLength, ok := resp.Header["Content-Length"]; ok {
if size, err := strconv.Atoi(contentLength[0]); err == nil && size > sizeLimit {
c.String(http.StatusRequestEntityTooLarge, "File too large.")
return
}
}
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxy(c, location)
return
}
}
c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return
}
}
func loadConfig() {
file, err := os.Open("config.json")
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
return
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
}
}(file)
var newConfig Config
decoder := json.NewDecoder(file)
if err := decoder.Decode(&newConfig); err != nil {
fmt.Printf("Error decoding config: %v\n", err)
return
}
configLock.Lock()
config = &newConfig
configLock.Unlock()
}
func checkURL(u string) []string {
for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// 匹配用户名、仓库名,或者用户名/仓库名
func checkList(matches, list []string) bool {
if len(matches) == 0 {
return false
}
for _, item := range list {
// 检查第一个匹配
if strings.HasPrefix(matches[0], item) {
return true
}
// 只有当存在第二个匹配时才检查
if len(matches) > 1 {
if strings.HasPrefix(matches[1], item) {
return true
}
// 只有当存在两个捕获组时才检查组合
if len(matches) > 2 {
if strings.HasPrefix(matches[0]+"/"+matches[1], item) {
return true
}
}
}
}
return false
}