github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/local_dc.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"net"
     9  	"net/url"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    16  )
    17  
    18  const (
    19  	maxEndpointsCheckPerLocation = 5
    20  )
    21  
    22  func checkFastestAddress(ctx context.Context, addresses []string) string {
    23  	ctx, cancel := xcontext.WithCancel(ctx)
    24  	defer cancel()
    25  
    26  	type result struct {
    27  		address string
    28  		err     error
    29  	}
    30  	results := make(chan result, len(addresses))
    31  	defer close(results)
    32  
    33  	startDial := make(chan struct{})
    34  	var dialer net.Dialer
    35  
    36  	var wg sync.WaitGroup
    37  	defer wg.Wait()
    38  
    39  	for _, addr := range addresses {
    40  		wg.Add(1)
    41  		go func(address string) {
    42  			defer wg.Done()
    43  			<-startDial
    44  			conn, err := dialer.DialContext(ctx, "tcp", address)
    45  			if err == nil {
    46  				cancel()
    47  				_ = conn.Close()
    48  			}
    49  			results <- result{address: address, err: err}
    50  		}(addr)
    51  	}
    52  
    53  	close(startDial)
    54  
    55  	for range addresses {
    56  		res := <-results
    57  		if res.err == nil {
    58  			return res.address
    59  		}
    60  	}
    61  
    62  	return ""
    63  }
    64  
    65  func detectFastestEndpoint(ctx context.Context, endpoints []endpoint.Endpoint) (endpoint.Endpoint, error) {
    66  	if len(endpoints) == 0 {
    67  		return nil, xerrors.WithStackTrace(errors.New("empty endpoints list"))
    68  	}
    69  
    70  	var lastErr error
    71  	// common is 2 ip address for every fqdn: ipv4 + ipv6
    72  	initialAddressToEndpointCapacity := len(endpoints) * 2
    73  	addressToEndpoint := make(map[string]endpoint.Endpoint, initialAddressToEndpointCapacity)
    74  	for _, ep := range endpoints {
    75  		host, port, err := extractHostPort(ep.Address())
    76  		if err != nil {
    77  			lastErr = xerrors.WithStackTrace(err)
    78  
    79  			continue
    80  		}
    81  
    82  		addresses, err := net.DefaultResolver.LookupHost(ctx, host)
    83  		if err != nil {
    84  			lastErr = err
    85  
    86  			continue
    87  		}
    88  		if len(addresses) == 0 {
    89  			lastErr = xerrors.WithStackTrace(fmt.Errorf("no ips for fqdn: %q", host))
    90  
    91  			continue
    92  		}
    93  
    94  		for _, ip := range addresses {
    95  			address := net.JoinHostPort(ip, port)
    96  			addressToEndpoint[address] = ep
    97  		}
    98  	}
    99  	if len(addressToEndpoint) == 0 {
   100  		return nil, xerrors.WithStackTrace(lastErr)
   101  	}
   102  	addressesToPing := make([]string, 0, len(addressToEndpoint))
   103  	for ip := range addressToEndpoint {
   104  		addressesToPing = append(addressesToPing, ip)
   105  	}
   106  
   107  	fastestAddress := checkFastestAddress(ctx, addressesToPing)
   108  	if fastestAddress == "" {
   109  		return nil, xerrors.WithStackTrace(errors.New("failed to check fastest address"))
   110  	}
   111  
   112  	return addressToEndpoint[fastestAddress], nil
   113  }
   114  
   115  func detectLocalDC(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) {
   116  	if len(endpoints) == 0 {
   117  		return "", xerrors.WithStackTrace(ErrNoEndpoints)
   118  	}
   119  	endpointsByDc := splitEndpointsByLocation(endpoints)
   120  
   121  	if len(endpointsByDc) == 1 {
   122  		return endpoints[0].Location(), nil
   123  	}
   124  
   125  	endpointsToTest := make([]endpoint.Endpoint, 0, maxEndpointsCheckPerLocation*len(endpointsByDc))
   126  	for _, dcEndpoints := range endpointsByDc {
   127  		endpointsToTest = append(endpointsToTest, getRandomEndpoints(dcEndpoints, maxEndpointsCheckPerLocation)...)
   128  	}
   129  
   130  	fastest, err := detectFastestEndpoint(ctx, endpointsToTest)
   131  	if err == nil {
   132  		return fastest.Location(), nil
   133  	}
   134  
   135  	return "", err
   136  }
   137  
   138  func extractHostPort(address string) (host, port string, _ error) {
   139  	if !strings.Contains(address, "://") {
   140  		address = "stub://" + address
   141  	}
   142  
   143  	u, err := url.Parse(address)
   144  	if err != nil {
   145  		return "", "", xerrors.WithStackTrace(err)
   146  	}
   147  	host, port, err = net.SplitHostPort(u.Host)
   148  	if err != nil {
   149  		return "", "", xerrors.WithStackTrace(err)
   150  	}
   151  
   152  	return host, port, nil
   153  }
   154  
   155  func getRandomEndpoints(endpoints []endpoint.Endpoint, count int) []endpoint.Endpoint {
   156  	if len(endpoints) <= count {
   157  		return endpoints
   158  	}
   159  
   160  	got := make(map[int]bool, maxEndpointsCheckPerLocation)
   161  
   162  	res := make([]endpoint.Endpoint, 0, maxEndpointsCheckPerLocation)
   163  	for len(got) < count {
   164  		//nolint:gosec
   165  		index := rand.Intn(len(endpoints))
   166  		if got[index] {
   167  			continue
   168  		}
   169  
   170  		got[index] = true
   171  		res = append(res, endpoints[index])
   172  	}
   173  
   174  	return res
   175  }
   176  
   177  func splitEndpointsByLocation(endpoints []endpoint.Endpoint) map[string][]endpoint.Endpoint {
   178  	res := make(map[string][]endpoint.Endpoint)
   179  	for _, ep := range endpoints {
   180  		location := ep.Location()
   181  		res[location] = append(res[location], ep)
   182  	}
   183  
   184  	return res
   185  }