github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/balancer.go (about) 1 package balancer 2 3 import ( 4 "context" 5 "fmt" 6 "sort" 7 8 "google.golang.org/grpc" 9 10 "github.com/ydb-platform/ydb-go-sdk/v3/config" 11 balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials" 15 internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery" 16 discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config" 17 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 18 "github.com/ydb-platform/ydb-go-sdk/v3/internal/repeater" 19 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 20 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 21 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 22 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" 23 "github.com/ydb-platform/ydb-go-sdk/v3/retry" 24 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 25 ) 26 27 var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints")) 28 29 type discoveryClient interface { 30 closer.Closer 31 32 Discover(ctx context.Context) ([]endpoint.Endpoint, error) 33 } 34 35 type Balancer struct { 36 driverConfig *config.Config 37 config balancerConfig.Config 38 pool *conn.Pool 39 discoveryClient discoveryClient 40 discoveryRepeater repeater.Repeater 41 localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) 42 43 mu xsync.RWMutex 44 connectionsState *connectionsState 45 46 onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) 47 } 48 49 func (b *Balancer) HasNode(id uint32) bool { 50 if b.config.SingleConn { 51 return true 52 } 53 b.mu.RLock() 54 defer b.mu.RUnlock() 55 if _, has := b.connectionsState.connByNodeID[id]; has { 56 return true 57 } 58 59 return false 60 } 61 62 func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context, endpoints []endpoint.Info)) { 63 b.mu.WithLock(func() { 64 b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints) 65 }) 66 } 67 68 func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) { 69 return retry.Retry( 70 repeater.WithEvent(ctx, repeater.EventInit), 71 func(childCtx context.Context) (err error) { 72 if err = b.clusterDiscoveryAttempt(childCtx); err != nil { 73 if credentials.IsAccessError(err) { 74 return credentials.AccessError("cluster discovery failed", err, 75 credentials.WithEndpoint(b.driverConfig.Endpoint()), 76 credentials.WithDatabase(b.driverConfig.Database()), 77 credentials.WithCredentials(b.driverConfig.Credentials()), 78 ) 79 } 80 // if got err but parent context is not done - mark error as retryable 81 if ctx.Err() == nil && xerrors.IsTimeoutError(err) { 82 return xerrors.WithStackTrace(xerrors.Retryable(err)) 83 } 84 85 return xerrors.WithStackTrace(err) 86 } 87 88 return nil 89 }, 90 retry.WithIdempotent(true), 91 retry.WithTrace(b.driverConfig.TraceRetry()), 92 ) 93 } 94 95 func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) { 96 var ( 97 address = "ydb:///" + b.driverConfig.Endpoint() 98 onDone = trace.DriverOnBalancerClusterDiscoveryAttempt( 99 b.driverConfig.Trace(), &ctx, 100 stack.FunctionID(""), 101 address, 102 ) 103 endpoints []endpoint.Endpoint 104 localDC string 105 cancel context.CancelFunc 106 ) 107 defer func() { 108 onDone(err) 109 }() 110 111 if dialTimeout := b.driverConfig.DialTimeout(); dialTimeout > 0 { 112 ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) 113 } else { 114 ctx, cancel = xcontext.WithCancel(ctx) 115 } 116 defer cancel() 117 118 endpoints, err = b.discoveryClient.Discover(ctx) 119 if err != nil { 120 return xerrors.WithStackTrace(err) 121 } 122 123 if b.config.DetectLocalDC { 124 localDC, err = b.localDCDetector(ctx, endpoints) 125 if err != nil { 126 return xerrors.WithStackTrace(err) 127 } 128 } 129 130 b.applyDiscoveredEndpoints(ctx, endpoints, localDC) 131 132 return nil 133 } 134 135 func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn) ( 136 nodes []trace.EndpointInfo, 137 added []trace.EndpointInfo, 138 dropped []trace.EndpointInfo, 139 ) { 140 nodes = make([]trace.EndpointInfo, 0, len(newestEndpoints)) 141 added = make([]trace.EndpointInfo, 0, len(previousConns)) 142 dropped = make([]trace.EndpointInfo, 0, len(previousConns)) 143 var ( 144 newestMap = make(map[string]struct{}, len(newestEndpoints)) 145 previousMap = make(map[string]struct{}, len(previousConns)) 146 ) 147 sort.Slice(newestEndpoints, func(i, j int) bool { 148 return newestEndpoints[i].Address() < newestEndpoints[j].Address() 149 }) 150 sort.Slice(previousConns, func(i, j int) bool { 151 return previousConns[i].Endpoint().Address() < previousConns[j].Endpoint().Address() 152 }) 153 for _, e := range previousConns { 154 previousMap[e.Endpoint().Address()] = struct{}{} 155 } 156 for _, e := range newestEndpoints { 157 nodes = append(nodes, e.Copy()) 158 newestMap[e.Address()] = struct{}{} 159 if _, has := previousMap[e.Address()]; !has { 160 added = append(added, e.Copy()) 161 } 162 } 163 for _, c := range previousConns { 164 if _, has := newestMap[c.Endpoint().Address()]; !has { 165 dropped = append(dropped, c.Endpoint().Copy()) 166 } 167 } 168 169 return nodes, added, dropped 170 } 171 172 func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []endpoint.Endpoint, localDC string) { 173 var ( 174 onDone = trace.DriverOnBalancerUpdate( 175 b.driverConfig.Trace(), &ctx, 176 stack.FunctionID(""), 177 b.config.DetectLocalDC, 178 ) 179 previousConns []conn.Conn 180 ) 181 defer func() { 182 nodes, added, dropped := endpointsDiff(endpoints, previousConns) 183 onDone(nodes, added, dropped, localDC, nil) 184 }() 185 186 connections := endpointsToConnections(b.pool, endpoints) 187 for _, c := range connections { 188 b.pool.Allow(ctx, c) 189 c.Endpoint().Touch() 190 } 191 192 info := balancerConfig.Info{SelfLocation: localDC} 193 state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback) 194 195 endpointsInfo := make([]endpoint.Info, len(endpoints)) 196 for i, e := range endpoints { 197 endpointsInfo[i] = e 198 } 199 200 b.mu.WithLock(func() { 201 if b.connectionsState != nil { 202 previousConns = b.connectionsState.all 203 } 204 b.connectionsState = state 205 for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints { 206 onApplyDiscoveredEndpoints(ctx, endpointsInfo) 207 } 208 }) 209 } 210 211 func (b *Balancer) Close(ctx context.Context) (err error) { 212 onDone := trace.DriverOnBalancerClose( 213 b.driverConfig.Trace(), &ctx, 214 stack.FunctionID(""), 215 ) 216 defer func() { 217 onDone(err) 218 }() 219 220 if b.discoveryRepeater != nil { 221 b.discoveryRepeater.Stop() 222 } 223 224 if err = b.discoveryClient.Close(ctx); err != nil { 225 return xerrors.WithStackTrace(err) 226 } 227 228 return nil 229 } 230 231 func New( 232 ctx context.Context, 233 driverConfig *config.Config, 234 pool *conn.Pool, 235 opts ...discoveryConfig.Option, 236 ) (b *Balancer, finalErr error) { 237 var ( 238 onDone = trace.DriverOnBalancerInit( 239 driverConfig.Trace(), &ctx, 240 stack.FunctionID(""), 241 driverConfig.Balancer().String(), 242 ) 243 discoveryConfig = discoveryConfig.New(append(opts, 244 discoveryConfig.With(driverConfig.Common), 245 discoveryConfig.WithEndpoint(driverConfig.Endpoint()), 246 discoveryConfig.WithDatabase(driverConfig.Database()), 247 discoveryConfig.WithSecure(driverConfig.Secure()), 248 discoveryConfig.WithMeta(driverConfig.Meta()), 249 )...) 250 ) 251 defer func() { 252 onDone(finalErr) 253 }() 254 255 b = &Balancer{ 256 driverConfig: driverConfig, 257 pool: pool, 258 localDCDetector: detectLocalDC, 259 } 260 d, err := internalDiscovery.New(ctx, pool.Get( 261 endpoint.New(driverConfig.Endpoint()), 262 ), discoveryConfig) 263 if err != nil { 264 return nil, err 265 } 266 267 b.discoveryClient = d 268 269 if config := driverConfig.Balancer(); config == nil { 270 b.config = balancerConfig.Config{} 271 } else { 272 b.config = *config 273 } 274 275 if b.config.SingleConn { 276 b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{ 277 endpoint.New(driverConfig.Endpoint()), 278 }, "") 279 } else { 280 // initialization of balancer state 281 if err := b.clusterDiscovery(ctx); err != nil { 282 return nil, xerrors.WithStackTrace(err) 283 } 284 // run background discovering 285 if d := discoveryConfig.Interval(); d > 0 { 286 b.discoveryRepeater = repeater.New(xcontext.WithoutDeadline(ctx), 287 d, b.clusterDiscoveryAttempt, 288 repeater.WithName("discovery"), 289 repeater.WithTrace(b.driverConfig.Trace()), 290 ) 291 } 292 } 293 294 return b, nil 295 } 296 297 func (b *Balancer) Invoke( 298 ctx context.Context, 299 method string, 300 args interface{}, 301 reply interface{}, 302 opts ...grpc.CallOption, 303 ) error { 304 return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { 305 return cc.Invoke(ctx, method, args, reply, opts...) 306 }) 307 } 308 309 func (b *Balancer) NewStream( 310 ctx context.Context, 311 desc *grpc.StreamDesc, 312 method string, 313 opts ...grpc.CallOption, 314 ) (_ grpc.ClientStream, err error) { 315 var client grpc.ClientStream 316 err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { 317 client, err = cc.NewStream(ctx, desc, method, opts...) 318 319 return err 320 }) 321 if err == nil { 322 return client, nil 323 } 324 325 return nil, err 326 } 327 328 func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) { 329 cc, err := b.getConn(ctx) 330 if err != nil { 331 return xerrors.WithStackTrace(err) 332 } 333 334 defer func() { 335 if err == nil { 336 if cc.GetState() == conn.Banned { 337 b.pool.Allow(ctx, cc) 338 } 339 } else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) { 340 b.pool.Ban(ctx, cc, err) 341 } 342 }() 343 344 if ctx, err = b.driverConfig.Meta().Context(ctx); err != nil { 345 return xerrors.WithStackTrace(err) 346 } 347 348 if err = f(ctx, cc); err != nil { 349 if conn.UseWrapping(ctx) { 350 if credentials.IsAccessError(err) { 351 err = credentials.AccessError("no access", err, 352 credentials.WithAddress(cc.Endpoint().String()), 353 credentials.WithNodeID(cc.Endpoint().NodeID()), 354 credentials.WithCredentials(b.driverConfig.Credentials()), 355 ) 356 } 357 358 return xerrors.WithStackTrace(err) 359 } 360 361 return err 362 } 363 364 return nil 365 } 366 367 func (b *Balancer) connections() *connectionsState { 368 b.mu.RLock() 369 defer b.mu.RUnlock() 370 371 return b.connectionsState 372 } 373 374 func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { 375 onDone := trace.DriverOnBalancerChooseEndpoint( 376 b.driverConfig.Trace(), &ctx, 377 stack.FunctionID(""), 378 ) 379 defer func() { 380 if err == nil { 381 onDone(c.Endpoint(), nil) 382 } else { 383 onDone(nil, err) 384 } 385 }() 386 387 if err = ctx.Err(); err != nil { 388 return nil, xerrors.WithStackTrace(err) 389 } 390 391 var ( 392 state = b.connections() 393 failedCount int 394 ) 395 396 defer func() { 397 if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil { 398 b.discoveryRepeater.Force() 399 } 400 }() 401 402 c, failedCount = state.GetConnection(ctx) 403 if c == nil { 404 return nil, xerrors.WithStackTrace( 405 fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount), 406 ) 407 } 408 409 return c, nil 410 } 411 412 func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn { 413 conns := make([]conn.Conn, 0, len(endpoints)) 414 for _, e := range endpoints { 415 conns = append(conns, p.Get(e)) 416 } 417 418 return conns 419 }