github.com/sagernet/sing-box@v1.9.0-rc.20/experimental/clashapi/server.go (about)

     1  package clashapi
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/sagernet/sing-box/adapter"
    14  	"github.com/sagernet/sing-box/common/urltest"
    15  	C "github.com/sagernet/sing-box/constant"
    16  	"github.com/sagernet/sing-box/experimental"
    17  	"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
    18  	"github.com/sagernet/sing-box/log"
    19  	"github.com/sagernet/sing-box/option"
    20  	"github.com/sagernet/sing/common"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	F "github.com/sagernet/sing/common/format"
    23  	"github.com/sagernet/sing/common/json"
    24  	N "github.com/sagernet/sing/common/network"
    25  	"github.com/sagernet/sing/service"
    26  	"github.com/sagernet/sing/service/filemanager"
    27  	"github.com/sagernet/ws"
    28  	"github.com/sagernet/ws/wsutil"
    29  
    30  	"github.com/go-chi/chi/v5"
    31  	"github.com/go-chi/cors"
    32  	"github.com/go-chi/render"
    33  )
    34  
    35  func init() {
    36  	experimental.RegisterClashServerConstructor(NewServer)
    37  }
    38  
    39  var _ adapter.ClashServer = (*Server)(nil)
    40  
    41  type Server struct {
    42  	ctx            context.Context
    43  	router         adapter.Router
    44  	logger         log.Logger
    45  	httpServer     *http.Server
    46  	trafficManager *trafficontrol.Manager
    47  	urlTestHistory *urltest.HistoryStorage
    48  	mode           string
    49  	modeList       []string
    50  	modeUpdateHook chan<- struct{}
    51  
    52  	externalController       bool
    53  	externalUI               string
    54  	externalUIDownloadURL    string
    55  	externalUIDownloadDetour string
    56  }
    57  
    58  func NewServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
    59  	trafficManager := trafficontrol.NewManager()
    60  	chiRouter := chi.NewRouter()
    61  	server := &Server{
    62  		ctx:    ctx,
    63  		router: router,
    64  		logger: logFactory.NewLogger("clash-api"),
    65  		httpServer: &http.Server{
    66  			Addr:    options.ExternalController,
    67  			Handler: chiRouter,
    68  		},
    69  		trafficManager:           trafficManager,
    70  		modeList:                 options.ModeList,
    71  		externalController:       options.ExternalController != "",
    72  		externalUIDownloadURL:    options.ExternalUIDownloadURL,
    73  		externalUIDownloadDetour: options.ExternalUIDownloadDetour,
    74  	}
    75  	server.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx)
    76  	if server.urlTestHistory == nil {
    77  		server.urlTestHistory = urltest.NewHistoryStorage()
    78  	}
    79  	defaultMode := "Rule"
    80  	if options.DefaultMode != "" {
    81  		defaultMode = options.DefaultMode
    82  	}
    83  	if !common.Contains(server.modeList, defaultMode) {
    84  		server.modeList = append([]string{defaultMode}, server.modeList...)
    85  	}
    86  	server.mode = defaultMode
    87  	//goland:noinspection GoDeprecation
    88  	//nolint:staticcheck
    89  	if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.CacheFile != "" || options.CacheID != "" {
    90  		return nil, E.New("cache_file and related fields in Clash API is deprecated in sing-box 1.8.0, use experimental.cache_file instead.")
    91  	}
    92  	cors := cors.New(cors.Options{
    93  		AllowedOrigins: []string{"*"},
    94  		AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"},
    95  		AllowedHeaders: []string{"Content-Type", "Authorization"},
    96  		MaxAge:         300,
    97  	})
    98  	chiRouter.Use(cors.Handler)
    99  	chiRouter.Group(func(r chi.Router) {
   100  		r.Use(authentication(options.Secret))
   101  		r.Get("/", hello(options.ExternalUI != ""))
   102  		r.Get("/logs", getLogs(logFactory))
   103  		r.Get("/traffic", traffic(trafficManager))
   104  		r.Get("/version", version)
   105  		r.Mount("/configs", configRouter(server, logFactory))
   106  		r.Mount("/proxies", proxyRouter(server, router))
   107  		r.Mount("/rules", ruleRouter(router))
   108  		r.Mount("/connections", connectionRouter(router, trafficManager))
   109  		r.Mount("/providers/proxies", proxyProviderRouter())
   110  		r.Mount("/providers/rules", ruleProviderRouter())
   111  		r.Mount("/script", scriptRouter())
   112  		r.Mount("/profile", profileRouter())
   113  		r.Mount("/cache", cacheRouter(ctx))
   114  		r.Mount("/dns", dnsRouter(router))
   115  
   116  		server.setupMetaAPI(r)
   117  	})
   118  	if options.ExternalUI != "" {
   119  		server.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI))
   120  		chiRouter.Group(func(r chi.Router) {
   121  			fs := http.StripPrefix("/ui", http.FileServer(http.Dir(server.externalUI)))
   122  			r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP)
   123  			r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) {
   124  				fs.ServeHTTP(w, r)
   125  			})
   126  		})
   127  	}
   128  	return server, nil
   129  }
   130  
   131  func (s *Server) PreStart() error {
   132  	cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
   133  	if cacheFile != nil {
   134  		mode := cacheFile.LoadMode()
   135  		if common.Any(s.modeList, func(it string) bool {
   136  			return strings.EqualFold(it, mode)
   137  		}) {
   138  			s.mode = mode
   139  		}
   140  	}
   141  	return nil
   142  }
   143  
   144  func (s *Server) Start() error {
   145  	if s.externalController {
   146  		s.checkAndDownloadExternalUI()
   147  		listener, err := net.Listen("tcp", s.httpServer.Addr)
   148  		if err != nil {
   149  			return E.Cause(err, "external controller listen error")
   150  		}
   151  		s.logger.Info("restful api listening at ", listener.Addr())
   152  		go func() {
   153  			err = s.httpServer.Serve(listener)
   154  			if err != nil && !errors.Is(err, http.ErrServerClosed) {
   155  				s.logger.Error("external controller serve error: ", err)
   156  			}
   157  		}()
   158  	}
   159  	return nil
   160  }
   161  
   162  func (s *Server) Close() error {
   163  	return common.Close(
   164  		common.PtrOrNil(s.httpServer),
   165  		s.trafficManager,
   166  		s.urlTestHistory,
   167  	)
   168  }
   169  
   170  func (s *Server) Mode() string {
   171  	return s.mode
   172  }
   173  
   174  func (s *Server) ModeList() []string {
   175  	return s.modeList
   176  }
   177  
   178  func (s *Server) SetModeUpdateHook(hook chan<- struct{}) {
   179  	s.modeUpdateHook = hook
   180  }
   181  
   182  func (s *Server) SetMode(newMode string) {
   183  	if !common.Contains(s.modeList, newMode) {
   184  		newMode = common.Find(s.modeList, func(it string) bool {
   185  			return strings.EqualFold(it, newMode)
   186  		})
   187  	}
   188  	if !common.Contains(s.modeList, newMode) {
   189  		return
   190  	}
   191  	if newMode == s.mode {
   192  		return
   193  	}
   194  	s.mode = newMode
   195  	if s.modeUpdateHook != nil {
   196  		select {
   197  		case s.modeUpdateHook <- struct{}{}:
   198  		default:
   199  		}
   200  	}
   201  	s.router.ClearDNSCache()
   202  	cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
   203  	if cacheFile != nil {
   204  		err := cacheFile.StoreMode(newMode)
   205  		if err != nil {
   206  			s.logger.Error(E.Cause(err, "save mode"))
   207  		}
   208  	}
   209  	s.logger.Info("updated mode: ", newMode)
   210  }
   211  
   212  func (s *Server) HistoryStorage() *urltest.HistoryStorage {
   213  	return s.urlTestHistory
   214  }
   215  
   216  func (s *Server) TrafficManager() *trafficontrol.Manager {
   217  	return s.trafficManager
   218  }
   219  
   220  func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
   221  	tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
   222  	return tracker, tracker
   223  }
   224  
   225  func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) {
   226  	tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
   227  	return tracker, tracker
   228  }
   229  
   230  func castMetadata(metadata adapter.InboundContext) trafficontrol.Metadata {
   231  	var inbound string
   232  	if metadata.Inbound != "" {
   233  		inbound = metadata.InboundType + "/" + metadata.Inbound
   234  	} else {
   235  		inbound = metadata.InboundType
   236  	}
   237  	var domain string
   238  	if metadata.Domain != "" {
   239  		domain = metadata.Domain
   240  	} else {
   241  		domain = metadata.Destination.Fqdn
   242  	}
   243  	var processPath string
   244  	if metadata.ProcessInfo != nil {
   245  		if metadata.ProcessInfo.ProcessPath != "" {
   246  			processPath = metadata.ProcessInfo.ProcessPath
   247  		} else if metadata.ProcessInfo.PackageName != "" {
   248  			processPath = metadata.ProcessInfo.PackageName
   249  		}
   250  		if processPath == "" {
   251  			if metadata.ProcessInfo.UserId != -1 {
   252  				processPath = F.ToString(metadata.ProcessInfo.UserId)
   253  			}
   254  		} else if metadata.ProcessInfo.User != "" {
   255  			processPath = F.ToString(processPath, " (", metadata.ProcessInfo.User, ")")
   256  		} else if metadata.ProcessInfo.UserId != -1 {
   257  			processPath = F.ToString(processPath, " (", metadata.ProcessInfo.UserId, ")")
   258  		}
   259  	}
   260  	return trafficontrol.Metadata{
   261  		NetWork:     metadata.Network,
   262  		Type:        inbound,
   263  		SrcIP:       metadata.Source.Addr,
   264  		DstIP:       metadata.Destination.Addr,
   265  		SrcPort:     F.ToString(metadata.Source.Port),
   266  		DstPort:     F.ToString(metadata.Destination.Port),
   267  		Host:        domain,
   268  		DNSMode:     "normal",
   269  		ProcessPath: processPath,
   270  	}
   271  }
   272  
   273  func authentication(serverSecret string) func(next http.Handler) http.Handler {
   274  	return func(next http.Handler) http.Handler {
   275  		fn := func(w http.ResponseWriter, r *http.Request) {
   276  			if serverSecret == "" {
   277  				next.ServeHTTP(w, r)
   278  				return
   279  			}
   280  
   281  			// Browser websocket not support custom header
   282  			if r.Header.Get("Upgrade") == "websocket" && r.URL.Query().Get("token") != "" {
   283  				token := r.URL.Query().Get("token")
   284  				if token != serverSecret {
   285  					render.Status(r, http.StatusUnauthorized)
   286  					render.JSON(w, r, ErrUnauthorized)
   287  					return
   288  				}
   289  				next.ServeHTTP(w, r)
   290  				return
   291  			}
   292  
   293  			header := r.Header.Get("Authorization")
   294  			bearer, token, found := strings.Cut(header, " ")
   295  
   296  			hasInvalidHeader := bearer != "Bearer"
   297  			hasInvalidSecret := !found || token != serverSecret
   298  			if hasInvalidHeader || hasInvalidSecret {
   299  				render.Status(r, http.StatusUnauthorized)
   300  				render.JSON(w, r, ErrUnauthorized)
   301  				return
   302  			}
   303  			next.ServeHTTP(w, r)
   304  		}
   305  		return http.HandlerFunc(fn)
   306  	}
   307  }
   308  
   309  func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) {
   310  	return func(w http.ResponseWriter, r *http.Request) {
   311  		if redirect {
   312  			http.Redirect(w, r, "/ui/", http.StatusTemporaryRedirect)
   313  		} else {
   314  			render.JSON(w, r, render.M{"hello": "clash"})
   315  		}
   316  	}
   317  }
   318  
   319  type Traffic struct {
   320  	Up   int64 `json:"up"`
   321  	Down int64 `json:"down"`
   322  }
   323  
   324  func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
   325  	return func(w http.ResponseWriter, r *http.Request) {
   326  		var conn net.Conn
   327  		if r.Header.Get("Upgrade") == "websocket" {
   328  			var err error
   329  			conn, _, _, err = ws.UpgradeHTTP(r, w)
   330  			if err != nil {
   331  				return
   332  			}
   333  			defer conn.Close()
   334  		}
   335  
   336  		if conn == nil {
   337  			w.Header().Set("Content-Type", "application/json")
   338  			render.Status(r, http.StatusOK)
   339  		}
   340  
   341  		tick := time.NewTicker(time.Second)
   342  		defer tick.Stop()
   343  		buf := &bytes.Buffer{}
   344  		var err error
   345  		for range tick.C {
   346  			buf.Reset()
   347  			up, down := trafficManager.Now()
   348  			if err := json.NewEncoder(buf).Encode(Traffic{
   349  				Up:   up,
   350  				Down: down,
   351  			}); err != nil {
   352  				break
   353  			}
   354  
   355  			if conn == nil {
   356  				_, err = w.Write(buf.Bytes())
   357  				w.(http.Flusher).Flush()
   358  			} else {
   359  				err = wsutil.WriteServerText(conn, buf.Bytes())
   360  			}
   361  
   362  			if err != nil {
   363  				break
   364  			}
   365  		}
   366  	}
   367  }
   368  
   369  type Log struct {
   370  	Type    string `json:"type"`
   371  	Payload string `json:"payload"`
   372  }
   373  
   374  func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
   375  	return func(w http.ResponseWriter, r *http.Request) {
   376  		levelText := r.URL.Query().Get("level")
   377  		if levelText == "" {
   378  			levelText = "info"
   379  		}
   380  
   381  		level, ok := log.ParseLevel(levelText)
   382  		if ok != nil {
   383  			render.Status(r, http.StatusBadRequest)
   384  			render.JSON(w, r, ErrBadRequest)
   385  			return
   386  		}
   387  
   388  		subscription, done, err := logFactory.Subscribe()
   389  		if err != nil {
   390  			render.Status(r, http.StatusNoContent)
   391  			return
   392  		}
   393  		defer logFactory.UnSubscribe(subscription)
   394  
   395  		var conn net.Conn
   396  		if r.Header.Get("Upgrade") == "websocket" {
   397  			conn, _, _, err = ws.UpgradeHTTP(r, w)
   398  			if err != nil {
   399  				return
   400  			}
   401  			defer conn.Close()
   402  		}
   403  
   404  		if conn == nil {
   405  			w.Header().Set("Content-Type", "application/json")
   406  			render.Status(r, http.StatusOK)
   407  		}
   408  
   409  		buf := &bytes.Buffer{}
   410  		var logEntry log.Entry
   411  		for {
   412  			select {
   413  			case <-done:
   414  				return
   415  			case logEntry = <-subscription:
   416  			}
   417  			if logEntry.Level > level {
   418  				continue
   419  			}
   420  			buf.Reset()
   421  			err = json.NewEncoder(buf).Encode(Log{
   422  				Type:    log.FormatLevel(logEntry.Level),
   423  				Payload: logEntry.Message,
   424  			})
   425  			if err != nil {
   426  				break
   427  			}
   428  			if conn == nil {
   429  				_, err = w.Write(buf.Bytes())
   430  				w.(http.Flusher).Flush()
   431  			} else {
   432  				err = wsutil.WriteServerText(conn, buf.Bytes())
   433  			}
   434  
   435  			if err != nil {
   436  				break
   437  			}
   438  		}
   439  	}
   440  }
   441  
   442  func version(w http.ResponseWriter, r *http.Request) {
   443  	render.JSON(w, r, render.M{"version": "sing-box " + C.Version, "premium": true, "meta": true})
   444  }