github.com/streamdal/segmentio-kafka-go@v0.4.47-streamdal/transport.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"runtime/pprof"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"github.com/segmentio/kafka-go/protocol"
    20  	"github.com/segmentio/kafka-go/protocol/apiversions"
    21  	"github.com/segmentio/kafka-go/protocol/createtopics"
    22  	"github.com/segmentio/kafka-go/protocol/findcoordinator"
    23  	meta "github.com/segmentio/kafka-go/protocol/metadata"
    24  	"github.com/segmentio/kafka-go/protocol/saslauthenticate"
    25  	"github.com/segmentio/kafka-go/protocol/saslhandshake"
    26  	"github.com/segmentio/kafka-go/sasl"
    27  )
    28  
    29  // Request is an interface implemented by types that represent messages sent
    30  // from kafka clients to brokers.
    31  type Request = protocol.Message
    32  
    33  // Response is an interface implemented by types that represent messages sent
    34  // from kafka brokers in response to client requests.
    35  type Response = protocol.Message
    36  
    37  // RoundTripper is an interface implemented by types which support interacting
    38  // with kafka brokers.
    39  type RoundTripper interface {
    40  	// RoundTrip sends a request to a kafka broker and returns the response that
    41  	// was received, or a non-nil error.
    42  	//
    43  	// The context passed as first argument can be used to asynchronnously abort
    44  	// the call if needed.
    45  	RoundTrip(context.Context, net.Addr, Request) (Response, error)
    46  }
    47  
    48  // Transport is an implementation of the RoundTripper interface.
    49  //
    50  // Transport values manage a pool of connections and automatically discovers the
    51  // clusters layout to route requests to the appropriate brokers.
    52  //
    53  // Transport values are safe to use concurrently from multiple goroutines.
    54  //
    55  // Note: The intent is for the Transport to become the underlying layer of the
    56  // kafka.Reader and kafka.Writer types.
    57  type Transport struct {
    58  	// A function used to establish connections to the kafka cluster.
    59  	Dial func(context.Context, string, string) (net.Conn, error)
    60  
    61  	// Time limit set for establishing connections to the kafka cluster. This
    62  	// limit includes all round trips done to establish the connections (TLS
    63  	// handshake, SASL negotiation, etc...).
    64  	//
    65  	// Defaults to 5s.
    66  	DialTimeout time.Duration
    67  
    68  	// Maximum amount of time that connections will remain open and unused.
    69  	// The transport will manage to automatically close connections that have
    70  	// been idle for too long, and re-open them on demand when the transport is
    71  	// used again.
    72  	//
    73  	// Defaults to 30s.
    74  	IdleTimeout time.Duration
    75  
    76  	// TTL for the metadata cached by this transport. Note that the value
    77  	// configured here is an upper bound, the transport randomizes the TTLs to
    78  	// avoid getting into states where multiple clients end up synchronized and
    79  	// cause bursts of requests to the kafka broker.
    80  	//
    81  	// Default to 6s.
    82  	MetadataTTL time.Duration
    83  
    84  	// Topic names for the metadata cached by this transport. If this field is left blank,
    85  	// metadata information of all topics in the cluster will be retrieved.
    86  	MetadataTopics []string
    87  
    88  	// Unique identifier that the transport communicates to the brokers when it
    89  	// sends requests.
    90  	ClientID string
    91  
    92  	// An optional configuration for TLS connections established by this
    93  	// transport.
    94  	//
    95  	// If the Server
    96  	TLS *tls.Config
    97  
    98  	// SASL configures the Transfer to use SASL authentication.
    99  	SASL sasl.Mechanism
   100  
   101  	// An optional resolver used to translate broker host names into network
   102  	// addresses.
   103  	//
   104  	// The resolver will be called for every request (not every connection),
   105  	// making it possible to implement ACL policies by validating that the
   106  	// program is allowed to connect to the kafka broker. This also means that
   107  	// the resolver should probably provide a caching layer to avoid storming
   108  	// the service discovery backend with requests.
   109  	//
   110  	// When set, the Dial function is not responsible for performing name
   111  	// resolution, and is always called with a pre-resolved address.
   112  	Resolver BrokerResolver
   113  
   114  	// The background context used to control goroutines started internally by
   115  	// the transport.
   116  	//
   117  	// If nil, context.Background() is used instead.
   118  	Context context.Context
   119  
   120  	mutex sync.RWMutex
   121  	pools map[networkAddress]*connPool
   122  }
   123  
   124  // DefaultTransport is the default transport used by kafka clients in this
   125  // package.
   126  var DefaultTransport RoundTripper = &Transport{
   127  	Dial: (&net.Dialer{
   128  		Timeout:   3 * time.Second,
   129  		DualStack: true,
   130  	}).DialContext,
   131  }
   132  
   133  // CloseIdleConnections closes all idle connections immediately, and marks all
   134  // connections that are in use to be closed when they become idle again.
   135  func (t *Transport) CloseIdleConnections() {
   136  	t.mutex.Lock()
   137  	defer t.mutex.Unlock()
   138  
   139  	for _, pool := range t.pools {
   140  		pool.unref()
   141  	}
   142  
   143  	for k := range t.pools {
   144  		delete(t.pools, k)
   145  	}
   146  }
   147  
   148  // RoundTrip sends a request to a kafka cluster and returns the response, or an
   149  // error if no responses were received.
   150  //
   151  // Message types are available in sub-packages of the protocol package. Each
   152  // kafka API is implemented in a different sub-package. For example, the request
   153  // and response types for the Fetch API are available in the protocol/fetch
   154  // package.
   155  //
   156  // The type of the response message will match the type of the request. For
   157  // example, if RoundTrip was called with a *fetch.Request as argument, the value
   158  // returned will be of type *fetch.Response. It is safe for the program to do a
   159  // type assertion after checking that no error was returned.
   160  //
   161  // This example illustrates the way this method is expected to be used:
   162  //
   163  //	r, err := transport.RoundTrip(ctx, addr, &fetch.Request{ ... })
   164  //	if err != nil {
   165  //		...
   166  //	} else {
   167  //		res := r.(*fetch.Response)
   168  //		...
   169  //	}
   170  //
   171  // The transport automatically selects the highest version of the API that is
   172  // supported by both the kafka-go package and the kafka broker. The negotiation
   173  // happens transparently once when connections are established.
   174  //
   175  // This API was introduced in version 0.4 as a way to leverage the lower-level
   176  // features of the kafka protocol, but also provide a more efficient way of
   177  // managing connections to kafka brokers.
   178  func (t *Transport) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) {
   179  	p := t.grabPool(addr)
   180  	defer p.unref()
   181  	return p.roundTrip(ctx, req)
   182  }
   183  
   184  func (t *Transport) dial() func(context.Context, string, string) (net.Conn, error) {
   185  	if t.Dial != nil {
   186  		return t.Dial
   187  	}
   188  	return defaultDialer.DialContext
   189  }
   190  
   191  func (t *Transport) dialTimeout() time.Duration {
   192  	if t.DialTimeout > 0 {
   193  		return t.DialTimeout
   194  	}
   195  	return 5 * time.Second
   196  }
   197  
   198  func (t *Transport) idleTimeout() time.Duration {
   199  	if t.IdleTimeout > 0 {
   200  		return t.IdleTimeout
   201  	}
   202  	return 30 * time.Second
   203  }
   204  
   205  func (t *Transport) metadataTTL() time.Duration {
   206  	if t.MetadataTTL > 0 {
   207  		return t.MetadataTTL
   208  	}
   209  	return 6 * time.Second
   210  }
   211  
   212  func (t *Transport) grabPool(addr net.Addr) *connPool {
   213  	k := networkAddress{
   214  		network: addr.Network(),
   215  		address: addr.String(),
   216  	}
   217  
   218  	t.mutex.RLock()
   219  	p := t.pools[k]
   220  	if p != nil {
   221  		p.ref()
   222  	}
   223  	t.mutex.RUnlock()
   224  
   225  	if p != nil {
   226  		return p
   227  	}
   228  
   229  	t.mutex.Lock()
   230  	defer t.mutex.Unlock()
   231  
   232  	if p := t.pools[k]; p != nil {
   233  		p.ref()
   234  		return p
   235  	}
   236  
   237  	ctx, cancel := context.WithCancel(t.context())
   238  
   239  	p = &connPool{
   240  		refc: 2,
   241  
   242  		dial:           t.dial(),
   243  		dialTimeout:    t.dialTimeout(),
   244  		idleTimeout:    t.idleTimeout(),
   245  		metadataTTL:    t.metadataTTL(),
   246  		metadataTopics: t.MetadataTopics,
   247  		clientID:       t.ClientID,
   248  		tls:            t.TLS,
   249  		sasl:           t.SASL,
   250  		resolver:       t.Resolver,
   251  
   252  		ready:  make(event),
   253  		wake:   make(chan event),
   254  		conns:  make(map[int32]*connGroup),
   255  		cancel: cancel,
   256  	}
   257  
   258  	p.ctrl = p.newConnGroup(addr)
   259  	go p.discover(ctx, p.wake)
   260  
   261  	if t.pools == nil {
   262  		t.pools = make(map[networkAddress]*connPool)
   263  	}
   264  	t.pools[k] = p
   265  	return p
   266  }
   267  
   268  func (t *Transport) context() context.Context {
   269  	if t.Context != nil {
   270  		return t.Context
   271  	}
   272  	return context.Background()
   273  }
   274  
   275  type event chan struct{}
   276  
   277  func (e event) trigger() { close(e) }
   278  
   279  type connPool struct {
   280  	refc uintptr
   281  	// Immutable fields of the connection pool. Connections access these field
   282  	// on their parent pool in a ready-only fashion, so no synchronization is
   283  	// required.
   284  	dial           func(context.Context, string, string) (net.Conn, error)
   285  	dialTimeout    time.Duration
   286  	idleTimeout    time.Duration
   287  	metadataTTL    time.Duration
   288  	metadataTopics []string
   289  	clientID       string
   290  	tls            *tls.Config
   291  	sasl           sasl.Mechanism
   292  	resolver       BrokerResolver
   293  	// Signaling mechanisms to orchestrate communications between the pool and
   294  	// the rest of the program.
   295  	once   sync.Once  // ensure that `ready` is triggered only once
   296  	ready  event      // triggered after the first metadata update
   297  	wake   chan event // used to force metadata updates
   298  	cancel context.CancelFunc
   299  	// Mutable fields of the connection pool, access must be synchronized.
   300  	mutex sync.RWMutex
   301  	conns map[int32]*connGroup // data connections used for produce/fetch/etc...
   302  	ctrl  *connGroup           // control connections used for metadata requests
   303  	state atomic.Value         // cached cluster state
   304  }
   305  
   306  type connPoolState struct {
   307  	metadata *meta.Response   // last metadata response seen by the pool
   308  	err      error            // last error from metadata requests
   309  	layout   protocol.Cluster // cluster layout built from metadata response
   310  }
   311  
   312  func (p *connPool) grabState() connPoolState {
   313  	state, _ := p.state.Load().(connPoolState)
   314  	return state
   315  }
   316  
   317  func (p *connPool) setState(state connPoolState) {
   318  	p.state.Store(state)
   319  }
   320  
   321  func (p *connPool) ref() {
   322  	atomic.AddUintptr(&p.refc, +1)
   323  }
   324  
   325  func (p *connPool) unref() {
   326  	if atomic.AddUintptr(&p.refc, ^uintptr(0)) == 0 {
   327  		p.mutex.Lock()
   328  		defer p.mutex.Unlock()
   329  
   330  		for _, conns := range p.conns {
   331  			conns.closeIdleConns()
   332  		}
   333  
   334  		p.ctrl.closeIdleConns()
   335  		p.cancel()
   336  	}
   337  }
   338  
   339  func (p *connPool) roundTrip(ctx context.Context, req Request) (Response, error) {
   340  	// This first select should never block after the first metadata response
   341  	// that would mark the pool as `ready`.
   342  	select {
   343  	case <-p.ready:
   344  	case <-ctx.Done():
   345  		return nil, ctx.Err()
   346  	}
   347  
   348  	state := p.grabState()
   349  	var response promise
   350  
   351  	switch m := req.(type) {
   352  	case *meta.Request:
   353  		// We serve metadata requests directly from the transport cache unless
   354  		// we would like to auto create a topic that isn't in our cache.
   355  		//
   356  		// This reduces the number of round trips to kafka brokers while keeping
   357  		// the logic simple when applying partitioning strategies.
   358  		if state.err != nil {
   359  			return nil, state.err
   360  		}
   361  
   362  		cachedMeta := filterMetadataResponse(m, state.metadata)
   363  		// requestNeeded indicates if we need to send this metadata request to the server.
   364  		// It's true when we want to auto-create topics and we don't have the topic in our
   365  		// cache.
   366  		var requestNeeded bool
   367  		if m.AllowAutoTopicCreation {
   368  			for _, topic := range cachedMeta.Topics {
   369  				if topic.ErrorCode == int16(UnknownTopicOrPartition) {
   370  					requestNeeded = true
   371  					break
   372  				}
   373  			}
   374  		}
   375  
   376  		if !requestNeeded {
   377  			return cachedMeta, nil
   378  		}
   379  
   380  	case protocol.Splitter:
   381  		// Messages that implement the Splitter interface trigger the creation of
   382  		// multiple requests that are all merged back into a single results by
   383  		// a merger.
   384  		messages, merger, err := m.Split(state.layout)
   385  		if err != nil {
   386  			return nil, err
   387  		}
   388  		promises := make([]promise, len(messages))
   389  		for i, m := range messages {
   390  			promises[i] = p.sendRequest(ctx, m, state)
   391  		}
   392  		response = join(promises, messages, merger)
   393  	}
   394  
   395  	if response == nil {
   396  		response = p.sendRequest(ctx, req, state)
   397  	}
   398  
   399  	r, err := response.await(ctx)
   400  	if err != nil {
   401  		return r, err
   402  	}
   403  
   404  	switch resp := r.(type) {
   405  	case *createtopics.Response:
   406  		// Force an update of the metadata when adding topics,
   407  		// otherwise the cached state would get out of sync.
   408  		topicsToRefresh := make([]string, 0, len(resp.Topics))
   409  		for _, topic := range resp.Topics {
   410  			// fixes issue 672: don't refresh topics that failed to create, it causes the library to hang indefinitely
   411  			if topic.ErrorCode != 0 {
   412  				continue
   413  			}
   414  
   415  			topicsToRefresh = append(topicsToRefresh, topic.Name)
   416  		}
   417  
   418  		p.refreshMetadata(ctx, topicsToRefresh)
   419  	case *meta.Response:
   420  		m := req.(*meta.Request)
   421  		// If we get here with allow auto topic creation then
   422  		// we didn't have that topic in our cache, so we should update
   423  		// the cache.
   424  		if m.AllowAutoTopicCreation {
   425  			topicsToRefresh := make([]string, 0, len(resp.Topics))
   426  			for _, topic := range resp.Topics {
   427  				// Don't refresh topics that failed to create, since that may
   428  				// mean that enable automatic topic creation is not enabled.
   429  				// That causes the library to hang indefinitely, same as
   430  				// don't refresh topics that failed to create,
   431  				// createtopics process. Fixes issue 806.
   432  				if topic.ErrorCode != 0 {
   433  					continue
   434  				}
   435  
   436  				topicsToRefresh = append(topicsToRefresh, topic.Name)
   437  			}
   438  			p.refreshMetadata(ctx, topicsToRefresh)
   439  		}
   440  	}
   441  
   442  	return r, nil
   443  }
   444  
   445  // refreshMetadata forces an update of the cached cluster metadata, and waits
   446  // for the given list of topics to appear. This waiting mechanism is necessary
   447  // to account for the fact that topic creation is asynchronous in kafka, and
   448  // causes subsequent requests to fail while the cluster state is propagated to
   449  // all the brokers.
   450  func (p *connPool) refreshMetadata(ctx context.Context, expectTopics []string) {
   451  	minBackoff := 100 * time.Millisecond
   452  	maxBackoff := 2 * time.Second
   453  	cancel := ctx.Done()
   454  
   455  	for ctx.Err() == nil {
   456  		notify := make(event)
   457  		select {
   458  		case <-cancel:
   459  			return
   460  		case p.wake <- notify:
   461  			select {
   462  			case <-notify:
   463  			case <-cancel:
   464  				return
   465  			}
   466  		}
   467  
   468  		state := p.grabState()
   469  		found := 0
   470  
   471  		for _, topic := range expectTopics {
   472  			if _, ok := state.layout.Topics[topic]; ok {
   473  				found++
   474  			}
   475  		}
   476  
   477  		if found == len(expectTopics) {
   478  			return
   479  		}
   480  
   481  		if delay := time.Duration(rand.Int63n(int64(minBackoff))); delay > 0 {
   482  			timer := time.NewTimer(minBackoff)
   483  			select {
   484  			case <-cancel:
   485  			case <-timer.C:
   486  			}
   487  			timer.Stop()
   488  
   489  			if minBackoff *= 2; minBackoff > maxBackoff {
   490  				minBackoff = maxBackoff
   491  			}
   492  		}
   493  	}
   494  }
   495  
   496  func (p *connPool) setReady() {
   497  	p.once.Do(p.ready.trigger)
   498  }
   499  
   500  // update is called periodically by the goroutine running the discover method
   501  // to refresh the cluster layout information used by the transport to route
   502  // requests to brokers.
   503  func (p *connPool) update(ctx context.Context, metadata *meta.Response, err error) {
   504  	var layout protocol.Cluster
   505  
   506  	if metadata != nil {
   507  		metadata.ThrottleTimeMs = 0
   508  
   509  		// Normalize the lists so we can apply binary search on them.
   510  		sortMetadataBrokers(metadata.Brokers)
   511  		sortMetadataTopics(metadata.Topics)
   512  
   513  		for i := range metadata.Topics {
   514  			t := &metadata.Topics[i]
   515  			sortMetadataPartitions(t.Partitions)
   516  		}
   517  
   518  		layout = makeLayout(metadata)
   519  	}
   520  
   521  	state := p.grabState()
   522  	addBrokers := make(map[int32]struct{})
   523  	delBrokers := make(map[int32]struct{})
   524  
   525  	if err != nil {
   526  		// Only update the error on the transport if the cluster layout was
   527  		// unknown. This ensures that we prioritize a previously known state
   528  		// of the cluster to reduce the impact of transient failures.
   529  		if state.metadata != nil {
   530  			return
   531  		}
   532  		state.err = err
   533  	} else {
   534  		for id, b2 := range layout.Brokers {
   535  			if b1, ok := state.layout.Brokers[id]; !ok {
   536  				addBrokers[id] = struct{}{}
   537  			} else if b1 != b2 {
   538  				addBrokers[id] = struct{}{}
   539  				delBrokers[id] = struct{}{}
   540  			}
   541  		}
   542  
   543  		for id := range state.layout.Brokers {
   544  			if _, ok := layout.Brokers[id]; !ok {
   545  				delBrokers[id] = struct{}{}
   546  			}
   547  		}
   548  
   549  		state.metadata, state.layout = metadata, layout
   550  		state.err = nil
   551  	}
   552  
   553  	defer p.setReady()
   554  	defer p.setState(state)
   555  
   556  	if len(addBrokers) != 0 || len(delBrokers) != 0 {
   557  		// Only acquire the lock when there is a change of layout. This is an
   558  		// infrequent event so we don't risk introducing regular contention on
   559  		// the mutex if we were to lock it on every update.
   560  		p.mutex.Lock()
   561  		defer p.mutex.Unlock()
   562  
   563  		if ctx.Err() != nil {
   564  			return // the pool has been closed, no need to update
   565  		}
   566  
   567  		for id := range delBrokers {
   568  			if broker := p.conns[id]; broker != nil {
   569  				broker.closeIdleConns()
   570  				delete(p.conns, id)
   571  			}
   572  		}
   573  
   574  		for id := range addBrokers {
   575  			broker := layout.Brokers[id]
   576  			p.conns[id] = p.newBrokerConnGroup(Broker{
   577  				Rack: broker.Rack,
   578  				Host: broker.Host,
   579  				Port: int(broker.Port),
   580  				ID:   int(broker.ID),
   581  			})
   582  		}
   583  	}
   584  }
   585  
   586  // discover is the entry point of an internal goroutine for the transport which
   587  // periodically requests updates of the cluster metadata and refreshes the
   588  // transport cached cluster layout.
   589  func (p *connPool) discover(ctx context.Context, wake <-chan event) {
   590  	prng := rand.New(rand.NewSource(time.Now().UnixNano()))
   591  	metadataTTL := func() time.Duration {
   592  		return time.Duration(prng.Int63n(int64(p.metadataTTL)))
   593  	}
   594  
   595  	timer := time.NewTimer(metadataTTL())
   596  	defer timer.Stop()
   597  
   598  	var notify event
   599  	done := ctx.Done()
   600  
   601  	req := &meta.Request{
   602  		TopicNames: p.metadataTopics,
   603  	}
   604  
   605  	for {
   606  		c, err := p.grabClusterConn(ctx)
   607  		if err != nil {
   608  			p.update(ctx, nil, err)
   609  		} else {
   610  			res := make(async, 1)
   611  			deadline, cancel := context.WithTimeout(ctx, p.metadataTTL)
   612  			c.reqs <- connRequest{
   613  				ctx: deadline,
   614  				req: req,
   615  				res: res,
   616  			}
   617  			r, err := res.await(deadline)
   618  			cancel()
   619  			if err != nil && errors.Is(err, ctx.Err()) {
   620  				return
   621  			}
   622  			ret, _ := r.(*meta.Response)
   623  			p.update(ctx, ret, err)
   624  		}
   625  
   626  		if notify != nil {
   627  			notify.trigger()
   628  			notify = nil
   629  		}
   630  
   631  		select {
   632  		case <-timer.C:
   633  			timer.Reset(metadataTTL())
   634  		case <-done:
   635  			return
   636  		case notify = <-wake:
   637  		}
   638  	}
   639  }
   640  
   641  // grabBrokerConn returns a connection to a specific broker represented by the
   642  // broker id passed as argument. If the broker id was not known, an error is
   643  // returned.
   644  func (p *connPool) grabBrokerConn(ctx context.Context, brokerID int32) (*conn, error) {
   645  	p.mutex.RLock()
   646  	g := p.conns[brokerID]
   647  	p.mutex.RUnlock()
   648  	if g == nil {
   649  		return nil, BrokerNotAvailable
   650  	}
   651  	return g.grabConnOrConnect(ctx)
   652  }
   653  
   654  // grabClusterConn returns the connection to the kafka cluster that the pool is
   655  // configured to connect to.
   656  //
   657  // The transport uses a shared `control` connection to the cluster for any
   658  // requests that aren't supposed to be sent to specific brokers (e.g. Fetch or
   659  // Produce requests). Requests intended to be routed to specific brokers are
   660  // dispatched on a separate pool of connections that the transport maintains.
   661  // This split help avoid head-of-line blocking situations where control requests
   662  // like Metadata would be queued behind large responses from Fetch requests for
   663  // example.
   664  //
   665  // In either cases, the requests are multiplexed so we can keep a minimal number
   666  // of connections open (N+1, where N is the number of brokers in the cluster).
   667  func (p *connPool) grabClusterConn(ctx context.Context) (*conn, error) {
   668  	return p.ctrl.grabConnOrConnect(ctx)
   669  }
   670  
   671  func (p *connPool) sendRequest(ctx context.Context, req Request, state connPoolState) promise {
   672  	brokerID := int32(-1)
   673  
   674  	switch m := req.(type) {
   675  	case protocol.BrokerMessage:
   676  		// Some requests are supposed to be sent to specific brokers (e.g. the
   677  		// partition leaders). They implement the BrokerMessage interface to
   678  		// delegate the routing decision to each message type.
   679  		broker, err := m.Broker(state.layout)
   680  		if err != nil {
   681  			return reject(err)
   682  		}
   683  		brokerID = broker.ID
   684  
   685  	case protocol.GroupMessage:
   686  		// Some requests are supposed to be sent to a group coordinator,
   687  		// look up which broker is currently the coordinator for the group
   688  		// so we can get a connection to that broker.
   689  		//
   690  		// TODO: should we cache the coordinator info?
   691  		p := p.sendRequest(ctx, &findcoordinator.Request{Key: m.Group()}, state)
   692  		r, err := p.await(ctx)
   693  		if err != nil {
   694  			return reject(err)
   695  		}
   696  		brokerID = r.(*findcoordinator.Response).NodeID
   697  	case protocol.TransactionalMessage:
   698  		p := p.sendRequest(ctx, &findcoordinator.Request{
   699  			Key:     m.Transaction(),
   700  			KeyType: int8(CoordinatorKeyTypeTransaction),
   701  		}, state)
   702  		r, err := p.await(ctx)
   703  		if err != nil {
   704  			return reject(err)
   705  		}
   706  		brokerID = r.(*findcoordinator.Response).NodeID
   707  	}
   708  
   709  	var c *conn
   710  	var err error
   711  	if brokerID >= 0 {
   712  		c, err = p.grabBrokerConn(ctx, brokerID)
   713  	} else {
   714  		c, err = p.grabClusterConn(ctx)
   715  	}
   716  	if err != nil {
   717  		return reject(err)
   718  	}
   719  
   720  	res := make(async, 1)
   721  
   722  	c.reqs <- connRequest{
   723  		ctx: ctx,
   724  		req: req,
   725  		res: res,
   726  	}
   727  
   728  	return res
   729  }
   730  
   731  func filterMetadataResponse(req *meta.Request, res *meta.Response) *meta.Response {
   732  	ret := *res
   733  
   734  	if req.TopicNames != nil {
   735  		ret.Topics = make([]meta.ResponseTopic, len(req.TopicNames))
   736  
   737  		for i, topicName := range req.TopicNames {
   738  			j, ok := findMetadataTopic(res.Topics, topicName)
   739  			if ok {
   740  				ret.Topics[i] = res.Topics[j]
   741  			} else {
   742  				ret.Topics[i] = meta.ResponseTopic{
   743  					ErrorCode: int16(UnknownTopicOrPartition),
   744  					Name:      topicName,
   745  				}
   746  			}
   747  		}
   748  	}
   749  
   750  	return &ret
   751  }
   752  
   753  func findMetadataTopic(topics []meta.ResponseTopic, topicName string) (int, bool) {
   754  	i := sort.Search(len(topics), func(i int) bool {
   755  		return topics[i].Name >= topicName
   756  	})
   757  	return i, i >= 0 && i < len(topics) && topics[i].Name == topicName
   758  }
   759  
   760  func sortMetadataBrokers(brokers []meta.ResponseBroker) {
   761  	sort.Slice(brokers, func(i, j int) bool {
   762  		return brokers[i].NodeID < brokers[j].NodeID
   763  	})
   764  }
   765  
   766  func sortMetadataTopics(topics []meta.ResponseTopic) {
   767  	sort.Slice(topics, func(i, j int) bool {
   768  		return topics[i].Name < topics[j].Name
   769  	})
   770  }
   771  
   772  func sortMetadataPartitions(partitions []meta.ResponsePartition) {
   773  	sort.Slice(partitions, func(i, j int) bool {
   774  		return partitions[i].PartitionIndex < partitions[j].PartitionIndex
   775  	})
   776  }
   777  
   778  func makeLayout(metadataResponse *meta.Response) protocol.Cluster {
   779  	layout := protocol.Cluster{
   780  		Controller: metadataResponse.ControllerID,
   781  		Brokers:    make(map[int32]protocol.Broker),
   782  		Topics:     make(map[string]protocol.Topic),
   783  	}
   784  
   785  	for _, broker := range metadataResponse.Brokers {
   786  		layout.Brokers[broker.NodeID] = protocol.Broker{
   787  			Rack: broker.Rack,
   788  			Host: broker.Host,
   789  			Port: broker.Port,
   790  			ID:   broker.NodeID,
   791  		}
   792  	}
   793  
   794  	for _, topic := range metadataResponse.Topics {
   795  		if topic.IsInternal {
   796  			continue // TODO: do we need to expose those?
   797  		}
   798  		layout.Topics[topic.Name] = protocol.Topic{
   799  			Name:       topic.Name,
   800  			Error:      topic.ErrorCode,
   801  			Partitions: makePartitions(topic.Partitions),
   802  		}
   803  	}
   804  
   805  	return layout
   806  }
   807  
   808  func makePartitions(metadataPartitions []meta.ResponsePartition) map[int32]protocol.Partition {
   809  	protocolPartitions := make(map[int32]protocol.Partition, len(metadataPartitions))
   810  	numBrokerIDs := 0
   811  
   812  	for _, p := range metadataPartitions {
   813  		numBrokerIDs += len(p.ReplicaNodes) + len(p.IsrNodes) + len(p.OfflineReplicas)
   814  	}
   815  
   816  	// Reduce the memory footprint a bit by allocating a single buffer to write
   817  	// all broker ids.
   818  	brokerIDs := make([]int32, 0, numBrokerIDs)
   819  
   820  	for _, p := range metadataPartitions {
   821  		var rep, isr, off []int32
   822  		brokerIDs, rep = appendBrokerIDs(brokerIDs, p.ReplicaNodes)
   823  		brokerIDs, isr = appendBrokerIDs(brokerIDs, p.IsrNodes)
   824  		brokerIDs, off = appendBrokerIDs(brokerIDs, p.OfflineReplicas)
   825  
   826  		protocolPartitions[p.PartitionIndex] = protocol.Partition{
   827  			ID:       p.PartitionIndex,
   828  			Error:    p.ErrorCode,
   829  			Leader:   p.LeaderID,
   830  			Replicas: rep,
   831  			ISR:      isr,
   832  			Offline:  off,
   833  		}
   834  	}
   835  
   836  	return protocolPartitions
   837  }
   838  
   839  func appendBrokerIDs(ids, brokers []int32) ([]int32, []int32) {
   840  	i := len(ids)
   841  	ids = append(ids, brokers...)
   842  	return ids, ids[i:len(ids):len(ids)]
   843  }
   844  
   845  func (p *connPool) newConnGroup(a net.Addr) *connGroup {
   846  	return &connGroup{
   847  		addr: a,
   848  		pool: p,
   849  		broker: Broker{
   850  			ID: -1,
   851  		},
   852  	}
   853  }
   854  
   855  func (p *connPool) newBrokerConnGroup(broker Broker) *connGroup {
   856  	return &connGroup{
   857  		addr: &networkAddress{
   858  			network: "tcp",
   859  			address: net.JoinHostPort(broker.Host, strconv.Itoa(broker.Port)),
   860  		},
   861  		pool:   p,
   862  		broker: broker,
   863  	}
   864  }
   865  
   866  type connRequest struct {
   867  	ctx context.Context
   868  	req Request
   869  	res async
   870  }
   871  
   872  // The promise interface is used as a message passing abstraction to coordinate
   873  // between goroutines that handle requests and responses.
   874  type promise interface {
   875  	// Waits until the promise is resolved, rejected, or the context canceled.
   876  	await(context.Context) (Response, error)
   877  }
   878  
   879  // async is an implementation of the promise interface which supports resolving
   880  // or rejecting the await call asynchronously.
   881  type async chan interface{}
   882  
   883  func (p async) await(ctx context.Context) (Response, error) {
   884  	select {
   885  	case x := <-p:
   886  		switch v := x.(type) {
   887  		case nil:
   888  			return nil, nil // A nil response is ok (e.g. when RequiredAcks is None)
   889  		case Response:
   890  			return v, nil
   891  		case error:
   892  			return nil, v
   893  		default:
   894  			panic(fmt.Errorf("BUG: promise resolved with impossible value of type %T", v))
   895  		}
   896  	case <-ctx.Done():
   897  		return nil, ctx.Err()
   898  	}
   899  }
   900  
   901  func (p async) resolve(res Response) { p <- res }
   902  
   903  func (p async) reject(err error) { p <- err }
   904  
   905  // rejected is an implementation of the promise interface which is always
   906  // returns an error. Values of this type are constructed using the reject
   907  // function.
   908  type rejected struct{ err error }
   909  
   910  func reject(err error) promise { return &rejected{err: err} }
   911  
   912  func (p *rejected) await(ctx context.Context) (Response, error) {
   913  	return nil, p.err
   914  }
   915  
   916  // joined is an implementation of the promise interface which merges results
   917  // from multiple promises into one await call using a merger.
   918  type joined struct {
   919  	promises []promise
   920  	requests []Request
   921  	merger   protocol.Merger
   922  }
   923  
   924  func join(promises []promise, requests []Request, merger protocol.Merger) promise {
   925  	return &joined{
   926  		promises: promises,
   927  		requests: requests,
   928  		merger:   merger,
   929  	}
   930  }
   931  
   932  func (p *joined) await(ctx context.Context) (Response, error) {
   933  	results := make([]interface{}, len(p.promises))
   934  
   935  	for i, sub := range p.promises {
   936  		m, err := sub.await(ctx)
   937  		if err != nil {
   938  			results[i] = err
   939  		} else {
   940  			results[i] = m
   941  		}
   942  	}
   943  
   944  	return p.merger.Merge(p.requests, results)
   945  }
   946  
   947  // Default dialer used by the transport connections when no Dial function
   948  // was configured by the program.
   949  var defaultDialer = net.Dialer{
   950  	Timeout:   3 * time.Second,
   951  	DualStack: true,
   952  }
   953  
   954  // connGroup represents a logical connection group to a kafka broker. The
   955  // actual network connections are lazily open before sending requests, and
   956  // closed if they are unused for longer than the idle timeout.
   957  type connGroup struct {
   958  	addr   net.Addr
   959  	broker Broker
   960  	// Immutable state of the connection.
   961  	pool *connPool
   962  	// Shared state of the connection, this is synchronized on the mutex through
   963  	// calls to the synchronized method. Both goroutines of the connection share
   964  	// the state maintained in these fields.
   965  	mutex     sync.Mutex
   966  	closed    bool
   967  	idleConns []*conn // stack of idle connections
   968  }
   969  
   970  func (g *connGroup) closeIdleConns() {
   971  	g.mutex.Lock()
   972  	conns := g.idleConns
   973  	g.idleConns = nil
   974  	g.closed = true
   975  	g.mutex.Unlock()
   976  
   977  	for _, c := range conns {
   978  		c.close()
   979  	}
   980  }
   981  
   982  func (g *connGroup) grabConnOrConnect(ctx context.Context) (*conn, error) {
   983  	rslv := g.pool.resolver
   984  	addr := g.addr
   985  	var c *conn
   986  
   987  	if rslv == nil {
   988  		c = g.grabConn()
   989  	} else {
   990  		var err error
   991  		broker := g.broker
   992  
   993  		if broker.ID < 0 {
   994  			host, port, err := splitHostPortNumber(addr.String())
   995  			if err != nil {
   996  				return nil, err
   997  			}
   998  			broker.Host = host
   999  			broker.Port = port
  1000  		}
  1001  
  1002  		ipAddrs, err := rslv.LookupBrokerIPAddr(ctx, broker)
  1003  		if err != nil {
  1004  			return nil, err
  1005  		}
  1006  
  1007  		for _, ipAddr := range ipAddrs {
  1008  			network := addr.Network()
  1009  			address := net.JoinHostPort(ipAddr.String(), strconv.Itoa(broker.Port))
  1010  
  1011  			if c = g.grabConnTo(network, address); c != nil {
  1012  				break
  1013  			}
  1014  		}
  1015  	}
  1016  
  1017  	if c == nil {
  1018  		connChan := make(chan *conn)
  1019  		errChan := make(chan error)
  1020  
  1021  		go func() {
  1022  			c, err := g.connect(ctx, addr)
  1023  			if err != nil {
  1024  				select {
  1025  				case errChan <- err:
  1026  				case <-ctx.Done():
  1027  				}
  1028  			} else {
  1029  				select {
  1030  				case connChan <- c:
  1031  				case <-ctx.Done():
  1032  					if !g.releaseConn(c) {
  1033  						c.close()
  1034  					}
  1035  				}
  1036  			}
  1037  		}()
  1038  
  1039  		select {
  1040  		case c = <-connChan:
  1041  		case err := <-errChan:
  1042  			return nil, err
  1043  		case <-ctx.Done():
  1044  			return nil, ctx.Err()
  1045  		}
  1046  	}
  1047  
  1048  	return c, nil
  1049  }
  1050  
  1051  func (g *connGroup) grabConnTo(network, address string) *conn {
  1052  	g.mutex.Lock()
  1053  	defer g.mutex.Unlock()
  1054  
  1055  	for i := len(g.idleConns) - 1; i >= 0; i-- {
  1056  		c := g.idleConns[i]
  1057  
  1058  		if c.network == network && c.address == address {
  1059  			copy(g.idleConns[i:], g.idleConns[i+1:])
  1060  			n := len(g.idleConns) - 1
  1061  			g.idleConns[n] = nil
  1062  			g.idleConns = g.idleConns[:n]
  1063  
  1064  			if c.timer != nil {
  1065  				c.timer.Stop()
  1066  			}
  1067  
  1068  			return c
  1069  		}
  1070  	}
  1071  
  1072  	return nil
  1073  }
  1074  
  1075  func (g *connGroup) grabConn() *conn {
  1076  	g.mutex.Lock()
  1077  	defer g.mutex.Unlock()
  1078  
  1079  	if len(g.idleConns) == 0 {
  1080  		return nil
  1081  	}
  1082  
  1083  	n := len(g.idleConns) - 1
  1084  	c := g.idleConns[n]
  1085  	g.idleConns[n] = nil
  1086  	g.idleConns = g.idleConns[:n]
  1087  
  1088  	if c.timer != nil {
  1089  		c.timer.Stop()
  1090  	}
  1091  
  1092  	return c
  1093  }
  1094  
  1095  func (g *connGroup) removeConn(c *conn) bool {
  1096  	g.mutex.Lock()
  1097  	defer g.mutex.Unlock()
  1098  
  1099  	if c.timer != nil {
  1100  		c.timer.Stop()
  1101  	}
  1102  
  1103  	for i, x := range g.idleConns {
  1104  		if x == c {
  1105  			copy(g.idleConns[i:], g.idleConns[i+1:])
  1106  			n := len(g.idleConns) - 1
  1107  			g.idleConns[n] = nil
  1108  			g.idleConns = g.idleConns[:n]
  1109  			return true
  1110  		}
  1111  	}
  1112  
  1113  	return false
  1114  }
  1115  
  1116  func (g *connGroup) releaseConn(c *conn) bool {
  1117  	idleTimeout := g.pool.idleTimeout
  1118  
  1119  	g.mutex.Lock()
  1120  	defer g.mutex.Unlock()
  1121  
  1122  	if g.closed {
  1123  		return false
  1124  	}
  1125  
  1126  	if c.timer != nil {
  1127  		c.timer.Reset(idleTimeout)
  1128  	} else {
  1129  		c.timer = time.AfterFunc(idleTimeout, func() {
  1130  			if g.removeConn(c) {
  1131  				c.close()
  1132  			}
  1133  		})
  1134  	}
  1135  
  1136  	g.idleConns = append(g.idleConns, c)
  1137  	return true
  1138  }
  1139  
  1140  func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) {
  1141  	deadline := time.Now().Add(g.pool.dialTimeout)
  1142  
  1143  	ctx, cancel := context.WithDeadline(ctx, deadline)
  1144  	defer cancel()
  1145  
  1146  	network := strings.Split(addr.Network(), ",")
  1147  	address := strings.Split(addr.String(), ",")
  1148  	var netConn net.Conn
  1149  	var netAddr net.Addr
  1150  	var err error
  1151  
  1152  	if len(address) > 1 {
  1153  		// Shuffle the list of addresses to randomize the order in which
  1154  		// connections are attempted. This prevents routing all connections
  1155  		// to the first broker (which will usually succeed).
  1156  		rand.Shuffle(len(address), func(i, j int) {
  1157  			network[i], network[j] = network[j], network[i]
  1158  			address[i], address[j] = address[j], address[i]
  1159  		})
  1160  	}
  1161  
  1162  	for i := range address {
  1163  		netConn, err = g.pool.dial(ctx, network[i], address[i])
  1164  		if err == nil {
  1165  			netAddr = &networkAddress{
  1166  				network: network[i],
  1167  				address: address[i],
  1168  			}
  1169  			break
  1170  		}
  1171  	}
  1172  
  1173  	if err != nil {
  1174  		return nil, err
  1175  	}
  1176  
  1177  	defer func() {
  1178  		if netConn != nil {
  1179  			netConn.Close()
  1180  		}
  1181  	}()
  1182  
  1183  	if tlsConfig := g.pool.tls; tlsConfig != nil {
  1184  		if tlsConfig.ServerName == "" {
  1185  			host, _ := splitHostPort(netAddr.String())
  1186  			tlsConfig = tlsConfig.Clone()
  1187  			tlsConfig.ServerName = host
  1188  		}
  1189  		netConn = tls.Client(netConn, tlsConfig)
  1190  	}
  1191  
  1192  	pc := protocol.NewConn(netConn, g.pool.clientID)
  1193  	pc.SetDeadline(deadline)
  1194  
  1195  	r, err := pc.RoundTrip(new(apiversions.Request))
  1196  	if err != nil {
  1197  		return nil, err
  1198  	}
  1199  	res := r.(*apiversions.Response)
  1200  	ver := make(map[protocol.ApiKey]int16, len(res.ApiKeys))
  1201  
  1202  	if res.ErrorCode != 0 {
  1203  		return nil, fmt.Errorf("negotating API versions with kafka broker at %s: %w", g.addr, Error(res.ErrorCode))
  1204  	}
  1205  
  1206  	for _, r := range res.ApiKeys {
  1207  		apiKey := protocol.ApiKey(r.ApiKey)
  1208  		ver[apiKey] = apiKey.SelectVersion(r.MinVersion, r.MaxVersion)
  1209  	}
  1210  
  1211  	pc.SetVersions(ver)
  1212  	pc.SetDeadline(time.Time{})
  1213  
  1214  	if g.pool.sasl != nil {
  1215  		host, port, err := splitHostPortNumber(netAddr.String())
  1216  		if err != nil {
  1217  			return nil, err
  1218  		}
  1219  		metadata := &sasl.Metadata{
  1220  			Host: host,
  1221  			Port: port,
  1222  		}
  1223  		if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, g.pool.sasl); err != nil {
  1224  			return nil, err
  1225  		}
  1226  	}
  1227  
  1228  	reqs := make(chan connRequest)
  1229  	c := &conn{
  1230  		network: netAddr.Network(),
  1231  		address: netAddr.String(),
  1232  		reqs:    reqs,
  1233  		group:   g,
  1234  	}
  1235  	go c.run(pc, reqs)
  1236  
  1237  	netConn = nil
  1238  	return c, nil
  1239  }
  1240  
  1241  type conn struct {
  1242  	reqs    chan<- connRequest
  1243  	network string
  1244  	address string
  1245  	once    sync.Once
  1246  	group   *connGroup
  1247  	timer   *time.Timer
  1248  }
  1249  
  1250  func (c *conn) close() {
  1251  	c.once.Do(func() { close(c.reqs) })
  1252  }
  1253  
  1254  func (c *conn) run(pc *protocol.Conn, reqs <-chan connRequest) {
  1255  	defer pc.Close()
  1256  
  1257  	for cr := range reqs {
  1258  		r, err := c.roundTrip(cr.ctx, pc, cr.req)
  1259  		if err != nil {
  1260  			cr.res.reject(err)
  1261  			if !errors.Is(err, protocol.ErrNoRecord) {
  1262  				break
  1263  			}
  1264  		} else {
  1265  			cr.res.resolve(r)
  1266  		}
  1267  		if !c.group.releaseConn(c) {
  1268  			break
  1269  		}
  1270  	}
  1271  }
  1272  
  1273  func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (Response, error) {
  1274  	pprof.SetGoroutineLabels(ctx)
  1275  	defer pprof.SetGoroutineLabels(context.Background())
  1276  
  1277  	if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
  1278  		pc.SetDeadline(deadline)
  1279  		defer pc.SetDeadline(time.Time{})
  1280  	}
  1281  
  1282  	return pc.RoundTrip(req)
  1283  }
  1284  
  1285  // authenticateSASL performs all of the required requests to authenticate this
  1286  // connection.  If any step fails, this function returns with an error.  A nil
  1287  // error indicates successful authentication.
  1288  func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error {
  1289  	if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil {
  1290  		return err
  1291  	}
  1292  
  1293  	sess, state, err := mechanism.Start(ctx)
  1294  	if err != nil {
  1295  		return err
  1296  	}
  1297  
  1298  	for completed := false; !completed; {
  1299  		challenge, err := saslAuthenticateRoundTrip(pc, state)
  1300  		if err != nil {
  1301  			if errors.Is(err, io.EOF) {
  1302  				// the broker may communicate a failed exchange by closing the
  1303  				// connection (esp. in the case where we're passing opaque sasl
  1304  				// data over the wire since there's no protocol info).
  1305  				return SASLAuthenticationFailed
  1306  			}
  1307  
  1308  			return err
  1309  		}
  1310  
  1311  		completed, state, err = sess.Next(ctx, challenge)
  1312  		if err != nil {
  1313  			return err
  1314  		}
  1315  	}
  1316  
  1317  	return nil
  1318  }
  1319  
  1320  // saslHandshake sends the SASL handshake message.  This will determine whether
  1321  // the Mechanism is supported by the cluster.  If it's not, this function will
  1322  // error out with UnsupportedSASLMechanism.
  1323  //
  1324  // If the mechanism is unsupported, the handshake request will reply with the
  1325  // list of the cluster's configured mechanisms, which could potentially be used
  1326  // to facilitate negotiation.  At the moment, we are not negotiating the
  1327  // mechanism as we believe that brokers are usually known to the client, and
  1328  // therefore the client should already know which mechanisms are supported.
  1329  //
  1330  // See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake
  1331  func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error {
  1332  	msg, err := pc.RoundTrip(&saslhandshake.Request{
  1333  		Mechanism: mechanism,
  1334  	})
  1335  	if err != nil {
  1336  		return err
  1337  	}
  1338  	res := msg.(*saslhandshake.Response)
  1339  	if res.ErrorCode != 0 {
  1340  		err = Error(res.ErrorCode)
  1341  	}
  1342  	return err
  1343  }
  1344  
  1345  // saslAuthenticate sends the SASL authenticate message.  This function must
  1346  // be immediately preceded by a successful saslHandshake.
  1347  //
  1348  // See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate
  1349  func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) {
  1350  	msg, err := pc.RoundTrip(&saslauthenticate.Request{
  1351  		AuthBytes: data,
  1352  	})
  1353  	if err != nil {
  1354  		return nil, err
  1355  	}
  1356  	res := msg.(*saslauthenticate.Response)
  1357  	if res.ErrorCode != 0 {
  1358  		err = makeError(res.ErrorCode, res.ErrorMessage)
  1359  	}
  1360  	return res.AuthBytes, err
  1361  }
  1362  
  1363  var _ RoundTripper = (*Transport)(nil)