github.com/yaling888/clash@v1.53.0/hub/route/server.go (about)

     1  package route
     2  
     3  import (
     4  	"crypto/subtle"
     5  	"encoding/json"
     6  	"net"
     7  	"net/http"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/go-chi/chi/v5"
    12  	"github.com/go-chi/cors"
    13  	"github.com/go-chi/render"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/phuslu/log"
    16  	"go.uber.org/atomic"
    17  
    18  	"github.com/yaling888/clash/common/observable"
    19  	"github.com/yaling888/clash/common/pool"
    20  	C "github.com/yaling888/clash/constant"
    21  	L "github.com/yaling888/clash/log"
    22  	"github.com/yaling888/clash/tunnel/statistic"
    23  )
    24  
    25  var (
    26  	serverSecret []byte
    27  	serverAddr   = ""
    28  
    29  	uiPath = ""
    30  
    31  	enablePPORF bool
    32  
    33  	bootTime = atomic.NewTime(time.Now())
    34  
    35  	upgrader = websocket.Upgrader{
    36  		CheckOrigin: func(r *http.Request) bool {
    37  			return true
    38  		},
    39  	}
    40  )
    41  
    42  type Traffic struct {
    43  	Up   int64 `json:"up"`
    44  	Down int64 `json:"down"`
    45  }
    46  
    47  func SetUIPath(path string) {
    48  	uiPath = C.Path.Resolve(path)
    49  }
    50  
    51  func SetPPROF(pprof bool) {
    52  	enablePPORF = pprof
    53  }
    54  
    55  func Start(addr string, secret string) {
    56  	if serverAddr != "" {
    57  		return
    58  	}
    59  
    60  	serverAddr = addr
    61  	serverSecret = []byte(secret)
    62  
    63  	r := chi.NewRouter()
    64  
    65  	corsM := cors.New(cors.Options{
    66  		AllowedOrigins: []string{"*"},
    67  		AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"},
    68  		AllowedHeaders: []string{"Content-Type", "Authorization"},
    69  		MaxAge:         300,
    70  	})
    71  
    72  	r.Use(corsM.Handler)
    73  	r.Group(func(r chi.Router) {
    74  		r.Use(authentication)
    75  
    76  		r.Get("/", hello)
    77  		r.Get("/logs", getLogs)
    78  		r.Get("/traffic", traffic)
    79  		r.Get("/version", version)
    80  		r.Get("/uptime", uptime)
    81  		r.Mount("/configs", configRouter())
    82  		r.Mount("/configs/geo", configGeoRouter())
    83  		r.Mount("/inbounds", inboundRouter())
    84  		r.Mount("/proxies", proxyRouter())
    85  		r.Mount("/rules", ruleRouter())
    86  		r.Mount("/connections", connectionRouter())
    87  		r.Mount("/providers/proxies", proxyProviderRouter())
    88  		r.Mount("/cache", cacheRouter())
    89  		r.Mount("/dns", dnsRouter())
    90  	})
    91  
    92  	if uiPath != "" {
    93  		r.Group(func(r chi.Router) {
    94  			fs := http.StripPrefix("/ui", http.FileServer(http.Dir(uiPath)))
    95  			r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP)
    96  			r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) {
    97  				fs.ServeHTTP(w, r)
    98  			})
    99  		})
   100  	}
   101  
   102  	if enablePPORF {
   103  		r.Mount("/debug/pprof", pprofRouter())
   104  	}
   105  
   106  	l, err := net.Listen("tcp", addr)
   107  	if err != nil {
   108  		log.Error().Err(err).Msg("[API] external controller listen failed")
   109  		return
   110  	}
   111  	serverAddr = l.Addr().String()
   112  	log.Info().Str("addr", serverAddr).Msg("[API] listening")
   113  	if err = http.Serve(l, r); err != nil {
   114  		log.Error().Err(err).Msg("[API] external controller serve failed")
   115  	}
   116  }
   117  
   118  func authentication(next http.Handler) http.Handler {
   119  	fn := func(w http.ResponseWriter, r *http.Request) {
   120  		if len(serverSecret) == 0 {
   121  			next.ServeHTTP(w, r)
   122  			return
   123  		}
   124  
   125  		// Browser websocket not support custom header
   126  		if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" {
   127  			token := r.URL.Query().Get("token")
   128  			if subtle.ConstantTimeCompare([]byte(token), serverSecret) != 1 {
   129  				render.Status(r, http.StatusUnauthorized)
   130  				render.JSON(w, r, ErrUnauthorized)
   131  				return
   132  			}
   133  			next.ServeHTTP(w, r)
   134  			return
   135  		}
   136  
   137  		header := r.Header.Get("Authorization")
   138  		bearer, token, found := strings.Cut(header, " ")
   139  
   140  		hasInvalidHeader := bearer != "Bearer"
   141  		hasInvalidSecret := !found || subtle.ConstantTimeCompare([]byte(token), serverSecret) != 1
   142  		if hasInvalidHeader || hasInvalidSecret {
   143  			render.Status(r, http.StatusUnauthorized)
   144  			render.JSON(w, r, ErrUnauthorized)
   145  			return
   146  		}
   147  		next.ServeHTTP(w, r)
   148  	}
   149  	return http.HandlerFunc(fn)
   150  }
   151  
   152  func hello(w http.ResponseWriter, r *http.Request) {
   153  	render.JSON(w, r, render.M{"hello": "clash plus pro"})
   154  }
   155  
   156  func traffic(w http.ResponseWriter, r *http.Request) {
   157  	var wsConn *websocket.Conn
   158  	if websocket.IsWebSocketUpgrade(r) {
   159  		var err error
   160  		wsConn, err = upgrader.Upgrade(w, r, nil)
   161  		if err != nil {
   162  			return
   163  		}
   164  	}
   165  
   166  	if wsConn == nil {
   167  		w.Header().Set("Content-Type", "application/json")
   168  		render.Status(r, http.StatusOK)
   169  	}
   170  
   171  	tick := time.NewTicker(time.Second)
   172  	defer tick.Stop()
   173  	t := statistic.DefaultManager
   174  	buf := pool.BufferWriter{}
   175  	encoder := json.NewEncoder(&buf)
   176  	var err error
   177  	for range tick.C {
   178  		buf.Reset()
   179  		up, down := t.Now()
   180  		if err := encoder.Encode(Traffic{
   181  			Up:   up,
   182  			Down: down,
   183  		}); err != nil {
   184  			break
   185  		}
   186  
   187  		if wsConn == nil {
   188  			_, err = w.Write(buf.Bytes())
   189  			w.(http.Flusher).Flush()
   190  		} else {
   191  			err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
   192  		}
   193  
   194  		if err != nil {
   195  			break
   196  		}
   197  	}
   198  }
   199  
   200  type Log struct {
   201  	Type    string `json:"type"`
   202  	Payload string `json:"payload"`
   203  }
   204  
   205  func getLogs(w http.ResponseWriter, r *http.Request) {
   206  	var (
   207  		levelText = r.URL.Query().Get("level")
   208  		format    = r.URL.Query().Get("format")
   209  	)
   210  	if levelText == "" {
   211  		levelText = "info"
   212  	}
   213  	if format == "" {
   214  		format = "text"
   215  	}
   216  
   217  	level, ok := L.LogLevelMapping[levelText]
   218  	if !ok {
   219  		render.Status(r, http.StatusBadRequest)
   220  		render.JSON(w, r, ErrBadRequest)
   221  		return
   222  	}
   223  
   224  	var wsConn *websocket.Conn
   225  	if websocket.IsWebSocketUpgrade(r) {
   226  		var err error
   227  		wsConn, err = upgrader.Upgrade(w, r, nil)
   228  		if err != nil {
   229  			return
   230  		}
   231  	}
   232  
   233  	var (
   234  		sub     observable.Subscription[L.Event]
   235  		ch      = make(chan L.Event, 1024)
   236  		buf     = pool.BufferWriter{}
   237  		encoder = json.NewEncoder(&buf)
   238  		closed  = false
   239  	)
   240  
   241  	if wsConn == nil {
   242  		w.Header().Set("Content-Type", "application/json")
   243  		render.Status(r, http.StatusOK)
   244  	} else if level > L.INFO {
   245  		go func() {
   246  			for _, _, err := wsConn.ReadMessage(); err != nil; {
   247  				closed = true
   248  				break
   249  			}
   250  		}()
   251  	}
   252  
   253  	if strings.EqualFold(format, "structured") {
   254  		sub = L.SubscribeJson()
   255  		defer L.UnSubscribeJson(sub)
   256  	} else {
   257  		sub = L.SubscribeText()
   258  		defer L.UnSubscribeText(sub)
   259  	}
   260  
   261  	go func() {
   262  		for elm := range sub {
   263  			select {
   264  			case ch <- elm:
   265  			default:
   266  			}
   267  		}
   268  		close(ch)
   269  	}()
   270  
   271  	for logM := range ch {
   272  		if closed {
   273  			break
   274  		}
   275  		if logM.LogLevel < level {
   276  			continue
   277  		}
   278  		buf.Reset()
   279  
   280  		if err := encoder.Encode(Log{
   281  			Type:    logM.Type(),
   282  			Payload: logM.Payload,
   283  		}); err != nil {
   284  			break
   285  		}
   286  
   287  		var err error
   288  		if wsConn == nil {
   289  			_, err = w.Write(buf.Bytes())
   290  			w.(http.Flusher).Flush()
   291  		} else {
   292  			err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
   293  		}
   294  
   295  		if err != nil {
   296  			break
   297  		}
   298  	}
   299  }
   300  
   301  func version(w http.ResponseWriter, r *http.Request) {
   302  	render.JSON(w, r, render.M{"version": "PlusPro-" + C.Version, "plus-pro": true})
   303  }
   304  
   305  func uptime(w http.ResponseWriter, r *http.Request) {
   306  	bt := bootTime.Load()
   307  	render.JSON(w, r, render.M{
   308  		"bootTime": bt.Format("2006-01-02 15:04:05 Mon -0700"),
   309  	})
   310  }