github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/local_dc.go (about) 1 package balancer 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "math/rand" 8 "net" 9 "net/url" 10 "strings" 11 "sync" 12 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 16 ) 17 18 const ( 19 maxEndpointsCheckPerLocation = 5 20 ) 21 22 func checkFastestAddress(ctx context.Context, addresses []string) string { 23 ctx, cancel := xcontext.WithCancel(ctx) 24 defer cancel() 25 26 type result struct { 27 address string 28 err error 29 } 30 results := make(chan result, len(addresses)) 31 defer close(results) 32 33 startDial := make(chan struct{}) 34 var dialer net.Dialer 35 36 var wg sync.WaitGroup 37 defer wg.Wait() 38 39 for _, addr := range addresses { 40 wg.Add(1) 41 go func(address string) { 42 defer wg.Done() 43 <-startDial 44 conn, err := dialer.DialContext(ctx, "tcp", address) 45 if err == nil { 46 cancel() 47 _ = conn.Close() 48 } 49 results <- result{address: address, err: err} 50 }(addr) 51 } 52 53 close(startDial) 54 55 for range addresses { 56 res := <-results 57 if res.err == nil { 58 return res.address 59 } 60 } 61 62 return "" 63 } 64 65 func detectFastestEndpoint(ctx context.Context, endpoints []endpoint.Endpoint) (endpoint.Endpoint, error) { 66 if len(endpoints) == 0 { 67 return nil, xerrors.WithStackTrace(errors.New("empty endpoints list")) 68 } 69 70 var lastErr error 71 // common is 2 ip address for every fqdn: ipv4 + ipv6 72 initialAddressToEndpointCapacity := len(endpoints) * 2 73 addressToEndpoint := make(map[string]endpoint.Endpoint, initialAddressToEndpointCapacity) 74 for _, ep := range endpoints { 75 host, port, err := extractHostPort(ep.Address()) 76 if err != nil { 77 lastErr = xerrors.WithStackTrace(err) 78 79 continue 80 } 81 82 addresses, err := net.DefaultResolver.LookupHost(ctx, host) 83 if err != nil { 84 lastErr = err 85 86 continue 87 } 88 if len(addresses) == 0 { 89 lastErr = xerrors.WithStackTrace(fmt.Errorf("no ips for fqdn: %q", host)) 90 91 continue 92 } 93 94 for _, ip := range addresses { 95 address := net.JoinHostPort(ip, port) 96 addressToEndpoint[address] = ep 97 } 98 } 99 if len(addressToEndpoint) == 0 { 100 return nil, xerrors.WithStackTrace(lastErr) 101 } 102 addressesToPing := make([]string, 0, len(addressToEndpoint)) 103 for ip := range addressToEndpoint { 104 addressesToPing = append(addressesToPing, ip) 105 } 106 107 fastestAddress := checkFastestAddress(ctx, addressesToPing) 108 if fastestAddress == "" { 109 return nil, xerrors.WithStackTrace(errors.New("failed to check fastest address")) 110 } 111 112 return addressToEndpoint[fastestAddress], nil 113 } 114 115 func detectLocalDC(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) { 116 if len(endpoints) == 0 { 117 return "", xerrors.WithStackTrace(ErrNoEndpoints) 118 } 119 endpointsByDc := splitEndpointsByLocation(endpoints) 120 121 if len(endpointsByDc) == 1 { 122 return endpoints[0].Location(), nil 123 } 124 125 endpointsToTest := make([]endpoint.Endpoint, 0, maxEndpointsCheckPerLocation*len(endpointsByDc)) 126 for _, dcEndpoints := range endpointsByDc { 127 endpointsToTest = append(endpointsToTest, getRandomEndpoints(dcEndpoints, maxEndpointsCheckPerLocation)...) 128 } 129 130 fastest, err := detectFastestEndpoint(ctx, endpointsToTest) 131 if err == nil { 132 return fastest.Location(), nil 133 } 134 135 return "", err 136 } 137 138 func extractHostPort(address string) (host, port string, _ error) { 139 if !strings.Contains(address, "://") { 140 address = "stub://" + address 141 } 142 143 u, err := url.Parse(address) 144 if err != nil { 145 return "", "", xerrors.WithStackTrace(err) 146 } 147 host, port, err = net.SplitHostPort(u.Host) 148 if err != nil { 149 return "", "", xerrors.WithStackTrace(err) 150 } 151 152 return host, port, nil 153 } 154 155 func getRandomEndpoints(endpoints []endpoint.Endpoint, count int) []endpoint.Endpoint { 156 if len(endpoints) <= count { 157 return endpoints 158 } 159 160 got := make(map[int]bool, maxEndpointsCheckPerLocation) 161 162 res := make([]endpoint.Endpoint, 0, maxEndpointsCheckPerLocation) 163 for len(got) < count { 164 //nolint:gosec 165 index := rand.Intn(len(endpoints)) 166 if got[index] { 167 continue 168 } 169 170 got[index] = true 171 res = append(res, endpoints[index]) 172 } 173 174 return res 175 } 176 177 func splitEndpointsByLocation(endpoints []endpoint.Endpoint) map[string][]endpoint.Endpoint { 178 res := make(map[string][]endpoint.Endpoint) 179 for _, ep := range endpoints { 180 location := ep.Location() 181 res[location] = append(res[location], ep) 182 } 183 184 return res 185 }