github.com/streamdal/segmentio-kafka-go@v0.4.47-streamdal/dialer.go (about) 1 package kafka 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/segmentio/kafka-go/sasl" 15 ) 16 17 // The Dialer type mirrors the net.Dialer API but is designed to open kafka 18 // connections instead of raw network connections. 19 type Dialer struct { 20 // Unique identifier for client connections established by this Dialer. 21 ClientID string 22 23 // Optionally specifies the function that the dialer uses to establish 24 // network connections. If nil, net.(*Dialer).DialContext is used instead. 25 // 26 // When DialFunc is set, LocalAddr, DualStack, FallbackDelay, and KeepAlive 27 // are ignored. 28 DialFunc func(ctx context.Context, network string, address string) (net.Conn, error) 29 30 // Timeout is the maximum amount of time a dial will wait for a connect to 31 // complete. If Deadline is also set, it may fail earlier. 32 // 33 // The default is no timeout. 34 // 35 // When dialing a name with multiple IP addresses, the timeout may be 36 // divided between them. 37 // 38 // With or without a timeout, the operating system may impose its own 39 // earlier timeout. For instance, TCP timeouts are often around 3 minutes. 40 Timeout time.Duration 41 42 // Deadline is the absolute point in time after which dials will fail. 43 // If Timeout is set, it may fail earlier. 44 // Zero means no deadline, or dependent on the operating system as with the 45 // Timeout option. 46 Deadline time.Time 47 48 // LocalAddr is the local address to use when dialing an address. 49 // The address must be of a compatible type for the network being dialed. 50 // If nil, a local address is automatically chosen. 51 LocalAddr net.Addr 52 53 // DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing when the 54 // network is "tcp" and the destination is a host name with both IPv4 and 55 // IPv6 addresses. This allows a client to tolerate networks where one 56 // address family is silently broken. 57 DualStack bool 58 59 // FallbackDelay specifies the length of time to wait before spawning a 60 // fallback connection, when DualStack is enabled. 61 // If zero, a default delay of 300ms is used. 62 FallbackDelay time.Duration 63 64 // KeepAlive specifies the keep-alive period for an active network 65 // connection. 66 // If zero, keep-alives are not enabled. Network protocols that do not 67 // support keep-alives ignore this field. 68 KeepAlive time.Duration 69 70 // Resolver optionally gives a hook to convert the broker address into an 71 // alternate host or IP address which is useful for custom service discovery. 72 // If a custom resolver returns any possible hosts, the first one will be 73 // used and the original discarded. If a port number is included with the 74 // resolved host, it will only be used if a port number was not previously 75 // specified. If no port is specified or resolved, the default of 9092 will be 76 // used. 77 Resolver Resolver 78 79 // TLS enables Dialer to open secure connections. If nil, standard net.Conn 80 // will be used. 81 TLS *tls.Config 82 83 // SASLMechanism configures the Dialer to use SASL authentication. If nil, 84 // no authentication will be performed. 85 SASLMechanism sasl.Mechanism 86 87 // The transactional id to use for transactional delivery. Idempotent 88 // deliver should be enabled if transactional id is configured. 89 // For more details look at transactional.id description here: http://kafka.apache.org/documentation.html#producerconfigs 90 // Empty string means that the connection will be non-transactional. 91 TransactionalID string 92 } 93 94 // Dial connects to the address on the named network. 95 func (d *Dialer) Dial(network string, address string) (*Conn, error) { 96 return d.DialContext(context.Background(), network, address) 97 } 98 99 // DialContext connects to the address on the named network using the provided 100 // context. 101 // 102 // The provided Context must be non-nil. If the context expires before the 103 // connection is complete, an error is returned. Once successfully connected, 104 // any expiration of the context will not affect the connection. 105 // 106 // When using TCP, and the host in the address parameter resolves to multiple 107 // network addresses, any dial timeout (from d.Timeout or ctx) is spread over 108 // each consecutive dial, such that each is given an appropriate fraction of the 109 // time to connect. For example, if a host has 4 IP addresses and the timeout is 110 // 1 minute, the connect to each single address will be given 15 seconds to 111 // complete before trying the next one. 112 func (d *Dialer) DialContext(ctx context.Context, network string, address string) (*Conn, error) { 113 return d.connect( 114 ctx, 115 network, 116 address, 117 ConnConfig{ 118 ClientID: d.ClientID, 119 TransactionalID: d.TransactionalID, 120 }, 121 ) 122 } 123 124 // DialLeader opens a connection to the leader of the partition for a given 125 // topic. 126 // 127 // The address given to the DialContext method may not be the one that the 128 // connection will end up being established to, because the dialer will lookup 129 // the partition leader for the topic and return a connection to that server. 130 // The original address is only used as a mechanism to discover the 131 // configuration of the kafka cluster that we're connecting to. 132 func (d *Dialer) DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) { 133 p, err := d.LookupPartition(ctx, network, address, topic, partition) 134 if err != nil { 135 return nil, err 136 } 137 return d.DialPartition(ctx, network, address, p) 138 } 139 140 // DialPartition opens a connection to the leader of the partition specified by partition 141 // descriptor. It's strongly advised to use descriptor of the partition that comes out of 142 // functions LookupPartition or LookupPartitions. 143 func (d *Dialer) DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) { 144 return d.connect(ctx, network, net.JoinHostPort(partition.Leader.Host, strconv.Itoa(partition.Leader.Port)), ConnConfig{ 145 ClientID: d.ClientID, 146 Topic: partition.Topic, 147 Partition: partition.ID, 148 Broker: partition.Leader.ID, 149 Rack: partition.Leader.Rack, 150 TransactionalID: d.TransactionalID, 151 }) 152 } 153 154 // LookupLeader searches for the kafka broker that is the leader of the 155 // partition for a given topic, returning a Broker value representing it. 156 func (d *Dialer) LookupLeader(ctx context.Context, network string, address string, topic string, partition int) (Broker, error) { 157 p, err := d.LookupPartition(ctx, network, address, topic, partition) 158 return p.Leader, err 159 } 160 161 // LookupPartition searches for the description of specified partition id. 162 func (d *Dialer) LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) { 163 c, err := d.DialContext(ctx, network, address) 164 if err != nil { 165 return Partition{}, err 166 } 167 defer c.Close() 168 169 brkch := make(chan Partition, 1) 170 errch := make(chan error, 1) 171 172 go func() { 173 for attempt := 0; true; attempt++ { 174 if attempt != 0 { 175 if !sleep(ctx, backoff(attempt, 100*time.Millisecond, 10*time.Second)) { 176 errch <- ctx.Err() 177 return 178 } 179 } 180 181 partitions, err := c.ReadPartitions(topic) 182 if err != nil { 183 if isTemporary(err) { 184 continue 185 } 186 errch <- err 187 return 188 } 189 190 for _, p := range partitions { 191 if p.ID == partition { 192 brkch <- p 193 return 194 } 195 } 196 } 197 198 errch <- UnknownTopicOrPartition 199 }() 200 201 var prt Partition 202 select { 203 case prt = <-brkch: 204 case err = <-errch: 205 case <-ctx.Done(): 206 err = ctx.Err() 207 } 208 return prt, err 209 } 210 211 // LookupPartitions returns the list of partitions that exist for the given topic. 212 func (d *Dialer) LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) { 213 conn, err := d.DialContext(ctx, network, address) 214 if err != nil { 215 return nil, err 216 } 217 defer conn.Close() 218 219 prtch := make(chan []Partition, 1) 220 errch := make(chan error, 1) 221 222 go func() { 223 if prt, err := conn.ReadPartitions(topic); err != nil { 224 errch <- err 225 } else { 226 prtch <- prt 227 } 228 }() 229 230 var prt []Partition 231 select { 232 case prt = <-prtch: 233 case err = <-errch: 234 case <-ctx.Done(): 235 err = ctx.Err() 236 } 237 return prt, err 238 } 239 240 // connectTLS returns a tls.Conn that has already completed the Handshake. 241 func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) { 242 tlsConn = tls.Client(conn, config) 243 errch := make(chan error) 244 245 go func() { 246 defer close(errch) 247 errch <- tlsConn.Handshake() 248 }() 249 250 select { 251 case <-ctx.Done(): 252 conn.Close() 253 tlsConn.Close() 254 <-errch // ignore possible error from Handshake 255 err = ctx.Err() 256 257 case err = <-errch: 258 } 259 260 return 261 } 262 263 // connect opens a socket connection to the broker, wraps it to create a 264 // kafka connection, and performs SASL authentication if configured to do so. 265 func (d *Dialer) connect(ctx context.Context, network, address string, connCfg ConnConfig) (*Conn, error) { 266 if d.Timeout != 0 { 267 var cancel context.CancelFunc 268 ctx, cancel = context.WithTimeout(ctx, d.Timeout) 269 defer cancel() 270 } 271 272 if !d.Deadline.IsZero() { 273 var cancel context.CancelFunc 274 ctx, cancel = context.WithDeadline(ctx, d.Deadline) 275 defer cancel() 276 } 277 278 c, err := d.dialContext(ctx, network, address) 279 if err != nil { 280 return nil, fmt.Errorf("failed to dial: %w", err) 281 } 282 283 conn := NewConnWith(c, connCfg) 284 285 if d.SASLMechanism != nil { 286 host, port, err := splitHostPortNumber(address) 287 if err != nil { 288 return nil, fmt.Errorf("could not determine host/port for SASL authentication: %w", err) 289 } 290 metadata := &sasl.Metadata{ 291 Host: host, 292 Port: port, 293 } 294 if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil { 295 _ = conn.Close() 296 return nil, fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err) 297 } 298 } 299 300 return conn, nil 301 } 302 303 // authenticateSASL performs all of the required requests to authenticate this 304 // connection. If any step fails, this function returns with an error. A nil 305 // error indicates successful authentication. 306 // 307 // In case of error, this function *does not* close the connection. That is the 308 // responsibility of the caller. 309 func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { 310 if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil { 311 return fmt.Errorf("SASL handshake failed: %w", err) 312 } 313 314 sess, state, err := d.SASLMechanism.Start(ctx) 315 if err != nil { 316 return fmt.Errorf("SASL authentication process could not be started: %w", err) 317 } 318 319 for completed := false; !completed; { 320 challenge, err := conn.saslAuthenticate(state) 321 switch { 322 case err == nil: 323 case errors.Is(err, io.EOF): 324 // the broker may communicate a failed exchange by closing the 325 // connection (esp. in the case where we're passing opaque sasl 326 // data over the wire since there's no protocol info). 327 return SASLAuthenticationFailed 328 default: 329 return err 330 } 331 332 completed, state, err = sess.Next(ctx, challenge) 333 if err != nil { 334 return fmt.Errorf("SASL authentication process has failed: %w", err) 335 } 336 } 337 338 return nil 339 } 340 341 func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) { 342 address, err := lookupHost(ctx, addr, d.Resolver) 343 if err != nil { 344 return nil, fmt.Errorf("failed to resolve host: %w", err) 345 } 346 347 dial := d.DialFunc 348 if dial == nil { 349 dial = (&net.Dialer{ 350 LocalAddr: d.LocalAddr, 351 DualStack: d.DualStack, 352 FallbackDelay: d.FallbackDelay, 353 KeepAlive: d.KeepAlive, 354 }).DialContext 355 } 356 357 conn, err := dial(ctx, network, address) 358 if err != nil { 359 return nil, fmt.Errorf("failed to open connection to %s: %w", address, err) 360 } 361 362 if d.TLS != nil { 363 c := d.TLS 364 // If no ServerName is set, infer the ServerName 365 // from the hostname we're connecting to. 366 if c.ServerName == "" { 367 c = d.TLS.Clone() 368 // Copied from tls.go in the standard library. 369 colonPos := strings.LastIndex(address, ":") 370 if colonPos == -1 { 371 colonPos = len(address) 372 } 373 hostname := address[:colonPos] 374 c.ServerName = hostname 375 } 376 return d.connectTLS(ctx, conn, c) 377 } 378 379 return conn, nil 380 } 381 382 // DefaultDialer is the default dialer used when none is specified. 383 var DefaultDialer = &Dialer{ 384 Timeout: 10 * time.Second, 385 DualStack: true, 386 } 387 388 // Dial is a convenience wrapper for DefaultDialer.Dial. 389 func Dial(network string, address string) (*Conn, error) { 390 return DefaultDialer.Dial(network, address) 391 } 392 393 // DialContext is a convenience wrapper for DefaultDialer.DialContext. 394 func DialContext(ctx context.Context, network string, address string) (*Conn, error) { 395 return DefaultDialer.DialContext(ctx, network, address) 396 } 397 398 // DialLeader is a convenience wrapper for DefaultDialer.DialLeader. 399 func DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) { 400 return DefaultDialer.DialLeader(ctx, network, address, topic, partition) 401 } 402 403 // DialPartition is a convenience wrapper for DefaultDialer.DialPartition. 404 func DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) { 405 return DefaultDialer.DialPartition(ctx, network, address, partition) 406 } 407 408 // LookupPartition is a convenience wrapper for DefaultDialer.LookupPartition. 409 func LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) { 410 return DefaultDialer.LookupPartition(ctx, network, address, topic, partition) 411 } 412 413 // LookupPartitions is a convenience wrapper for DefaultDialer.LookupPartitions. 414 func LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) { 415 return DefaultDialer.LookupPartitions(ctx, network, address, topic) 416 } 417 418 func sleep(ctx context.Context, duration time.Duration) bool { 419 if duration == 0 { 420 select { 421 default: 422 return true 423 case <-ctx.Done(): 424 return false 425 } 426 } 427 timer := time.NewTimer(duration) 428 defer timer.Stop() 429 select { 430 case <-timer.C: 431 return true 432 case <-ctx.Done(): 433 return false 434 } 435 } 436 437 func backoff(attempt int, min time.Duration, max time.Duration) time.Duration { 438 d := time.Duration(attempt*attempt) * min 439 if d > max { 440 d = max 441 } 442 return d 443 } 444 445 func canonicalAddress(s string) string { 446 return net.JoinHostPort(splitHostPort(s)) 447 } 448 449 func splitHostPort(s string) (host string, port string) { 450 host, port, _ = net.SplitHostPort(s) 451 if len(host) == 0 && len(port) == 0 { 452 host = s 453 port = "9092" 454 } 455 return 456 } 457 458 func splitHostPortNumber(s string) (host string, portNumber int, err error) { 459 host, port := splitHostPort(s) 460 portNumber, err = strconv.Atoi(port) 461 if err != nil { 462 return host, 0, fmt.Errorf("%s: %w", s, err) 463 } 464 return host, portNumber, nil 465 } 466 467 func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) { 468 host, port := splitHostPort(address) 469 470 if resolver != nil { 471 resolved, err := resolver.LookupHost(ctx, host) 472 if err != nil { 473 return "", fmt.Errorf("failed to resolve host %s: %w", host, err) 474 } 475 476 // if the resolver doesn't return anything, we'll fall back on the provided 477 // address instead 478 if len(resolved) > 0 { 479 resolvedHost, resolvedPort := splitHostPort(resolved[0]) 480 481 // we'll always prefer the resolved host 482 host = resolvedHost 483 484 // in the case of port though, the provided address takes priority, and we 485 // only use the resolved address to set the port when not specified 486 if port == "" { 487 port = resolvedPort 488 } 489 } 490 } 491 492 return net.JoinHostPort(host, port), nil 493 }