github.com/anycable/anycable-go@v1.5.1/cli/cli.go (about)

     1  package cli
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log/slog"
     7  	"net/http"
     8  	"os"
     9  	"runtime"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/anycable/anycable-go/broadcast"
    15  	"github.com/anycable/anycable-go/broker"
    16  	"github.com/anycable/anycable-go/common"
    17  	"github.com/anycable/anycable-go/config"
    18  	"github.com/anycable/anycable-go/enats"
    19  	"github.com/anycable/anycable-go/identity"
    20  	"github.com/anycable/anycable-go/logger"
    21  	metricspkg "github.com/anycable/anycable-go/metrics"
    22  	"github.com/anycable/anycable-go/mrb"
    23  	"github.com/anycable/anycable-go/node"
    24  	"github.com/anycable/anycable-go/pubsub"
    25  	"github.com/anycable/anycable-go/router"
    26  	"github.com/anycable/anycable-go/server"
    27  	"github.com/anycable/anycable-go/sse"
    28  	"github.com/anycable/anycable-go/streams"
    29  	"github.com/anycable/anycable-go/telemetry"
    30  	"github.com/anycable/anycable-go/utils"
    31  	"github.com/anycable/anycable-go/version"
    32  	"github.com/anycable/anycable-go/ws"
    33  	"github.com/gorilla/websocket"
    34  	"github.com/joomcode/errorx"
    35  
    36  	"go.uber.org/automaxprocs/maxprocs"
    37  )
    38  
    39  type controllerFactory = func(*metricspkg.Metrics, *config.Config, *slog.Logger) (node.Controller, error)
    40  type disconnectorFactory = func(*node.Node, *config.Config, *slog.Logger) (node.Disconnector, error)
    41  type broadcastersFactory = func(broadcast.Handler, *config.Config, *slog.Logger) ([]broadcast.Broadcaster, error)
    42  type brokerFactory = func(broker.Broadcaster, *config.Config, *slog.Logger) (broker.Broker, error)
    43  type subscriberFactory = func(pubsub.Handler, *config.Config, *slog.Logger) (pubsub.Subscriber, error)
    44  type websocketHandler = func(*node.Node, *config.Config, *slog.Logger) (http.Handler, error)
    45  
    46  type Shutdownable interface {
    47  	Shutdown(ctx context.Context) error
    48  }
    49  
    50  type Runner struct {
    51  	options []Option
    52  
    53  	name   string
    54  	config *config.Config
    55  	log    *slog.Logger
    56  
    57  	controllerFactory       controllerFactory
    58  	disconnectorFactory     disconnectorFactory
    59  	subscriberFactory       subscriberFactory
    60  	brokerFactory           brokerFactory
    61  	websocketHandlerFactory websocketHandler
    62  
    63  	broadcastersFactory broadcastersFactory
    64  	websocketEndpoints  map[string]websocketHandler
    65  
    66  	router           *router.RouterController
    67  	metrics          *metricspkg.Metrics
    68  	telemetryEnabled bool
    69  
    70  	errChan       chan error
    71  	shutdownables []Shutdownable
    72  }
    73  
    74  // NewRunner creates returns new Runner structure
    75  func NewRunner(c *config.Config, options []Option) (*Runner, error) {
    76  	r := &Runner{
    77  		options:            options,
    78  		config:             c,
    79  		shutdownables:      []Shutdownable{},
    80  		websocketEndpoints: make(map[string]websocketHandler),
    81  		errChan:            make(chan error),
    82  	}
    83  
    84  	err := r.checkAndSetDefaults()
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	return r, nil
    90  }
    91  
    92  // checkAndSetDefaults applies passed options and checks that all required fields are set
    93  func (r *Runner) checkAndSetDefaults() error {
    94  	for _, o := range r.options {
    95  		err := o(r)
    96  		if err != nil {
    97  			return err
    98  		}
    99  	}
   100  
   101  	if r.log == nil {
   102  		_, err := logger.InitLogger(r.config.LogFormat, r.config.LogLevel)
   103  		if err != nil {
   104  			return errorx.Decorate(err, "failed to initialize default logger")
   105  		}
   106  
   107  		r.log = slog.With("nodeid", r.config.ID)
   108  	}
   109  
   110  	err := r.config.LoadPresets(r.log)
   111  
   112  	if err != nil {
   113  		return errorx.Decorate(err, "failed to load configuration presets")
   114  	}
   115  
   116  	server.SSL = &r.config.SSL
   117  	server.Host = r.config.Host
   118  	server.MaxConn = r.config.MaxConn
   119  	server.Logger = r.log
   120  
   121  	if r.name == "" {
   122  		return errorx.AssertionFailed.New("Name is blank, specify WithName()")
   123  	}
   124  
   125  	if r.controllerFactory == nil {
   126  		return errorx.AssertionFailed.New("Controller is blank, specify WithController()")
   127  	}
   128  
   129  	if r.brokerFactory == nil {
   130  		return errorx.AssertionFailed.New("Broker is blank, specify WithBroker()")
   131  	}
   132  
   133  	if r.subscriberFactory == nil {
   134  		return errorx.AssertionFailed.New("Subscriber is blank, specify WithSubscriber()")
   135  	}
   136  
   137  	if r.disconnectorFactory == nil {
   138  		r.disconnectorFactory = r.defaultDisconnector
   139  	}
   140  
   141  	if r.websocketHandlerFactory == nil {
   142  		r.websocketHandlerFactory = r.defaultWebSocketHandler
   143  	}
   144  
   145  	metrics, err := r.initMetrics(&r.config.Metrics)
   146  
   147  	if err != nil {
   148  		return errorx.Decorate(err, "failed to initialize metrics writer")
   149  	}
   150  
   151  	r.metrics = metrics
   152  
   153  	return nil
   154  }
   155  
   156  // Run starts the instance
   157  func (r *Runner) Run() error {
   158  	numProcs := r.setMaxProcs()
   159  	r.announceDebugMode()
   160  
   161  	mrubySupport := r.initMRuby()
   162  
   163  	r.log.Info(fmt.Sprintf("Starting %s %s%s (pid: %d, open file limit: %s, gomaxprocs: %d)", r.name, version.Version(), mrubySupport, os.Getpid(), utils.OpenFileLimit(), numProcs))
   164  
   165  	if r.config.IsPublic() {
   166  		r.log.Warn("Server is running in the public mode")
   167  	}
   168  
   169  	appNode, err := r.runNode()
   170  
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	wsServer, err := server.ForPort(strconv.Itoa(r.config.Port))
   176  	if err != nil {
   177  		return errorx.Decorate(err, "failed to initialize WebSocket server at %s:%d", r.config.Host, r.config.Port)
   178  	}
   179  
   180  	wsHandler, err := r.websocketHandlerFactory(appNode, r.config, r.log)
   181  	if err != nil {
   182  		return errorx.Decorate(err, "failed to initialize WebSocket handler")
   183  	}
   184  
   185  	for _, path := range r.config.Path {
   186  		wsServer.SetupHandler(path, wsHandler)
   187  		r.log.Info(fmt.Sprintf("Handle WebSocket connections at %s%s", wsServer.Address(), path))
   188  	}
   189  
   190  	for path, handlerFactory := range r.websocketEndpoints {
   191  		handler, err := handlerFactory(appNode, r.config, r.log)
   192  		if err != nil {
   193  			return errorx.Decorate(err, "failed to initialize WebSocket handler for %s", path)
   194  		}
   195  		wsServer.SetupHandler(path, handler)
   196  	}
   197  
   198  	wsServer.SetupHandler(r.config.HealthPath, http.HandlerFunc(server.HealthHandler))
   199  	r.log.Info(fmt.Sprintf("Handle health requests at %s%s", wsServer.Address(), r.config.HealthPath))
   200  
   201  	if r.config.SSE.Enabled {
   202  		r.log.Info(
   203  			fmt.Sprintf("Handle SSE requests at %s%s",
   204  				wsServer.Address(), r.config.SSE.Path),
   205  		)
   206  
   207  		sseHandler, err := r.defaultSSEHandler(appNode, wsServer.ShutdownCtx(), r.config)
   208  
   209  		if err != nil {
   210  			return errorx.Decorate(err, "failed to initialize SSE handler")
   211  		}
   212  
   213  		wsServer.SetupHandler(r.config.SSE.Path, sseHandler)
   214  	}
   215  
   216  	go r.startWSServer(wsServer)
   217  	go r.metrics.Run() // nolint:errcheck
   218  
   219  	// We MUST first stop the server (=stop accepting new connections), then gracefully disconnect active clients
   220  	r.shutdownables = append([]Shutdownable{wsServer}, r.shutdownables...)
   221  	r.setupSignalHandlers()
   222  
   223  	// Wait for an error (or none)
   224  	return <-r.errChan
   225  }
   226  
   227  func (r *Runner) runNode() (*node.Node, error) {
   228  	metrics := r.metrics
   229  
   230  	r.shutdownables = append(r.shutdownables, metrics)
   231  
   232  	controller, err := r.newController(metrics)
   233  	if err != nil {
   234  		return nil, err
   235  	}
   236  
   237  	appNode := node.NewNode(
   238  		&r.config.App,
   239  		node.WithController(controller),
   240  		node.WithInstrumenter(metrics),
   241  		node.WithLogger(r.log),
   242  		node.WithID(r.config.ID),
   243  	)
   244  
   245  	if r.telemetryEnabled {
   246  		telemetryConfig := telemetry.NewConfig()
   247  		tracker := telemetry.NewTracker(metrics, r.config, telemetryConfig)
   248  
   249  		r.log.With("context", "telemetry").Info(tracker.Announce())
   250  		go tracker.Collect()
   251  
   252  		r.shutdownables = append(r.shutdownables, tracker)
   253  	}
   254  
   255  	subscriber, err := r.subscriberFactory(appNode, r.config, r.log)
   256  
   257  	if err != nil {
   258  		return nil, errorx.Decorate(err, "couldn't configure pub/sub")
   259  	}
   260  
   261  	appBroker, err := r.brokerFactory(subscriber, r.config, r.log)
   262  	if err != nil {
   263  		return nil, errorx.Decorate(err, "failed to initialize broker")
   264  	}
   265  
   266  	if appBroker != nil {
   267  		r.log.Info(appBroker.Announce())
   268  		appNode.SetBroker(appBroker)
   269  	}
   270  
   271  	disconnector, err := r.disconnectorFactory(appNode, r.config, r.log)
   272  	if err != nil {
   273  		return nil, errorx.Decorate(err, "failed to initialize disconnector")
   274  	}
   275  
   276  	go disconnector.Run() // nolint:errcheck
   277  	appNode.SetDisconnector(disconnector)
   278  
   279  	if r.config.EmbedNats {
   280  		service, enatsErr := r.embedNATS(&r.config.EmbeddedNats)
   281  
   282  		if enatsErr != nil {
   283  			return nil, errorx.Decorate(enatsErr, "failed to start embedded NATS server")
   284  		}
   285  
   286  		desc := service.Description()
   287  
   288  		if desc != "" {
   289  			desc = fmt.Sprintf(" (%s)", desc)
   290  		}
   291  
   292  		r.log.Info(fmt.Sprintf("Embedded NATS server started: %s%s", r.config.EmbeddedNats.ServiceAddr, desc))
   293  
   294  		r.shutdownables = append(r.shutdownables, service)
   295  	}
   296  
   297  	err = appNode.Start()
   298  
   299  	if err != nil {
   300  		return nil, errorx.Decorate(err, "failed to initialize application")
   301  	}
   302  
   303  	err = subscriber.Start(r.errChan)
   304  	if err != nil {
   305  		return nil, errorx.Decorate(err, "failed to start subscriber")
   306  	}
   307  
   308  	if appBroker != nil {
   309  		err = appBroker.Start(r.errChan)
   310  		if err != nil {
   311  			return nil, errorx.Decorate(err, "failed to start broker")
   312  		}
   313  	}
   314  
   315  	r.shutdownables = append(r.shutdownables, subscriber)
   316  
   317  	if r.broadcastersFactory != nil {
   318  		broadcasters, berr := r.broadcastersFactory(appNode, r.config, r.log)
   319  
   320  		if berr != nil {
   321  			return nil, errorx.Decorate(err, "couldn't configure broadcasters")
   322  		}
   323  
   324  		for _, broadcaster := range broadcasters {
   325  			err = broadcaster.Start(r.errChan)
   326  			if err != nil {
   327  				return nil, errorx.Decorate(err, "failed to start broadcaster")
   328  			}
   329  
   330  			r.shutdownables = append(r.shutdownables, broadcaster)
   331  		}
   332  	}
   333  
   334  	err = controller.Start()
   335  	if err != nil {
   336  		return nil, errorx.Decorate(err, "failed to initialize RPC controller")
   337  	}
   338  
   339  	r.shutdownables = append([]Shutdownable{appNode, appBroker}, r.shutdownables...)
   340  
   341  	r.announceGoPools()
   342  	return appNode, nil
   343  }
   344  
   345  func (r *Runner) setMaxProcs() int {
   346  	// See https://github.com/uber-go/automaxprocs/issues/18
   347  	nopLog := func(string, ...interface{}) {}
   348  	maxprocs.Set(maxprocs.Logger(nopLog)) // nolint:errcheck
   349  
   350  	return runtime.GOMAXPROCS(0)
   351  }
   352  
   353  func (r *Runner) announceDebugMode() {
   354  	if r.config.Debug {
   355  		r.log.Debug("🔧 🔧 🔧 Debug mode is on 🔧 🔧 🔧")
   356  	}
   357  }
   358  
   359  func (r *Runner) initMetrics(c *metricspkg.Config) (*metricspkg.Metrics, error) {
   360  	m, err := metricspkg.NewFromConfig(c, r.log)
   361  
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	if c.Statsd.Enabled() {
   367  		sw := metricspkg.NewStatsdWriter(c.Statsd, c.Tags, r.log)
   368  		m.RegisterWriter(sw)
   369  	}
   370  
   371  	return m, nil
   372  }
   373  
   374  func (r *Runner) newController(metrics *metricspkg.Metrics) (node.Controller, error) {
   375  	controller, err := r.controllerFactory(metrics, r.config, r.log)
   376  	if err != nil {
   377  		return nil, errorx.Decorate(err, "!!! Failed to initialize controller !!!")
   378  	}
   379  
   380  	ids := []identity.Identifier{}
   381  
   382  	if r.config.JWT.Enabled() {
   383  		ids = append(ids, identity.NewJWTIdentifier(&r.config.JWT, r.log))
   384  		r.log.Info(fmt.Sprintf("JWT authentication is enabled (param: %s, enforced: %v)", r.config.JWT.Param, r.config.JWT.Force))
   385  	}
   386  
   387  	if r.config.SkipAuth {
   388  		ids = append(ids, identity.NewPublicIdentifier())
   389  		r.log.Info("connection authentication is disabled")
   390  	}
   391  
   392  	if len(ids) > 1 {
   393  		identifier := identity.NewIdentifierPipeline(ids...)
   394  		controller = identity.NewIdentifiableController(controller, identifier)
   395  	} else if len(ids) == 1 {
   396  		controller = identity.NewIdentifiableController(controller, ids[0])
   397  	}
   398  
   399  	if !r.Router().Empty() {
   400  		r.Router().SetDefault(controller)
   401  		controller = r.Router()
   402  		r.log.Info(fmt.Sprintf("Using channels router: %s", strings.Join(r.Router().Routes(), ", ")))
   403  	}
   404  
   405  	return controller, nil
   406  }
   407  
   408  func (r *Runner) startWSServer(wsServer *server.HTTPServer) {
   409  	go func() {
   410  		err := wsServer.StartAndAnnounce("WebSocket server")
   411  		if err != nil {
   412  			if !wsServer.Stopped() {
   413  				r.errChan <- fmt.Errorf("WebSocket server at %s stopped: %v", wsServer.Address(), err)
   414  			}
   415  		}
   416  	}()
   417  }
   418  
   419  func (r *Runner) startMetrics(metrics *metricspkg.Metrics) {
   420  	err := metrics.Run()
   421  	if err != nil {
   422  		r.errChan <- fmt.Errorf("!!! Metrics module failed to start !!!\n%v", err)
   423  	}
   424  }
   425  
   426  func (r *Runner) defaultDisconnector(n *node.Node, c *config.Config, l *slog.Logger) (node.Disconnector, error) {
   427  	if c.DisconnectorDisabled {
   428  		return node.NewNoopDisconnector(), nil
   429  	}
   430  	return node.NewDisconnectQueue(n, &c.DisconnectQueue, l), nil
   431  }
   432  
   433  func (r *Runner) defaultWebSocketHandler(n *node.Node, c *config.Config, l *slog.Logger) (http.Handler, error) {
   434  	extractor := server.DefaultHeadersExtractor{Headers: c.Headers, Cookies: c.Cookies}
   435  	return ws.WebsocketHandler(common.ActionCableProtocols(), &extractor, &c.WS, r.log, func(wsc *websocket.Conn, info *server.RequestInfo, callback func()) error {
   436  		wrappedConn := ws.NewConnection(wsc)
   437  
   438  		opts := []node.SessionOption{}
   439  		opts = append(opts, r.sessionOptionsFromProtocol(wsc.Subprotocol())...)
   440  		opts = append(opts, r.sessionOptionsFromParams(info)...)
   441  
   442  		session := node.NewSession(n, wrappedConn, info.URL, info.Headers, info.UID, opts...)
   443  
   444  		if session.AuthenticateOnConnect() {
   445  			_, err := n.Authenticate(session)
   446  
   447  			if err != nil {
   448  				return err
   449  			}
   450  		}
   451  
   452  		return session.Serve(callback)
   453  	}), nil
   454  }
   455  
   456  func (r *Runner) defaultSSEHandler(n *node.Node, ctx context.Context, c *config.Config) (http.Handler, error) {
   457  	extractor := server.DefaultHeadersExtractor{Headers: c.Headers, Cookies: c.Cookies}
   458  	handler := sse.SSEHandler(n, ctx, &extractor, &c.SSE, r.log)
   459  
   460  	return handler, nil
   461  }
   462  
   463  func (r *Runner) initMRuby() string {
   464  	if mrb.Supported() {
   465  		var mrbv string
   466  		mrbv, err := mrb.Version()
   467  		if err != nil {
   468  			r.log.Error(fmt.Sprintf("mruby failed to initialize: %v", err))
   469  		} else {
   470  			return " (with " + mrbv + ")"
   471  		}
   472  	}
   473  
   474  	return ""
   475  }
   476  
   477  func (r *Runner) Router() *router.RouterController {
   478  	if r.router == nil {
   479  		r.SetRouter(r.defaultRouter())
   480  	}
   481  
   482  	return r.router
   483  }
   484  
   485  func (r *Runner) SetRouter(router *router.RouterController) {
   486  	r.router = router
   487  }
   488  
   489  func (r *Runner) Instrumenter() metricspkg.Instrumenter {
   490  	return r.metrics
   491  }
   492  
   493  func (r *Runner) defaultRouter() *router.RouterController {
   494  	router := router.NewRouterController(nil)
   495  
   496  	if r.config.Streams.PubSubChannel != "" {
   497  		streamController := streams.NewStreamsController(&r.config.Streams, r.log)
   498  		router.Route(r.config.Streams.PubSubChannel, streamController) // nolint:errcheck
   499  	}
   500  
   501  	if r.config.Streams.Turbo && r.config.Streams.GetTurboSecret() != "" {
   502  		turboController := streams.NewTurboController(r.config.Streams.GetTurboSecret(), r.log)
   503  		router.Route("Turbo::StreamsChannel", turboController) // nolint:errcheck
   504  	}
   505  
   506  	if r.config.Streams.CableReady && r.config.Streams.GetCableReadySecret() != "" {
   507  		crController := streams.NewCableReadyController(r.config.Streams.GetCableReadySecret(), r.log)
   508  		router.Route("CableReady::Stream", crController) // nolint:errcheck
   509  	}
   510  
   511  	return router
   512  }
   513  
   514  func (r *Runner) announceGoPools() {
   515  	configs := make([]string, 0)
   516  	pools := utils.AllPools()
   517  
   518  	for _, pool := range pools {
   519  		configs = append(configs, fmt.Sprintf("%s: %d", pool.Name(), pool.Size()))
   520  	}
   521  
   522  	r.log.Debug(fmt.Sprintf("Go pools initialized (%s)", strings.Join(configs, ", ")))
   523  }
   524  
   525  func (r *Runner) setupSignalHandlers() {
   526  	s := utils.NewGracefulSignals(time.Duration(r.config.App.ShutdownTimeout) * time.Second)
   527  
   528  	s.HandleForceTerminate(func() {
   529  		r.log.Warn("Immediate termination requested. Stopped")
   530  		r.errChan <- nil
   531  	})
   532  
   533  	s.Handle(func(ctx context.Context) error {
   534  		r.log.Info(fmt.Sprintf("Shutting down... (hit Ctrl-C to stop immediately or wait for up to %ds for graceful shutdown)", r.config.App.ShutdownTimeout))
   535  		return nil
   536  	})
   537  
   538  	for _, shutdownable := range r.shutdownables {
   539  		s.Handle(shutdownable.Shutdown)
   540  	}
   541  
   542  	s.Handle(func(ctx context.Context) error {
   543  		r.errChan <- nil
   544  		return nil
   545  	})
   546  
   547  	s.Listen()
   548  }
   549  
   550  func (r *Runner) embedNATS(c *enats.Config) (*enats.Service, error) {
   551  	service := enats.NewService(c, r.log)
   552  
   553  	err := service.Start()
   554  
   555  	if err != nil {
   556  		return nil, err
   557  	}
   558  
   559  	return service, nil
   560  }