github.com/hoveychen/kafka-go@v0.4.42/dialer.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/hoveychen/kafka-go/sasl"
    15  )
    16  
    17  // The Dialer type mirrors the net.Dialer API but is designed to open kafka
    18  // connections instead of raw network connections.
    19  type Dialer struct {
    20  	// Unique identifier for client connections established by this Dialer.
    21  	ClientID string
    22  
    23  	// Optionally specifies the function that the dialer uses to establish
    24  	// network connections. If nil, net.(*Dialer).DialContext is used instead.
    25  	//
    26  	// When DialFunc is set, LocalAddr, DualStack, FallbackDelay, and KeepAlive
    27  	// are ignored.
    28  	DialFunc func(ctx context.Context, network string, address string) (net.Conn, error)
    29  
    30  	// Timeout is the maximum amount of time a dial will wait for a connect to
    31  	// complete. If Deadline is also set, it may fail earlier.
    32  	//
    33  	// The default is no timeout.
    34  	//
    35  	// When dialing a name with multiple IP addresses, the timeout may be
    36  	// divided between them.
    37  	//
    38  	// With or without a timeout, the operating system may impose its own
    39  	// earlier timeout. For instance, TCP timeouts are often around 3 minutes.
    40  	Timeout time.Duration
    41  
    42  	// Deadline is the absolute point in time after which dials will fail.
    43  	// If Timeout is set, it may fail earlier.
    44  	// Zero means no deadline, or dependent on the operating system as with the
    45  	// Timeout option.
    46  	Deadline time.Time
    47  
    48  	// LocalAddr is the local address to use when dialing an address.
    49  	// The address must be of a compatible type for the network being dialed.
    50  	// If nil, a local address is automatically chosen.
    51  	LocalAddr net.Addr
    52  
    53  	// DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing when the
    54  	// network is "tcp" and the destination is a host name with both IPv4 and
    55  	// IPv6 addresses. This allows a client to tolerate networks where one
    56  	// address family is silently broken.
    57  	DualStack bool
    58  
    59  	// FallbackDelay specifies the length of time to wait before spawning a
    60  	// fallback connection, when DualStack is enabled.
    61  	// If zero, a default delay of 300ms is used.
    62  	FallbackDelay time.Duration
    63  
    64  	// KeepAlive specifies the keep-alive period for an active network
    65  	// connection.
    66  	// If zero, keep-alives are not enabled. Network protocols that do not
    67  	// support keep-alives ignore this field.
    68  	KeepAlive time.Duration
    69  
    70  	// Resolver optionally gives a hook to convert the broker address into an
    71  	// alternate host or IP address which is useful for custom service discovery.
    72  	// If a custom resolver returns any possible hosts, the first one will be
    73  	// used and the original discarded. If a port number is included with the
    74  	// resolved host, it will only be used if a port number was not previously
    75  	// specified. If no port is specified or resolved, the default of 9092 will be
    76  	// used.
    77  	Resolver Resolver
    78  
    79  	// TLS enables Dialer to open secure connections.  If nil, standard net.Conn
    80  	// will be used.
    81  	TLS *tls.Config
    82  
    83  	// SASLMechanism configures the Dialer to use SASL authentication.  If nil,
    84  	// no authentication will be performed.
    85  	SASLMechanism sasl.Mechanism
    86  
    87  	// The transactional id to use for transactional delivery. Idempotent
    88  	// deliver should be enabled if transactional id is configured.
    89  	// For more details look at transactional.id description here: http://kafka.apache.org/documentation.html#producerconfigs
    90  	// Empty string means that the connection will be non-transactional.
    91  	TransactionalID string
    92  }
    93  
    94  // Dial connects to the address on the named network.
    95  func (d *Dialer) Dial(network string, address string) (*Conn, error) {
    96  	return d.DialContext(context.Background(), network, address)
    97  }
    98  
    99  // DialContext connects to the address on the named network using the provided
   100  // context.
   101  //
   102  // The provided Context must be non-nil. If the context expires before the
   103  // connection is complete, an error is returned. Once successfully connected,
   104  // any expiration of the context will not affect the connection.
   105  //
   106  // When using TCP, and the host in the address parameter resolves to multiple
   107  // network addresses, any dial timeout (from d.Timeout or ctx) is spread over
   108  // each consecutive dial, such that each is given an appropriate fraction of the
   109  // time to connect. For example, if a host has 4 IP addresses and the timeout is
   110  // 1 minute, the connect to each single address will be given 15 seconds to
   111  // complete before trying the next one.
   112  func (d *Dialer) DialContext(ctx context.Context, network string, address string) (*Conn, error) {
   113  	return d.connect(
   114  		ctx,
   115  		network,
   116  		address,
   117  		ConnConfig{
   118  			ClientID:        d.ClientID,
   119  			TransactionalID: d.TransactionalID,
   120  		},
   121  	)
   122  }
   123  
   124  // DialLeader opens a connection to the leader of the partition for a given
   125  // topic.
   126  //
   127  // The address given to the DialContext method may not be the one that the
   128  // connection will end up being established to, because the dialer will lookup
   129  // the partition leader for the topic and return a connection to that server.
   130  // The original address is only used as a mechanism to discover the
   131  // configuration of the kafka cluster that we're connecting to.
   132  func (d *Dialer) DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
   133  	p, err := d.LookupPartition(ctx, network, address, topic, partition)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	return d.DialPartition(ctx, network, address, p)
   138  }
   139  
   140  // DialPartition opens a connection to the leader of the partition specified by partition
   141  // descriptor. It's strongly advised to use descriptor of the partition that comes out of
   142  // functions LookupPartition or LookupPartitions.
   143  func (d *Dialer) DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
   144  	return d.connect(ctx, network, net.JoinHostPort(partition.Leader.Host, strconv.Itoa(partition.Leader.Port)), ConnConfig{
   145  		ClientID:        d.ClientID,
   146  		Topic:           partition.Topic,
   147  		Partition:       partition.ID,
   148  		Broker:          partition.Leader.ID,
   149  		Rack:            partition.Leader.Rack,
   150  		TransactionalID: d.TransactionalID,
   151  	})
   152  }
   153  
   154  // LookupLeader searches for the kafka broker that is the leader of the
   155  // partition for a given topic, returning a Broker value representing it.
   156  func (d *Dialer) LookupLeader(ctx context.Context, network string, address string, topic string, partition int) (Broker, error) {
   157  	p, err := d.LookupPartition(ctx, network, address, topic, partition)
   158  	return p.Leader, err
   159  }
   160  
   161  // LookupPartition searches for the description of specified partition id.
   162  func (d *Dialer) LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
   163  	c, err := d.DialContext(ctx, network, address)
   164  	if err != nil {
   165  		return Partition{}, err
   166  	}
   167  	defer c.Close()
   168  
   169  	brkch := make(chan Partition, 1)
   170  	errch := make(chan error, 1)
   171  
   172  	go func() {
   173  		for attempt := 0; true; attempt++ {
   174  			if attempt != 0 {
   175  				if !sleep(ctx, backoff(attempt, 100*time.Millisecond, 10*time.Second)) {
   176  					errch <- ctx.Err()
   177  					return
   178  				}
   179  			}
   180  
   181  			partitions, err := c.ReadPartitions(topic)
   182  			if err != nil {
   183  				if isTemporary(err) {
   184  					continue
   185  				}
   186  				errch <- err
   187  				return
   188  			}
   189  
   190  			for _, p := range partitions {
   191  				if p.ID == partition {
   192  					brkch <- p
   193  					return
   194  				}
   195  			}
   196  		}
   197  
   198  		errch <- UnknownTopicOrPartition
   199  	}()
   200  
   201  	var prt Partition
   202  	select {
   203  	case prt = <-brkch:
   204  	case err = <-errch:
   205  	case <-ctx.Done():
   206  		err = ctx.Err()
   207  	}
   208  	return prt, err
   209  }
   210  
   211  // LookupPartitions returns the list of partitions that exist for the given topic.
   212  func (d *Dialer) LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
   213  	conn, err := d.DialContext(ctx, network, address)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	defer conn.Close()
   218  
   219  	prtch := make(chan []Partition, 1)
   220  	errch := make(chan error, 1)
   221  
   222  	go func() {
   223  		if prt, err := conn.ReadPartitions(topic); err != nil {
   224  			errch <- err
   225  		} else {
   226  			prtch <- prt
   227  		}
   228  	}()
   229  
   230  	var prt []Partition
   231  	select {
   232  	case prt = <-prtch:
   233  	case err = <-errch:
   234  	case <-ctx.Done():
   235  		err = ctx.Err()
   236  	}
   237  	return prt, err
   238  }
   239  
   240  // connectTLS returns a tls.Conn that has already completed the Handshake.
   241  func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) {
   242  	tlsConn = tls.Client(conn, config)
   243  	errch := make(chan error)
   244  
   245  	go func() {
   246  		defer close(errch)
   247  		errch <- tlsConn.Handshake()
   248  	}()
   249  
   250  	select {
   251  	case <-ctx.Done():
   252  		conn.Close()
   253  		tlsConn.Close()
   254  		<-errch // ignore possible error from Handshake
   255  		err = ctx.Err()
   256  
   257  	case err = <-errch:
   258  	}
   259  
   260  	return
   261  }
   262  
   263  // connect opens a socket connection to the broker, wraps it to create a
   264  // kafka connection, and performs SASL authentication if configured to do so.
   265  func (d *Dialer) connect(ctx context.Context, network, address string, connCfg ConnConfig) (*Conn, error) {
   266  	if d.Timeout != 0 {
   267  		var cancel context.CancelFunc
   268  		ctx, cancel = context.WithTimeout(ctx, d.Timeout)
   269  		defer cancel()
   270  	}
   271  
   272  	if !d.Deadline.IsZero() {
   273  		var cancel context.CancelFunc
   274  		ctx, cancel = context.WithDeadline(ctx, d.Deadline)
   275  		defer cancel()
   276  	}
   277  
   278  	c, err := d.dialContext(ctx, network, address)
   279  	if err != nil {
   280  		return nil, fmt.Errorf("failed to dial: %w", err)
   281  	}
   282  
   283  	conn := NewConnWith(c, connCfg)
   284  
   285  	if d.SASLMechanism != nil {
   286  		host, port, err := splitHostPortNumber(address)
   287  		if err != nil {
   288  			return nil, fmt.Errorf("could not determine host/port for SASL authentication: %w", err)
   289  		}
   290  		metadata := &sasl.Metadata{
   291  			Host: host,
   292  			Port: port,
   293  		}
   294  		if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil {
   295  			_ = conn.Close()
   296  			return nil, fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err)
   297  		}
   298  	}
   299  
   300  	return conn, nil
   301  }
   302  
   303  // authenticateSASL performs all of the required requests to authenticate this
   304  // connection.  If any step fails, this function returns with an error.  A nil
   305  // error indicates successful authentication.
   306  //
   307  // In case of error, this function *does not* close the connection.  That is the
   308  // responsibility of the caller.
   309  func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
   310  	if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
   311  		return fmt.Errorf("SASL handshake failed: %w", err)
   312  	}
   313  
   314  	sess, state, err := d.SASLMechanism.Start(ctx)
   315  	if err != nil {
   316  		return fmt.Errorf("SASL authentication process could not be started: %w", err)
   317  	}
   318  
   319  	for completed := false; !completed; {
   320  		challenge, err := conn.saslAuthenticate(state)
   321  		switch {
   322  		case err == nil:
   323  		case errors.Is(err, io.EOF):
   324  			// the broker may communicate a failed exchange by closing the
   325  			// connection (esp. in the case where we're passing opaque sasl
   326  			// data over the wire since there's no protocol info).
   327  			return SASLAuthenticationFailed
   328  		default:
   329  			return err
   330  		}
   331  
   332  		completed, state, err = sess.Next(ctx, challenge)
   333  		if err != nil {
   334  			return fmt.Errorf("SASL authentication process has failed: %w", err)
   335  		}
   336  	}
   337  
   338  	return nil
   339  }
   340  
   341  func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
   342  	address, err := lookupHost(ctx, addr, d.Resolver)
   343  	if err != nil {
   344  		return nil, fmt.Errorf("failed to resolve host: %w", err)
   345  	}
   346  
   347  	dial := d.DialFunc
   348  	if dial == nil {
   349  		dial = (&net.Dialer{
   350  			LocalAddr:     d.LocalAddr,
   351  			DualStack:     d.DualStack,
   352  			FallbackDelay: d.FallbackDelay,
   353  			KeepAlive:     d.KeepAlive,
   354  		}).DialContext
   355  	}
   356  
   357  	conn, err := dial(ctx, network, address)
   358  	if err != nil {
   359  		return nil, fmt.Errorf("failed to open connection to %s: %w", address, err)
   360  	}
   361  
   362  	if d.TLS != nil {
   363  		c := d.TLS
   364  		// If no ServerName is set, infer the ServerName
   365  		// from the hostname we're connecting to.
   366  		if c.ServerName == "" {
   367  			c = d.TLS.Clone()
   368  			// Copied from tls.go in the standard library.
   369  			colonPos := strings.LastIndex(address, ":")
   370  			if colonPos == -1 {
   371  				colonPos = len(address)
   372  			}
   373  			hostname := address[:colonPos]
   374  			c.ServerName = hostname
   375  		}
   376  		return d.connectTLS(ctx, conn, c)
   377  	}
   378  
   379  	return conn, nil
   380  }
   381  
   382  // DefaultDialer is the default dialer used when none is specified.
   383  var DefaultDialer = &Dialer{
   384  	Timeout:   10 * time.Second,
   385  	DualStack: true,
   386  }
   387  
   388  // Dial is a convenience wrapper for DefaultDialer.Dial.
   389  func Dial(network string, address string) (*Conn, error) {
   390  	return DefaultDialer.Dial(network, address)
   391  }
   392  
   393  // DialContext is a convenience wrapper for DefaultDialer.DialContext.
   394  func DialContext(ctx context.Context, network string, address string) (*Conn, error) {
   395  	return DefaultDialer.DialContext(ctx, network, address)
   396  }
   397  
   398  // DialLeader is a convenience wrapper for DefaultDialer.DialLeader.
   399  func DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
   400  	return DefaultDialer.DialLeader(ctx, network, address, topic, partition)
   401  }
   402  
   403  // DialPartition is a convenience wrapper for DefaultDialer.DialPartition.
   404  func DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
   405  	return DefaultDialer.DialPartition(ctx, network, address, partition)
   406  }
   407  
   408  // LookupPartition is a convenience wrapper for DefaultDialer.LookupPartition.
   409  func LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
   410  	return DefaultDialer.LookupPartition(ctx, network, address, topic, partition)
   411  }
   412  
   413  // LookupPartitions is a convenience wrapper for DefaultDialer.LookupPartitions.
   414  func LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
   415  	return DefaultDialer.LookupPartitions(ctx, network, address, topic)
   416  }
   417  
   418  func sleep(ctx context.Context, duration time.Duration) bool {
   419  	if duration == 0 {
   420  		select {
   421  		default:
   422  			return true
   423  		case <-ctx.Done():
   424  			return false
   425  		}
   426  	}
   427  	timer := time.NewTimer(duration)
   428  	defer timer.Stop()
   429  	select {
   430  	case <-timer.C:
   431  		return true
   432  	case <-ctx.Done():
   433  		return false
   434  	}
   435  }
   436  
   437  func backoff(attempt int, min time.Duration, max time.Duration) time.Duration {
   438  	d := time.Duration(attempt*attempt) * min
   439  	if d > max {
   440  		d = max
   441  	}
   442  	return d
   443  }
   444  
   445  func canonicalAddress(s string) string {
   446  	return net.JoinHostPort(splitHostPort(s))
   447  }
   448  
   449  func splitHostPort(s string) (host string, port string) {
   450  	host, port, _ = net.SplitHostPort(s)
   451  	if len(host) == 0 && len(port) == 0 {
   452  		host = s
   453  		port = "9092"
   454  	}
   455  	return
   456  }
   457  
   458  func splitHostPortNumber(s string) (host string, portNumber int, err error) {
   459  	host, port := splitHostPort(s)
   460  	portNumber, err = strconv.Atoi(port)
   461  	if err != nil {
   462  		return host, 0, fmt.Errorf("%s: %w", s, err)
   463  	}
   464  	return host, portNumber, nil
   465  }
   466  
   467  func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
   468  	host, port := splitHostPort(address)
   469  
   470  	if resolver != nil {
   471  		resolved, err := resolver.LookupHost(ctx, host)
   472  		if err != nil {
   473  			return "", fmt.Errorf("failed to resolve host %s: %w", host, err)
   474  		}
   475  
   476  		// if the resolver doesn't return anything, we'll fall back on the provided
   477  		// address instead
   478  		if len(resolved) > 0 {
   479  			resolvedHost, resolvedPort, err := net.SplitHostPort(resolved[0])
   480  			if err == nil {
   481  				//no error if "fully qualified" (host:[port])
   482  				if resolvedPort != "" {
   483  					//use explicitly set port
   484  					port = resolvedPort
   485  				}
   486  				host = resolvedHost
   487  			} else {
   488  				//fallback to resolved host
   489  				host = resolved[0]
   490  			}
   491  		}
   492  	}
   493  
   494  	return net.JoinHostPort(host, port), nil
   495  }