amuz.es/src/go/misc@v1.0.1/networking/failback.go (about)

     1  package networking
     2  
     3  import (
     4  	"context"
     5  	"github.com/pkg/errors"
     6  	"log"
     7  	"math/rand"
     8  	"net"
     9  	"strconv"
    10  )
    11  
    12  // RoundrobinDialer is a wrapper for the DialContext function.
    13  type (
    14  	RoundrobinDialer struct {
    15  		Dialer       *net.Dialer
    16  		FallbackHost net.TCPAddr
    17  	}
    18  )
    19  
    20  // DialContext is a connector method with Fail-Over approach.
    21  func (d *RoundrobinDialer) DialContext(ctx context.Context, network string, hosts ...string) (net.Conn, error) {
    22  	list := rand.Perm(len(hosts))
    23  	for _, idx := range list {
    24  		host := hosts[idx]
    25  		for _, resolvedIP := range d.resolveHost(ctx, network, host) {
    26  			var (
    27  				conn net.Conn
    28  				err  error
    29  			)
    30  			if conn, err = d.Dialer.DialContext(ctx, resolvedIP.Network(), resolvedIP.AddrPort().String()); err != nil {
    31  				log.Printf("failed to connected to %s => %s : %s", host, resolvedIP.String(), err)
    32  			} else {
    33  				return conn, nil
    34  			}
    35  		}
    36  	}
    37  
    38  	if !d.FallbackHost.IP.IsUnspecified() {
    39  		log.Printf("attempting to connect fallback address(%s)\n", d.FallbackHost.String())
    40  		return d.Dialer.DialContext(ctx, d.FallbackHost.Network(), d.FallbackHost.AddrPort().String())
    41  	}
    42  	return nil, errors.New("name resolve failure")
    43  }
    44  
    45  // resolveHost is actually resolve network addresses from a given hostname.
    46  func (d *RoundrobinDialer) resolveHost(ctx context.Context, network, connectAddr string) (addrs []net.TCPAddr) {
    47  	log.Printf("attempting to connect %s", connectAddr)
    48  	var (
    49  		host        string
    50  		portString  string
    51  		resolvedIPs []net.IP
    52  		tcpPort     int
    53  		err         error
    54  	)
    55  	if host, portString, err = net.SplitHostPort(connectAddr); err != nil {
    56  		log.Printf("invalid connection string format: %s", err)
    57  		return
    58  	} else if port, _ := strconv.ParseInt(portString, 10, 32); port < 0 || port > 65535 {
    59  		log.Printf("invalid port format : %s", portString)
    60  		return
    61  	} else {
    62  		tcpPort = int(port)
    63  	}
    64  
    65  	if resolvedIPs, err = d.Dialer.Resolver.LookupIP(ctx, network, host); err != nil {
    66  		log.Printf("cannot resolve host %s(%s) : %s", host, network, err)
    67  		return
    68  	}
    69  
    70  	addrs = make([]net.TCPAddr, 0, len(resolvedIPs))
    71  	for _, resolvedIP := range resolvedIPs {
    72  		addrs = append(addrs, net.TCPAddr{IP: resolvedIP, Port: tcpPort})
    73  	}
    74  	return
    75  }