github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/experimental/clashapi/server.go (about)

     1  package clashapi
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/inazumav/sing-box/adapter"
    14  	"github.com/inazumav/sing-box/common/json"
    15  	"github.com/inazumav/sing-box/common/urltest"
    16  	C "github.com/inazumav/sing-box/constant"
    17  	"github.com/inazumav/sing-box/experimental"
    18  	"github.com/inazumav/sing-box/experimental/clashapi/cachefile"
    19  	"github.com/inazumav/sing-box/experimental/clashapi/trafficontrol"
    20  	"github.com/inazumav/sing-box/log"
    21  	"github.com/inazumav/sing-box/option"
    22  	"github.com/sagernet/sing/common"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	F "github.com/sagernet/sing/common/format"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/sing/service"
    27  	"github.com/sagernet/sing/service/filemanager"
    28  	"github.com/sagernet/websocket"
    29  
    30  	"github.com/go-chi/chi/v5"
    31  	"github.com/go-chi/cors"
    32  	"github.com/go-chi/render"
    33  )
    34  
    35  func init() {
    36  	experimental.RegisterClashServerConstructor(NewServer)
    37  }
    38  
    39  var _ adapter.ClashServer = (*Server)(nil)
    40  
    41  type Server struct {
    42  	ctx            context.Context
    43  	router         adapter.Router
    44  	logger         log.Logger
    45  	httpServer     *http.Server
    46  	trafficManager *trafficontrol.Manager
    47  	urlTestHistory *urltest.HistoryStorage
    48  	mode           string
    49  	modeList       []string
    50  	modeUpdateHook chan<- struct{}
    51  	storeMode      bool
    52  	storeSelected  bool
    53  	storeFakeIP    bool
    54  	cacheFilePath  string
    55  	cacheID        string
    56  	cacheFile      adapter.ClashCacheFile
    57  
    58  	externalController       bool
    59  	externalUI               string
    60  	externalUIDownloadURL    string
    61  	externalUIDownloadDetour string
    62  }
    63  
    64  func NewServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
    65  	trafficManager := trafficontrol.NewManager()
    66  	chiRouter := chi.NewRouter()
    67  	server := &Server{
    68  		ctx:    ctx,
    69  		router: router,
    70  		logger: logFactory.NewLogger("clash-api"),
    71  		httpServer: &http.Server{
    72  			Addr:    options.ExternalController,
    73  			Handler: chiRouter,
    74  		},
    75  		trafficManager:           trafficManager,
    76  		modeList:                 options.ModeList,
    77  		externalController:       options.ExternalController != "",
    78  		storeMode:                options.StoreMode,
    79  		storeSelected:            options.StoreSelected,
    80  		externalController:       options.ExternalController != "",
    81  		storeFakeIP:              options.StoreFakeIP,
    82  		externalUIDownloadURL:    options.ExternalUIDownloadURL,
    83  		externalUIDownloadDetour: options.ExternalUIDownloadDetour,
    84  	}
    85  	server.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx)
    86  	if server.urlTestHistory == nil {
    87  		server.urlTestHistory = urltest.NewHistoryStorage()
    88  	}
    89  	defaultMode := "Rule"
    90  	if options.DefaultMode != "" {
    91  		defaultMode = options.DefaultMode
    92  	}
    93  	if !common.Contains(server.modeList, defaultMode) {
    94  		server.modeList = append([]string{defaultMode}, server.modeList...)
    95  	}
    96  	server.mode = defaultMode
    97  	if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.ExternalController == "" {
    98  		cachePath := os.ExpandEnv(options.CacheFile)
    99  		if cachePath == "" {
   100  			cachePath = "cache.db"
   101  		}
   102  		if foundPath, loaded := C.FindPath(cachePath); loaded {
   103  			cachePath = foundPath
   104  		} else {
   105  			cachePath = filemanager.BasePath(ctx, cachePath)
   106  		}
   107  		server.cacheFilePath = cachePath
   108  		server.cacheID = options.CacheID
   109  	}
   110  	cors := cors.New(cors.Options{
   111  		AllowedOrigins: []string{"*"},
   112  		AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"},
   113  		AllowedHeaders: []string{"Content-Type", "Authorization"},
   114  		MaxAge:         300,
   115  	})
   116  	chiRouter.Use(cors.Handler)
   117  	chiRouter.Group(func(r chi.Router) {
   118  		r.Use(authentication(options.Secret))
   119  		r.Get("/", hello(options.ExternalUI != ""))
   120  		r.Get("/logs", getLogs(logFactory))
   121  		r.Get("/traffic", traffic(trafficManager))
   122  		r.Get("/version", version)
   123  		r.Mount("/configs", configRouter(server, logFactory))
   124  		r.Mount("/proxies", proxyRouter(server, router))
   125  		r.Mount("/rules", ruleRouter(router))
   126  		r.Mount("/connections", connectionRouter(router, trafficManager))
   127  		r.Mount("/providers/proxies", proxyProviderRouter())
   128  		r.Mount("/providers/rules", ruleProviderRouter())
   129  		r.Mount("/script", scriptRouter())
   130  		r.Mount("/profile", profileRouter())
   131  		r.Mount("/cache", cacheRouter(router))
   132  		r.Mount("/dns", dnsRouter(router))
   133  
   134  		server.setupMetaAPI(r)
   135  	})
   136  	if options.ExternalUI != "" {
   137  		server.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI))
   138  		chiRouter.Group(func(r chi.Router) {
   139  			fs := http.StripPrefix("/ui", http.FileServer(http.Dir(server.externalUI)))
   140  			r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP)
   141  			r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) {
   142  				fs.ServeHTTP(w, r)
   143  			})
   144  		})
   145  	}
   146  	return server, nil
   147  }
   148  
   149  func (s *Server) PreStart() error {
   150  	if s.cacheFilePath != "" {
   151  		cacheFile, err := cachefile.Open(s.cacheFilePath, s.cacheID)
   152  		if err != nil {
   153  			return E.Cause(err, "open cache file")
   154  		}
   155  		s.cacheFile = cacheFile
   156  		if s.storeMode {
   157  			mode := s.cacheFile.LoadMode()
   158  			if common.Any(s.modeList, func(it string) bool {
   159  				return strings.EqualFold(it, mode)
   160  			}) {
   161  				s.mode = mode
   162  			}
   163  		}
   164  	}
   165  	return nil
   166  }
   167  
   168  func (s *Server) Start() error {
   169  	if s.externalController {
   170  		s.checkAndDownloadExternalUI()
   171  		listener, err := net.Listen("tcp", s.httpServer.Addr)
   172  		if err != nil {
   173  			return E.Cause(err, "external controller listen error")
   174  		}
   175  		s.logger.Info("restful api listening at ", listener.Addr())
   176  		go func() {
   177  			err = s.httpServer.Serve(listener)
   178  			if err != nil && !errors.Is(err, http.ErrServerClosed) {
   179  				s.logger.Error("external controller serve error: ", err)
   180  			}
   181  		}()
   182  	}
   183  	return nil
   184  }
   185  
   186  func (s *Server) Close() error {
   187  	return common.Close(
   188  		common.PtrOrNil(s.httpServer),
   189  		s.trafficManager,
   190  		s.cacheFile,
   191  		s.urlTestHistory,
   192  	)
   193  }
   194  
   195  func (s *Server) Mode() string {
   196  	return s.mode
   197  }
   198  
   199  func (s *Server) ModeList() []string {
   200  	return s.modeList
   201  }
   202  
   203  func (s *Server) SetModeUpdateHook(hook chan<- struct{}) {
   204  	s.modeUpdateHook = hook
   205  }
   206  
   207  func (s *Server) SetMode(newMode string) {
   208  	if !common.Contains(s.modeList, newMode) {
   209  		newMode = common.Find(s.modeList, func(it string) bool {
   210  			return strings.EqualFold(it, newMode)
   211  		})
   212  	}
   213  	if !common.Contains(s.modeList, newMode) {
   214  		return
   215  	}
   216  	if newMode == s.mode {
   217  		return
   218  	}
   219  	s.mode = newMode
   220  	if s.modeUpdateHook != nil {
   221  		select {
   222  		case s.modeUpdateHook <- struct{}{}:
   223  		default:
   224  		}
   225  	}
   226  	s.router.ClearDNSCache()
   227  	if s.storeMode {
   228  		err := s.cacheFile.StoreMode(newMode)
   229  		if err != nil {
   230  			s.logger.Error(E.Cause(err, "save mode"))
   231  		}
   232  	}
   233  	s.logger.Info("updated mode: ", newMode)
   234  }
   235  
   236  func (s *Server) StoreSelected() bool {
   237  	return s.storeSelected
   238  }
   239  
   240  func (s *Server) StoreFakeIP() bool {
   241  	return s.storeFakeIP
   242  }
   243  
   244  func (s *Server) CacheFile() adapter.ClashCacheFile {
   245  	return s.cacheFile
   246  }
   247  
   248  func (s *Server) HistoryStorage() *urltest.HistoryStorage {
   249  	return s.urlTestHistory
   250  }
   251  
   252  func (s *Server) TrafficManager() *trafficontrol.Manager {
   253  	return s.trafficManager
   254  }
   255  
   256  func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
   257  	tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
   258  	return tracker, tracker
   259  }
   260  
   261  func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) {
   262  	tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
   263  	return tracker, tracker
   264  }
   265  
   266  func castMetadata(metadata adapter.InboundContext) trafficontrol.Metadata {
   267  	var inbound string
   268  	if metadata.Inbound != "" {
   269  		inbound = metadata.InboundType + "/" + metadata.Inbound
   270  	} else {
   271  		inbound = metadata.InboundType
   272  	}
   273  	var domain string
   274  	if metadata.Domain != "" {
   275  		domain = metadata.Domain
   276  	} else {
   277  		domain = metadata.Destination.Fqdn
   278  	}
   279  	var processPath string
   280  	if metadata.ProcessInfo != nil {
   281  		if metadata.ProcessInfo.ProcessPath != "" {
   282  			processPath = metadata.ProcessInfo.ProcessPath
   283  		} else if metadata.ProcessInfo.PackageName != "" {
   284  			processPath = metadata.ProcessInfo.PackageName
   285  		}
   286  		if processPath == "" {
   287  			if metadata.ProcessInfo.UserId != -1 {
   288  				processPath = F.ToString(metadata.ProcessInfo.UserId)
   289  			}
   290  		} else if metadata.ProcessInfo.User != "" {
   291  			processPath = F.ToString(processPath, " (", metadata.ProcessInfo.User, ")")
   292  		} else if metadata.ProcessInfo.UserId != -1 {
   293  			processPath = F.ToString(processPath, " (", metadata.ProcessInfo.UserId, ")")
   294  		}
   295  	}
   296  	return trafficontrol.Metadata{
   297  		NetWork:     metadata.Network,
   298  		Type:        inbound,
   299  		SrcIP:       metadata.Source.Addr,
   300  		DstIP:       metadata.Destination.Addr,
   301  		SrcPort:     F.ToString(metadata.Source.Port),
   302  		DstPort:     F.ToString(metadata.Destination.Port),
   303  		Host:        domain,
   304  		DNSMode:     "normal",
   305  		ProcessPath: processPath,
   306  	}
   307  }
   308  
   309  func authentication(serverSecret string) func(next http.Handler) http.Handler {
   310  	return func(next http.Handler) http.Handler {
   311  		fn := func(w http.ResponseWriter, r *http.Request) {
   312  			if serverSecret == "" {
   313  				next.ServeHTTP(w, r)
   314  				return
   315  			}
   316  
   317  			// Browser websocket not support custom header
   318  			if websocket.IsWebSocketUpgrade(r) && r.URL.Query().Get("token") != "" {
   319  				token := r.URL.Query().Get("token")
   320  				if token != serverSecret {
   321  					render.Status(r, http.StatusUnauthorized)
   322  					render.JSON(w, r, ErrUnauthorized)
   323  					return
   324  				}
   325  				next.ServeHTTP(w, r)
   326  				return
   327  			}
   328  
   329  			header := r.Header.Get("Authorization")
   330  			bearer, token, found := strings.Cut(header, " ")
   331  
   332  			hasInvalidHeader := bearer != "Bearer"
   333  			hasInvalidSecret := !found || token != serverSecret
   334  			if hasInvalidHeader || hasInvalidSecret {
   335  				render.Status(r, http.StatusUnauthorized)
   336  				render.JSON(w, r, ErrUnauthorized)
   337  				return
   338  			}
   339  			next.ServeHTTP(w, r)
   340  		}
   341  		return http.HandlerFunc(fn)
   342  	}
   343  }
   344  
   345  func hello(redirect bool) func(w http.ResponseWriter, r *http.Request) {
   346  	return func(w http.ResponseWriter, r *http.Request) {
   347  		if redirect {
   348  			http.Redirect(w, r, "/ui/", http.StatusTemporaryRedirect)
   349  		} else {
   350  			render.JSON(w, r, render.M{"hello": "clash"})
   351  		}
   352  	}
   353  }
   354  
   355  var upgrader = websocket.Upgrader{
   356  	CheckOrigin: func(r *http.Request) bool {
   357  		return true
   358  	},
   359  }
   360  
   361  type Traffic struct {
   362  	Up   int64 `json:"up"`
   363  	Down int64 `json:"down"`
   364  }
   365  
   366  func traffic(trafficManager *trafficontrol.Manager) func(w http.ResponseWriter, r *http.Request) {
   367  	return func(w http.ResponseWriter, r *http.Request) {
   368  		var wsConn *websocket.Conn
   369  		if websocket.IsWebSocketUpgrade(r) {
   370  			var err error
   371  			wsConn, err = upgrader.Upgrade(w, r, nil)
   372  			if err != nil {
   373  				return
   374  			}
   375  		}
   376  
   377  		if wsConn == nil {
   378  			w.Header().Set("Content-Type", "application/json")
   379  			render.Status(r, http.StatusOK)
   380  		}
   381  
   382  		tick := time.NewTicker(time.Second)
   383  		defer tick.Stop()
   384  		buf := &bytes.Buffer{}
   385  		var err error
   386  		for range tick.C {
   387  			buf.Reset()
   388  			up, down := trafficManager.Now()
   389  			if err := json.NewEncoder(buf).Encode(Traffic{
   390  				Up:   up,
   391  				Down: down,
   392  			}); err != nil {
   393  				break
   394  			}
   395  
   396  			if wsConn == nil {
   397  				_, err = w.Write(buf.Bytes())
   398  				w.(http.Flusher).Flush()
   399  			} else {
   400  				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
   401  			}
   402  
   403  			if err != nil {
   404  				break
   405  			}
   406  		}
   407  	}
   408  }
   409  
   410  type Log struct {
   411  	Type    string `json:"type"`
   412  	Payload string `json:"payload"`
   413  }
   414  
   415  func getLogs(logFactory log.ObservableFactory) func(w http.ResponseWriter, r *http.Request) {
   416  	return func(w http.ResponseWriter, r *http.Request) {
   417  		levelText := r.URL.Query().Get("level")
   418  		if levelText == "" {
   419  			levelText = "info"
   420  		}
   421  
   422  		level, ok := log.ParseLevel(levelText)
   423  		if ok != nil {
   424  			render.Status(r, http.StatusBadRequest)
   425  			render.JSON(w, r, ErrBadRequest)
   426  			return
   427  		}
   428  
   429  		subscription, done, err := logFactory.Subscribe()
   430  		if err != nil {
   431  			render.Status(r, http.StatusNoContent)
   432  			return
   433  		}
   434  		defer logFactory.UnSubscribe(subscription)
   435  
   436  		var wsConn *websocket.Conn
   437  		if websocket.IsWebSocketUpgrade(r) {
   438  			var err error
   439  			wsConn, err = upgrader.Upgrade(w, r, nil)
   440  			if err != nil {
   441  				return
   442  			}
   443  		}
   444  
   445  		if wsConn == nil {
   446  			w.Header().Set("Content-Type", "application/json")
   447  			render.Status(r, http.StatusOK)
   448  		}
   449  
   450  		buf := &bytes.Buffer{}
   451  		var logEntry log.Entry
   452  		for {
   453  			select {
   454  			case <-done:
   455  				return
   456  			case logEntry = <-subscription:
   457  			}
   458  			if logEntry.Level > level {
   459  				continue
   460  			}
   461  			buf.Reset()
   462  			err = json.NewEncoder(buf).Encode(Log{
   463  				Type:    log.FormatLevel(logEntry.Level),
   464  				Payload: logEntry.Message,
   465  			})
   466  			if err != nil {
   467  				break
   468  			}
   469  			if wsConn == nil {
   470  				_, err = w.Write(buf.Bytes())
   471  				w.(http.Flusher).Flush()
   472  			} else {
   473  				err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes())
   474  			}
   475  
   476  			if err != nil {
   477  				break
   478  			}
   479  		}
   480  	}
   481  }
   482  
   483  func version(w http.ResponseWriter, r *http.Request) {
   484  	render.JSON(w, r, render.M{"version": "sing-box " + C.Version, "premium": true, "meta": true})
   485  }