github.com/rbisecke/kafka-go@v0.4.27/dialer.go (about)

     1  package kafka
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/rbisecke/kafka-go/sasl"
    14  )
    15  
    16  // The Dialer type mirrors the net.Dialer API but is designed to open kafka
    17  // connections instead of raw network connections.
    18  type Dialer struct {
    19  	// Unique identifier for client connections established by this Dialer.
    20  	ClientID string
    21  
    22  	// Optionally specifies the function that the dialer uses to establish
    23  	// network connections. If nil, net.(*Dialer).DialContext is used instead.
    24  	//
    25  	// When DialFunc is set, LocalAddr, DualStack, FallbackDelay, and KeepAlive
    26  	// are ignored.
    27  	DialFunc func(ctx context.Context, network string, address string) (net.Conn, error)
    28  
    29  	// Timeout is the maximum amount of time a dial will wait for a connect to
    30  	// complete. If Deadline is also set, it may fail earlier.
    31  	//
    32  	// The default is no timeout.
    33  	//
    34  	// When dialing a name with multiple IP addresses, the timeout may be
    35  	// divided between them.
    36  	//
    37  	// With or without a timeout, the operating system may impose its own
    38  	// earlier timeout. For instance, TCP timeouts are often around 3 minutes.
    39  	Timeout time.Duration
    40  
    41  	// Deadline is the absolute point in time after which dials will fail.
    42  	// If Timeout is set, it may fail earlier.
    43  	// Zero means no deadline, or dependent on the operating system as with the
    44  	// Timeout option.
    45  	Deadline time.Time
    46  
    47  	// LocalAddr is the local address to use when dialing an address.
    48  	// The address must be of a compatible type for the network being dialed.
    49  	// If nil, a local address is automatically chosen.
    50  	LocalAddr net.Addr
    51  
    52  	// DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing when the
    53  	// network is "tcp" and the destination is a host name with both IPv4 and
    54  	// IPv6 addresses. This allows a client to tolerate networks where one
    55  	// address family is silently broken.
    56  	DualStack bool
    57  
    58  	// FallbackDelay specifies the length of time to wait before spawning a
    59  	// fallback connection, when DualStack is enabled.
    60  	// If zero, a default delay of 300ms is used.
    61  	FallbackDelay time.Duration
    62  
    63  	// KeepAlive specifies the keep-alive period for an active network
    64  	// connection.
    65  	// If zero, keep-alives are not enabled. Network protocols that do not
    66  	// support keep-alives ignore this field.
    67  	KeepAlive time.Duration
    68  
    69  	// Resolver optionally gives a hook to convert the broker address into an
    70  	// alternate host or IP address which is useful for custom service discovery.
    71  	// If a custom resolver returns any possible hosts, the first one will be
    72  	// used and the original discarded. If a port number is included with the
    73  	// resolved host, it will only be used if a port number was not previously
    74  	// specified. If no port is specified or resolved, the default of 9092 will be
    75  	// used.
    76  	Resolver Resolver
    77  
    78  	// TLS enables Dialer to open secure connections.  If nil, standard net.Conn
    79  	// will be used.
    80  	TLS *tls.Config
    81  
    82  	// SASLMechanism configures the Dialer to use SASL authentication.  If nil,
    83  	// no authentication will be performed.
    84  	SASLMechanism sasl.Mechanism
    85  
    86  	// The transactional id to use for transactional delivery. Idempotent
    87  	// deliver should be enabled if transactional id is configured.
    88  	// For more details look at transactional.id description here: http://kafka.apache.org/documentation.html#producerconfigs
    89  	// Empty string means that the connection will be non-transactional.
    90  	TransactionalID string
    91  }
    92  
    93  // Dial connects to the address on the named network.
    94  func (d *Dialer) Dial(network string, address string) (*Conn, error) {
    95  	return d.DialContext(context.Background(), network, address)
    96  }
    97  
    98  // DialContext connects to the address on the named network using the provided
    99  // context.
   100  //
   101  // The provided Context must be non-nil. If the context expires before the
   102  // connection is complete, an error is returned. Once successfully connected,
   103  // any expiration of the context will not affect the connection.
   104  //
   105  // When using TCP, and the host in the address parameter resolves to multiple
   106  // network addresses, any dial timeout (from d.Timeout or ctx) is spread over
   107  // each consecutive dial, such that each is given an appropriate fraction of the
   108  // time to connect. For example, if a host has 4 IP addresses and the timeout is
   109  // 1 minute, the connect to each single address will be given 15 seconds to
   110  // complete before trying the next one.
   111  func (d *Dialer) DialContext(ctx context.Context, network string, address string) (*Conn, error) {
   112  	return d.connect(
   113  		ctx,
   114  		network,
   115  		address,
   116  		ConnConfig{
   117  			ClientID:        d.ClientID,
   118  			TransactionalID: d.TransactionalID,
   119  		},
   120  	)
   121  }
   122  
   123  // DialLeader opens a connection to the leader of the partition for a given
   124  // topic.
   125  //
   126  // The address given to the DialContext method may not be the one that the
   127  // connection will end up being established to, because the dialer will lookup
   128  // the partition leader for the topic and return a connection to that server.
   129  // The original address is only used as a mechanism to discover the
   130  // configuration of the kafka cluster that we're connecting to.
   131  func (d *Dialer) DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
   132  	p, err := d.LookupPartition(ctx, network, address, topic, partition)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	return d.DialPartition(ctx, network, address, p)
   137  }
   138  
   139  // DialPartition opens a connection to the leader of the partition specified by partition
   140  // descriptor. It's strongly advised to use descriptor of the partition that comes out of
   141  // functions LookupPartition or LookupPartitions.
   142  func (d *Dialer) DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
   143  	return d.connect(ctx, network, net.JoinHostPort(partition.Leader.Host, strconv.Itoa(partition.Leader.Port)), ConnConfig{
   144  		ClientID:        d.ClientID,
   145  		Topic:           partition.Topic,
   146  		Partition:       partition.ID,
   147  		Broker:          partition.Leader.ID,
   148  		Rack:            partition.Leader.Rack,
   149  		TransactionalID: d.TransactionalID,
   150  	})
   151  }
   152  
   153  // LookupLeader searches for the kafka broker that is the leader of the
   154  // partition for a given topic, returning a Broker value representing it.
   155  func (d *Dialer) LookupLeader(ctx context.Context, network string, address string, topic string, partition int) (Broker, error) {
   156  	p, err := d.LookupPartition(ctx, network, address, topic, partition)
   157  	return p.Leader, err
   158  }
   159  
   160  // LookupPartition searches for the description of specified partition id.
   161  func (d *Dialer) LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
   162  	c, err := d.DialContext(ctx, network, address)
   163  	if err != nil {
   164  		return Partition{}, err
   165  	}
   166  	defer c.Close()
   167  
   168  	brkch := make(chan Partition, 1)
   169  	errch := make(chan error, 1)
   170  
   171  	go func() {
   172  		for attempt := 0; true; attempt++ {
   173  			if attempt != 0 {
   174  				if !sleep(ctx, backoff(attempt, 100*time.Millisecond, 10*time.Second)) {
   175  					errch <- ctx.Err()
   176  					return
   177  				}
   178  			}
   179  
   180  			partitions, err := c.ReadPartitions(topic)
   181  			if err != nil {
   182  				if isTemporary(err) {
   183  					continue
   184  				}
   185  				errch <- err
   186  				return
   187  			}
   188  
   189  			for _, p := range partitions {
   190  				if p.ID == partition {
   191  					brkch <- p
   192  					return
   193  				}
   194  			}
   195  		}
   196  
   197  		errch <- UnknownTopicOrPartition
   198  	}()
   199  
   200  	var prt Partition
   201  	select {
   202  	case prt = <-brkch:
   203  	case err = <-errch:
   204  	case <-ctx.Done():
   205  		err = ctx.Err()
   206  	}
   207  	return prt, err
   208  }
   209  
   210  // LookupPartitions returns the list of partitions that exist for the given topic.
   211  func (d *Dialer) LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
   212  	conn, err := d.DialContext(ctx, network, address)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	defer conn.Close()
   217  
   218  	prtch := make(chan []Partition, 1)
   219  	errch := make(chan error, 1)
   220  
   221  	go func() {
   222  		if prt, err := conn.ReadPartitions(topic); err != nil {
   223  			errch <- err
   224  		} else {
   225  			prtch <- prt
   226  		}
   227  	}()
   228  
   229  	var prt []Partition
   230  	select {
   231  	case prt = <-prtch:
   232  	case err = <-errch:
   233  	case <-ctx.Done():
   234  		err = ctx.Err()
   235  	}
   236  	return prt, err
   237  }
   238  
   239  // connectTLS returns a tls.Conn that has already completed the Handshake
   240  func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) {
   241  	tlsConn = tls.Client(conn, config)
   242  	errch := make(chan error)
   243  
   244  	go func() {
   245  		defer close(errch)
   246  		errch <- tlsConn.Handshake()
   247  	}()
   248  
   249  	select {
   250  	case <-ctx.Done():
   251  		conn.Close()
   252  		tlsConn.Close()
   253  		<-errch // ignore possible error from Handshake
   254  		err = ctx.Err()
   255  
   256  	case err = <-errch:
   257  	}
   258  
   259  	return
   260  }
   261  
   262  // connect opens a socket connection to the broker, wraps it to create a
   263  // kafka connection, and performs SASL authentication if configured to do so.
   264  func (d *Dialer) connect(ctx context.Context, network, address string, connCfg ConnConfig) (*Conn, error) {
   265  	if d.Timeout != 0 {
   266  		var cancel context.CancelFunc
   267  		ctx, cancel = context.WithTimeout(ctx, d.Timeout)
   268  		defer cancel()
   269  	}
   270  
   271  	if !d.Deadline.IsZero() {
   272  		var cancel context.CancelFunc
   273  		ctx, cancel = context.WithDeadline(ctx, d.Deadline)
   274  		defer cancel()
   275  	}
   276  
   277  	c, err := d.dialContext(ctx, network, address)
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  
   282  	conn := NewConnWith(c, connCfg)
   283  
   284  	if d.SASLMechanism != nil {
   285  		host, port, err := splitHostPortNumber(address)
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  		metadata := &sasl.Metadata{
   290  			Host: host,
   291  			Port: port,
   292  		}
   293  		if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil {
   294  			_ = conn.Close()
   295  			return nil, err
   296  		}
   297  	}
   298  
   299  	return conn, nil
   300  }
   301  
   302  // authenticateSASL performs all of the required requests to authenticate this
   303  // connection.  If any step fails, this function returns with an error.  A nil
   304  // error indicates successful authentication.
   305  //
   306  // In case of error, this function *does not* close the connection.  That is the
   307  // responsibility of the caller.
   308  func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
   309  	if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
   310  		return err
   311  	}
   312  
   313  	sess, state, err := d.SASLMechanism.Start(ctx)
   314  	if err != nil {
   315  		return err
   316  	}
   317  
   318  	for completed := false; !completed; {
   319  		challenge, err := conn.saslAuthenticate(state)
   320  		switch err {
   321  		case nil:
   322  		case io.EOF:
   323  			// the broker may communicate a failed exchange by closing the
   324  			// connection (esp. in the case where we're passing opaque sasl
   325  			// data over the wire since there's no protocol info).
   326  			return SASLAuthenticationFailed
   327  		default:
   328  			return err
   329  		}
   330  
   331  		completed, state, err = sess.Next(ctx, challenge)
   332  		if err != nil {
   333  			return err
   334  		}
   335  	}
   336  
   337  	return nil
   338  }
   339  
   340  func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
   341  	address, err := lookupHost(ctx, addr, d.Resolver)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	dial := d.DialFunc
   347  	if dial == nil {
   348  		dial = (&net.Dialer{
   349  			LocalAddr:     d.LocalAddr,
   350  			DualStack:     d.DualStack,
   351  			FallbackDelay: d.FallbackDelay,
   352  			KeepAlive:     d.KeepAlive,
   353  		}).DialContext
   354  	}
   355  
   356  	conn, err := dial(ctx, network, address)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  
   361  	if d.TLS != nil {
   362  		c := d.TLS
   363  		// If no ServerName is set, infer the ServerName
   364  		// from the hostname we're connecting to.
   365  		if c.ServerName == "" {
   366  			c = d.TLS.Clone()
   367  			// Copied from tls.go in the standard library.
   368  			colonPos := strings.LastIndex(address, ":")
   369  			if colonPos == -1 {
   370  				colonPos = len(address)
   371  			}
   372  			hostname := address[:colonPos]
   373  			c.ServerName = hostname
   374  		}
   375  		return d.connectTLS(ctx, conn, c)
   376  	}
   377  
   378  	return conn, nil
   379  }
   380  
   381  // DefaultDialer is the default dialer used when none is specified.
   382  var DefaultDialer = &Dialer{
   383  	Timeout:   10 * time.Second,
   384  	DualStack: true,
   385  }
   386  
   387  // Dial is a convenience wrapper for DefaultDialer.Dial.
   388  func Dial(network string, address string) (*Conn, error) {
   389  	return DefaultDialer.Dial(network, address)
   390  }
   391  
   392  // DialContext is a convenience wrapper for DefaultDialer.DialContext.
   393  func DialContext(ctx context.Context, network string, address string) (*Conn, error) {
   394  	return DefaultDialer.DialContext(ctx, network, address)
   395  }
   396  
   397  // DialLeader is a convenience wrapper for DefaultDialer.DialLeader.
   398  func DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
   399  	return DefaultDialer.DialLeader(ctx, network, address, topic, partition)
   400  }
   401  
   402  // DialPartition is a convenience wrapper for DefaultDialer.DialPartition.
   403  func DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
   404  	return DefaultDialer.DialPartition(ctx, network, address, partition)
   405  }
   406  
   407  // LookupPartition is a convenience wrapper for DefaultDialer.LookupPartition.
   408  func LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
   409  	return DefaultDialer.LookupPartition(ctx, network, address, topic, partition)
   410  }
   411  
   412  // LookupPartitions is a convenience wrapper for DefaultDialer.LookupPartitions.
   413  func LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
   414  	return DefaultDialer.LookupPartitions(ctx, network, address, topic)
   415  }
   416  
   417  func sleep(ctx context.Context, duration time.Duration) bool {
   418  	if duration == 0 {
   419  		select {
   420  		default:
   421  			return true
   422  		case <-ctx.Done():
   423  			return false
   424  		}
   425  	}
   426  	timer := time.NewTimer(duration)
   427  	defer timer.Stop()
   428  	select {
   429  	case <-timer.C:
   430  		return true
   431  	case <-ctx.Done():
   432  		return false
   433  	}
   434  }
   435  
   436  func backoff(attempt int, min time.Duration, max time.Duration) time.Duration {
   437  	d := time.Duration(attempt*attempt) * min
   438  	if d > max {
   439  		d = max
   440  	}
   441  	return d
   442  }
   443  
   444  func canonicalAddress(s string) string {
   445  	return net.JoinHostPort(splitHostPort(s))
   446  }
   447  
   448  func splitHostPort(s string) (host string, port string) {
   449  	host, port, _ = net.SplitHostPort(s)
   450  	if len(host) == 0 && len(port) == 0 {
   451  		host = s
   452  		port = "9092"
   453  	}
   454  	return
   455  }
   456  
   457  func splitHostPortNumber(s string) (host string, portNumber int, err error) {
   458  	host, port := splitHostPort(s)
   459  	portNumber, err = strconv.Atoi(port)
   460  	if err != nil {
   461  		return host, 0, fmt.Errorf("%s: %w", s, err)
   462  	}
   463  	return host, portNumber, nil
   464  }
   465  
   466  func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
   467  	host, port := splitHostPort(address)
   468  
   469  	if resolver != nil {
   470  		resolved, err := resolver.LookupHost(ctx, host)
   471  		if err != nil {
   472  			return "", err
   473  		}
   474  
   475  		// if the resolver doesn't return anything, we'll fall back on the provided
   476  		// address instead
   477  		if len(resolved) > 0 {
   478  			resolvedHost, resolvedPort := splitHostPort(resolved[0])
   479  
   480  			// we'll always prefer the resolved host
   481  			host = resolvedHost
   482  
   483  			// in the case of port though, the provided address takes priority, and we
   484  			// only use the resolved address to set the port when not specified
   485  			if port == "" {
   486  				port = resolvedPort
   487  			}
   488  		}
   489  	}
   490  
   491  	return net.JoinHostPort(host, port), nil
   492  }