github.com/sagernet/sing-box@v1.2.7/experimental/clashapi/server.go (about)

     1  package clashapi
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/sagernet/sing-box/adapter"
    14  	"github.com/sagernet/sing-box/common/json"
    15  	"github.com/sagernet/sing-box/common/urltest"
    16  	C "github.com/sagernet/sing-box/constant"
    17  	"github.com/sagernet/sing-box/experimental"
    18  	"github.com/sagernet/sing-box/experimental/clashapi/cachefile"
    19  	"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
    20  	"github.com/sagernet/sing-box/log"
    21  	"github.com/sagernet/sing-box/option"
    22  	"github.com/sagernet/sing/common"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	F "github.com/sagernet/sing/common/format"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/websocket"
    27  
    28  	"github.com/go-chi/chi/v5"
    29  	"github.com/go-chi/cors"
    30  	"github.com/go-chi/render"
    31  )
    32  
    33  func init() {
    34  	experimental.RegisterClashServerConstructor(NewServer)
    35  }
    36  
    37  var _ adapter.ClashServer = (*Server)(nil)
    38  
    39  type Server struct {
    40  	router         adapter.Router
    41  	logger         log.Logger
    42  	httpServer     *http.Server
    43  	trafficManager *trafficontrol.Manager
    44  	urlTestHistory *urltest.HistoryStorage
    45  	mode           string
    46  	storeSelected  bool
    47  	cacheFilePath  string
    48  	cacheFile      adapter.ClashCacheFile
    49  }
    50  
    51  func NewServer(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
    52  	trafficManager := trafficontrol.NewManager()
    53  	chiRouter := chi.NewRouter()
    54  	server := &Server{
    55  		router: router,
    56  		logger: logFactory.NewLogger("clash-api"),
    57  		httpServer: &http.Server{
    58  			Addr:    options.ExternalController,
    59  			Handler: chiRouter,
    60  		},
    61  		trafficManager: trafficManager,
    62  		urlTestHistory: urltest.NewHistoryStorage(),
    63  		mode:           strings.ToLower(options.DefaultMode),
    64  	}
    65  	if server.mode == "" {
    66  		server.mode = "rule"
    67  	}
    68  	if options.StoreSelected {
    69  		server.storeSelected = true
    70  		cachePath := os.ExpandEnv(options.CacheFile)
    71  		if cachePath == "" {
    72  			cachePath = "cache.db"
    73  		}
    74  		if foundPath, loaded := C.FindPath(cachePath); loaded {
    75  			cachePath = foundPath
    76  		} else {
    77  			cachePath = C.BasePath(cachePath)
    78  		}
    79  		server.cacheFilePath = cachePath
    80  	}
    81  	cors := cors.New(cors.Options{
    82  		AllowedOrigins: []string{"*"},
    83  		AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"},
    84  		AllowedHeaders: []string{"Content-Type", "Authorization"},
    85  		MaxAge:         300,
    86  	})
    87  	chiRouter.Use(cors.Handler)
    88  	chiRouter.Group(func(r chi.Router) {
    89  		r.Use(authentication(options.Secret))
    90  		r.Get("/", hello(options.ExternalUI != ""))
    91  		r.Get("/logs", getLogs(logFactory))
    92  		r.Get("/traffic", traffic(trafficManager))
    93  		r.Get("/version", version)
    94  		r.Mount("/configs", configRouter(server, logFactory, server.logger))
    95  		r.Mount("/proxies", proxyRouter(server, router))
    96  		r.Mount("/rules", ruleRouter(router))
    97  		r.Mount("/connections", connectionRouter(trafficManager))
    98  		r.Mount("/providers/proxies", proxyProviderRouter())
    99  		r.Mount("/providers/rules", ruleProviderRouter())
   100  		r.Mount("/script", scriptRouter())
   101  		r.Mount("/profile", profileRouter())
   102  		r.Mount("/cache", cacheRouter())
   103  		r.Mount("/dns", dnsRouter(router))
   104  	})
   105  	if options.ExternalUI != "" {
   106  		chiRouter.Group(func(r chi.Router) {
   107  			fs := http.StripPrefix("/ui", http.FileServer(http.Dir(C.BasePath(os.ExpandEnv(options.ExternalUI)))))
   108  			r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP)
   109  			r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) {
   110  				fs.ServeHTTP(w, r)
   111  			})
   112  		})
   113  	}
   114  	return server, nil
   115  }
   116  
   117  func (s *Server) PreStart() error {
   118  	if s.cacheFilePath != "" {
   119  		cacheFile, err := cachefile.Open(s.cacheFilePath)
   120  		if err != nil {
   121  			return E.Cause(err, "open cache file")
   122  		}
   123  		s.cacheFile = cacheFile
   124  	}
   125  	return nil
   126  }
   127  
   128  func (s *Server) Start() error {
   129  	listener, err := net.Listen("tcp", s.httpServer.Addr)
   130  	if err != nil {
   131  		return E.Cause(err, "external controller listen error")
   132  	}
   133  	s.logger.Info("restful api listening at ", listener.Addr())
   134  	go func() {
   135  		err = s.httpServer.Serve(listener)
   136  		if err != nil && !errors.Is(err, http.ErrServerClosed) {
   137  			s.logger.Error("external controller serve error: ", err)
   138  		}
   139  	}()
   140  	return nil
   141  }
   142  
   143  func (s *Server) Close() error {
   144  	return common.Close(
   145  		common.PtrOrNil(s.httpServer),
   146  		s.trafficManager,
   147  		s.cacheFile,
   148  	)
   149  }
   150  
   151  func (s *Server) Mode() string {
   152  	return s.mode
   153  }
   154  
   155  func (s *Server) StoreSelected() bool {
   156  	return s.storeSelected
   157  }
   158  
   159  func (s *Server) CacheFile() adapter.ClashCacheFile {
   160  	return s.cacheFile
   161  }
   162  
   163  func (s *Server) HistoryStorage() *urltest.HistoryStorage {
   164  	return s.urlTestHistory
   165  }
   166  
   167  func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
   168  	tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
   169  	return tracker, tracker
   170  }
   171  
   172  func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) {
   173  	tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
   174  	return tracker, tracker
   175  }
   176  
   177  func castMetadata(metadata adapter.InboundContext) trafficontrol.Metadata {
   178  	var inbound string
   179  	if metadata.Inbound != "" {
   180  		inbound = metadata.InboundType + "/" + metadata.Inbound
   181  	} else {
   182  		inbound = metadata.InboundType
   183  	}
   184  	var domain string
   185  	if metadata.Domain != "" {
   186  		domain = metadata.Domain
   187  	} else {
   188  		domain = metadata.Destination.Fqdn
   189  	}
   190  	var processPath string
   191  	if metadata.ProcessInfo != nil {
   192  		if metadata.ProcessInfo.ProcessPath != "" {
   193  			processPath = metadata.ProcessInfo.ProcessPath
   194  		} else if metadata.ProcessInfo.PackageName != "" {
   195  			processPath = metadata.ProcessInfo.PackageName
   196  		}
   197  		if processPath == "" {
   198  			if metadata.ProcessInfo.UserId != -1 {
   199  				processPath = F.ToString(metadata.ProcessInfo.UserId)
   200  			}
   201  		} else if metadata.ProcessInfo.User != "" {
   202  			processPath = F.ToString(processPath, " (", metadata.ProcessInfo.User, ")")
   203  		} else if metadata.ProcessInfo.UserId != -1 {
   204  			processPath = F.ToString(processPath, " (", metadata.ProcessInfo.UserId, ")")
   205  		}
   206  	}
   207  	return trafficontrol.Metadata{
   208  		NetWork:     metadata.Network,
   209  		Type:        inbound,
   210  		SrcIP:       metadata.Source.Addr,
   211  		DstIP:       metadata.Destination.Addr,
   212  		SrcPort:     F.ToString(metadata.Source.Port),
   213  		DstPort:     F.ToString(metadata.Destination.Port),
   214  		Host:        domain,
   215  		DNSMode:     "normal",
   216  		ProcessPath: processPath,
   217  	}
   218  }
   219  
   220  func authentication(serverSecret string) func(next http.Handler) http.Handler {
   221  	return func(next http.Handler) http.Handler {
   222  		fn := func(w http.ResponseWriter, r *http.Request) {
   223  			if serverSecret == "" {
   224  				next.ServeHTTP(w, r)
   225  				return
   226  			}
   227  
   228  			// Browser websocket not support custom header
   229  			if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" {
   230  				token := r.URL.Query().Get("token")
   231  				if token != serverSecret {
   232  					render.Status(r, http.StatusUnauthorized)
   233  					render.JSON(w, r, ErrUnauthorized)
   234  					return
   235  				}
   236  				next.ServeHTTP(w, r)
   237  				return
   238  			}
   239  
   240  			header := r.Header.Get("Authorization")
   241  			bearer, token, found := strings.Cut(header, " ")
   242  
   243  			hasInvalidHeader := bearer != "Bearer"
   244  			hasInvalidSecret := !found || token != serverSecret
   245  			if hasInvalidHeader || hasInvalidSecret {
   246  				render.Status(r, http.StatusUnauthorized)
   247  				render.JSON(w, r, ErrUnauthorized)
   248  				return
   249  			}
   250  			next.ServeHTTP(w, r)
   251  		}
   252  		return http.HandlerFunc(fn)
   253  	}
   254  }
   255  
   256  func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) {
   257  	return func(w http.ResponseWriter, r *http.Request) {
   258  		if redirect {
   259  			http.Redirect(w, r, "/ui/", http.StatusTemporaryRedirect)
   260  		} else {
   261  			render.JSON(w, r, render.M{"hello": "clash"})
   262  		}
   263  	}
   264  }
   265  
   266  var upgrader = websocket.Upgrader{
   267  	CheckOrigin: func(r *http.Request) bool {
   268  		return true
   269  	},
   270  }
   271  
   272  type Traffic struct {
   273  	Up   int64 `json:"up"`
   274  	Down int64 `json:"down"`
   275  }
   276  
   277  func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
   278  	return func(w http.ResponseWriter, r *http.Request) {
   279  		var wsConn *websocket.Conn
   280  		if websocket.IsWebSocketUpgrade(r) {
   281  			var err error
   282  			wsConn, err = upgrader.Upgrade(w, r, nil)
   283  			if err != nil {
   284  				return
   285  			}
   286  		}
   287  
   288  		if wsConn == nil {
   289  			w.Header().Set("Content-Type", "application/json")
   290  			render.Status(r, http.StatusOK)
   291  		}
   292  
   293  		tick := time.NewTicker(time.Second)
   294  		defer tick.Stop()
   295  		buf := &bytes.Buffer{}
   296  		var err error
   297  		for range tick.C {
   298  			buf.Reset()
   299  			up, down := trafficManager.Now()
   300  			if err := json.NewEncoder(buf).Encode(Traffic{
   301  				Up:   up,
   302  				Down: down,
   303  			}); err != nil {
   304  				break
   305  			}
   306  
   307  			if wsConn == nil {
   308  				_, err = w.Write(buf.Bytes())
   309  				w.(http.Flusher).Flush()
   310  			} else {
   311  				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
   312  			}
   313  
   314  			if err != nil {
   315  				break
   316  			}
   317  		}
   318  	}
   319  }
   320  
   321  type Log struct {
   322  	Type    string `json:"type"`
   323  	Payload string `json:"payload"`
   324  }
   325  
   326  func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
   327  	return func(w http.ResponseWriter, r *http.Request) {
   328  		levelText := r.URL.Query().Get("level")
   329  		if levelText == "" {
   330  			levelText = "info"
   331  		}
   332  
   333  		level, ok := log.ParseLevel(levelText)
   334  		if ok != nil {
   335  			render.Status(r, http.StatusBadRequest)
   336  			render.JSON(w, r, ErrBadRequest)
   337  			return
   338  		}
   339  
   340  		subscription, done, err := logFactory.Subscribe()
   341  		if err != nil {
   342  			render.Status(r, http.StatusNoContent)
   343  			return
   344  		}
   345  		defer logFactory.UnSubscribe(subscription)
   346  
   347  		var wsConn *websocket.Conn
   348  		if websocket.IsWebSocketUpgrade(r) {
   349  			var err error
   350  			wsConn, err = upgrader.Upgrade(w, r, nil)
   351  			if err != nil {
   352  				return
   353  			}
   354  		}
   355  
   356  		if wsConn == nil {
   357  			w.Header().Set("Content-Type", "application/json")
   358  			render.Status(r, http.StatusOK)
   359  		}
   360  
   361  		buf := &bytes.Buffer{}
   362  		var logEntry log.Entry
   363  		for {
   364  			select {
   365  			case <-done:
   366  				return
   367  			case logEntry = <-subscription:
   368  			}
   369  			if logEntry.Level > level {
   370  				continue
   371  			}
   372  			buf.Reset()
   373  			err = json.NewEncoder(buf).Encode(Log{
   374  				Type:    log.FormatLevel(logEntry.Level),
   375  				Payload: logEntry.Message,
   376  			})
   377  			if err != nil {
   378  				break
   379  			}
   380  			if wsConn == nil {
   381  				_, err = w.Write(buf.Bytes())
   382  				w.(http.Flusher).Flush()
   383  			} else {
   384  				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
   385  			}
   386  
   387  			if err != nil {
   388  				break
   389  			}
   390  		}
   391  	}
   392  }
   393  
   394  func version(w http.ResponseWriter, r *http.Request) {
   395  	render.JSON(w, r, render.M{"version": "sing-box " + C.Version, "premium": true})
   396  }