github.com/deanMdreon/kafka-go@v0.4.32/transport.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"runtime/pprof"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"github.com/deanMdreon/kafka-go/protocol"
    20  	"github.com/deanMdreon/kafka-go/protocol/apiversions"
    21  	"github.com/deanMdreon/kafka-go/protocol/createtopics"
    22  	"github.com/deanMdreon/kafka-go/protocol/findcoordinator"
    23  	meta "github.com/deanMdreon/kafka-go/protocol/metadata"
    24  	"github.com/deanMdreon/kafka-go/protocol/saslauthenticate"
    25  	"github.com/deanMdreon/kafka-go/protocol/saslhandshake"
    26  	"github.com/deanMdreon/kafka-go/sasl"
    27  )
    28  
    29  // Request is an interface implemented by types that represent messages sent
    30  // from kafka clients to brokers.
    31  type Request = protocol.Message
    32  
    33  // Response is an interface implemented by types that represent messages sent
    34  // from kafka brokers in response to client requests.
    35  type Response = protocol.Message
    36  
    37  // RoundTripper is an interface implemented by types which support interacting
    38  // with kafka brokers.
    39  type RoundTripper interface {
    40  	// RoundTrip sends a request to a kafka broker and returns the response that
    41  	// was received, or a non-nil error.
    42  	//
    43  	// The context passed as first argument can be used to asynchronnously abort
    44  	// the call if needed.
    45  	RoundTrip(context.Context, net.Addr, Request) (Response, error)
    46  }
    47  
    48  // Transport is an implementation of the RoundTripper interface.
    49  //
    50  // Transport values manage a pool of connections and automatically discovers the
    51  // clusters layout to route requests to the appropriate brokers.
    52  //
    53  // Transport values are safe to use concurrently from multiple goroutines.
    54  //
    55  // Note: The intent is for the Transport to become the underlying layer of the
    56  // kafka.Reader and kafka.Writer types.
    57  type Transport struct {
    58  	// A function used to establish connections to the kafka cluster.
    59  	Dial func(context.Context, string, string) (net.Conn, error)
    60  
    61  	// Time limit set for establishing connections to the kafka cluster. This
    62  	// limit includes all round trips done to establish the connections (TLS
    63  	// hadbhaske, SASL negotiation, etc...).
    64  	//
    65  	// Defaults to 5s.
    66  	DialTimeout time.Duration
    67  
    68  	// Maximum amount of time that connections will remain open and unused.
    69  	// The transport will manage to automatically close connections that have
    70  	// been idle for too long, and re-open them on demand when the transport is
    71  	// used again.
    72  	//
    73  	// Defaults to 30s.
    74  	IdleTimeout time.Duration
    75  
    76  	// TTL for the metadata cached by this transport. Note that the value
    77  	// configured here is an upper bound, the transport randomizes the TTLs to
    78  	// avoid getting into states where multiple clients end up synchronized and
    79  	// cause bursts of requests to the kafka broker.
    80  	//
    81  	// Default to 6s.
    82  	MetadataTTL time.Duration
    83  
    84  	// Unique identifier that the transport communicates to the brokers when it
    85  	// sends requests.
    86  	ClientID string
    87  
    88  	// An optional configuration for TLS connections established by this
    89  	// transport.
    90  	//
    91  	// If the Server
    92  	TLS *tls.Config
    93  
    94  	// SASL configures the Transfer to use SASL authentication.
    95  	SASL sasl.Mechanism
    96  
    97  	// An optional resolver used to translate broker host names into network
    98  	// addresses.
    99  	//
   100  	// The resolver will be called for every request (not every connection),
   101  	// making it possible to implement ACL policies by validating that the
   102  	// program is allowed to connect to the kafka broker. This also means that
   103  	// the resolver should probably provide a caching layer to avoid storming
   104  	// the service discovery backend with requests.
   105  	//
   106  	// When set, the Dial function is not responsible for performing name
   107  	// resolution, and is always called with a pre-resolved address.
   108  	Resolver BrokerResolver
   109  
   110  	// The background context used to control goroutines started internally by
   111  	// the transport.
   112  	//
   113  	// If nil, context.Background() is used instead.
   114  	Context context.Context
   115  
   116  	mutex sync.RWMutex
   117  	pools map[networkAddress]*connPool
   118  }
   119  
   120  // DefaultTransport is the default transport used by kafka clients in this
   121  // package.
   122  var DefaultTransport RoundTripper = &Transport{
   123  	Dial: (&net.Dialer{
   124  		Timeout:   3 * time.Second,
   125  		DualStack: true,
   126  	}).DialContext,
   127  }
   128  
   129  // CloseIdleConnections closes all idle connections immediately, and marks all
   130  // connections that are in use to be closed when they become idle again.
   131  func (t *Transport) CloseIdleConnections() {
   132  	t.mutex.Lock()
   133  	defer t.mutex.Unlock()
   134  
   135  	for _, pool := range t.pools {
   136  		pool.unref()
   137  	}
   138  
   139  	for k := range t.pools {
   140  		delete(t.pools, k)
   141  	}
   142  }
   143  
   144  // RoundTrip sends a request to a kafka cluster and returns the response, or an
   145  // error if no responses were received.
   146  //
   147  // Message types are available in sub-packages of the protocol package. Each
   148  // kafka API is implemented in a different sub-package. For example, the request
   149  // and response types for the Fetch API are available in the protocol/fetch
   150  // package.
   151  //
   152  // The type of the response message will match the type of the request. For
   153  // exmple, if RoundTrip was called with a *fetch.Request as argument, the value
   154  // returned will be of type *fetch.Response. It is safe for the program to do a
   155  // type assertion after checking that no error was returned.
   156  //
   157  // This example illustrates the way this method is expected to be used:
   158  //
   159  //	r, err := transport.RoundTrip(ctx, addr, &fetch.Request{ ... })
   160  //	if err != nil {
   161  //		...
   162  //	} else {
   163  //		res := r.(*fetch.Response)
   164  //		...
   165  //	}
   166  //
   167  // The transport automatically selects the highest version of the API that is
   168  // supported by both the kafka-go package and the kafka broker. The negotiation
   169  // happens transparently once when connections are established.
   170  //
   171  // This API was introduced in version 0.4 as a way to leverage the lower-level
   172  // features of the kafka protocol, but also provide a more efficient way of
   173  // managing connections to kafka brokers.
   174  func (t *Transport) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) {
   175  	p := t.grabPool(addr)
   176  	defer p.unref()
   177  	return p.roundTrip(ctx, req)
   178  }
   179  
   180  func (t *Transport) dial() func(context.Context, string, string) (net.Conn, error) {
   181  	if t.Dial != nil {
   182  		return t.Dial
   183  	}
   184  	return defaultDialer.DialContext
   185  }
   186  
   187  func (t *Transport) dialTimeout() time.Duration {
   188  	if t.DialTimeout > 0 {
   189  		return t.DialTimeout
   190  	}
   191  	return 5 * time.Second
   192  }
   193  
   194  func (t *Transport) idleTimeout() time.Duration {
   195  	if t.IdleTimeout > 0 {
   196  		return t.IdleTimeout
   197  	}
   198  	return 30 * time.Second
   199  }
   200  
   201  func (t *Transport) metadataTTL() time.Duration {
   202  	if t.MetadataTTL > 0 {
   203  		return t.MetadataTTL
   204  	}
   205  	return 6 * time.Second
   206  }
   207  
   208  func (t *Transport) grabPool(addr net.Addr) *connPool {
   209  	k := networkAddress{
   210  		network: addr.Network(),
   211  		address: addr.String(),
   212  	}
   213  
   214  	t.mutex.RLock()
   215  	p := t.pools[k]
   216  	if p != nil {
   217  		p.ref()
   218  	}
   219  	t.mutex.RUnlock()
   220  
   221  	if p != nil {
   222  		return p
   223  	}
   224  
   225  	t.mutex.Lock()
   226  	defer t.mutex.Unlock()
   227  
   228  	if p := t.pools[k]; p != nil {
   229  		p.ref()
   230  		return p
   231  	}
   232  
   233  	ctx, cancel := context.WithCancel(t.context())
   234  
   235  	p = &connPool{
   236  		refc: 2,
   237  
   238  		dial:        t.dial(),
   239  		dialTimeout: t.dialTimeout(),
   240  		idleTimeout: t.idleTimeout(),
   241  		metadataTTL: t.metadataTTL(),
   242  		clientID:    t.ClientID,
   243  		tls:         t.TLS,
   244  		sasl:        t.SASL,
   245  		resolver:    t.Resolver,
   246  
   247  		ready:  make(event),
   248  		wake:   make(chan event),
   249  		conns:  make(map[int32]*connGroup),
   250  		cancel: cancel,
   251  	}
   252  
   253  	p.ctrl = p.newConnGroup(addr)
   254  	go p.discover(ctx, p.wake)
   255  
   256  	if t.pools == nil {
   257  		t.pools = make(map[networkAddress]*connPool)
   258  	}
   259  	t.pools[k] = p
   260  	return p
   261  }
   262  
   263  func (t *Transport) context() context.Context {
   264  	if t.Context != nil {
   265  		return t.Context
   266  	}
   267  	return context.Background()
   268  }
   269  
   270  type event chan struct{}
   271  
   272  func (e event) trigger() { close(e) }
   273  
   274  type connPool struct {
   275  	refc uintptr
   276  	// Immutable fields of the connection pool. Connections access these field
   277  	// on their parent pool in a ready-only fashion, so no synchronization is
   278  	// required.
   279  	dial        func(context.Context, string, string) (net.Conn, error)
   280  	dialTimeout time.Duration
   281  	idleTimeout time.Duration
   282  	metadataTTL time.Duration
   283  	clientID    string
   284  	tls         *tls.Config
   285  	sasl        sasl.Mechanism
   286  	resolver    BrokerResolver
   287  	// Signaling mechanisms to orchestrate communications between the pool and
   288  	// the rest of the program.
   289  	once   sync.Once  // ensure that `ready` is triggered only once
   290  	ready  event      // triggered after the first metadata update
   291  	wake   chan event // used to force metadata updates
   292  	cancel context.CancelFunc
   293  	// Mutable fields of the connection pool, access must be synchronized.
   294  	mutex sync.RWMutex
   295  	conns map[int32]*connGroup // data connections used for produce/fetch/etc...
   296  	ctrl  *connGroup           // control connections used for metadata requests
   297  	state atomic.Value         // cached cluster state
   298  }
   299  
   300  type connPoolState struct {
   301  	metadata *meta.Response   // last metadata response seen by the pool
   302  	err      error            // last error from metadata requests
   303  	layout   protocol.Cluster // cluster layout built from metadata response
   304  }
   305  
   306  func (p *connPool) grabState() connPoolState {
   307  	state, _ := p.state.Load().(connPoolState)
   308  	return state
   309  }
   310  
   311  func (p *connPool) setState(state connPoolState) {
   312  	p.state.Store(state)
   313  }
   314  
   315  func (p *connPool) ref() {
   316  	atomic.AddUintptr(&p.refc, +1)
   317  }
   318  
   319  func (p *connPool) unref() {
   320  	if atomic.AddUintptr(&p.refc, ^uintptr(0)) == 0 {
   321  		p.mutex.Lock()
   322  		defer p.mutex.Unlock()
   323  
   324  		for _, conns := range p.conns {
   325  			conns.closeIdleConns()
   326  		}
   327  
   328  		p.ctrl.closeIdleConns()
   329  		p.cancel()
   330  	}
   331  }
   332  
   333  func (p *connPool) roundTrip(ctx context.Context, req Request) (Response, error) {
   334  	// This first select should never block after the first metadata response
   335  	// that would mark the pool as `ready`.
   336  	select {
   337  	case <-p.ready:
   338  	case <-ctx.Done():
   339  		return nil, ctx.Err()
   340  	}
   341  
   342  	state := p.grabState()
   343  	var response promise
   344  
   345  	switch m := req.(type) {
   346  	case *meta.Request:
   347  		// We serve metadata requests directly from the transport cache unless
   348  		// we would like to auto create a topic that isn't in our cache.
   349  		//
   350  		// This reduces the number of round trips to kafka brokers while keeping
   351  		// the logic simple when applying partitioning strategies.
   352  		if state.err != nil {
   353  			return nil, state.err
   354  		}
   355  
   356  		cachedMeta := filterMetadataResponse(m, state.metadata)
   357  		// requestNeeded indicates if we need to send this metadata request to the server.
   358  		// It's true when we want to auto-create topics and we don't have the topic in our
   359  		// cache.
   360  		var requestNeeded bool
   361  		if m.AllowAutoTopicCreation {
   362  			for _, topic := range cachedMeta.Topics {
   363  				if topic.ErrorCode == int16(UnknownTopicOrPartition) {
   364  					requestNeeded = true
   365  					break
   366  				}
   367  			}
   368  		}
   369  
   370  		if !requestNeeded {
   371  			return cachedMeta, nil
   372  		}
   373  
   374  	case protocol.Splitter:
   375  		// Messages that implement the Splitter interface trigger the creation of
   376  		// multiple requests that are all merged back into a single results by
   377  		// a merger.
   378  		messages, merger, err := m.Split(state.layout)
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  		promises := make([]promise, len(messages))
   383  		for i, m := range messages {
   384  			promises[i] = p.sendRequest(ctx, m, state)
   385  		}
   386  		response = join(promises, messages, merger)
   387  	}
   388  
   389  	if response == nil {
   390  		response = p.sendRequest(ctx, req, state)
   391  	}
   392  
   393  	r, err := response.await(ctx)
   394  	if err != nil {
   395  		return r, err
   396  	}
   397  
   398  	switch resp := r.(type) {
   399  	case *createtopics.Response:
   400  		// Force an update of the metadata when adding topics,
   401  		// otherwise the cached state would get out of sync.
   402  		topicsToRefresh := make([]string, 0, len(resp.Topics))
   403  		for _, topic := range resp.Topics {
   404  			// fixes issue 672: don't refresh topics that failed to create, it causes the library to hang indefinitely
   405  			if topic.ErrorCode != 0 {
   406  				continue
   407  			}
   408  
   409  			topicsToRefresh = append(topicsToRefresh, topic.Name)
   410  		}
   411  
   412  		p.refreshMetadata(ctx, topicsToRefresh)
   413  	case *meta.Response:
   414  		m := req.(*meta.Request)
   415  		// If we get here with allow auto topic creation then
   416  		// we didn't have that topic in our cache so we should update
   417  		// the cache.
   418  		if m.AllowAutoTopicCreation {
   419  			topicsToRefresh := make([]string, 0, len(resp.Topics))
   420  			for _, topic := range resp.Topics {
   421  				// fixes issue 806: don't refresh topics that failed to create,
   422  				// it may means kafka doesn't enable auto topic creation.
   423  				// This causes the library to hang indefinitely, same as createtopics process.
   424  				if topic.ErrorCode != 0 {
   425  					continue
   426  				}
   427  
   428  				topicsToRefresh = append(topicsToRefresh, topic.Name)
   429  			}
   430  			p.refreshMetadata(ctx, topicsToRefresh)
   431  		}
   432  	}
   433  
   434  	return r, nil
   435  }
   436  
   437  // refreshMetadata forces an update of the cached cluster metadata, and waits
   438  // for the given list of topics to appear. This waiting mechanism is necessary
   439  // to account for the fact that topic creation is asynchronous in kafka, and
   440  // causes subsequent requests to fail while the cluster state is propagated to
   441  // all the brokers.
   442  func (p *connPool) refreshMetadata(ctx context.Context, expectTopics []string) {
   443  	minBackoff := 100 * time.Millisecond
   444  	maxBackoff := 2 * time.Second
   445  	cancel := ctx.Done()
   446  
   447  	for ctx.Err() == nil {
   448  		notify := make(event)
   449  		select {
   450  		case <-cancel:
   451  			return
   452  		case p.wake <- notify:
   453  			select {
   454  			case <-notify:
   455  			case <-cancel:
   456  				return
   457  			}
   458  		}
   459  
   460  		state := p.grabState()
   461  		found := 0
   462  
   463  		for _, topic := range expectTopics {
   464  			if _, ok := state.layout.Topics[topic]; ok {
   465  				found++
   466  			}
   467  		}
   468  
   469  		if found == len(expectTopics) {
   470  			return
   471  		}
   472  
   473  		if delay := time.Duration(rand.Int63n(int64(minBackoff))); delay > 0 {
   474  			timer := time.NewTimer(minBackoff)
   475  			select {
   476  			case <-cancel:
   477  			case <-timer.C:
   478  			}
   479  			timer.Stop()
   480  
   481  			if minBackoff *= 2; minBackoff > maxBackoff {
   482  				minBackoff = maxBackoff
   483  			}
   484  		}
   485  	}
   486  }
   487  
   488  func (p *connPool) setReady() {
   489  	p.once.Do(p.ready.trigger)
   490  }
   491  
   492  // update is called periodically by the goroutine running the discover method
   493  // to refresh the cluster layout information used by the transport to route
   494  // requests to brokers.
   495  func (p *connPool) update(ctx context.Context, metadata *meta.Response, err error) {
   496  	var layout protocol.Cluster
   497  
   498  	if metadata != nil {
   499  		metadata.ThrottleTimeMs = 0
   500  
   501  		// Normalize the lists so we can apply binary search on them.
   502  		sortMetadataBrokers(metadata.Brokers)
   503  		sortMetadataTopics(metadata.Topics)
   504  
   505  		for i := range metadata.Topics {
   506  			t := &metadata.Topics[i]
   507  			sortMetadataPartitions(t.Partitions)
   508  		}
   509  
   510  		layout = makeLayout(metadata)
   511  	}
   512  
   513  	state := p.grabState()
   514  	addBrokers := make(map[int32]struct{})
   515  	delBrokers := make(map[int32]struct{})
   516  
   517  	if err != nil {
   518  		// Only update the error on the transport if the cluster layout was
   519  		// unknown. This ensures that we prioritize a previously known state
   520  		// of the cluster to reduce the impact of transient failures.
   521  		if state.metadata != nil {
   522  			return
   523  		}
   524  		state.err = err
   525  	} else {
   526  		for id, b2 := range layout.Brokers {
   527  			if b1, ok := state.layout.Brokers[id]; !ok {
   528  				addBrokers[id] = struct{}{}
   529  			} else if b1 != b2 {
   530  				addBrokers[id] = struct{}{}
   531  				delBrokers[id] = struct{}{}
   532  			}
   533  		}
   534  
   535  		for id := range state.layout.Brokers {
   536  			if _, ok := layout.Brokers[id]; !ok {
   537  				delBrokers[id] = struct{}{}
   538  			}
   539  		}
   540  
   541  		state.metadata, state.layout = metadata, layout
   542  		state.err = nil
   543  	}
   544  
   545  	defer p.setReady()
   546  	defer p.setState(state)
   547  
   548  	if len(addBrokers) != 0 || len(delBrokers) != 0 {
   549  		// Only acquire the lock when there is a change of layout. This is an
   550  		// infrequent event so we don't risk introducing regular contention on
   551  		// the mutex if we were to lock it on every update.
   552  		p.mutex.Lock()
   553  		defer p.mutex.Unlock()
   554  
   555  		if ctx.Err() != nil {
   556  			return // the pool has been closed, no need to update
   557  		}
   558  
   559  		for id := range delBrokers {
   560  			if broker := p.conns[id]; broker != nil {
   561  				broker.closeIdleConns()
   562  				delete(p.conns, id)
   563  			}
   564  		}
   565  
   566  		for id := range addBrokers {
   567  			broker := layout.Brokers[id]
   568  			p.conns[id] = p.newBrokerConnGroup(Broker{
   569  				Rack: broker.Rack,
   570  				Host: broker.Host,
   571  				Port: int(broker.Port),
   572  				ID:   int(broker.ID),
   573  			})
   574  		}
   575  	}
   576  }
   577  
   578  // discover is the entry point of an internal goroutine for the transport which
   579  // periodically requests updates of the cluster metadata and refreshes the
   580  // transport cached cluster layout.
   581  func (p *connPool) discover(ctx context.Context, wake <-chan event) {
   582  	prng := rand.New(rand.NewSource(time.Now().UnixNano()))
   583  	metadataTTL := func() time.Duration {
   584  		return time.Duration(prng.Int63n(int64(p.metadataTTL)))
   585  	}
   586  
   587  	timer := time.NewTimer(metadataTTL())
   588  	defer timer.Stop()
   589  
   590  	var notify event
   591  	done := ctx.Done()
   592  
   593  	for {
   594  		c, err := p.grabClusterConn(ctx)
   595  		if err != nil {
   596  			p.update(ctx, nil, err)
   597  		} else {
   598  			res := make(async, 1)
   599  			req := &meta.Request{}
   600  			deadline, cancel := context.WithTimeout(ctx, p.metadataTTL)
   601  			c.reqs <- connRequest{
   602  				ctx: deadline,
   603  				req: req,
   604  				res: res,
   605  			}
   606  			r, err := res.await(deadline)
   607  			cancel()
   608  			if err != nil && err == ctx.Err() {
   609  				return
   610  			}
   611  			ret, _ := r.(*meta.Response)
   612  			p.update(ctx, ret, err)
   613  		}
   614  
   615  		if notify != nil {
   616  			notify.trigger()
   617  			notify = nil
   618  		}
   619  
   620  		select {
   621  		case <-timer.C:
   622  			timer.Reset(metadataTTL())
   623  		case <-done:
   624  			return
   625  		case notify = <-wake:
   626  		}
   627  	}
   628  }
   629  
   630  // grabBrokerConn returns a connection to a specific broker represented by the
   631  // broker id passed as argument. If the broker id was not known, an error is
   632  // returned.
   633  func (p *connPool) grabBrokerConn(ctx context.Context, brokerID int32) (*conn, error) {
   634  	p.mutex.RLock()
   635  	g := p.conns[brokerID]
   636  	p.mutex.RUnlock()
   637  	if g == nil {
   638  		return nil, BrokerNotAvailable
   639  	}
   640  	return g.grabConnOrConnect(ctx)
   641  }
   642  
   643  // grabClusterConn returns the connection to the kafka cluster that the pool is
   644  // configured to connect to.
   645  //
   646  // The transport uses a shared `control` connection to the cluster for any
   647  // requests that aren't supposed to be sent to specific brokers (e.g. Fetch or
   648  // Produce requests). Requests intended to be routed to specific brokers are
   649  // dispatched on a separate pool of connections that the transport maintains.
   650  // This split help avoid head-of-line blocking situations where control requests
   651  // like Metadata would be queued behind large responses from Fetch requests for
   652  // example.
   653  //
   654  // In either cases, the requests are multiplexed so we can keep a minimal number
   655  // of connections open (N+1, where N is the number of brokers in the cluster).
   656  func (p *connPool) grabClusterConn(ctx context.Context) (*conn, error) {
   657  	return p.ctrl.grabConnOrConnect(ctx)
   658  }
   659  
   660  func (p *connPool) sendRequest(ctx context.Context, req Request, state connPoolState) promise {
   661  	brokerID := int32(-1)
   662  
   663  	switch m := req.(type) {
   664  	case protocol.BrokerMessage:
   665  		// Some requests are supposed to be sent to specific brokers (e.g. the
   666  		// partition leaders). They implement the BrokerMessage interface to
   667  		// delegate the routing decision to each message type.
   668  		broker, err := m.Broker(state.layout)
   669  		if err != nil {
   670  			return reject(err)
   671  		}
   672  		brokerID = broker.ID
   673  
   674  	case protocol.GroupMessage:
   675  		// Some requests are supposed to be sent to a group coordinator,
   676  		// look up which broker is currently the coordinator for the group
   677  		// so we can get a connection to that broker.
   678  		//
   679  		// TODO: should we cache the coordinator info?
   680  		p := p.sendRequest(ctx, &findcoordinator.Request{Key: m.Group()}, state)
   681  		r, err := p.await(ctx)
   682  		if err != nil {
   683  			return reject(err)
   684  		}
   685  		brokerID = r.(*findcoordinator.Response).NodeID
   686  	case protocol.TransactionalMessage:
   687  		p := p.sendRequest(ctx, &findcoordinator.Request{
   688  			Key:     m.Transaction(),
   689  			KeyType: int8(CoordinatorKeyTypeTransaction),
   690  		}, state)
   691  		r, err := p.await(ctx)
   692  		if err != nil {
   693  			return reject(err)
   694  		}
   695  		brokerID = r.(*findcoordinator.Response).NodeID
   696  	}
   697  
   698  	var c *conn
   699  	var err error
   700  	if brokerID >= 0 {
   701  		c, err = p.grabBrokerConn(ctx, brokerID)
   702  	} else {
   703  		c, err = p.grabClusterConn(ctx)
   704  	}
   705  	if err != nil {
   706  		return reject(err)
   707  	}
   708  
   709  	res := make(async, 1)
   710  
   711  	c.reqs <- connRequest{
   712  		ctx: ctx,
   713  		req: req,
   714  		res: res,
   715  	}
   716  
   717  	return res
   718  }
   719  
   720  func filterMetadataResponse(req *meta.Request, res *meta.Response) *meta.Response {
   721  	ret := *res
   722  
   723  	if req.TopicNames != nil {
   724  		ret.Topics = make([]meta.ResponseTopic, len(req.TopicNames))
   725  
   726  		for i, topicName := range req.TopicNames {
   727  			j, ok := findMetadataTopic(res.Topics, topicName)
   728  			if ok {
   729  				ret.Topics[i] = res.Topics[j]
   730  			} else {
   731  				ret.Topics[i] = meta.ResponseTopic{
   732  					ErrorCode: int16(UnknownTopicOrPartition),
   733  					Name:      topicName,
   734  				}
   735  			}
   736  		}
   737  	}
   738  
   739  	return &ret
   740  }
   741  
   742  func findMetadataTopic(topics []meta.ResponseTopic, topicName string) (int, bool) {
   743  	i := sort.Search(len(topics), func(i int) bool {
   744  		return topics[i].Name >= topicName
   745  	})
   746  	return i, i >= 0 && i < len(topics) && topics[i].Name == topicName
   747  }
   748  
   749  func sortMetadataBrokers(brokers []meta.ResponseBroker) {
   750  	sort.Slice(brokers, func(i, j int) bool {
   751  		return brokers[i].NodeID < brokers[j].NodeID
   752  	})
   753  }
   754  
   755  func sortMetadataTopics(topics []meta.ResponseTopic) {
   756  	sort.Slice(topics, func(i, j int) bool {
   757  		return topics[i].Name < topics[j].Name
   758  	})
   759  }
   760  
   761  func sortMetadataPartitions(partitions []meta.ResponsePartition) {
   762  	sort.Slice(partitions, func(i, j int) bool {
   763  		return partitions[i].PartitionIndex < partitions[j].PartitionIndex
   764  	})
   765  }
   766  
   767  func makeLayout(metadataResponse *meta.Response) protocol.Cluster {
   768  	layout := protocol.Cluster{
   769  		Controller: metadataResponse.ControllerID,
   770  		Brokers:    make(map[int32]protocol.Broker),
   771  		Topics:     make(map[string]protocol.Topic),
   772  	}
   773  
   774  	for _, broker := range metadataResponse.Brokers {
   775  		layout.Brokers[broker.NodeID] = protocol.Broker{
   776  			Rack: broker.Rack,
   777  			Host: broker.Host,
   778  			Port: broker.Port,
   779  			ID:   broker.NodeID,
   780  		}
   781  	}
   782  
   783  	for _, topic := range metadataResponse.Topics {
   784  		if topic.IsInternal {
   785  			continue // TODO: do we need to expose those?
   786  		}
   787  		layout.Topics[topic.Name] = protocol.Topic{
   788  			Name:       topic.Name,
   789  			Error:      topic.ErrorCode,
   790  			Partitions: makePartitions(topic.Partitions),
   791  		}
   792  	}
   793  
   794  	return layout
   795  }
   796  
   797  func makePartitions(metadataPartitions []meta.ResponsePartition) map[int32]protocol.Partition {
   798  	protocolPartitions := make(map[int32]protocol.Partition, len(metadataPartitions))
   799  	numBrokerIDs := 0
   800  
   801  	for _, p := range metadataPartitions {
   802  		numBrokerIDs += len(p.ReplicaNodes) + len(p.IsrNodes) + len(p.OfflineReplicas)
   803  	}
   804  
   805  	// Reduce the memory footprint a bit by allocating a single buffer to write
   806  	// all broker ids.
   807  	brokerIDs := make([]int32, 0, numBrokerIDs)
   808  
   809  	for _, p := range metadataPartitions {
   810  		var rep, isr, off []int32
   811  		brokerIDs, rep = appendBrokerIDs(brokerIDs, p.ReplicaNodes)
   812  		brokerIDs, isr = appendBrokerIDs(brokerIDs, p.IsrNodes)
   813  		brokerIDs, off = appendBrokerIDs(brokerIDs, p.OfflineReplicas)
   814  
   815  		protocolPartitions[p.PartitionIndex] = protocol.Partition{
   816  			ID:       p.PartitionIndex,
   817  			Error:    p.ErrorCode,
   818  			Leader:   p.LeaderID,
   819  			Replicas: rep,
   820  			ISR:      isr,
   821  			Offline:  off,
   822  		}
   823  	}
   824  
   825  	return protocolPartitions
   826  }
   827  
   828  func appendBrokerIDs(ids, brokers []int32) ([]int32, []int32) {
   829  	i := len(ids)
   830  	ids = append(ids, brokers...)
   831  	return ids, ids[i:len(ids):len(ids)]
   832  }
   833  
   834  func (p *connPool) newConnGroup(a net.Addr) *connGroup {
   835  	return &connGroup{
   836  		addr: a,
   837  		pool: p,
   838  		broker: Broker{
   839  			ID: -1,
   840  		},
   841  	}
   842  }
   843  
   844  func (p *connPool) newBrokerConnGroup(broker Broker) *connGroup {
   845  	return &connGroup{
   846  		addr: &networkAddress{
   847  			network: "tcp",
   848  			address: net.JoinHostPort(broker.Host, strconv.Itoa(broker.Port)),
   849  		},
   850  		pool:   p,
   851  		broker: broker,
   852  	}
   853  }
   854  
   855  type connRequest struct {
   856  	ctx context.Context
   857  	req Request
   858  	res async
   859  }
   860  
   861  // The promise interface is used as a message passing abstraction to coordinate
   862  // between goroutines that handle requests and responses.
   863  type promise interface {
   864  	// Waits until the promise is resolved, rejected, or the context canceled.
   865  	await(context.Context) (Response, error)
   866  }
   867  
   868  // async is an implementation of the promise interface which supports resolving
   869  // or rejecting the await call asynchronously.
   870  type async chan interface{}
   871  
   872  func (p async) await(ctx context.Context) (Response, error) {
   873  	select {
   874  	case x := <-p:
   875  		switch v := x.(type) {
   876  		case nil:
   877  			return nil, nil // A nil response is ok (e.g. when RequiredAcks is None)
   878  		case Response:
   879  			return v, nil
   880  		case error:
   881  			return nil, v
   882  		default:
   883  			panic(fmt.Errorf("BUG: promise resolved with impossible value of type %T", v))
   884  		}
   885  	case <-ctx.Done():
   886  		return nil, ctx.Err()
   887  	}
   888  }
   889  
   890  func (p async) resolve(res Response) { p <- res }
   891  
   892  func (p async) reject(err error) { p <- err }
   893  
   894  // rejected is an implementation of the promise interface which is always
   895  // returns an error. Values of this type are constructed using the reject
   896  // function.
   897  type rejected struct{ err error }
   898  
   899  func reject(err error) promise { return &rejected{err: err} }
   900  
   901  func (p *rejected) await(ctx context.Context) (Response, error) {
   902  	return nil, p.err
   903  }
   904  
   905  // joined is an implementation of the promise interface which merges results
   906  // from multiple promises into one await call using a merger.
   907  type joined struct {
   908  	promises []promise
   909  	requests []Request
   910  	merger   protocol.Merger
   911  }
   912  
   913  func join(promises []promise, requests []Request, merger protocol.Merger) promise {
   914  	return &joined{
   915  		promises: promises,
   916  		requests: requests,
   917  		merger:   merger,
   918  	}
   919  }
   920  
   921  func (p *joined) await(ctx context.Context) (Response, error) {
   922  	results := make([]interface{}, len(p.promises))
   923  
   924  	for i, sub := range p.promises {
   925  		m, err := sub.await(ctx)
   926  		if err != nil {
   927  			results[i] = err
   928  		} else {
   929  			results[i] = m
   930  		}
   931  	}
   932  
   933  	return p.merger.Merge(p.requests, results)
   934  }
   935  
   936  // Default dialer used by the transport connections when no Dial function
   937  // was configured by the program.
   938  var defaultDialer = net.Dialer{
   939  	Timeout:   3 * time.Second,
   940  	DualStack: true,
   941  }
   942  
   943  // connGroup represents a logical connection group to a kafka broker. The
   944  // actual network connections are lazily open before sending requests, and
   945  // closed if they are unused for longer than the idle timeout.
   946  type connGroup struct {
   947  	addr   net.Addr
   948  	broker Broker
   949  	// Immutable state of the connection.
   950  	pool *connPool
   951  	// Shared state of the connection, this is synchronized on the mutex through
   952  	// calls to the synchronized method. Both goroutines of the connection share
   953  	// the state maintained in these fields.
   954  	mutex     sync.Mutex
   955  	closed    bool
   956  	idleConns []*conn // stack of idle connections
   957  }
   958  
   959  func (g *connGroup) closeIdleConns() {
   960  	g.mutex.Lock()
   961  	conns := g.idleConns
   962  	g.idleConns = nil
   963  	g.closed = true
   964  	g.mutex.Unlock()
   965  
   966  	for _, c := range conns {
   967  		c.close()
   968  	}
   969  }
   970  
   971  func (g *connGroup) grabConnOrConnect(ctx context.Context) (*conn, error) {
   972  	rslv := g.pool.resolver
   973  	addr := g.addr
   974  	var c *conn
   975  
   976  	if rslv == nil {
   977  		c = g.grabConn()
   978  	} else {
   979  		var err error
   980  		broker := g.broker
   981  
   982  		if broker.ID < 0 {
   983  			host, port, err := splitHostPortNumber(addr.String())
   984  			if err != nil {
   985  				return nil, err
   986  			}
   987  			broker.Host = host
   988  			broker.Port = port
   989  		}
   990  
   991  		ipAddrs, err := rslv.LookupBrokerIPAddr(ctx, broker)
   992  		if err != nil {
   993  			return nil, err
   994  		}
   995  
   996  		for _, ipAddr := range ipAddrs {
   997  			network := addr.Network()
   998  			address := net.JoinHostPort(ipAddr.String(), strconv.Itoa(broker.Port))
   999  
  1000  			if c = g.grabConnTo(network, address); c != nil {
  1001  				break
  1002  			}
  1003  		}
  1004  	}
  1005  
  1006  	if c == nil {
  1007  		connChan := make(chan *conn)
  1008  		errChan := make(chan error)
  1009  
  1010  		go func() {
  1011  			c, err := g.connect(ctx, addr)
  1012  			if err != nil {
  1013  				select {
  1014  				case errChan <- err:
  1015  				case <-ctx.Done():
  1016  				}
  1017  			} else {
  1018  				select {
  1019  				case connChan <- c:
  1020  				case <-ctx.Done():
  1021  					if !g.releaseConn(c) {
  1022  						c.close()
  1023  					}
  1024  				}
  1025  			}
  1026  		}()
  1027  
  1028  		select {
  1029  		case c = <-connChan:
  1030  		case err := <-errChan:
  1031  			return nil, err
  1032  		case <-ctx.Done():
  1033  			return nil, ctx.Err()
  1034  		}
  1035  	}
  1036  
  1037  	return c, nil
  1038  }
  1039  
  1040  func (g *connGroup) grabConnTo(network, address string) *conn {
  1041  	g.mutex.Lock()
  1042  	defer g.mutex.Unlock()
  1043  
  1044  	for i := len(g.idleConns) - 1; i >= 0; i-- {
  1045  		c := g.idleConns[i]
  1046  
  1047  		if c.network == network && c.address == address {
  1048  			copy(g.idleConns[i:], g.idleConns[i+1:])
  1049  			n := len(g.idleConns) - 1
  1050  			g.idleConns[n] = nil
  1051  			g.idleConns = g.idleConns[:n]
  1052  
  1053  			if c.timer != nil {
  1054  				c.timer.Stop()
  1055  			}
  1056  
  1057  			return c
  1058  		}
  1059  	}
  1060  
  1061  	return nil
  1062  }
  1063  
  1064  func (g *connGroup) grabConn() *conn {
  1065  	g.mutex.Lock()
  1066  	defer g.mutex.Unlock()
  1067  
  1068  	if len(g.idleConns) == 0 {
  1069  		return nil
  1070  	}
  1071  
  1072  	n := len(g.idleConns) - 1
  1073  	c := g.idleConns[n]
  1074  	g.idleConns[n] = nil
  1075  	g.idleConns = g.idleConns[:n]
  1076  
  1077  	if c.timer != nil {
  1078  		c.timer.Stop()
  1079  	}
  1080  
  1081  	return c
  1082  }
  1083  
  1084  func (g *connGroup) removeConn(c *conn) bool {
  1085  	g.mutex.Lock()
  1086  	defer g.mutex.Unlock()
  1087  
  1088  	if c.timer != nil {
  1089  		c.timer.Stop()
  1090  	}
  1091  
  1092  	for i, x := range g.idleConns {
  1093  		if x == c {
  1094  			copy(g.idleConns[i:], g.idleConns[i+1:])
  1095  			n := len(g.idleConns) - 1
  1096  			g.idleConns[n] = nil
  1097  			g.idleConns = g.idleConns[:n]
  1098  			return true
  1099  		}
  1100  	}
  1101  
  1102  	return false
  1103  }
  1104  
  1105  func (g *connGroup) releaseConn(c *conn) bool {
  1106  	idleTimeout := g.pool.idleTimeout
  1107  
  1108  	g.mutex.Lock()
  1109  	defer g.mutex.Unlock()
  1110  
  1111  	if g.closed {
  1112  		return false
  1113  	}
  1114  
  1115  	if c.timer != nil {
  1116  		c.timer.Reset(idleTimeout)
  1117  	} else {
  1118  		c.timer = time.AfterFunc(idleTimeout, func() {
  1119  			if g.removeConn(c) {
  1120  				c.close()
  1121  			}
  1122  		})
  1123  	}
  1124  
  1125  	g.idleConns = append(g.idleConns, c)
  1126  	return true
  1127  }
  1128  
  1129  func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) {
  1130  	deadline := time.Now().Add(g.pool.dialTimeout)
  1131  
  1132  	ctx, cancel := context.WithDeadline(ctx, deadline)
  1133  	defer cancel()
  1134  
  1135  	network := strings.Split(addr.Network(), ",")
  1136  	address := strings.Split(addr.String(), ",")
  1137  	var netConn net.Conn
  1138  	var netAddr net.Addr
  1139  	var err error
  1140  
  1141  	if len(address) > 1 {
  1142  		// Shuffle the list of addresses to randomize the order in which
  1143  		// connections are attempted. This prevents routing all connections
  1144  		// to the first broker (which will usually succeed).
  1145  		rand.Shuffle(len(address), func(i, j int) {
  1146  			network[i], network[j] = network[j], network[i]
  1147  			address[i], address[j] = address[j], address[i]
  1148  		})
  1149  	}
  1150  
  1151  	for i := range address {
  1152  		netConn, err = g.pool.dial(ctx, network[i], address[i])
  1153  		if err == nil {
  1154  			netAddr = &networkAddress{
  1155  				network: network[i],
  1156  				address: address[i],
  1157  			}
  1158  			break
  1159  		}
  1160  	}
  1161  
  1162  	if err != nil {
  1163  		return nil, err
  1164  	}
  1165  
  1166  	defer func() {
  1167  		if netConn != nil {
  1168  			netConn.Close()
  1169  		}
  1170  	}()
  1171  
  1172  	if tlsConfig := g.pool.tls; tlsConfig != nil {
  1173  		if tlsConfig.ServerName == "" {
  1174  			host, _ := splitHostPort(netAddr.String())
  1175  			tlsConfig = tlsConfig.Clone()
  1176  			tlsConfig.ServerName = host
  1177  		}
  1178  		netConn = tls.Client(netConn, tlsConfig)
  1179  	}
  1180  
  1181  	pc := protocol.NewConn(netConn, g.pool.clientID)
  1182  	pc.SetDeadline(deadline)
  1183  
  1184  	r, err := pc.RoundTrip(new(apiversions.Request))
  1185  	if err != nil {
  1186  		return nil, err
  1187  	}
  1188  	res := r.(*apiversions.Response)
  1189  	ver := make(map[protocol.ApiKey]int16, len(res.ApiKeys))
  1190  
  1191  	if res.ErrorCode != 0 {
  1192  		return nil, fmt.Errorf("negotating API versions with kafka broker at %s: %w", g.addr, Error(res.ErrorCode))
  1193  	}
  1194  
  1195  	for _, r := range res.ApiKeys {
  1196  		apiKey := protocol.ApiKey(r.ApiKey)
  1197  		ver[apiKey] = apiKey.SelectVersion(r.MinVersion, r.MaxVersion)
  1198  	}
  1199  
  1200  	pc.SetVersions(ver)
  1201  	pc.SetDeadline(time.Time{})
  1202  
  1203  	if g.pool.sasl != nil {
  1204  		host, port, err := splitHostPortNumber(netAddr.String())
  1205  		if err != nil {
  1206  			return nil, err
  1207  		}
  1208  		metadata := &sasl.Metadata{
  1209  			Host: host,
  1210  			Port: port,
  1211  		}
  1212  		if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, g.pool.sasl); err != nil {
  1213  			return nil, err
  1214  		}
  1215  	}
  1216  
  1217  	reqs := make(chan connRequest)
  1218  	c := &conn{
  1219  		network: netAddr.Network(),
  1220  		address: netAddr.String(),
  1221  		reqs:    reqs,
  1222  		group:   g,
  1223  	}
  1224  	go c.run(pc, reqs)
  1225  
  1226  	netConn = nil
  1227  	return c, nil
  1228  }
  1229  
  1230  type conn struct {
  1231  	reqs    chan<- connRequest
  1232  	network string
  1233  	address string
  1234  	once    sync.Once
  1235  	group   *connGroup
  1236  	timer   *time.Timer
  1237  }
  1238  
  1239  func (c *conn) close() {
  1240  	c.once.Do(func() { close(c.reqs) })
  1241  }
  1242  
  1243  func (c *conn) run(pc *protocol.Conn, reqs <-chan connRequest) {
  1244  	defer pc.Close()
  1245  
  1246  	for cr := range reqs {
  1247  		r, err := c.roundTrip(cr.ctx, pc, cr.req)
  1248  		if err != nil {
  1249  			cr.res.reject(err)
  1250  			if !errors.Is(err, protocol.ErrNoRecord) {
  1251  				break
  1252  			}
  1253  		} else {
  1254  			cr.res.resolve(r)
  1255  		}
  1256  		if !c.group.releaseConn(c) {
  1257  			break
  1258  		}
  1259  	}
  1260  }
  1261  
  1262  func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (Response, error) {
  1263  	pprof.SetGoroutineLabels(ctx)
  1264  	defer pprof.SetGoroutineLabels(context.Background())
  1265  
  1266  	if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
  1267  		pc.SetDeadline(deadline)
  1268  		defer pc.SetDeadline(time.Time{})
  1269  	}
  1270  
  1271  	return pc.RoundTrip(req)
  1272  }
  1273  
  1274  // authenticateSASL performs all of the required requests to authenticate this
  1275  // connection.  If any step fails, this function returns with an error.  A nil
  1276  // error indicates successful authentication.
  1277  func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error {
  1278  	if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil {
  1279  		return err
  1280  	}
  1281  
  1282  	sess, state, err := mechanism.Start(ctx)
  1283  	if err != nil {
  1284  		return err
  1285  	}
  1286  
  1287  	for completed := false; !completed; {
  1288  		challenge, err := saslAuthenticateRoundTrip(pc, state)
  1289  		switch err {
  1290  		case nil:
  1291  		case io.EOF:
  1292  			// the broker may communicate a failed exchange by closing the
  1293  			// connection (esp. in the case where we're passing opaque sasl
  1294  			// data over the wire since there's no protocol info).
  1295  			return SASLAuthenticationFailed
  1296  		default:
  1297  			return err
  1298  		}
  1299  
  1300  		completed, state, err = sess.Next(ctx, challenge)
  1301  		if err != nil {
  1302  			return err
  1303  		}
  1304  	}
  1305  
  1306  	return nil
  1307  }
  1308  
  1309  // saslHandshake sends the SASL handshake message.  This will determine whether
  1310  // the Mechanism is supported by the cluster.  If it's not, this function will
  1311  // error out with UnsupportedSASLMechanism.
  1312  //
  1313  // If the mechanism is unsupported, the handshake request will reply with the
  1314  // list of the cluster's configured mechanisms, which could potentially be used
  1315  // to facilitate negotiation.  At the moment, we are not negotiating the
  1316  // mechanism as we believe that brokers are usually known to the client, and
  1317  // therefore the client should already know which mechanisms are supported.
  1318  //
  1319  // See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake
  1320  func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error {
  1321  	msg, err := pc.RoundTrip(&saslhandshake.Request{
  1322  		Mechanism: mechanism,
  1323  	})
  1324  	if err != nil {
  1325  		return err
  1326  	}
  1327  	res := msg.(*saslhandshake.Response)
  1328  	if res.ErrorCode != 0 {
  1329  		err = Error(res.ErrorCode)
  1330  	}
  1331  	return err
  1332  }
  1333  
  1334  // saslAuthenticate sends the SASL authenticate message.  This function must
  1335  // be immediately preceded by a successful saslHandshake.
  1336  //
  1337  // See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate
  1338  func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) {
  1339  	msg, err := pc.RoundTrip(&saslauthenticate.Request{
  1340  		AuthBytes: data,
  1341  	})
  1342  	if err != nil {
  1343  		return nil, err
  1344  	}
  1345  	res := msg.(*saslauthenticate.Response)
  1346  	if res.ErrorCode != 0 {
  1347  		err = makeError(res.ErrorCode, res.ErrorMessage)
  1348  	}
  1349  	return res.AuthBytes, err
  1350  }
  1351  
  1352  var _ RoundTripper = (*Transport)(nil)