go.uber.org/yarpc@v1.72.1/dispatcher.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package yarpc
    22  
    23  import (
    24  	"context"
    25  	"fmt"
    26  
    27  	"go.uber.org/multierr"
    28  	"go.uber.org/net/metrics"
    29  	"go.uber.org/yarpc/api/middleware"
    30  	"go.uber.org/yarpc/api/transport"
    31  	"go.uber.org/yarpc/internal"
    32  	"go.uber.org/yarpc/internal/firstoutboundmiddleware"
    33  	"go.uber.org/yarpc/internal/inboundmiddleware"
    34  	"go.uber.org/yarpc/internal/observability"
    35  	"go.uber.org/yarpc/internal/outboundmiddleware"
    36  	"go.uber.org/yarpc/internal/request"
    37  	"go.uber.org/yarpc/pkg/lifecycle"
    38  	"go.uber.org/zap"
    39  )
    40  
    41  // Inbounds contains a list of inbound transports. Each inbound transport
    42  // specifies a source through which incoming requests are received.
    43  type Inbounds []transport.Inbound
    44  
    45  // Outbounds provides access to outbounds for a remote service. Outbounds
    46  // define how requests are sent from this service to the remote service.
    47  type Outbounds map[string]transport.Outbounds
    48  
    49  // OutboundMiddleware contains the different types of outbound middlewares.
    50  type OutboundMiddleware struct {
    51  	Unary  middleware.UnaryOutbound
    52  	Oneway middleware.OnewayOutbound
    53  	Stream middleware.StreamOutbound
    54  }
    55  
    56  // InboundMiddleware contains the different types of inbound middlewares.
    57  type InboundMiddleware struct {
    58  	Unary  middleware.UnaryInbound
    59  	Oneway middleware.OnewayInbound
    60  	Stream middleware.StreamInbound
    61  }
    62  
    63  // RouterMiddleware wraps the Router middleware
    64  type RouterMiddleware middleware.Router
    65  
    66  // NewDispatcher builds a new Dispatcher using the specified Config. At
    67  // minimum, a service name must be specified.
    68  //
    69  // Invalid configurations or errors in constructing the Dispatcher will cause
    70  // panics.
    71  func NewDispatcher(cfg Config) *Dispatcher {
    72  	if cfg.Name == "" {
    73  		panic("yarpc.NewDispatcher expects a service name")
    74  	}
    75  	if err := internal.ValidateServiceName(cfg.Name); err != nil {
    76  		panic("yarpc.NewDispatcher expects a valid service name: " + err.Error())
    77  	}
    78  
    79  	logger := cfg.Logging.logger(cfg.Name)
    80  	extractor := cfg.Logging.extractor()
    81  
    82  	meter, stopMeter := cfg.Metrics.scope(cfg.Name, logger)
    83  	cfg = addObservingMiddleware(cfg, meter, logger, extractor)
    84  	cfg = addFirstOutboundMiddleware(cfg)
    85  
    86  	return &Dispatcher{
    87  		name:              cfg.Name,
    88  		table:             middleware.ApplyRouteTable(NewMapRouter(cfg.Name), cfg.RouterMiddleware),
    89  		inbounds:          cfg.Inbounds,
    90  		outbounds:         convertOutbounds(cfg.Outbounds, cfg.OutboundMiddleware),
    91  		transports:        collectTransports(cfg.Inbounds, cfg.Outbounds),
    92  		inboundMiddleware: cfg.InboundMiddleware,
    93  		log:               logger,
    94  		meter:             meter,
    95  		stopMeter:         stopMeter,
    96  		once:              lifecycle.NewOnce(),
    97  	}
    98  }
    99  
   100  func addObservingMiddleware(cfg Config, meter *metrics.Scope, logger *zap.Logger, extractor observability.ContextExtractor) Config {
   101  	if cfg.DisableAutoObservabilityMiddleware {
   102  		return cfg
   103  	}
   104  
   105  	observer := observability.NewMiddleware(observability.Config{
   106  		Logger:              logger,
   107  		Scope:               meter,
   108  		ContextExtractor:    extractor,
   109  		MetricTagsBlocklist: cfg.Metrics.TagsBlocklist,
   110  		Levels: observability.LevelsConfig{
   111  			Default: observability.DirectionalLevelsConfig{
   112  				Success:          cfg.Logging.Levels.Success,
   113  				Failure:          cfg.Logging.Levels.Failure,
   114  				ApplicationError: cfg.Logging.Levels.ApplicationError,
   115  				ServerError:      cfg.Logging.Levels.ServerError,
   116  				ClientError:      cfg.Logging.Levels.ClientError,
   117  			},
   118  			Inbound: observability.DirectionalLevelsConfig{
   119  				Success:          cfg.Logging.Levels.Inbound.Success,
   120  				Failure:          cfg.Logging.Levels.Inbound.Failure,
   121  				ApplicationError: cfg.Logging.Levels.Inbound.ApplicationError,
   122  				ServerError:      cfg.Logging.Levels.Inbound.ServerError,
   123  				ClientError:      cfg.Logging.Levels.Inbound.ClientError,
   124  			},
   125  			Outbound: observability.DirectionalLevelsConfig{
   126  				Success:          cfg.Logging.Levels.Outbound.Success,
   127  				Failure:          cfg.Logging.Levels.Outbound.Failure,
   128  				ApplicationError: cfg.Logging.Levels.Outbound.ApplicationError,
   129  				ServerError:      cfg.Logging.Levels.Outbound.ServerError,
   130  				ClientError:      cfg.Logging.Levels.Outbound.ClientError,
   131  			},
   132  		},
   133  	})
   134  
   135  	cfg.InboundMiddleware.Unary = inboundmiddleware.UnaryChain(observer, cfg.InboundMiddleware.Unary)
   136  	cfg.InboundMiddleware.Oneway = inboundmiddleware.OnewayChain(observer, cfg.InboundMiddleware.Oneway)
   137  	cfg.InboundMiddleware.Stream = inboundmiddleware.StreamChain(observer, cfg.InboundMiddleware.Stream)
   138  
   139  	cfg.OutboundMiddleware.Unary = outboundmiddleware.UnaryChain(cfg.OutboundMiddleware.Unary, observer)
   140  	cfg.OutboundMiddleware.Oneway = outboundmiddleware.OnewayChain(cfg.OutboundMiddleware.Oneway, observer)
   141  	cfg.OutboundMiddleware.Stream = outboundmiddleware.StreamChain(cfg.OutboundMiddleware.Stream, observer)
   142  
   143  	return cfg
   144  }
   145  
   146  // Add the first outbound middleware, which ensures that `transport.Request`
   147  // will have appropriate fields.
   148  func addFirstOutboundMiddleware(cfg Config) Config {
   149  	first := firstoutboundmiddleware.New()
   150  	cfg.OutboundMiddleware.Unary = outboundmiddleware.UnaryChain(first, cfg.OutboundMiddleware.Unary)
   151  	cfg.OutboundMiddleware.Oneway = outboundmiddleware.OnewayChain(first, cfg.OutboundMiddleware.Oneway)
   152  	cfg.OutboundMiddleware.Stream = outboundmiddleware.StreamChain(first, cfg.OutboundMiddleware.Stream)
   153  	return cfg
   154  }
   155  
   156  // convertOutbounds applies outbound middleware and creates validator outbounds
   157  func convertOutbounds(outbounds Outbounds, mw OutboundMiddleware) Outbounds {
   158  	outboundSpecs := make(Outbounds, len(outbounds))
   159  
   160  	for outboundKey, outs := range outbounds {
   161  		if outs.Unary == nil && outs.Oneway == nil && outs.Stream == nil {
   162  			panic(fmt.Sprintf("no outbound set for outbound key %q in dispatcher", outboundKey))
   163  		}
   164  
   165  		var (
   166  			unaryOutbound  transport.UnaryOutbound
   167  			onewayOutbound transport.OnewayOutbound
   168  			streamOutbound transport.StreamOutbound
   169  		)
   170  		serviceName := outboundKey
   171  
   172  		// apply outbound middleware and create ValidatorOutbounds
   173  
   174  		if outs.Unary != nil {
   175  			unaryOutbound = middleware.ApplyUnaryOutbound(outs.Unary, mw.Unary)
   176  			unaryOutbound = request.UnaryValidatorOutbound{UnaryOutbound: unaryOutbound, Namer: namerOrNil(unaryOutbound)}
   177  		}
   178  
   179  		if outs.Oneway != nil {
   180  			onewayOutbound = middleware.ApplyOnewayOutbound(outs.Oneway, mw.Oneway)
   181  			onewayOutbound = request.OnewayValidatorOutbound{OnewayOutbound: onewayOutbound, Namer: namerOrNil(onewayOutbound)}
   182  		}
   183  
   184  		if outs.Stream != nil {
   185  			streamOutbound = middleware.ApplyStreamOutbound(outs.Stream, mw.Stream)
   186  			streamOutbound = request.StreamValidatorOutbound{StreamOutbound: streamOutbound, Namer: namerOrNil(streamOutbound)}
   187  		}
   188  
   189  		if outs.ServiceName != "" {
   190  			serviceName = outs.ServiceName
   191  		}
   192  
   193  		outboundSpecs[outboundKey] = transport.Outbounds{
   194  			ServiceName: serviceName,
   195  			Unary:       unaryOutbound,
   196  			Oneway:      onewayOutbound,
   197  			Stream:      streamOutbound,
   198  		}
   199  	}
   200  
   201  	return outboundSpecs
   202  }
   203  
   204  func namerOrNil(o transport.Outbound) (namer transport.Namer) {
   205  	if n, ok := o.(transport.Namer); ok {
   206  		namer = n
   207  	}
   208  	return
   209  }
   210  
   211  // collectTransports iterates over all inbounds and outbounds and collects all
   212  // of their unique underlying transports. Multiple inbounds and outbounds may
   213  // share a transport, and we only want the dispatcher to manage their lifecycle
   214  // once.
   215  func collectTransports(inbounds Inbounds, outbounds Outbounds) []transport.Transport {
   216  	// Collect all unique transports from inbounds and outbounds.
   217  	transports := make(map[transport.Transport]struct{})
   218  	for _, inbound := range inbounds {
   219  		for _, transport := range inbound.Transports() {
   220  			transports[transport] = struct{}{}
   221  		}
   222  	}
   223  	for _, outbound := range outbounds {
   224  		if unary := outbound.Unary; unary != nil {
   225  			for _, transport := range unary.Transports() {
   226  				transports[transport] = struct{}{}
   227  			}
   228  		}
   229  		if oneway := outbound.Oneway; oneway != nil {
   230  			for _, transport := range oneway.Transports() {
   231  				transports[transport] = struct{}{}
   232  			}
   233  		}
   234  		if stream := outbound.Stream; stream != nil {
   235  			for _, transport := range stream.Transports() {
   236  				transports[transport] = struct{}{}
   237  			}
   238  		}
   239  	}
   240  	keys := make([]transport.Transport, 0, len(transports))
   241  	for key := range transports {
   242  		keys = append(keys, key)
   243  	}
   244  	return keys
   245  }
   246  
   247  // Dispatcher encapsulates a YARPC application. It acts as the entry point to
   248  // send and receive YARPC requests in a transport and encoding agnostic way.
   249  type Dispatcher struct {
   250  	table      transport.RouteTable
   251  	name       string
   252  	inbounds   Inbounds
   253  	outbounds  Outbounds
   254  	transports []transport.Transport
   255  
   256  	inboundMiddleware InboundMiddleware
   257  
   258  	log       *zap.Logger
   259  	meter     *metrics.Scope
   260  	stopMeter context.CancelFunc
   261  
   262  	once *lifecycle.Once
   263  }
   264  
   265  // Inbounds returns a copy of the list of inbounds for this RPC object.
   266  //
   267  // The Inbounds will be returned in the same order that was used in the
   268  // configuration.
   269  func (d *Dispatcher) Inbounds() Inbounds {
   270  	inbounds := make(Inbounds, len(d.inbounds))
   271  	copy(inbounds, d.inbounds)
   272  	return inbounds
   273  }
   274  
   275  // Outbounds returns a copy of the list of outbounds for this RPC object.
   276  func (d *Dispatcher) Outbounds() Outbounds {
   277  	outbounds := make(Outbounds, len(d.outbounds))
   278  	for k, v := range d.outbounds {
   279  		outbounds[k] = v
   280  	}
   281  	return outbounds
   282  }
   283  
   284  // ClientConfig provides the configuration needed to talk to the given
   285  // service through an outboundKey. This configuration may be directly
   286  // passed into encoding-specific RPC clients.
   287  //
   288  // 	keyvalueClient := json.New(dispatcher.ClientConfig("keyvalue"))
   289  //
   290  // This function panics if the outboundKey is not known.
   291  func (d *Dispatcher) ClientConfig(outboundKey string) transport.ClientConfig {
   292  	return d.MustOutboundConfig(outboundKey)
   293  }
   294  
   295  // MustOutboundConfig provides the configuration needed to talk to the given
   296  // service through an outboundKey. This configuration may be directly
   297  // passed into encoding-specific RPC clients.
   298  //
   299  // 	keyvalueClient := json.New(dispatcher.MustOutboundConfig("keyvalue"))
   300  //
   301  // This function panics if the outboundKey is not known.
   302  func (d *Dispatcher) MustOutboundConfig(outboundKey string) *transport.OutboundConfig {
   303  	if oc, ok := d.OutboundConfig(outboundKey); ok {
   304  		return oc
   305  	}
   306  	panic(fmt.Sprintf("no configured outbound transport for outbound key %q", outboundKey))
   307  }
   308  
   309  // OutboundConfig provides the configuration needed to talk to the given
   310  // service through an outboundKey. This configuration may be directly
   311  // passed into encoding-specific RPC clients.
   312  //
   313  //  outboundConfig, ok := dispatcher.OutboundConfig("keyvalue")
   314  //  if !ok {
   315  //    // do something
   316  //  }
   317  // 	keyvalueClient := json.New(outboundConfig)
   318  func (d *Dispatcher) OutboundConfig(outboundKey string) (oc *transport.OutboundConfig, ok bool) {
   319  	if out, ok := d.outbounds[outboundKey]; ok {
   320  		return &transport.OutboundConfig{
   321  			CallerName: d.name,
   322  			Outbounds:  out,
   323  		}, true
   324  	}
   325  	return nil, false
   326  }
   327  
   328  // InboundMiddleware returns the middleware applied to all inbound handlers.
   329  // Router middleware and fallback handlers can use the InboundMiddleware to
   330  // wrap custom handlers.
   331  func (d *Dispatcher) InboundMiddleware() InboundMiddleware {
   332  	return d.inboundMiddleware
   333  }
   334  
   335  // Register registers zero or more procedures with this dispatcher. Incoming
   336  // requests to these procedures will be routed to the handlers specified in
   337  // the given Procedures.
   338  func (d *Dispatcher) Register(rs []transport.Procedure) {
   339  	procedures := make([]transport.Procedure, 0, len(rs))
   340  
   341  	for _, r := range rs {
   342  		switch r.HandlerSpec.Type() {
   343  		case transport.Unary:
   344  			h := middleware.ApplyUnaryInbound(r.HandlerSpec.Unary(),
   345  				d.inboundMiddleware.Unary)
   346  			r.HandlerSpec = transport.NewUnaryHandlerSpec(h)
   347  		case transport.Oneway:
   348  			h := middleware.ApplyOnewayInbound(r.HandlerSpec.Oneway(),
   349  				d.inboundMiddleware.Oneway)
   350  			r.HandlerSpec = transport.NewOnewayHandlerSpec(h)
   351  		case transport.Streaming:
   352  			h := middleware.ApplyStreamInbound(r.HandlerSpec.Stream(),
   353  				d.inboundMiddleware.Stream)
   354  			r.HandlerSpec = transport.NewStreamHandlerSpec(h)
   355  		default:
   356  			panic(fmt.Sprintf("unknown handler type %q for service %q, procedure %q",
   357  				r.HandlerSpec.Type(), r.Service, r.Name))
   358  		}
   359  
   360  		procedures = append(procedures, r)
   361  		d.log.Info("Registration succeeded.", zap.Object("registeredProcedure", r))
   362  	}
   363  
   364  	d.table.Register(procedures)
   365  }
   366  
   367  // Start starts the Dispatcher, allowing it to accept and process new incoming
   368  // requests. This starts all inbounds and outbounds configured on this
   369  // Dispatcher.
   370  //
   371  // This function returns immediately after everything has been started.
   372  // Servers should add a `select {}` to block to process all incoming requests.
   373  //
   374  // 	if err := dispatcher.Start(); err != nil {
   375  // 		log.Fatal(err)
   376  // 	}
   377  // 	defer dispatcher.Stop()
   378  //
   379  // 	select {}
   380  //
   381  // Start and PhasedStart are mutually exclusive. See the PhasedStart
   382  // documentation for details.
   383  func (d *Dispatcher) Start() error {
   384  	starter := &PhasedStarter{
   385  		dispatcher: d,
   386  		log:        d.log,
   387  	}
   388  	return d.once.Start(func() error {
   389  		d.log.Info("starting dispatcher")
   390  		starter.setRouters()
   391  		if err := starter.StartTransports(); err != nil {
   392  			return err
   393  		}
   394  		if err := starter.StartOutbounds(); err != nil {
   395  			return err
   396  		}
   397  		if err := starter.StartInbounds(); err != nil {
   398  			return err
   399  		}
   400  		d.log.Info("dispatcher startup complete")
   401  		return nil
   402  	})
   403  }
   404  
   405  // PhasedStart is a more granular alternative to Start, and is intended only
   406  // for advanced users. Rather than starting all transports, inbounds, and
   407  // outbounds at once, it lets the user start them separately.
   408  //
   409  // Start and PhasedStart are mutually exclusive. If Start is called first,
   410  // PhasedStart is a no-op and returns the same error (if any) that Start
   411  // returned. If PhasedStart is called first, Start is a no-op and always
   412  // returns a nil error; the caller is responsible for using the PhasedStarter
   413  // to complete startup.
   414  func (d *Dispatcher) PhasedStart() (*PhasedStarter, error) {
   415  	starter := &PhasedStarter{
   416  		dispatcher: d,
   417  		log:        d.log,
   418  	}
   419  	if err := d.once.Start(func() error {
   420  		starter.log.Info("beginning phased dispatcher start")
   421  		starter.setRouters()
   422  		return nil
   423  	}); err != nil {
   424  		return nil, err
   425  	}
   426  	return starter, nil
   427  }
   428  
   429  // Stop stops the Dispatcher, shutting down all inbounds, outbounds, and
   430  // transports. This function returns after everything has been stopped.
   431  //
   432  // Stop and PhasedStop are mutually exclusive. See the PhasedStop
   433  // documentation for details.
   434  func (d *Dispatcher) Stop() error {
   435  	stopper := &PhasedStopper{
   436  		dispatcher: d,
   437  		log:        d.log,
   438  	}
   439  	return d.once.Stop(func() error {
   440  		d.log.Info("shutting down dispatcher")
   441  		return multierr.Combine(
   442  			stopper.StopInbounds(),
   443  			stopper.StopOutbounds(),
   444  			stopper.StopTransports(),
   445  		)
   446  	})
   447  }
   448  
   449  // PhasedStop is a more granular alternative to Stop, and is intended only for
   450  // advanced users. Rather than stopping all inbounds, outbounds, and
   451  // transports at once, it lets the user stop them separately.
   452  //
   453  // Stop and PhasedStop are mutually exclusive. If Stop is called first,
   454  // PhasedStop is a no-op and returns the same error (if any) that Stop
   455  // returned. If PhasedStop is called first, Stop is a no-op and always returns
   456  // a nil error; the caller is responsible for using the PhasedStopper to
   457  // complete shutdown.
   458  func (d *Dispatcher) PhasedStop() (*PhasedStopper, error) {
   459  	if err := d.once.Stop(func() error { return nil }); err != nil {
   460  		return nil, err
   461  	}
   462  	return &PhasedStopper{
   463  		dispatcher: d,
   464  		log:        d.log,
   465  	}, nil
   466  }
   467  
   468  // Router returns the procedure router.
   469  func (d *Dispatcher) Router() transport.Router {
   470  	return d.table
   471  }
   472  
   473  // Name returns the name of the dispatcher.
   474  func (d *Dispatcher) Name() string {
   475  	return d.name
   476  }