github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/gossip/resolver/resolver.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package resolver
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"net"
    17  	"os"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/base"
    20  	"github.com/cockroachdb/cockroach/pkg/util"
    21  	"github.com/cockroachdb/cockroach/pkg/util/log"
    22  	"github.com/cockroachdb/cockroach/pkg/util/netutil"
    23  	"github.com/cockroachdb/errors"
    24  )
    25  
    26  // Resolver is an interface which provides an abstract factory for
    27  // net.Addr addresses. Resolvers are not thread safe.
    28  type Resolver interface {
    29  	Type() string
    30  	Addr() string
    31  	GetAddress() (net.Addr, error)
    32  }
    33  
    34  // NewResolver takes an address and returns a new resolver.
    35  func NewResolver(address string) (Resolver, error) {
    36  	if len(address) == 0 {
    37  		return nil, errors.Errorf("invalid address value: %q", address)
    38  	}
    39  
    40  	// Ensure addr has port and host set.
    41  	address = ensureHostPort(address, base.DefaultPort)
    42  	return &socketResolver{typ: "tcp", addr: address}, nil
    43  }
    44  
    45  // SRV returns a slice of addresses from SRV record lookup
    46  func SRV(ctx context.Context, name string) ([]string, error) {
    47  	// Ignore port
    48  	name, _, err := netutil.SplitHostPort(name, base.DefaultPort)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	if name == "" {
    54  		return nil, nil
    55  	}
    56  
    57  	// "" as the addr and proto forces the direct look up of the name
    58  	_, recs, err := lookupSRV("", "", name)
    59  	if err != nil {
    60  		if dnsErr := (*net.DNSError)(nil); errors.As(err, &dnsErr) && dnsErr.Err == "no such host" {
    61  			return nil, nil
    62  		}
    63  
    64  		if log.V(1) {
    65  			log.Infof(context.TODO(), "failed to lookup SRV record for %q: %v", name, err)
    66  		}
    67  
    68  		return nil, nil
    69  	}
    70  
    71  	addrs := []string{}
    72  	for _, r := range recs {
    73  		if r.Port != 0 {
    74  			addrs = append(addrs, net.JoinHostPort(r.Target, fmt.Sprintf("%d", r.Port)))
    75  		}
    76  	}
    77  
    78  	return addrs, nil
    79  }
    80  
    81  // NewResolverFromAddress takes a net.Addr and constructs a resolver.
    82  func NewResolverFromAddress(addr net.Addr) (Resolver, error) {
    83  	switch addr.Network() {
    84  	case "tcp":
    85  		return &socketResolver{typ: addr.Network(), addr: addr.String()}, nil
    86  	default:
    87  		return nil, errors.Errorf("unknown address network %q for %v", addr.Network(), addr)
    88  	}
    89  }
    90  
    91  // NewResolverFromUnresolvedAddr takes a util.UnresolvedAddr and constructs a resolver.
    92  func NewResolverFromUnresolvedAddr(addr util.UnresolvedAddr) (Resolver, error) {
    93  	return NewResolverFromAddress(&addr)
    94  }
    95  
    96  // ensureHostPort takes a host:port addr, where the host and port are optional. If host and port are
    97  // present, the output is equal to addr. If port is not present, defaultPort is used. If host is not
    98  // present, hostname (or "127.0.0.1" as a fallback) is used.
    99  func ensureHostPort(addr string, defaultPort string) string {
   100  	host, port, err := net.SplitHostPort(addr)
   101  	if err != nil {
   102  		return net.JoinHostPort(addr, defaultPort)
   103  	}
   104  	if host == "" {
   105  		host, err = os.Hostname()
   106  		if err != nil {
   107  			host = "127.0.0.1"
   108  		}
   109  	}
   110  	if port == "" {
   111  		port = defaultPort
   112  	}
   113  
   114  	return net.JoinHostPort(host, port)
   115  }
   116  
   117  var (
   118  	lookupSRV = net.LookupSRV
   119  )
   120  
   121  // TestingOverrideSRVLookupFn enables a test to temporarily override
   122  // the SRV lookup function.
   123  func TestingOverrideSRVLookupFn(
   124  	fn func(service, proto, name string) (cname string, addrs []*net.SRV, err error),
   125  ) func() {
   126  	prevFn := lookupSRV
   127  	lookupSRV = fn
   128  	return func() { lookupSRV = prevFn }
   129  }