go.uber.org/yarpc@v1.72.1/transport/tchannel/transport.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"context"
    25  	"crypto/tls"
    26  	"errors"
    27  	"fmt"
    28  	"net"
    29  	"sync"
    30  	"time"
    31  
    32  	"github.com/opentracing/opentracing-go"
    33  	"github.com/uber/tchannel-go"
    34  	"go.uber.org/net/metrics"
    35  	backoffapi "go.uber.org/yarpc/api/backoff"
    36  	"go.uber.org/yarpc/api/peer"
    37  	"go.uber.org/yarpc/api/transport"
    38  	yarpctls "go.uber.org/yarpc/api/transport/tls"
    39  	"go.uber.org/yarpc/pkg/lifecycle"
    40  	"go.uber.org/yarpc/transport/internal/tls/dialer"
    41  	"go.uber.org/yarpc/transport/internal/tls/muxlistener"
    42  	"go.uber.org/zap"
    43  )
    44  
    45  type headerCase int
    46  
    47  const (
    48  	canonicalizedHeaderCase headerCase = iota
    49  	originalHeaderCase
    50  )
    51  
    52  // Transport is a TChannel transport suitable for use with YARPC's peer
    53  // selection system.
    54  // The transport implements peer.Transport so multiple peer.List
    55  // implementations can retain and release shared peers.
    56  // The transport implements transport.Transport so it is suitable for lifecycle
    57  // management.
    58  type Transport struct {
    59  	lock sync.Mutex
    60  	once *lifecycle.Once
    61  
    62  	ch                *tchannel.Channel
    63  	router            transport.Router
    64  	tracer            opentracing.Tracer
    65  	logger            *zap.Logger
    66  	meter             *metrics.Scope
    67  	name              string
    68  	addr              string
    69  	listener          net.Listener
    70  	dialer            func(ctx context.Context, network, hostPort string) (net.Conn, error)
    71  	newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter
    72  
    73  	connTimeout         time.Duration
    74  	connectorsGroup     sync.WaitGroup
    75  	connBackoffStrategy backoffapi.Strategy
    76  	headerCase          headerCase
    77  
    78  	peers map[string]*tchannelPeer
    79  
    80  	nativeTChannelMethods          NativeTChannelMethods
    81  	excludeServiceHeaderInResponse bool
    82  
    83  	inboundTLSConfig *tls.Config
    84  	inboundTLSMode   *yarpctls.Mode
    85  
    86  	outboundTLSConfigProvider yarpctls.OutboundTLSConfigProvider
    87  	outboundChannels          []*outboundChannel
    88  }
    89  
    90  // NewTransport is a YARPC transport that facilitates sending and receiving
    91  // YARPC requests through TChannel.
    92  // It uses a shared TChannel Channel for both, incoming and outgoing requests,
    93  // ensuring reuse of connections and other resources.
    94  //
    95  // Either the local service name (with the ServiceName option) or a user-owned
    96  // TChannel (with the WithChannel option) MUST be specified.
    97  func NewTransport(opts ...TransportOption) (*Transport, error) {
    98  	options := newTransportOptions()
    99  
   100  	for _, opt := range opts {
   101  		opt(&options)
   102  	}
   103  
   104  	if options.ch != nil {
   105  		return nil, fmt.Errorf("NewTransport does not accept WithChannel, use NewChannelTransport")
   106  	}
   107  
   108  	return options.newTransport(), nil
   109  }
   110  
   111  func (o transportOptions) newTransport() *Transport {
   112  	logger := o.logger
   113  	if logger == nil {
   114  		logger = zap.NewNop()
   115  	}
   116  	headerCase := canonicalizedHeaderCase
   117  	if o.originalHeaders {
   118  		headerCase = originalHeaderCase
   119  	}
   120  	return &Transport{
   121  		once:                           lifecycle.NewOnce(),
   122  		name:                           o.name,
   123  		addr:                           o.addr,
   124  		listener:                       o.listener,
   125  		dialer:                         o.dialer,
   126  		connTimeout:                    o.connTimeout,
   127  		connBackoffStrategy:            o.connBackoffStrategy,
   128  		peers:                          make(map[string]*tchannelPeer),
   129  		tracer:                         o.tracer,
   130  		logger:                         logger,
   131  		meter:                          o.meter,
   132  		headerCase:                     headerCase,
   133  		newResponseWriter:              newHandlerWriter,
   134  		nativeTChannelMethods:          o.nativeTChannelMethods,
   135  		excludeServiceHeaderInResponse: o.excludeServiceHeaderInResponse,
   136  		inboundTLSConfig:               o.inboundTLSConfig,
   137  		inboundTLSMode:                 o.inboundTLSMode,
   138  		outboundTLSConfigProvider:      o.outboundTLSConfigProvider,
   139  	}
   140  }
   141  
   142  // ListenAddr exposes the listen address of the transport.
   143  func (t *Transport) ListenAddr() string {
   144  	return t.addr
   145  }
   146  
   147  // RetainPeer adds a peer subscriber (typically a peer chooser) and causes the
   148  // transport to maintain persistent connections with that peer.
   149  func (t *Transport) RetainPeer(pid peer.Identifier, sub peer.Subscriber) (peer.Peer, error) {
   150  	return t.retainPeer(pid, sub, t.ch)
   151  }
   152  
   153  func (t *Transport) retainPeer(pid peer.Identifier, sub peer.Subscriber, ch *tchannel.Channel) (peer.Peer, error) {
   154  	t.lock.Lock()
   155  	defer t.lock.Unlock()
   156  
   157  	p := t.getOrCreatePeer(pid, ch)
   158  	p.Subscribe(sub)
   159  	return p, nil
   160  }
   161  
   162  // **NOTE** should only be called while the lock write mutex is acquired
   163  func (t *Transport) getOrCreatePeer(pid peer.Identifier, ch *tchannel.Channel) *tchannelPeer {
   164  	addr := pid.Identifier()
   165  	if p, ok := t.peers[addr]; ok {
   166  		return p
   167  	}
   168  
   169  	p := newPeer(addr, t, ch)
   170  	t.peers[addr] = p
   171  	// Start a peer connection loop
   172  	t.connectorsGroup.Add(1)
   173  	go p.maintainConnection()
   174  
   175  	return p
   176  }
   177  
   178  // ReleasePeer releases a peer from the peer.Subscriber and removes that peer
   179  // from the Transport if nothing is listening to it.
   180  func (t *Transport) ReleasePeer(pid peer.Identifier, sub peer.Subscriber) error {
   181  	t.lock.Lock()
   182  	defer t.lock.Unlock()
   183  
   184  	p, ok := t.peers[pid.Identifier()]
   185  	if !ok {
   186  		return peer.ErrTransportHasNoReferenceToPeer{
   187  			TransportName:  "tchannel.Transport",
   188  			PeerIdentifier: pid.Identifier(),
   189  		}
   190  	}
   191  
   192  	if err := p.Unsubscribe(sub); err != nil {
   193  		return err
   194  	}
   195  
   196  	if p.NumSubscribers() == 0 {
   197  		// Release the peer so that the connection retention loop stops.
   198  		p.release()
   199  		delete(t.peers, pid.Identifier())
   200  	}
   201  
   202  	return nil
   203  }
   204  
   205  // Start starts the TChannel transport. This starts making connections and
   206  // accepting inbound requests. All inbounds must have been assigned a router
   207  // to accept inbound requests before this is called.
   208  func (t *Transport) Start() error {
   209  	return t.once.Start(t.start)
   210  }
   211  
   212  func (t *Transport) start() error {
   213  	t.lock.Lock()
   214  	defer t.lock.Unlock()
   215  
   216  	var skipHandlerMethods []string
   217  	if t.nativeTChannelMethods != nil {
   218  		skipHandlerMethods = t.nativeTChannelMethods.SkipMethodNames()
   219  	}
   220  
   221  	chopts := tchannel.ChannelOptions{
   222  		Tracer: t.tracer,
   223  		Handler: handler{
   224  			router:                         t.router,
   225  			tracer:                         t.tracer,
   226  			headerCase:                     t.headerCase,
   227  			logger:                         t.logger,
   228  			newResponseWriter:              t.newResponseWriter,
   229  			excludeServiceHeaderInResponse: t.excludeServiceHeaderInResponse,
   230  		},
   231  		OnPeerStatusChanged: t.onPeerStatusChanged,
   232  		Dialer:              t.dialer,
   233  		SkipHandlerMethods:  skipHandlerMethods,
   234  	}
   235  	ch, err := tchannel.NewChannel(t.name, &chopts)
   236  	if err != nil {
   237  		return err
   238  	}
   239  	t.ch = ch
   240  
   241  	if t.nativeTChannelMethods != nil {
   242  		for name, handler := range t.nativeTChannelMethods.Methods() {
   243  			ch.Register(handler, name)
   244  		}
   245  	}
   246  
   247  	listener := t.listener
   248  	if listener == nil {
   249  		addr := t.addr
   250  		// Default to ListenIP if addr wasn't given.
   251  		if addr == "" {
   252  			listenIP, err := tchannel.ListenIP()
   253  			if err != nil {
   254  				return err
   255  			}
   256  
   257  			addr = listenIP.String() + ":0"
   258  			// TODO(abg): Find a way to export this to users
   259  		}
   260  
   261  		// TODO(abg): If addr was just the port (":4040"), we want to use
   262  		// ListenIP() + ":4040" rather than just ":4040".
   263  		listener, err = net.Listen("tcp", addr)
   264  		if err != nil {
   265  			return err
   266  		}
   267  	}
   268  
   269  	if t.inboundTLSMode != nil && *t.inboundTLSMode != yarpctls.Disabled {
   270  		if t.inboundTLSConfig == nil {
   271  			return errors.New("tchannel TLS enabled but configuration not provided")
   272  		}
   273  
   274  		listener = muxlistener.NewListener(muxlistener.Config{
   275  			Listener:      listener,
   276  			TLSConfig:     t.inboundTLSConfig,
   277  			ServiceName:   t.name,
   278  			TransportName: TransportName,
   279  			Meter:         t.meter,
   280  			Logger:        t.logger,
   281  			Mode:          *t.inboundTLSMode,
   282  		})
   283  	}
   284  
   285  	if err := t.ch.Serve(listener); err != nil {
   286  		return err
   287  	}
   288  	t.addr = t.ch.PeerInfo().HostPort
   289  
   290  	for _, outboundChannel := range t.outboundChannels {
   291  		if err := outboundChannel.start(); err != nil {
   292  			return err
   293  		}
   294  	}
   295  	return nil
   296  }
   297  
   298  // Stop stops the TChannel transport. It starts rejecting incoming requests
   299  // and draining connections before closing them.
   300  // In a future version of YARPC, Stop will block until the underlying channel
   301  // has closed completely.
   302  func (t *Transport) Stop() error {
   303  	return t.once.Stop(t.stop)
   304  }
   305  
   306  func (t *Transport) stop() error {
   307  	t.ch.Close()
   308  	for _, outboundChannel := range t.outboundChannels {
   309  		outboundChannel.stop()
   310  	}
   311  	t.connectorsGroup.Wait()
   312  	return nil
   313  }
   314  
   315  // IsRunning returns whether the TChannel transport is running.
   316  func (t *Transport) IsRunning() bool {
   317  	return t.once.IsRunning()
   318  }
   319  
   320  // onPeerStatusChanged receives notifications from TChannel Channel when any
   321  // peer's status changes.
   322  func (t *Transport) onPeerStatusChanged(tp *tchannel.Peer) {
   323  	t.lock.Lock()
   324  	defer t.lock.Unlock()
   325  
   326  	p, ok := t.peers[tp.HostPort()]
   327  	if !ok {
   328  		return
   329  	}
   330  	p.notifyConnectionStatusChanged()
   331  }
   332  
   333  // CreateTLSOutboundChannel creates a outbound channel for managing tls
   334  // connections with the given tls config and destination name.
   335  // Usage:
   336  // 	tr, _ := tchannel.NewTransport(...)
   337  //  outboundCh, _ := tr.CreateTLSOutboundChannel(tls-config, "dest-name")
   338  //  outbound := tr.NewOutbound(peer.NewSingle(id, outboundCh))
   339  func (t *Transport) CreateTLSOutboundChannel(tlsConfig *tls.Config, destinationName string) (peer.Transport, error) {
   340  	params := dialer.Params{
   341  		Config:        tlsConfig,
   342  		Meter:         t.meter,
   343  		Logger:        t.logger,
   344  		ServiceName:   t.name,
   345  		TransportName: TransportName,
   346  		Dest:          destinationName,
   347  		Dialer:        t.dialer,
   348  	}
   349  	return t.createOutboundChannel(dialer.NewTLSDialer(params).DialContext)
   350  }
   351  
   352  func (t *Transport) createOutboundChannel(dialerFunc dialerFunc) (peer.Transport, error) {
   353  	t.lock.Lock()
   354  	defer t.lock.Unlock()
   355  
   356  	if t.once.State() != lifecycle.Idle {
   357  		return nil, errors.New("tchannel outbound channel cannot be created after starting transport")
   358  	}
   359  	outboundChannel := newOutboundChannel(t, dialerFunc)
   360  	t.outboundChannels = append(t.outboundChannels, outboundChannel)
   361  	return outboundChannel, nil
   362  }