vitess.io/vitess@v0.16.2/go/netutil/netutil.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package netutil contains network-related utility functions.
    18  package netutil
    19  
    20  import (
    21  	"bytes"
    22  	"fmt"
    23  	"math/rand"
    24  	"net"
    25  	"os"
    26  	"sort"
    27  	"strconv"
    28  	"strings"
    29  	"time"
    30  )
    31  
    32  func init() {
    33  	rand.Seed(time.Now().UnixNano())
    34  }
    35  
    36  // byPriorityWeight sorts records by ascending priority and weight.
    37  type byPriorityWeight []*net.SRV
    38  
    39  func (addrs byPriorityWeight) Len() int { return len(addrs) }
    40  
    41  func (addrs byPriorityWeight) Swap(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }
    42  
    43  func (addrs byPriorityWeight) Less(i, j int) bool {
    44  	return addrs[i].Priority < addrs[j].Priority ||
    45  		(addrs[i].Priority == addrs[j].Priority && addrs[i].Weight < addrs[j].Weight)
    46  }
    47  
    48  // shuffleByWeight shuffles SRV records by weight using the algorithm
    49  // described in RFC 2782.
    50  // NOTE(msolo) This is disabled when the weights are zero.
    51  func (addrs byPriorityWeight) shuffleByWeight() {
    52  	sum := 0
    53  	for _, addr := range addrs {
    54  		sum += int(addr.Weight)
    55  	}
    56  	for sum > 0 && len(addrs) > 1 {
    57  		s := 0
    58  		n := rand.Intn(sum)
    59  		for i := range addrs {
    60  			s += int(addrs[i].Weight)
    61  			if s > n {
    62  				if i > 0 {
    63  					t := addrs[i]
    64  					copy(addrs[1:i+1], addrs[0:i])
    65  					addrs[0] = t
    66  				}
    67  				break
    68  			}
    69  		}
    70  		sum -= int(addrs[0].Weight)
    71  		addrs = addrs[1:]
    72  	}
    73  }
    74  
    75  func (addrs byPriorityWeight) sortRfc2782() {
    76  	sort.Sort(addrs)
    77  	i := 0
    78  	for j := 1; j < len(addrs); j++ {
    79  		if addrs[i].Priority != addrs[j].Priority {
    80  			addrs[i:j].shuffleByWeight()
    81  			i = j
    82  		}
    83  	}
    84  	addrs[i:].shuffleByWeight()
    85  }
    86  
    87  // SortRfc2782 reorders SRV records as specified in RFC 2782.
    88  func SortRfc2782(srvs []*net.SRV) {
    89  	byPriorityWeight(srvs).sortRfc2782()
    90  }
    91  
    92  // SplitHostPort is an alternative to net.SplitHostPort that also parses the
    93  // integer port. In addition, it is more tolerant of improperly escaped IPv6
    94  // addresses, such as "::1:456", which should actually be "[::1]:456".
    95  func SplitHostPort(addr string) (string, int, error) {
    96  	host, port, err := net.SplitHostPort(addr)
    97  	if err != nil {
    98  		// If the above proper parsing fails, fall back on a naive split.
    99  		i := strings.LastIndex(addr, ":")
   100  		if i < 0 {
   101  			return "", 0, fmt.Errorf("SplitHostPort: missing port in %q", addr)
   102  		}
   103  		host = addr[:i]
   104  		port = addr[i+1:]
   105  	}
   106  	p, err := strconv.ParseUint(port, 10, 16)
   107  	if err != nil {
   108  		return "", 0, fmt.Errorf("SplitHostPort: can't parse port %q: %v", port, err)
   109  	}
   110  	return host, int(p), nil
   111  }
   112  
   113  // JoinHostPort is an extension to net.JoinHostPort that also formats the
   114  // integer port.
   115  func JoinHostPort(host string, port int32) string {
   116  	return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
   117  }
   118  
   119  // FullyQualifiedHostname returns the FQDN of the machine.
   120  func FullyQualifiedHostname() (string, error) {
   121  	// The machine hostname (which is also returned by os.Hostname()) may not be
   122  	// set to the FQDN, but only the first part of it e.g. "localhost" instead of
   123  	// "localhost.localdomain".
   124  	// To get the full FQDN, we do the following:
   125  
   126  	// 1. Get the machine hostname. Example: localhost
   127  	hostname, err := os.Hostname()
   128  	if err != nil {
   129  		return "", fmt.Errorf("FullyQualifiedHostname: failed to retrieve the hostname of this machine: %v", err)
   130  	}
   131  
   132  	// 2. Look up the IP address for that hostname. Example: 127.0.0.1
   133  	ips, err := net.LookupHost(hostname)
   134  	if err != nil {
   135  		return "", fmt.Errorf("FullyQualifiedHostname: failed to lookup the IP of this machine's hostname (%v): %v", hostname, err)
   136  	}
   137  	if len(ips) == 0 {
   138  		return "", fmt.Errorf("FullyQualifiedHostname: lookup of the IP of this machine's hostname (%v) did not return any IP address", hostname)
   139  	}
   140  	// If multiple IPs are returned, we only look at the first one.
   141  	localIP := ips[0]
   142  
   143  	// 3. Reverse lookup the IP. Example: localhost.localdomain
   144  	resolvedHostnames, err := net.LookupAddr(localIP)
   145  	if err != nil {
   146  		return "", fmt.Errorf("FullyQualifiedHostname: failed to reverse lookup this machine's local IP (%v): %v", localIP, err)
   147  	}
   148  	if len(resolvedHostnames) == 0 {
   149  		return "", fmt.Errorf("FullyQualifiedHostname: reverse lookup of this machine's local IP (%v) did not return any hostnames", localIP)
   150  	}
   151  	// If multiple hostnames are found, we return only the first one.
   152  	// If multiple hostnames are listed e.g. in an entry in the /etc/hosts file,
   153  	// the current Go implementation returns them in that order.
   154  	// Example for an /etc/hosts entry:
   155  	//   127.0.0.1	localhost.localdomain localhost
   156  	// If the FQDN isn't returned by this function, check the order in the entry
   157  	// in your /etc/hosts file.
   158  	return strings.TrimSuffix(resolvedHostnames[0], "."), nil
   159  }
   160  
   161  // FullyQualifiedHostnameOrPanic is the same as FullyQualifiedHostname
   162  // but panics in case of an error.
   163  func FullyQualifiedHostnameOrPanic() string {
   164  	hostname, err := FullyQualifiedHostname()
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	return hostname
   169  }
   170  
   171  // ResolveIPv4Addrs resolves the address:port part into IP address:port pairs
   172  func ResolveIPv4Addrs(addr string) ([]string, error) {
   173  	host, port, err := net.SplitHostPort(addr)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  	ipAddrs, err := net.LookupIP(host)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	result := make([]string, 0, len(ipAddrs))
   182  	for _, ipAddr := range ipAddrs {
   183  		ipv4 := ipAddr.To4()
   184  		if ipv4 != nil {
   185  			result = append(result, net.JoinHostPort(ipv4.String(), port))
   186  		}
   187  	}
   188  	if len(result) == 0 {
   189  		return nil, fmt.Errorf("no IPv4addr for name %v", host)
   190  	}
   191  	return result, nil
   192  }
   193  
   194  func dnsLookup(host string) ([]net.IP, error) {
   195  	addrs, err := net.LookupHost(host)
   196  	if err != nil {
   197  		return nil, fmt.Errorf("Error looking up dns name [%v]: (%v)", host, err)
   198  	}
   199  	naddr := make([]net.IP, len(addrs))
   200  	for i, a := range addrs {
   201  		naddr[i] = net.ParseIP(a)
   202  	}
   203  	sort.Slice(naddr, func(i, j int) bool {
   204  		return bytes.Compare(naddr[i], naddr[j]) < 0
   205  	})
   206  	return naddr, nil
   207  }
   208  
   209  // DNSTracker is a closure that persists state for
   210  //
   211  //	tracking changes in the DNS resolution of a target dns name
   212  //	returns true if the DNS name resolution has changed
   213  //	If there is a lookup problem, we pretend nothing has changed
   214  func DNSTracker(host string) func() (bool, error) {
   215  	dnsName := host
   216  	var addrs []net.IP
   217  	if dnsName != "" {
   218  		addrs, _ = dnsLookup(dnsName)
   219  	}
   220  
   221  	return func() (bool, error) {
   222  		if dnsName == "" {
   223  			return false, nil
   224  		}
   225  		newaddrs, err := dnsLookup(dnsName)
   226  		if err != nil {
   227  			return false, err
   228  		}
   229  		if len(newaddrs) == 0 { // Should not happen, but just in case
   230  			return false, fmt.Errorf("Connection DNS for %s reporting as empty, ignoring", dnsName)
   231  		}
   232  		if !addrEqual(addrs, newaddrs) {
   233  			oldaddr := addrs
   234  			addrs = newaddrs // Update the closure variable
   235  			return true, fmt.Errorf("Connection DNS for %s has changed; old: %v  new: %v", dnsName, oldaddr, newaddrs)
   236  		}
   237  		return false, nil
   238  	}
   239  }
   240  
   241  func addrEqual(a, b []net.IP) bool {
   242  	if len(a) != len(b) {
   243  		return false
   244  	}
   245  	for idx, v := range a {
   246  		if !net.IP.Equal(v, b[idx]) {
   247  			return false
   248  		}
   249  	}
   250  	return true
   251  }
   252  
   253  // NormalizeIP normalizes loopback addresses to avoid spurious errors when
   254  // communicating to different representations of the loopback.
   255  //
   256  // Note: this also maps IPv6 localhost to IPv4 localhost, as
   257  // TabletManagerClient.GetReplicas() (the only place this function is used on)
   258  // will return only IPv4 addresses.
   259  func NormalizeIP(s string) string {
   260  	if ip := net.ParseIP(s); ip != nil && ip.IsLoopback() {
   261  		return "127.0.0.1"
   262  	}
   263  
   264  	return s
   265  }