github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/balancer/balancer.go (about) 1 package balancer 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 "sync/atomic" 8 9 "google.golang.org/grpc" 10 11 "github.com/ydb-platform/ydb-go-sdk/v3/config" 12 balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials" 16 internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery" 17 discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config" 18 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 19 "github.com/ydb-platform/ydb-go-sdk/v3/internal/repeater" 20 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 21 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 22 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 23 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices" 24 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" 25 "github.com/ydb-platform/ydb-go-sdk/v3/retry" 26 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 27 ) 28 29 var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints")) 30 31 type discoveryClient interface { 32 closer.Closer 33 34 Discover(ctx context.Context) ([]endpoint.Endpoint, error) 35 } 36 37 type Balancer struct { 38 driverConfig *config.Config 39 config balancerConfig.Config 40 pool *conn.Pool 41 discoveryClient discoveryClient 42 discoveryRepeater repeater.Repeater 43 localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) 44 45 connectionsState atomic.Pointer[connectionsState] 46 47 mu xsync.RWMutex 48 onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) 49 } 50 51 func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context, endpoints []endpoint.Info)) { 52 b.mu.WithLock(func() { 53 b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints) 54 }) 55 } 56 57 func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) { 58 return retry.Retry( 59 repeater.WithEvent(ctx, repeater.EventInit), 60 func(childCtx context.Context) (err error) { 61 if err = b.clusterDiscoveryAttempt(childCtx); err != nil { 62 if credentials.IsAccessError(err) { 63 return credentials.AccessError("cluster discovery failed", err, 64 credentials.WithEndpoint(b.driverConfig.Endpoint()), 65 credentials.WithDatabase(b.driverConfig.Database()), 66 credentials.WithCredentials(b.driverConfig.Credentials()), 67 ) 68 } 69 // if got err but parent context is not done - mark error as retryable 70 if ctx.Err() == nil && xerrors.IsTimeoutError(err) { 71 return xerrors.WithStackTrace(xerrors.Retryable(err)) 72 } 73 74 return xerrors.WithStackTrace(err) 75 } 76 77 return nil 78 }, 79 retry.WithIdempotent(true), 80 retry.WithTrace(b.driverConfig.TraceRetry()), 81 retry.WithBudget(b.driverConfig.RetryBudget()), 82 ) 83 } 84 85 func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) { 86 var ( 87 address = "ydb:///" + b.driverConfig.Endpoint() 88 onDone = trace.DriverOnBalancerClusterDiscoveryAttempt( 89 b.driverConfig.Trace(), &ctx, 90 stack.FunctionID( 91 "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).clusterDiscoveryAttempt", 92 ), 93 address, 94 b.driverConfig.Database(), 95 ) 96 endpoints []endpoint.Endpoint 97 localDC string 98 cancel context.CancelFunc 99 ) 100 defer func() { 101 onDone(err) 102 }() 103 104 if dialTimeout := b.driverConfig.DialTimeout(); dialTimeout > 0 { 105 ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) 106 } else { 107 ctx, cancel = xcontext.WithCancel(ctx) 108 } 109 defer cancel() 110 111 endpoints, err = b.discoveryClient.Discover(ctx) 112 if err != nil { 113 return xerrors.WithStackTrace(err) 114 } 115 116 if b.config.DetectNearestDC { 117 localDC, err = b.localDCDetector(ctx, endpoints) 118 if err != nil { 119 return xerrors.WithStackTrace(err) 120 } 121 } 122 123 b.applyDiscoveredEndpoints(ctx, endpoints, localDC) 124 125 return nil 126 } 127 128 func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoint.Endpoint, localDC string) { 129 var ( 130 onDone = trace.DriverOnBalancerUpdate( 131 b.driverConfig.Trace(), &ctx, 132 stack.FunctionID( 133 "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"), 134 b.config.DetectNearestDC, 135 b.driverConfig.Database(), 136 ) 137 previous = b.connections().All() 138 ) 139 defer func() { 140 _, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int { 141 return strings.Compare(lhs.Address(), rhs.Address()) 142 }) 143 onDone( 144 xslices.Transform(newest, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), 145 xslices.Transform(added, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), 146 xslices.Transform(dropped, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), 147 localDC, 148 ) 149 }() 150 151 connections := endpointsToConnections(b.pool, newest) 152 for _, c := range connections { 153 b.pool.Allow(ctx, c) 154 c.Endpoint().Touch() 155 } 156 157 info := balancerConfig.Info{SelfLocation: localDC} 158 state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback) 159 160 endpointsInfo := make([]endpoint.Info, len(newest)) 161 for i, e := range newest { 162 endpointsInfo[i] = e 163 } 164 165 b.connectionsState.Store(state) 166 167 b.mu.WithLock(func() { 168 for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints { 169 onApplyDiscoveredEndpoints(ctx, endpointsInfo) 170 } 171 }) 172 } 173 174 func (b *Balancer) Close(ctx context.Context) (err error) { 175 onDone := trace.DriverOnBalancerClose( 176 b.driverConfig.Trace(), &ctx, 177 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).Close"), 178 ) 179 defer func() { 180 onDone(err) 181 }() 182 183 if b.discoveryRepeater != nil { 184 b.discoveryRepeater.Stop() 185 } 186 187 if err = b.discoveryClient.Close(ctx); err != nil { 188 return xerrors.WithStackTrace(err) 189 } 190 191 return nil 192 } 193 194 func New( 195 ctx context.Context, 196 driverConfig *config.Config, 197 pool *conn.Pool, 198 opts ...discoveryConfig.Option, 199 ) (b *Balancer, finalErr error) { 200 var ( 201 onDone = trace.DriverOnBalancerInit( 202 driverConfig.Trace(), &ctx, 203 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.New"), 204 driverConfig.Balancer().String(), 205 ) 206 discoveryConfig = discoveryConfig.New(append(opts, 207 discoveryConfig.With(driverConfig.Common), 208 discoveryConfig.WithEndpoint(driverConfig.Endpoint()), 209 discoveryConfig.WithDatabase(driverConfig.Database()), 210 discoveryConfig.WithSecure(driverConfig.Secure()), 211 discoveryConfig.WithMeta(driverConfig.Meta()), 212 )...) 213 ) 214 defer func() { 215 onDone(finalErr) 216 }() 217 218 b = &Balancer{ 219 driverConfig: driverConfig, 220 pool: pool, 221 discoveryClient: internalDiscovery.New(ctx, pool.Get( 222 endpoint.New(driverConfig.Endpoint()), 223 ), discoveryConfig), 224 localDCDetector: detectLocalDC, 225 } 226 227 if config := driverConfig.Balancer(); config == nil { 228 b.config = balancerConfig.Config{} 229 } else { 230 b.config = *config 231 } 232 233 if b.config.SingleConn { 234 b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{ 235 endpoint.New(driverConfig.Endpoint()), 236 }, "") 237 } else { 238 // initialization of balancer state 239 if err := b.clusterDiscovery(ctx); err != nil { 240 return nil, xerrors.WithStackTrace(err) 241 } 242 // run background discovering 243 if d := discoveryConfig.Interval(); d > 0 { 244 b.discoveryRepeater = repeater.New(xcontext.ValueOnly(ctx), 245 d, b.clusterDiscoveryAttempt, 246 repeater.WithName("discovery"), 247 repeater.WithTrace(b.driverConfig.Trace()), 248 ) 249 } 250 } 251 252 return b, nil 253 } 254 255 func (b *Balancer) Invoke( 256 ctx context.Context, 257 method string, 258 args interface{}, 259 reply interface{}, 260 opts ...grpc.CallOption, 261 ) error { 262 return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { 263 return cc.Invoke(ctx, method, args, reply, opts...) 264 }) 265 } 266 267 func (b *Balancer) NewStream( 268 ctx context.Context, 269 desc *grpc.StreamDesc, 270 method string, 271 opts ...grpc.CallOption, 272 ) (_ grpc.ClientStream, err error) { 273 var client grpc.ClientStream 274 err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { 275 client, err = cc.NewStream(ctx, desc, method, opts...) 276 277 return err 278 }) 279 if err == nil { 280 return client, nil 281 } 282 283 return nil, err 284 } 285 286 func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) { 287 cc, err := b.getConn(ctx) 288 if err != nil { 289 return xerrors.WithStackTrace(err) 290 } 291 292 defer func() { 293 if err == nil { 294 if cc.GetState() == conn.Banned { 295 b.pool.Allow(ctx, cc) 296 } 297 } else if conn.IsBadConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) { 298 b.pool.Ban(ctx, cc, err) 299 } 300 }() 301 302 if ctx, err = b.driverConfig.Meta().Context(ctx); err != nil { 303 return xerrors.WithStackTrace(err) 304 } 305 306 if err = f(ctx, cc); err != nil { 307 if conn.UseWrapping(ctx) { 308 if credentials.IsAccessError(err) { 309 err = credentials.AccessError("no access", err, 310 credentials.WithAddress(cc.Endpoint().String()), 311 credentials.WithNodeID(cc.Endpoint().NodeID()), 312 credentials.WithCredentials(b.driverConfig.Credentials()), 313 ) 314 } 315 316 return xerrors.WithStackTrace(err) 317 } 318 319 return err 320 } 321 322 return nil 323 } 324 325 func (b *Balancer) connections() *connectionsState { 326 return b.connectionsState.Load() 327 } 328 329 func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { 330 onDone := trace.DriverOnBalancerChooseEndpoint( 331 b.driverConfig.Trace(), &ctx, 332 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).getConn"), 333 ) 334 defer func() { 335 if err == nil { 336 onDone(c.Endpoint(), nil) 337 } else { 338 onDone(nil, err) 339 } 340 }() 341 342 if err = ctx.Err(); err != nil { 343 return nil, xerrors.WithStackTrace(err) 344 } 345 346 var ( 347 state = b.connections() 348 failedCount int 349 ) 350 351 defer func() { 352 if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil { 353 b.discoveryRepeater.Force() 354 } 355 }() 356 357 c, failedCount = state.GetConnection(ctx) 358 if c == nil { 359 return nil, xerrors.WithStackTrace( 360 fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount), 361 ) 362 } 363 364 return c, nil 365 } 366 367 func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn { 368 conns := make([]conn.Conn, 0, len(endpoints)) 369 for _, e := range endpoints { 370 conns = append(conns, p.Get(e)) 371 } 372 373 return conns 374 }