github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/balancer.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  
     8  	"google.golang.org/grpc"
     9  
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/config"
    11  	balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials"
    15  	internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery"
    16  	discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/repeater"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    21  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    22  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
    23  	"github.com/ydb-platform/ydb-go-sdk/v3/retry"
    24  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    25  )
    26  
    27  var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints"))
    28  
    29  type discoveryClient interface {
    30  	closer.Closer
    31  
    32  	Discover(ctx context.Context) ([]endpoint.Endpoint, error)
    33  }
    34  
    35  type Balancer struct {
    36  	driverConfig      *config.Config
    37  	config            balancerConfig.Config
    38  	pool              *conn.Pool
    39  	discoveryClient   discoveryClient
    40  	discoveryRepeater repeater.Repeater
    41  	localDCDetector   func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
    42  
    43  	mu               xsync.RWMutex
    44  	connectionsState *connectionsState
    45  
    46  	onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
    47  }
    48  
    49  func (b *Balancer) HasNode(id uint32) bool {
    50  	if b.config.SingleConn {
    51  		return true
    52  	}
    53  	b.mu.RLock()
    54  	defer b.mu.RUnlock()
    55  	if _, has := b.connectionsState.connByNodeID[id]; has {
    56  		return true
    57  	}
    58  
    59  	return false
    60  }
    61  
    62  func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context, endpoints []endpoint.Info)) {
    63  	b.mu.WithLock(func() {
    64  		b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints)
    65  	})
    66  }
    67  
    68  func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) {
    69  	return retry.Retry(
    70  		repeater.WithEvent(ctx, repeater.EventInit),
    71  		func(childCtx context.Context) (err error) {
    72  			if err = b.clusterDiscoveryAttempt(childCtx); err != nil {
    73  				if credentials.IsAccessError(err) {
    74  					return credentials.AccessError("cluster discovery failed", err,
    75  						credentials.WithEndpoint(b.driverConfig.Endpoint()),
    76  						credentials.WithDatabase(b.driverConfig.Database()),
    77  						credentials.WithCredentials(b.driverConfig.Credentials()),
    78  					)
    79  				}
    80  				// if got err but parent context is not done - mark error as retryable
    81  				if ctx.Err() == nil && xerrors.IsTimeoutError(err) {
    82  					return xerrors.WithStackTrace(xerrors.Retryable(err))
    83  				}
    84  
    85  				return xerrors.WithStackTrace(err)
    86  			}
    87  
    88  			return nil
    89  		},
    90  		retry.WithIdempotent(true),
    91  		retry.WithTrace(b.driverConfig.TraceRetry()),
    92  	)
    93  }
    94  
    95  func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
    96  	var (
    97  		address = "ydb:///" + b.driverConfig.Endpoint()
    98  		onDone  = trace.DriverOnBalancerClusterDiscoveryAttempt(
    99  			b.driverConfig.Trace(), &ctx,
   100  			stack.FunctionID(""),
   101  			address,
   102  		)
   103  		endpoints []endpoint.Endpoint
   104  		localDC   string
   105  		cancel    context.CancelFunc
   106  	)
   107  	defer func() {
   108  		onDone(err)
   109  	}()
   110  
   111  	if dialTimeout := b.driverConfig.DialTimeout(); dialTimeout > 0 {
   112  		ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout)
   113  	} else {
   114  		ctx, cancel = xcontext.WithCancel(ctx)
   115  	}
   116  	defer cancel()
   117  
   118  	endpoints, err = b.discoveryClient.Discover(ctx)
   119  	if err != nil {
   120  		return xerrors.WithStackTrace(err)
   121  	}
   122  
   123  	if b.config.DetectLocalDC {
   124  		localDC, err = b.localDCDetector(ctx, endpoints)
   125  		if err != nil {
   126  			return xerrors.WithStackTrace(err)
   127  		}
   128  	}
   129  
   130  	b.applyDiscoveredEndpoints(ctx, endpoints, localDC)
   131  
   132  	return nil
   133  }
   134  
   135  func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn) (
   136  	nodes []trace.EndpointInfo,
   137  	added []trace.EndpointInfo,
   138  	dropped []trace.EndpointInfo,
   139  ) {
   140  	nodes = make([]trace.EndpointInfo, 0, len(newestEndpoints))
   141  	added = make([]trace.EndpointInfo, 0, len(previousConns))
   142  	dropped = make([]trace.EndpointInfo, 0, len(previousConns))
   143  	var (
   144  		newestMap   = make(map[string]struct{}, len(newestEndpoints))
   145  		previousMap = make(map[string]struct{}, len(previousConns))
   146  	)
   147  	sort.Slice(newestEndpoints, func(i, j int) bool {
   148  		return newestEndpoints[i].Address() < newestEndpoints[j].Address()
   149  	})
   150  	sort.Slice(previousConns, func(i, j int) bool {
   151  		return previousConns[i].Endpoint().Address() < previousConns[j].Endpoint().Address()
   152  	})
   153  	for _, e := range previousConns {
   154  		previousMap[e.Endpoint().Address()] = struct{}{}
   155  	}
   156  	for _, e := range newestEndpoints {
   157  		nodes = append(nodes, e.Copy())
   158  		newestMap[e.Address()] = struct{}{}
   159  		if _, has := previousMap[e.Address()]; !has {
   160  			added = append(added, e.Copy())
   161  		}
   162  	}
   163  	for _, c := range previousConns {
   164  		if _, has := newestMap[c.Endpoint().Address()]; !has {
   165  			dropped = append(dropped, c.Endpoint().Copy())
   166  		}
   167  	}
   168  
   169  	return nodes, added, dropped
   170  }
   171  
   172  func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []endpoint.Endpoint, localDC string) {
   173  	var (
   174  		onDone = trace.DriverOnBalancerUpdate(
   175  			b.driverConfig.Trace(), &ctx,
   176  			stack.FunctionID(""),
   177  			b.config.DetectLocalDC,
   178  		)
   179  		previousConns []conn.Conn
   180  	)
   181  	defer func() {
   182  		nodes, added, dropped := endpointsDiff(endpoints, previousConns)
   183  		onDone(nodes, added, dropped, localDC, nil)
   184  	}()
   185  
   186  	connections := endpointsToConnections(b.pool, endpoints)
   187  	for _, c := range connections {
   188  		b.pool.Allow(ctx, c)
   189  		c.Endpoint().Touch()
   190  	}
   191  
   192  	info := balancerConfig.Info{SelfLocation: localDC}
   193  	state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback)
   194  
   195  	endpointsInfo := make([]endpoint.Info, len(endpoints))
   196  	for i, e := range endpoints {
   197  		endpointsInfo[i] = e
   198  	}
   199  
   200  	b.mu.WithLock(func() {
   201  		if b.connectionsState != nil {
   202  			previousConns = b.connectionsState.all
   203  		}
   204  		b.connectionsState = state
   205  		for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
   206  			onApplyDiscoveredEndpoints(ctx, endpointsInfo)
   207  		}
   208  	})
   209  }
   210  
   211  func (b *Balancer) Close(ctx context.Context) (err error) {
   212  	onDone := trace.DriverOnBalancerClose(
   213  		b.driverConfig.Trace(), &ctx,
   214  		stack.FunctionID(""),
   215  	)
   216  	defer func() {
   217  		onDone(err)
   218  	}()
   219  
   220  	if b.discoveryRepeater != nil {
   221  		b.discoveryRepeater.Stop()
   222  	}
   223  
   224  	if err = b.discoveryClient.Close(ctx); err != nil {
   225  		return xerrors.WithStackTrace(err)
   226  	}
   227  
   228  	return nil
   229  }
   230  
   231  func New(
   232  	ctx context.Context,
   233  	driverConfig *config.Config,
   234  	pool *conn.Pool,
   235  	opts ...discoveryConfig.Option,
   236  ) (b *Balancer, finalErr error) {
   237  	var (
   238  		onDone = trace.DriverOnBalancerInit(
   239  			driverConfig.Trace(), &ctx,
   240  			stack.FunctionID(""),
   241  			driverConfig.Balancer().String(),
   242  		)
   243  		discoveryConfig = discoveryConfig.New(append(opts,
   244  			discoveryConfig.With(driverConfig.Common),
   245  			discoveryConfig.WithEndpoint(driverConfig.Endpoint()),
   246  			discoveryConfig.WithDatabase(driverConfig.Database()),
   247  			discoveryConfig.WithSecure(driverConfig.Secure()),
   248  			discoveryConfig.WithMeta(driverConfig.Meta()),
   249  		)...)
   250  	)
   251  	defer func() {
   252  		onDone(finalErr)
   253  	}()
   254  
   255  	b = &Balancer{
   256  		driverConfig:    driverConfig,
   257  		pool:            pool,
   258  		localDCDetector: detectLocalDC,
   259  	}
   260  	d, err := internalDiscovery.New(ctx, pool.Get(
   261  		endpoint.New(driverConfig.Endpoint()),
   262  	), discoveryConfig)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	b.discoveryClient = d
   268  
   269  	if config := driverConfig.Balancer(); config == nil {
   270  		b.config = balancerConfig.Config{}
   271  	} else {
   272  		b.config = *config
   273  	}
   274  
   275  	if b.config.SingleConn {
   276  		b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{
   277  			endpoint.New(driverConfig.Endpoint()),
   278  		}, "")
   279  	} else {
   280  		// initialization of balancer state
   281  		if err := b.clusterDiscovery(ctx); err != nil {
   282  			return nil, xerrors.WithStackTrace(err)
   283  		}
   284  		// run background discovering
   285  		if d := discoveryConfig.Interval(); d > 0 {
   286  			b.discoveryRepeater = repeater.New(xcontext.WithoutDeadline(ctx),
   287  				d, b.clusterDiscoveryAttempt,
   288  				repeater.WithName("discovery"),
   289  				repeater.WithTrace(b.driverConfig.Trace()),
   290  			)
   291  		}
   292  	}
   293  
   294  	return b, nil
   295  }
   296  
   297  func (b *Balancer) Invoke(
   298  	ctx context.Context,
   299  	method string,
   300  	args interface{},
   301  	reply interface{},
   302  	opts ...grpc.CallOption,
   303  ) error {
   304  	return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
   305  		return cc.Invoke(ctx, method, args, reply, opts...)
   306  	})
   307  }
   308  
   309  func (b *Balancer) NewStream(
   310  	ctx context.Context,
   311  	desc *grpc.StreamDesc,
   312  	method string,
   313  	opts ...grpc.CallOption,
   314  ) (_ grpc.ClientStream, err error) {
   315  	var client grpc.ClientStream
   316  	err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
   317  		client, err = cc.NewStream(ctx, desc, method, opts...)
   318  
   319  		return err
   320  	})
   321  	if err == nil {
   322  		return client, nil
   323  	}
   324  
   325  	return nil, err
   326  }
   327  
   328  func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) {
   329  	cc, err := b.getConn(ctx)
   330  	if err != nil {
   331  		return xerrors.WithStackTrace(err)
   332  	}
   333  
   334  	defer func() {
   335  		if err == nil {
   336  			if cc.GetState() == conn.Banned {
   337  				b.pool.Allow(ctx, cc)
   338  			}
   339  		} else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
   340  			b.pool.Ban(ctx, cc, err)
   341  		}
   342  	}()
   343  
   344  	if ctx, err = b.driverConfig.Meta().Context(ctx); err != nil {
   345  		return xerrors.WithStackTrace(err)
   346  	}
   347  
   348  	if err = f(ctx, cc); err != nil {
   349  		if conn.UseWrapping(ctx) {
   350  			if credentials.IsAccessError(err) {
   351  				err = credentials.AccessError("no access", err,
   352  					credentials.WithAddress(cc.Endpoint().String()),
   353  					credentials.WithNodeID(cc.Endpoint().NodeID()),
   354  					credentials.WithCredentials(b.driverConfig.Credentials()),
   355  				)
   356  			}
   357  
   358  			return xerrors.WithStackTrace(err)
   359  		}
   360  
   361  		return err
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  func (b *Balancer) connections() *connectionsState {
   368  	b.mu.RLock()
   369  	defer b.mu.RUnlock()
   370  
   371  	return b.connectionsState
   372  }
   373  
   374  func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
   375  	onDone := trace.DriverOnBalancerChooseEndpoint(
   376  		b.driverConfig.Trace(), &ctx,
   377  		stack.FunctionID(""),
   378  	)
   379  	defer func() {
   380  		if err == nil {
   381  			onDone(c.Endpoint(), nil)
   382  		} else {
   383  			onDone(nil, err)
   384  		}
   385  	}()
   386  
   387  	if err = ctx.Err(); err != nil {
   388  		return nil, xerrors.WithStackTrace(err)
   389  	}
   390  
   391  	var (
   392  		state       = b.connections()
   393  		failedCount int
   394  	)
   395  
   396  	defer func() {
   397  		if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil {
   398  			b.discoveryRepeater.Force()
   399  		}
   400  	}()
   401  
   402  	c, failedCount = state.GetConnection(ctx)
   403  	if c == nil {
   404  		return nil, xerrors.WithStackTrace(
   405  			fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount),
   406  		)
   407  	}
   408  
   409  	return c, nil
   410  }
   411  
   412  func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn {
   413  	conns := make([]conn.Conn, 0, len(endpoints))
   414  	for _, e := range endpoints {
   415  		conns = append(conns, p.Get(e))
   416  	}
   417  
   418  	return conns
   419  }