github.com/rbisecke/kafka-go@v0.4.27/transport.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"runtime/pprof"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"github.com/rbisecke/kafka-go/protocol"
    20  	"github.com/rbisecke/kafka-go/protocol/apiversions"
    21  	"github.com/rbisecke/kafka-go/protocol/createtopics"
    22  	"github.com/rbisecke/kafka-go/protocol/findcoordinator"
    23  	meta "github.com/rbisecke/kafka-go/protocol/metadata"
    24  	"github.com/rbisecke/kafka-go/protocol/saslauthenticate"
    25  	"github.com/rbisecke/kafka-go/protocol/saslhandshake"
    26  	"github.com/rbisecke/kafka-go/sasl"
    27  )
    28  
    29  // Request is an interface implemented by types that represent messages sent
    30  // from kafka clients to brokers.
    31  type Request = protocol.Message
    32  
    33  // Response is an interface implemented by types that represent messages sent
    34  // from kafka brokers in response to client requests.
    35  type Response = protocol.Message
    36  
    37  // RoundTripper is an interface implemented by types which support interacting
    38  // with kafka brokers.
    39  type RoundTripper interface {
    40  	// RoundTrip sends a request to a kafka broker and returns the response that
    41  	// was received, or a non-nil error.
    42  	//
    43  	// The context passed as first argument can be used to asynchronnously abort
    44  	// the call if needed.
    45  	RoundTrip(context.Context, net.Addr, Request) (Response, error)
    46  }
    47  
    48  // Transport is an implementation of the RoundTripper interface.
    49  //
    50  // Transport values manage a pool of connections and automatically discovers the
    51  // clusters layout to route requests to the appropriate brokers.
    52  //
    53  // Transport values are safe to use concurrently from multiple goroutines.
    54  //
    55  // Note: The intent is for the Transport to become the underlying layer of the
    56  // kafka.Reader and kafka.Writer types.
    57  type Transport struct {
    58  	// A function used to establish connections to the kafka cluster.
    59  	Dial func(context.Context, string, string) (net.Conn, error)
    60  
    61  	// Time limit set for establishing connections to the kafka cluster. This
    62  	// limit includes all round trips done to establish the connections (TLS
    63  	// hadbhaske, SASL negotiation, etc...).
    64  	//
    65  	// Defaults to 5s.
    66  	DialTimeout time.Duration
    67  
    68  	// Maximum amount of time that connections will remain open and unused.
    69  	// The transport will manage to automatically close connections that have
    70  	// been idle for too long, and re-open them on demand when the transport is
    71  	// used again.
    72  	//
    73  	// Defaults to 30s.
    74  	IdleTimeout time.Duration
    75  
    76  	// TTL for the metadata cached by this transport. Note that the value
    77  	// configured here is an upper bound, the transport randomizes the TTLs to
    78  	// avoid getting into states where multiple clients end up synchronized and
    79  	// cause bursts of requests to the kafka broker.
    80  	//
    81  	// Default to 6s.
    82  	MetadataTTL time.Duration
    83  
    84  	// Unique identifier that the transport communicates to the brokers when it
    85  	// sends requests.
    86  	ClientID string
    87  
    88  	// An optional configuration for TLS connections established by this
    89  	// transport.
    90  	//
    91  	// If the Server
    92  	TLS *tls.Config
    93  
    94  	// SASL configures the Transfer to use SASL authentication.
    95  	SASL sasl.Mechanism
    96  
    97  	// An optional resolver used to translate broker host names into network
    98  	// addresses.
    99  	//
   100  	// The resolver will be called for every request (not every connection),
   101  	// making it possible to implement ACL policies by validating that the
   102  	// program is allowed to connect to the kafka broker. This also means that
   103  	// the resolver should probably provide a caching layer to avoid storming
   104  	// the service discovery backend with requests.
   105  	//
   106  	// When set, the Dial function is not responsible for performing name
   107  	// resolution, and is always called with a pre-resolved address.
   108  	Resolver BrokerResolver
   109  
   110  	// The background context used to control goroutines started internally by
   111  	// the transport.
   112  	//
   113  	// If nil, context.Background() is used instead.
   114  	Context context.Context
   115  
   116  	mutex sync.RWMutex
   117  	pools map[networkAddress]*connPool
   118  }
   119  
   120  // DefaultTransport is the default transport used by kafka clients in this
   121  // package.
   122  var DefaultTransport RoundTripper = &Transport{
   123  	Dial: (&net.Dialer{
   124  		Timeout:   3 * time.Second,
   125  		DualStack: true,
   126  	}).DialContext,
   127  }
   128  
   129  // CloseIdleConnections closes all idle connections immediately, and marks all
   130  // connections that are in use to be closed when they become idle again.
   131  func (t *Transport) CloseIdleConnections() {
   132  	t.mutex.Lock()
   133  	defer t.mutex.Unlock()
   134  
   135  	for _, pool := range t.pools {
   136  		pool.unref()
   137  	}
   138  
   139  	for k := range t.pools {
   140  		delete(t.pools, k)
   141  	}
   142  }
   143  
   144  // RoundTrip sends a request to a kafka cluster and returns the response, or an
   145  // error if no responses were received.
   146  //
   147  // Message types are available in sub-packages of the protocol package. Each
   148  // kafka API is implemented in a different sub-package. For example, the request
   149  // and response types for the Fetch API are available in the protocol/fetch
   150  // package.
   151  //
   152  // The type of the response message will match the type of the request. For
   153  // exmple, if RoundTrip was called with a *fetch.Request as argument, the value
   154  // returned will be of type *fetch.Response. It is safe for the program to do a
   155  // type assertion after checking that no error was returned.
   156  //
   157  // This example illustrates the way this method is expected to be used:
   158  //
   159  //	r, err := transport.RoundTrip(ctx, addr, &fetch.Request{ ... })
   160  //	if err != nil {
   161  //		...
   162  //	} else {
   163  //		res := r.(*fetch.Response)
   164  //		...
   165  //	}
   166  //
   167  // The transport automatically selects the highest version of the API that is
   168  // supported by both the kafka-go package and the kafka broker. The negotiation
   169  // happens transparently once when connections are established.
   170  //
   171  // This API was introduced in version 0.4 as a way to leverage the lower-level
   172  // features of the kafka protocol, but also provide a more efficient way of
   173  // managing connections to kafka brokers.
   174  func (t *Transport) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) {
   175  	p := t.grabPool(addr)
   176  	defer p.unref()
   177  	return p.roundTrip(ctx, req)
   178  }
   179  
   180  func (t *Transport) dial() func(context.Context, string, string) (net.Conn, error) {
   181  	if t.Dial != nil {
   182  		return t.Dial
   183  	}
   184  	return defaultDialer.DialContext
   185  }
   186  
   187  func (t *Transport) dialTimeout() time.Duration {
   188  	if t.DialTimeout > 0 {
   189  		return t.DialTimeout
   190  	}
   191  	return 5 * time.Second
   192  }
   193  
   194  func (t *Transport) idleTimeout() time.Duration {
   195  	if t.IdleTimeout > 0 {
   196  		return t.IdleTimeout
   197  	}
   198  	return 30 * time.Second
   199  }
   200  
   201  func (t *Transport) metadataTTL() time.Duration {
   202  	if t.MetadataTTL > 0 {
   203  		return t.MetadataTTL
   204  	}
   205  	return 6 * time.Second
   206  }
   207  
   208  func (t *Transport) grabPool(addr net.Addr) *connPool {
   209  	k := networkAddress{
   210  		network: addr.Network(),
   211  		address: addr.String(),
   212  	}
   213  
   214  	t.mutex.RLock()
   215  	p := t.pools[k]
   216  	if p != nil {
   217  		p.ref()
   218  	}
   219  	t.mutex.RUnlock()
   220  
   221  	if p != nil {
   222  		return p
   223  	}
   224  
   225  	t.mutex.Lock()
   226  	defer t.mutex.Unlock()
   227  
   228  	if p := t.pools[k]; p != nil {
   229  		p.ref()
   230  		return p
   231  	}
   232  
   233  	ctx, cancel := context.WithCancel(t.context())
   234  
   235  	p = &connPool{
   236  		refc: 2,
   237  
   238  		dial:        t.dial(),
   239  		dialTimeout: t.dialTimeout(),
   240  		idleTimeout: t.idleTimeout(),
   241  		metadataTTL: t.metadataTTL(),
   242  		clientID:    t.ClientID,
   243  		tls:         t.TLS,
   244  		sasl:        t.SASL,
   245  		resolver:    t.Resolver,
   246  
   247  		ready:  make(event),
   248  		wake:   make(chan event),
   249  		conns:  make(map[int32]*connGroup),
   250  		cancel: cancel,
   251  	}
   252  
   253  	p.ctrl = p.newConnGroup(addr)
   254  	go p.discover(ctx, p.wake)
   255  
   256  	if t.pools == nil {
   257  		t.pools = make(map[networkAddress]*connPool)
   258  	}
   259  	t.pools[k] = p
   260  	return p
   261  }
   262  
   263  func (t *Transport) context() context.Context {
   264  	if t.Context != nil {
   265  		return t.Context
   266  	}
   267  	return context.Background()
   268  }
   269  
   270  type event chan struct{}
   271  
   272  func (e event) trigger() { close(e) }
   273  
   274  type connPool struct {
   275  	refc uintptr
   276  	// Immutable fields of the connection pool. Connections access these field
   277  	// on their parent pool in a ready-only fashion, so no synchronization is
   278  	// required.
   279  	dial        func(context.Context, string, string) (net.Conn, error)
   280  	dialTimeout time.Duration
   281  	idleTimeout time.Duration
   282  	metadataTTL time.Duration
   283  	clientID    string
   284  	tls         *tls.Config
   285  	sasl        sasl.Mechanism
   286  	resolver    BrokerResolver
   287  	// Signaling mechanisms to orchestrate communications between the pool and
   288  	// the rest of the program.
   289  	once   sync.Once  // ensure that `ready` is triggered only once
   290  	ready  event      // triggered after the first metadata update
   291  	wake   chan event // used to force metadata updates
   292  	cancel context.CancelFunc
   293  	// Mutable fields of the connection pool, access must be synchronized.
   294  	mutex sync.RWMutex
   295  	conns map[int32]*connGroup // data connections used for produce/fetch/etc...
   296  	ctrl  *connGroup           // control connections used for metadata requests
   297  	state atomic.Value         // cached cluster state
   298  }
   299  
   300  type connPoolState struct {
   301  	metadata *meta.Response   // last metadata response seen by the pool
   302  	err      error            // last error from metadata requests
   303  	layout   protocol.Cluster // cluster layout built from metadata response
   304  }
   305  
   306  func (p *connPool) grabState() connPoolState {
   307  	state, _ := p.state.Load().(connPoolState)
   308  	return state
   309  }
   310  
   311  func (p *connPool) setState(state connPoolState) {
   312  	p.state.Store(state)
   313  }
   314  
   315  func (p *connPool) ref() {
   316  	atomic.AddUintptr(&p.refc, +1)
   317  }
   318  
   319  func (p *connPool) unref() {
   320  	if atomic.AddUintptr(&p.refc, ^uintptr(0)) == 0 {
   321  		p.mutex.Lock()
   322  		defer p.mutex.Unlock()
   323  
   324  		for _, conns := range p.conns {
   325  			conns.closeIdleConns()
   326  		}
   327  
   328  		p.ctrl.closeIdleConns()
   329  		p.cancel()
   330  	}
   331  }
   332  
   333  func (p *connPool) roundTrip(ctx context.Context, req Request) (Response, error) {
   334  	// This first select should never block after the first metadata response
   335  	// that would mark the pool as `ready`.
   336  	select {
   337  	case <-p.ready:
   338  	case <-ctx.Done():
   339  		return nil, ctx.Err()
   340  	}
   341  
   342  	state := p.grabState()
   343  	var response promise
   344  
   345  	switch m := req.(type) {
   346  	case *meta.Request:
   347  		// We serve metadata requests directly from the transport cache unless
   348  		// we would like to auto create a topic that isn't in our cache.
   349  		//
   350  		// This reduces the number of round trips to kafka brokers while keeping
   351  		// the logic simple when applying partitioning strategies.
   352  		if state.err != nil {
   353  			return nil, state.err
   354  		}
   355  
   356  		cachedMeta := filterMetadataResponse(m, state.metadata)
   357  		// requestNeeded indicates if we need to send this metadata request to the server.
   358  		// It's true when we want to auto-create topics and we don't have the topic in our
   359  		// cache.
   360  		var requestNeeded bool
   361  		if m.AllowAutoTopicCreation {
   362  			for _, topic := range cachedMeta.Topics {
   363  				if topic.ErrorCode == int16(UnknownTopicOrPartition) {
   364  					requestNeeded = true
   365  					break
   366  				}
   367  			}
   368  		}
   369  
   370  		if !requestNeeded {
   371  			return cachedMeta, nil
   372  		}
   373  
   374  	case protocol.Splitter:
   375  		// Messages that implement the Splitter interface trigger the creation of
   376  		// multiple requests that are all merged back into a single results by
   377  		// a merger.
   378  		messages, merger, err := m.Split(state.layout)
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  		promises := make([]promise, len(messages))
   383  		for i, m := range messages {
   384  			promises[i] = p.sendRequest(ctx, m, state)
   385  		}
   386  		response = join(promises, messages, merger)
   387  	}
   388  
   389  	if response == nil {
   390  		response = p.sendRequest(ctx, req, state)
   391  	}
   392  
   393  	r, err := response.await(ctx)
   394  	if err != nil {
   395  		return r, err
   396  	}
   397  
   398  	switch resp := r.(type) {
   399  	case *createtopics.Response:
   400  		// Force an update of the metadata when adding topics,
   401  		// otherwise the cached state would get out of sync.
   402  		topicsToRefresh := make([]string, 0, len(resp.Topics))
   403  		for _, topic := range resp.Topics {
   404  			// fixes issue 672: don't refresh topics that failed to create, it causes the library to hang indefinitely
   405  			if topic.ErrorCode != 0 {
   406  				continue
   407  			}
   408  
   409  			topicsToRefresh = append(topicsToRefresh, topic.Name)
   410  		}
   411  
   412  		p.refreshMetadata(ctx, topicsToRefresh)
   413  	case *meta.Response:
   414  		m := req.(*meta.Request)
   415  		// If we get here with allow auto topic creation then
   416  		// we didn't have that topic in our cache so we should update
   417  		// the cache.
   418  		if m.AllowAutoTopicCreation {
   419  			p.refreshMetadata(ctx, m.TopicNames)
   420  		}
   421  	}
   422  
   423  	return r, nil
   424  }
   425  
   426  // refreshMetadata forces an update of the cached cluster metadata, and waits
   427  // for the given list of topics to appear. This waiting mechanism is necessary
   428  // to account for the fact that topic creation is asynchronous in kafka, and
   429  // causes subsequent requests to fail while the cluster state is propagated to
   430  // all the brokers.
   431  func (p *connPool) refreshMetadata(ctx context.Context, expectTopics []string) {
   432  	minBackoff := 100 * time.Millisecond
   433  	maxBackoff := 2 * time.Second
   434  	cancel := ctx.Done()
   435  
   436  	for ctx.Err() == nil {
   437  		notify := make(event)
   438  		select {
   439  		case <-cancel:
   440  			return
   441  		case p.wake <- notify:
   442  			select {
   443  			case <-notify:
   444  			case <-cancel:
   445  				return
   446  			}
   447  		}
   448  
   449  		state := p.grabState()
   450  		found := 0
   451  
   452  		for _, topic := range expectTopics {
   453  			if _, ok := state.layout.Topics[topic]; ok {
   454  				found++
   455  			}
   456  		}
   457  
   458  		if found == len(expectTopics) {
   459  			return
   460  		}
   461  
   462  		if delay := time.Duration(rand.Int63n(int64(minBackoff))); delay > 0 {
   463  			timer := time.NewTimer(minBackoff)
   464  			select {
   465  			case <-cancel:
   466  			case <-timer.C:
   467  			}
   468  			timer.Stop()
   469  
   470  			if minBackoff *= 2; minBackoff > maxBackoff {
   471  				minBackoff = maxBackoff
   472  			}
   473  		}
   474  	}
   475  }
   476  
   477  func (p *connPool) setReady() {
   478  	p.once.Do(p.ready.trigger)
   479  }
   480  
   481  // update is called periodically by the goroutine running the discover method
   482  // to refresh the cluster layout information used by the transport to route
   483  // requests to brokers.
   484  func (p *connPool) update(ctx context.Context, metadata *meta.Response, err error) {
   485  	var layout protocol.Cluster
   486  
   487  	if metadata != nil {
   488  		metadata.ThrottleTimeMs = 0
   489  
   490  		// Normalize the lists so we can apply binary search on them.
   491  		sortMetadataBrokers(metadata.Brokers)
   492  		sortMetadataTopics(metadata.Topics)
   493  
   494  		for i := range metadata.Topics {
   495  			t := &metadata.Topics[i]
   496  			sortMetadataPartitions(t.Partitions)
   497  		}
   498  
   499  		layout = makeLayout(metadata)
   500  	}
   501  
   502  	state := p.grabState()
   503  	addBrokers := make(map[int32]struct{})
   504  	delBrokers := make(map[int32]struct{})
   505  
   506  	if err != nil {
   507  		// Only update the error on the transport if the cluster layout was
   508  		// unknown. This ensures that we prioritize a previously known state
   509  		// of the cluster to reduce the impact of transient failures.
   510  		if state.metadata != nil {
   511  			return
   512  		}
   513  		state.err = err
   514  	} else {
   515  		for id, b2 := range layout.Brokers {
   516  			if b1, ok := state.layout.Brokers[id]; !ok {
   517  				addBrokers[id] = struct{}{}
   518  			} else if b1 != b2 {
   519  				addBrokers[id] = struct{}{}
   520  				delBrokers[id] = struct{}{}
   521  			}
   522  		}
   523  
   524  		for id := range state.layout.Brokers {
   525  			if _, ok := layout.Brokers[id]; !ok {
   526  				delBrokers[id] = struct{}{}
   527  			}
   528  		}
   529  
   530  		state.metadata, state.layout = metadata, layout
   531  		state.err = nil
   532  	}
   533  
   534  	defer p.setReady()
   535  	defer p.setState(state)
   536  
   537  	if len(addBrokers) != 0 || len(delBrokers) != 0 {
   538  		// Only acquire the lock when there is a change of layout. This is an
   539  		// infrequent event so we don't risk introducing regular contention on
   540  		// the mutex if we were to lock it on every update.
   541  		p.mutex.Lock()
   542  		defer p.mutex.Unlock()
   543  
   544  		if ctx.Err() != nil {
   545  			return // the pool has been closed, no need to update
   546  		}
   547  
   548  		for id := range delBrokers {
   549  			if broker := p.conns[id]; broker != nil {
   550  				broker.closeIdleConns()
   551  				delete(p.conns, id)
   552  			}
   553  		}
   554  
   555  		for id := range addBrokers {
   556  			broker := layout.Brokers[id]
   557  			p.conns[id] = p.newBrokerConnGroup(Broker{
   558  				Rack: broker.Rack,
   559  				Host: broker.Host,
   560  				Port: int(broker.Port),
   561  				ID:   int(broker.ID),
   562  			})
   563  		}
   564  	}
   565  }
   566  
   567  // discover is the entry point of an internal goroutine for the transport which
   568  // periodically requests updates of the cluster metadata and refreshes the
   569  // transport cached cluster layout.
   570  func (p *connPool) discover(ctx context.Context, wake <-chan event) {
   571  	prng := rand.New(rand.NewSource(time.Now().UnixNano()))
   572  	metadataTTL := func() time.Duration {
   573  		return time.Duration(prng.Int63n(int64(p.metadataTTL)))
   574  	}
   575  
   576  	timer := time.NewTimer(metadataTTL())
   577  	defer timer.Stop()
   578  
   579  	var notify event
   580  	done := ctx.Done()
   581  
   582  	for {
   583  		c, err := p.grabClusterConn(ctx)
   584  		if err != nil {
   585  			p.update(ctx, nil, err)
   586  		} else {
   587  			res := make(async, 1)
   588  			req := &meta.Request{
   589  				IncludeClusterAuthorizedOperations: true,
   590  				IncludeTopicAuthorizedOperations:   true,
   591  			}
   592  			deadline, cancel := context.WithTimeout(ctx, p.metadataTTL)
   593  			c.reqs <- connRequest{
   594  				ctx: deadline,
   595  				req: req,
   596  				res: res,
   597  			}
   598  			r, err := res.await(deadline)
   599  			cancel()
   600  			if err != nil && err == ctx.Err() {
   601  				return
   602  			}
   603  			ret, _ := r.(*meta.Response)
   604  			p.update(ctx, ret, err)
   605  		}
   606  
   607  		if notify != nil {
   608  			notify.trigger()
   609  			notify = nil
   610  		}
   611  
   612  		select {
   613  		case <-timer.C:
   614  			timer.Reset(metadataTTL())
   615  		case <-done:
   616  			return
   617  		case notify = <-wake:
   618  		}
   619  	}
   620  }
   621  
   622  // grabBrokerConn returns a connection to a specific broker represented by the
   623  // broker id passed as argument. If the broker id was not known, an error is
   624  // returned.
   625  func (p *connPool) grabBrokerConn(ctx context.Context, brokerID int32) (*conn, error) {
   626  	p.mutex.RLock()
   627  	g := p.conns[brokerID]
   628  	p.mutex.RUnlock()
   629  	if g == nil {
   630  		return nil, BrokerNotAvailable
   631  	}
   632  	return g.grabConnOrConnect(ctx)
   633  }
   634  
   635  // grabClusterConn returns the connection to the kafka cluster that the pool is
   636  // configured to connect to.
   637  //
   638  // The transport uses a shared `control` connection to the cluster for any
   639  // requests that aren't supposed to be sent to specific brokers (e.g. Fetch or
   640  // Produce requests). Requests intended to be routed to specific brokers are
   641  // dispatched on a separate pool of connections that the transport maintains.
   642  // This split help avoid head-of-line blocking situations where control requests
   643  // like Metadata would be queued behind large responses from Fetch requests for
   644  // example.
   645  //
   646  // In either cases, the requests are multiplexed so we can keep a minimal number
   647  // of connections open (N+1, where N is the number of brokers in the cluster).
   648  func (p *connPool) grabClusterConn(ctx context.Context) (*conn, error) {
   649  	return p.ctrl.grabConnOrConnect(ctx)
   650  }
   651  
   652  func (p *connPool) sendRequest(ctx context.Context, req Request, state connPoolState) promise {
   653  	brokerID := int32(-1)
   654  
   655  	switch m := req.(type) {
   656  	case protocol.BrokerMessage:
   657  		// Some requests are supposed to be sent to specific brokers (e.g. the
   658  		// partition leaders). They implement the BrokerMessage interface to
   659  		// delegate the routing decision to each message type.
   660  		broker, err := m.Broker(state.layout)
   661  		if err != nil {
   662  			return reject(err)
   663  		}
   664  		brokerID = broker.ID
   665  
   666  	case protocol.GroupMessage:
   667  		// Some requests are supposed to be sent to a group coordinator,
   668  		// look up which broker is currently the coordinator for the group
   669  		// so we can get a connection to that broker.
   670  		//
   671  		// TODO: should we cache the coordinator info?
   672  		p := p.sendRequest(ctx, &findcoordinator.Request{Key: m.Group()}, state)
   673  		r, err := p.await(ctx)
   674  		if err != nil {
   675  			return reject(err)
   676  		}
   677  		brokerID = r.(*findcoordinator.Response).NodeID
   678  	case protocol.TransactionalMessage:
   679  		p := p.sendRequest(ctx, &findcoordinator.Request{
   680  			Key:     m.Transaction(),
   681  			KeyType: int8(CoordinatorKeyTypeTransaction),
   682  		}, state)
   683  		r, err := p.await(ctx)
   684  		if err != nil {
   685  			return reject(err)
   686  		}
   687  		brokerID = r.(*findcoordinator.Response).NodeID
   688  	}
   689  
   690  	var c *conn
   691  	var err error
   692  	if brokerID >= 0 {
   693  		c, err = p.grabBrokerConn(ctx, brokerID)
   694  	} else {
   695  		c, err = p.grabClusterConn(ctx)
   696  	}
   697  	if err != nil {
   698  		return reject(err)
   699  	}
   700  
   701  	res := make(async, 1)
   702  
   703  	c.reqs <- connRequest{
   704  		ctx: ctx,
   705  		req: req,
   706  		res: res,
   707  	}
   708  
   709  	return res
   710  }
   711  
   712  func filterMetadataResponse(req *meta.Request, res *meta.Response) *meta.Response {
   713  	ret := *res
   714  
   715  	if req.TopicNames != nil {
   716  		ret.Topics = make([]meta.ResponseTopic, len(req.TopicNames))
   717  
   718  		for i, topicName := range req.TopicNames {
   719  			j, ok := findMetadataTopic(res.Topics, topicName)
   720  			if ok {
   721  				ret.Topics[i] = res.Topics[j]
   722  			} else {
   723  				ret.Topics[i] = meta.ResponseTopic{
   724  					ErrorCode: int16(UnknownTopicOrPartition),
   725  					Name:      topicName,
   726  				}
   727  			}
   728  		}
   729  	}
   730  
   731  	return &ret
   732  }
   733  
   734  func findMetadataTopic(topics []meta.ResponseTopic, topicName string) (int, bool) {
   735  	i := sort.Search(len(topics), func(i int) bool {
   736  		return topics[i].Name >= topicName
   737  	})
   738  	return i, i >= 0 && i < len(topics) && topics[i].Name == topicName
   739  }
   740  
   741  func sortMetadataBrokers(brokers []meta.ResponseBroker) {
   742  	sort.Slice(brokers, func(i, j int) bool {
   743  		return brokers[i].NodeID < brokers[j].NodeID
   744  	})
   745  }
   746  
   747  func sortMetadataTopics(topics []meta.ResponseTopic) {
   748  	sort.Slice(topics, func(i, j int) bool {
   749  		return topics[i].Name < topics[j].Name
   750  	})
   751  }
   752  
   753  func sortMetadataPartitions(partitions []meta.ResponsePartition) {
   754  	sort.Slice(partitions, func(i, j int) bool {
   755  		return partitions[i].PartitionIndex < partitions[j].PartitionIndex
   756  	})
   757  }
   758  
   759  func makeLayout(metadataResponse *meta.Response) protocol.Cluster {
   760  	layout := protocol.Cluster{
   761  		Controller: metadataResponse.ControllerID,
   762  		Brokers:    make(map[int32]protocol.Broker),
   763  		Topics:     make(map[string]protocol.Topic),
   764  	}
   765  
   766  	for _, broker := range metadataResponse.Brokers {
   767  		layout.Brokers[broker.NodeID] = protocol.Broker{
   768  			Rack: broker.Rack,
   769  			Host: broker.Host,
   770  			Port: broker.Port,
   771  			ID:   broker.NodeID,
   772  		}
   773  	}
   774  
   775  	for _, topic := range metadataResponse.Topics {
   776  		if topic.IsInternal {
   777  			continue // TODO: do we need to expose those?
   778  		}
   779  		layout.Topics[topic.Name] = protocol.Topic{
   780  			Name:       topic.Name,
   781  			Error:      topic.ErrorCode,
   782  			Partitions: makePartitions(topic.Partitions),
   783  		}
   784  	}
   785  
   786  	return layout
   787  }
   788  
   789  func makePartitions(metadataPartitions []meta.ResponsePartition) map[int32]protocol.Partition {
   790  	protocolPartitions := make(map[int32]protocol.Partition, len(metadataPartitions))
   791  	numBrokerIDs := 0
   792  
   793  	for _, p := range metadataPartitions {
   794  		numBrokerIDs += len(p.ReplicaNodes) + len(p.IsrNodes) + len(p.OfflineReplicas)
   795  	}
   796  
   797  	// Reduce the memory footprint a bit by allocating a single buffer to write
   798  	// all broker ids.
   799  	brokerIDs := make([]int32, 0, numBrokerIDs)
   800  
   801  	for _, p := range metadataPartitions {
   802  		var rep, isr, off []int32
   803  		brokerIDs, rep = appendBrokerIDs(brokerIDs, p.ReplicaNodes)
   804  		brokerIDs, isr = appendBrokerIDs(brokerIDs, p.IsrNodes)
   805  		brokerIDs, off = appendBrokerIDs(brokerIDs, p.OfflineReplicas)
   806  
   807  		protocolPartitions[p.PartitionIndex] = protocol.Partition{
   808  			ID:       p.PartitionIndex,
   809  			Error:    p.ErrorCode,
   810  			Leader:   p.LeaderID,
   811  			Replicas: rep,
   812  			ISR:      isr,
   813  			Offline:  off,
   814  		}
   815  	}
   816  
   817  	return protocolPartitions
   818  }
   819  
   820  func appendBrokerIDs(ids, brokers []int32) ([]int32, []int32) {
   821  	i := len(ids)
   822  	ids = append(ids, brokers...)
   823  	return ids, ids[i:len(ids):len(ids)]
   824  }
   825  
   826  func (p *connPool) newConnGroup(a net.Addr) *connGroup {
   827  	return &connGroup{
   828  		addr: a,
   829  		pool: p,
   830  		broker: Broker{
   831  			ID: -1,
   832  		},
   833  	}
   834  }
   835  
   836  func (p *connPool) newBrokerConnGroup(broker Broker) *connGroup {
   837  	return &connGroup{
   838  		addr: &networkAddress{
   839  			network: "tcp",
   840  			address: net.JoinHostPort(broker.Host, strconv.Itoa(broker.Port)),
   841  		},
   842  		pool:   p,
   843  		broker: broker,
   844  	}
   845  }
   846  
   847  type connRequest struct {
   848  	ctx context.Context
   849  	req Request
   850  	res async
   851  }
   852  
   853  // The promise interface is used as a message passing abstraction to coordinate
   854  // between goroutines that handle requests and responses.
   855  type promise interface {
   856  	// Waits until the promise is resolved, rejected, or the context canceled.
   857  	await(context.Context) (Response, error)
   858  }
   859  
   860  // async is an implementation of the promise interface which supports resolving
   861  // or rejecting the await call asynchronously.
   862  type async chan interface{}
   863  
   864  func (p async) await(ctx context.Context) (Response, error) {
   865  	select {
   866  	case x := <-p:
   867  		switch v := x.(type) {
   868  		case nil:
   869  			return nil, nil // A nil response is ok (e.g. when RequiredAcks is None)
   870  		case Response:
   871  			return v, nil
   872  		case error:
   873  			return nil, v
   874  		default:
   875  			panic(fmt.Errorf("BUG: promise resolved with impossible value of type %T", v))
   876  		}
   877  	case <-ctx.Done():
   878  		return nil, ctx.Err()
   879  	}
   880  }
   881  
   882  func (p async) resolve(res Response) { p <- res }
   883  
   884  func (p async) reject(err error) { p <- err }
   885  
   886  // rejected is an implementation of the promise interface which is always
   887  // returns an error. Values of this type are constructed using the reject
   888  // function.
   889  type rejected struct{ err error }
   890  
   891  func reject(err error) promise { return &rejected{err: err} }
   892  
   893  func (p *rejected) await(ctx context.Context) (Response, error) {
   894  	return nil, p.err
   895  }
   896  
   897  // joined is an implementation of the promise interface which merges results
   898  // from multiple promises into one await call using a merger.
   899  type joined struct {
   900  	promises []promise
   901  	requests []Request
   902  	merger   protocol.Merger
   903  }
   904  
   905  func join(promises []promise, requests []Request, merger protocol.Merger) promise {
   906  	return &joined{
   907  		promises: promises,
   908  		requests: requests,
   909  		merger:   merger,
   910  	}
   911  }
   912  
   913  func (p *joined) await(ctx context.Context) (Response, error) {
   914  	results := make([]interface{}, len(p.promises))
   915  
   916  	for i, sub := range p.promises {
   917  		m, err := sub.await(ctx)
   918  		if err != nil {
   919  			results[i] = err
   920  		} else {
   921  			results[i] = m
   922  		}
   923  	}
   924  
   925  	return p.merger.Merge(p.requests, results)
   926  }
   927  
   928  // Default dialer used by the transport connections when no Dial function
   929  // was configured by the program.
   930  var defaultDialer = net.Dialer{
   931  	Timeout:   3 * time.Second,
   932  	DualStack: true,
   933  }
   934  
   935  // connGroup represents a logical connection group to a kafka broker. The
   936  // actual network connections are lazily open before sending requests, and
   937  // closed if they are unused for longer than the idle timeout.
   938  type connGroup struct {
   939  	addr   net.Addr
   940  	broker Broker
   941  	// Immutable state of the connection.
   942  	pool *connPool
   943  	// Shared state of the connection, this is synchronized on the mutex through
   944  	// calls to the synchronized method. Both goroutines of the connection share
   945  	// the state maintained in these fields.
   946  	mutex     sync.Mutex
   947  	closed    bool
   948  	idleConns []*conn // stack of idle connections
   949  }
   950  
   951  func (g *connGroup) closeIdleConns() {
   952  	g.mutex.Lock()
   953  	conns := g.idleConns
   954  	g.idleConns = nil
   955  	g.closed = true
   956  	g.mutex.Unlock()
   957  
   958  	for _, c := range conns {
   959  		c.close()
   960  	}
   961  }
   962  
   963  func (g *connGroup) grabConnOrConnect(ctx context.Context) (*conn, error) {
   964  	rslv := g.pool.resolver
   965  	addr := g.addr
   966  	var c *conn
   967  
   968  	if rslv == nil {
   969  		c = g.grabConn()
   970  	} else {
   971  		var err error
   972  		broker := g.broker
   973  
   974  		if broker.ID < 0 {
   975  			host, port, err := splitHostPortNumber(addr.String())
   976  			if err != nil {
   977  				return nil, err
   978  			}
   979  			broker.Host = host
   980  			broker.Port = port
   981  		}
   982  
   983  		ipAddrs, err := rslv.LookupBrokerIPAddr(ctx, broker)
   984  		if err != nil {
   985  			return nil, err
   986  		}
   987  
   988  		for _, ipAddr := range ipAddrs {
   989  			network := addr.Network()
   990  			address := net.JoinHostPort(ipAddr.String(), strconv.Itoa(broker.Port))
   991  
   992  			if c = g.grabConnTo(network, address); c != nil {
   993  				break
   994  			}
   995  		}
   996  	}
   997  
   998  	if c == nil {
   999  		connChan := make(chan *conn)
  1000  		errChan := make(chan error)
  1001  
  1002  		go func() {
  1003  			c, err := g.connect(ctx, addr)
  1004  			if err != nil {
  1005  				select {
  1006  				case errChan <- err:
  1007  				case <-ctx.Done():
  1008  				}
  1009  			} else {
  1010  				select {
  1011  				case connChan <- c:
  1012  				case <-ctx.Done():
  1013  					if !g.releaseConn(c) {
  1014  						c.close()
  1015  					}
  1016  				}
  1017  			}
  1018  		}()
  1019  
  1020  		select {
  1021  		case c = <-connChan:
  1022  		case err := <-errChan:
  1023  			return nil, err
  1024  		case <-ctx.Done():
  1025  			return nil, ctx.Err()
  1026  		}
  1027  	}
  1028  
  1029  	return c, nil
  1030  }
  1031  
  1032  func (g *connGroup) grabConnTo(network, address string) *conn {
  1033  	g.mutex.Lock()
  1034  	defer g.mutex.Unlock()
  1035  
  1036  	for i := len(g.idleConns) - 1; i >= 0; i-- {
  1037  		c := g.idleConns[i]
  1038  
  1039  		if c.network == network && c.address == address {
  1040  			copy(g.idleConns[i:], g.idleConns[i+1:])
  1041  			n := len(g.idleConns) - 1
  1042  			g.idleConns[n] = nil
  1043  			g.idleConns = g.idleConns[:n]
  1044  
  1045  			if c.timer != nil {
  1046  				c.timer.Stop()
  1047  			}
  1048  
  1049  			return c
  1050  		}
  1051  	}
  1052  
  1053  	return nil
  1054  }
  1055  
  1056  func (g *connGroup) grabConn() *conn {
  1057  	g.mutex.Lock()
  1058  	defer g.mutex.Unlock()
  1059  
  1060  	if len(g.idleConns) == 0 {
  1061  		return nil
  1062  	}
  1063  
  1064  	n := len(g.idleConns) - 1
  1065  	c := g.idleConns[n]
  1066  	g.idleConns[n] = nil
  1067  	g.idleConns = g.idleConns[:n]
  1068  
  1069  	if c.timer != nil {
  1070  		c.timer.Stop()
  1071  	}
  1072  
  1073  	return c
  1074  }
  1075  
  1076  func (g *connGroup) removeConn(c *conn) bool {
  1077  	g.mutex.Lock()
  1078  	defer g.mutex.Unlock()
  1079  
  1080  	if c.timer != nil {
  1081  		c.timer.Stop()
  1082  	}
  1083  
  1084  	for i, x := range g.idleConns {
  1085  		if x == c {
  1086  			copy(g.idleConns[i:], g.idleConns[i+1:])
  1087  			n := len(g.idleConns) - 1
  1088  			g.idleConns[n] = nil
  1089  			g.idleConns = g.idleConns[:n]
  1090  			return true
  1091  		}
  1092  	}
  1093  
  1094  	return false
  1095  }
  1096  
  1097  func (g *connGroup) releaseConn(c *conn) bool {
  1098  	idleTimeout := g.pool.idleTimeout
  1099  
  1100  	g.mutex.Lock()
  1101  	defer g.mutex.Unlock()
  1102  
  1103  	if g.closed {
  1104  		return false
  1105  	}
  1106  
  1107  	if c.timer != nil {
  1108  		c.timer.Reset(idleTimeout)
  1109  	} else {
  1110  		c.timer = time.AfterFunc(idleTimeout, func() {
  1111  			if g.removeConn(c) {
  1112  				c.close()
  1113  			}
  1114  		})
  1115  	}
  1116  
  1117  	g.idleConns = append(g.idleConns, c)
  1118  	return true
  1119  }
  1120  
  1121  func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) {
  1122  	deadline := time.Now().Add(g.pool.dialTimeout)
  1123  
  1124  	ctx, cancel := context.WithDeadline(ctx, deadline)
  1125  	defer cancel()
  1126  
  1127  	network := strings.Split(addr.Network(), ",")
  1128  	address := strings.Split(addr.String(), ",")
  1129  	var netConn net.Conn
  1130  	var netAddr net.Addr
  1131  	var err error
  1132  
  1133  	if len(address) > 1 {
  1134  		// Shuffle the list of addresses to randomize the order in which
  1135  		// connections are attempted. This prevents routing all connections
  1136  		// to the first broker (which will usually succeed).
  1137  		rand.Shuffle(len(address), func(i, j int) {
  1138  			network[i], network[j] = network[j], network[i]
  1139  			address[i], address[j] = address[j], address[i]
  1140  		})
  1141  	}
  1142  
  1143  	for i := range address {
  1144  		netConn, err = g.pool.dial(ctx, network[i], address[i])
  1145  		if err == nil {
  1146  			netAddr = &networkAddress{
  1147  				network: network[i],
  1148  				address: address[i],
  1149  			}
  1150  			break
  1151  		}
  1152  	}
  1153  
  1154  	if err != nil {
  1155  		return nil, err
  1156  	}
  1157  
  1158  	defer func() {
  1159  		if netConn != nil {
  1160  			netConn.Close()
  1161  		}
  1162  	}()
  1163  
  1164  	if tlsConfig := g.pool.tls; tlsConfig != nil {
  1165  		if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
  1166  			host, _ := splitHostPort(netAddr.String())
  1167  			tlsConfig = tlsConfig.Clone()
  1168  			tlsConfig.ServerName = host
  1169  		}
  1170  		netConn = tls.Client(netConn, tlsConfig)
  1171  	}
  1172  
  1173  	pc := protocol.NewConn(netConn, g.pool.clientID)
  1174  	pc.SetDeadline(deadline)
  1175  
  1176  	r, err := pc.RoundTrip(new(apiversions.Request))
  1177  	if err != nil {
  1178  		return nil, err
  1179  	}
  1180  	res := r.(*apiversions.Response)
  1181  	ver := make(map[protocol.ApiKey]int16, len(res.ApiKeys))
  1182  
  1183  	if res.ErrorCode != 0 {
  1184  		return nil, fmt.Errorf("negotating API versions with kafka broker at %s: %w", g.addr, Error(res.ErrorCode))
  1185  	}
  1186  
  1187  	for _, r := range res.ApiKeys {
  1188  		apiKey := protocol.ApiKey(r.ApiKey)
  1189  		ver[apiKey] = apiKey.SelectVersion(r.MinVersion, r.MaxVersion)
  1190  	}
  1191  
  1192  	pc.SetVersions(ver)
  1193  	pc.SetDeadline(time.Time{})
  1194  
  1195  	if g.pool.sasl != nil {
  1196  		host, port, err := splitHostPortNumber(netAddr.String())
  1197  		if err != nil {
  1198  			return nil, err
  1199  		}
  1200  		metadata := &sasl.Metadata{
  1201  			Host: host,
  1202  			Port: port,
  1203  		}
  1204  		if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, g.pool.sasl); err != nil {
  1205  			return nil, err
  1206  		}
  1207  	}
  1208  
  1209  	reqs := make(chan connRequest)
  1210  	c := &conn{
  1211  		network: netAddr.Network(),
  1212  		address: netAddr.String(),
  1213  		reqs:    reqs,
  1214  		group:   g,
  1215  	}
  1216  	go c.run(pc, reqs)
  1217  
  1218  	netConn = nil
  1219  	return c, nil
  1220  }
  1221  
  1222  type conn struct {
  1223  	reqs    chan<- connRequest
  1224  	network string
  1225  	address string
  1226  	once    sync.Once
  1227  	group   *connGroup
  1228  	timer   *time.Timer
  1229  }
  1230  
  1231  func (c *conn) close() {
  1232  	c.once.Do(func() { close(c.reqs) })
  1233  }
  1234  
  1235  func (c *conn) run(pc *protocol.Conn, reqs <-chan connRequest) {
  1236  	defer pc.Close()
  1237  
  1238  	for cr := range reqs {
  1239  		r, err := c.roundTrip(cr.ctx, pc, cr.req)
  1240  		if err != nil {
  1241  			cr.res.reject(err)
  1242  			if !errors.Is(err, protocol.ErrNoRecord) {
  1243  				break
  1244  			}
  1245  		} else {
  1246  			cr.res.resolve(r)
  1247  		}
  1248  		if !c.group.releaseConn(c) {
  1249  			break
  1250  		}
  1251  	}
  1252  }
  1253  
  1254  func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (Response, error) {
  1255  	pprof.SetGoroutineLabels(ctx)
  1256  	defer pprof.SetGoroutineLabels(context.Background())
  1257  
  1258  	if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
  1259  		pc.SetDeadline(deadline)
  1260  		defer pc.SetDeadline(time.Time{})
  1261  	}
  1262  
  1263  	return pc.RoundTrip(req)
  1264  }
  1265  
  1266  // authenticateSASL performs all of the required requests to authenticate this
  1267  // connection.  If any step fails, this function returns with an error.  A nil
  1268  // error indicates successful authentication.
  1269  func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error {
  1270  	if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil {
  1271  		return err
  1272  	}
  1273  
  1274  	sess, state, err := mechanism.Start(ctx)
  1275  	if err != nil {
  1276  		return err
  1277  	}
  1278  
  1279  	for completed := false; !completed; {
  1280  		challenge, err := saslAuthenticateRoundTrip(pc, state)
  1281  		switch err {
  1282  		case nil:
  1283  		case io.EOF:
  1284  			// the broker may communicate a failed exchange by closing the
  1285  			// connection (esp. in the case where we're passing opaque sasl
  1286  			// data over the wire since there's no protocol info).
  1287  			return SASLAuthenticationFailed
  1288  		default:
  1289  			return err
  1290  		}
  1291  
  1292  		completed, state, err = sess.Next(ctx, challenge)
  1293  		if err != nil {
  1294  			return err
  1295  		}
  1296  	}
  1297  
  1298  	return nil
  1299  }
  1300  
  1301  // saslHandshake sends the SASL handshake message.  This will determine whether
  1302  // the Mechanism is supported by the cluster.  If it's not, this function will
  1303  // error out with UnsupportedSASLMechanism.
  1304  //
  1305  // If the mechanism is unsupported, the handshake request will reply with the
  1306  // list of the cluster's configured mechanisms, which could potentially be used
  1307  // to facilitate negotiation.  At the moment, we are not negotiating the
  1308  // mechanism as we believe that brokers are usually known to the client, and
  1309  // therefore the client should already know which mechanisms are supported.
  1310  //
  1311  // See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake
  1312  func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error {
  1313  	msg, err := pc.RoundTrip(&saslhandshake.Request{
  1314  		Mechanism: mechanism,
  1315  	})
  1316  	if err != nil {
  1317  		return err
  1318  	}
  1319  	res := msg.(*saslhandshake.Response)
  1320  	if res.ErrorCode != 0 {
  1321  		err = Error(res.ErrorCode)
  1322  	}
  1323  	return err
  1324  }
  1325  
  1326  // saslAuthenticate sends the SASL authenticate message.  This function must
  1327  // be immediately preceded by a successful saslHandshake.
  1328  //
  1329  // See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate
  1330  func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) {
  1331  	msg, err := pc.RoundTrip(&saslauthenticate.Request{
  1332  		AuthBytes: data,
  1333  	})
  1334  	if err != nil {
  1335  		return nil, err
  1336  	}
  1337  	res := msg.(*saslauthenticate.Response)
  1338  	if res.ErrorCode != 0 {
  1339  		err = makeError(res.ErrorCode, res.ErrorMessage)
  1340  	}
  1341  	return res.AuthBytes, err
  1342  }
  1343  
  1344  var _ RoundTripper = (*Transport)(nil)