github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/balancer/balancer.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"sync/atomic"
     8  
     9  	"google.golang.org/grpc"
    10  
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/config"
    12  	balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials"
    16  	internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery"
    17  	discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/repeater"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    21  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    22  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    23  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices"
    24  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
    25  	"github.com/ydb-platform/ydb-go-sdk/v3/retry"
    26  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    27  )
    28  
    29  var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints"))
    30  
    31  type discoveryClient interface {
    32  	closer.Closer
    33  
    34  	Discover(ctx context.Context) ([]endpoint.Endpoint, error)
    35  }
    36  
    37  type Balancer struct {
    38  	driverConfig      *config.Config
    39  	config            balancerConfig.Config
    40  	pool              *conn.Pool
    41  	discoveryClient   discoveryClient
    42  	discoveryRepeater repeater.Repeater
    43  	localDCDetector   func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
    44  
    45  	connectionsState atomic.Pointer[connectionsState]
    46  
    47  	mu                         xsync.RWMutex
    48  	onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
    49  }
    50  
    51  func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context, endpoints []endpoint.Info)) {
    52  	b.mu.WithLock(func() {
    53  		b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints)
    54  	})
    55  }
    56  
    57  func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) {
    58  	return retry.Retry(
    59  		repeater.WithEvent(ctx, repeater.EventInit),
    60  		func(childCtx context.Context) (err error) {
    61  			if err = b.clusterDiscoveryAttempt(childCtx); err != nil {
    62  				if credentials.IsAccessError(err) {
    63  					return credentials.AccessError("cluster discovery failed", err,
    64  						credentials.WithEndpoint(b.driverConfig.Endpoint()),
    65  						credentials.WithDatabase(b.driverConfig.Database()),
    66  						credentials.WithCredentials(b.driverConfig.Credentials()),
    67  					)
    68  				}
    69  				// if got err but parent context is not done - mark error as retryable
    70  				if ctx.Err() == nil && xerrors.IsTimeoutError(err) {
    71  					return xerrors.WithStackTrace(xerrors.Retryable(err))
    72  				}
    73  
    74  				return xerrors.WithStackTrace(err)
    75  			}
    76  
    77  			return nil
    78  		},
    79  		retry.WithIdempotent(true),
    80  		retry.WithTrace(b.driverConfig.TraceRetry()),
    81  		retry.WithBudget(b.driverConfig.RetryBudget()),
    82  	)
    83  }
    84  
    85  func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
    86  	var (
    87  		address = "ydb:///" + b.driverConfig.Endpoint()
    88  		onDone  = trace.DriverOnBalancerClusterDiscoveryAttempt(
    89  			b.driverConfig.Trace(), &ctx,
    90  			stack.FunctionID(
    91  				"github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).clusterDiscoveryAttempt",
    92  			),
    93  			address,
    94  			b.driverConfig.Database(),
    95  		)
    96  		endpoints []endpoint.Endpoint
    97  		localDC   string
    98  		cancel    context.CancelFunc
    99  	)
   100  	defer func() {
   101  		onDone(err)
   102  	}()
   103  
   104  	if dialTimeout := b.driverConfig.DialTimeout(); dialTimeout > 0 {
   105  		ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout)
   106  	} else {
   107  		ctx, cancel = xcontext.WithCancel(ctx)
   108  	}
   109  	defer cancel()
   110  
   111  	endpoints, err = b.discoveryClient.Discover(ctx)
   112  	if err != nil {
   113  		return xerrors.WithStackTrace(err)
   114  	}
   115  
   116  	if b.config.DetectNearestDC {
   117  		localDC, err = b.localDCDetector(ctx, endpoints)
   118  		if err != nil {
   119  			return xerrors.WithStackTrace(err)
   120  		}
   121  	}
   122  
   123  	b.applyDiscoveredEndpoints(ctx, endpoints, localDC)
   124  
   125  	return nil
   126  }
   127  
   128  func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoint.Endpoint, localDC string) {
   129  	var (
   130  		onDone = trace.DriverOnBalancerUpdate(
   131  			b.driverConfig.Trace(), &ctx,
   132  			stack.FunctionID(
   133  				"github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
   134  			b.config.DetectNearestDC,
   135  			b.driverConfig.Database(),
   136  		)
   137  		previous = b.connections().All()
   138  	)
   139  	defer func() {
   140  		_, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int {
   141  			return strings.Compare(lhs.Address(), rhs.Address())
   142  		})
   143  		onDone(
   144  			xslices.Transform(newest, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
   145  			xslices.Transform(added, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
   146  			xslices.Transform(dropped, func(t endpoint.Endpoint) trace.EndpointInfo { return t }),
   147  			localDC,
   148  		)
   149  	}()
   150  
   151  	connections := endpointsToConnections(b.pool, newest)
   152  	for _, c := range connections {
   153  		b.pool.Allow(ctx, c)
   154  		c.Endpoint().Touch()
   155  	}
   156  
   157  	info := balancerConfig.Info{SelfLocation: localDC}
   158  	state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback)
   159  
   160  	endpointsInfo := make([]endpoint.Info, len(newest))
   161  	for i, e := range newest {
   162  		endpointsInfo[i] = e
   163  	}
   164  
   165  	b.connectionsState.Store(state)
   166  
   167  	b.mu.WithLock(func() {
   168  		for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
   169  			onApplyDiscoveredEndpoints(ctx, endpointsInfo)
   170  		}
   171  	})
   172  }
   173  
   174  func (b *Balancer) Close(ctx context.Context) (err error) {
   175  	onDone := trace.DriverOnBalancerClose(
   176  		b.driverConfig.Trace(), &ctx,
   177  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).Close"),
   178  	)
   179  	defer func() {
   180  		onDone(err)
   181  	}()
   182  
   183  	if b.discoveryRepeater != nil {
   184  		b.discoveryRepeater.Stop()
   185  	}
   186  
   187  	if err = b.discoveryClient.Close(ctx); err != nil {
   188  		return xerrors.WithStackTrace(err)
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  func New(
   195  	ctx context.Context,
   196  	driverConfig *config.Config,
   197  	pool *conn.Pool,
   198  	opts ...discoveryConfig.Option,
   199  ) (b *Balancer, finalErr error) {
   200  	var (
   201  		onDone = trace.DriverOnBalancerInit(
   202  			driverConfig.Trace(), &ctx,
   203  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.New"),
   204  			driverConfig.Balancer().String(),
   205  		)
   206  		discoveryConfig = discoveryConfig.New(append(opts,
   207  			discoveryConfig.With(driverConfig.Common),
   208  			discoveryConfig.WithEndpoint(driverConfig.Endpoint()),
   209  			discoveryConfig.WithDatabase(driverConfig.Database()),
   210  			discoveryConfig.WithSecure(driverConfig.Secure()),
   211  			discoveryConfig.WithMeta(driverConfig.Meta()),
   212  		)...)
   213  	)
   214  	defer func() {
   215  		onDone(finalErr)
   216  	}()
   217  
   218  	b = &Balancer{
   219  		driverConfig: driverConfig,
   220  		pool:         pool,
   221  		discoveryClient: internalDiscovery.New(ctx, pool.Get(
   222  			endpoint.New(driverConfig.Endpoint()),
   223  		), discoveryConfig),
   224  		localDCDetector: detectLocalDC,
   225  	}
   226  
   227  	if config := driverConfig.Balancer(); config == nil {
   228  		b.config = balancerConfig.Config{}
   229  	} else {
   230  		b.config = *config
   231  	}
   232  
   233  	if b.config.SingleConn {
   234  		b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{
   235  			endpoint.New(driverConfig.Endpoint()),
   236  		}, "")
   237  	} else {
   238  		// initialization of balancer state
   239  		if err := b.clusterDiscovery(ctx); err != nil {
   240  			return nil, xerrors.WithStackTrace(err)
   241  		}
   242  		// run background discovering
   243  		if d := discoveryConfig.Interval(); d > 0 {
   244  			b.discoveryRepeater = repeater.New(xcontext.ValueOnly(ctx),
   245  				d, b.clusterDiscoveryAttempt,
   246  				repeater.WithName("discovery"),
   247  				repeater.WithTrace(b.driverConfig.Trace()),
   248  			)
   249  		}
   250  	}
   251  
   252  	return b, nil
   253  }
   254  
   255  func (b *Balancer) Invoke(
   256  	ctx context.Context,
   257  	method string,
   258  	args interface{},
   259  	reply interface{},
   260  	opts ...grpc.CallOption,
   261  ) error {
   262  	return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
   263  		return cc.Invoke(ctx, method, args, reply, opts...)
   264  	})
   265  }
   266  
   267  func (b *Balancer) NewStream(
   268  	ctx context.Context,
   269  	desc *grpc.StreamDesc,
   270  	method string,
   271  	opts ...grpc.CallOption,
   272  ) (_ grpc.ClientStream, err error) {
   273  	var client grpc.ClientStream
   274  	err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
   275  		client, err = cc.NewStream(ctx, desc, method, opts...)
   276  
   277  		return err
   278  	})
   279  	if err == nil {
   280  		return client, nil
   281  	}
   282  
   283  	return nil, err
   284  }
   285  
   286  func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) {
   287  	cc, err := b.getConn(ctx)
   288  	if err != nil {
   289  		return xerrors.WithStackTrace(err)
   290  	}
   291  
   292  	defer func() {
   293  		if err == nil {
   294  			if cc.GetState() == conn.Banned {
   295  				b.pool.Allow(ctx, cc)
   296  			}
   297  		} else if conn.IsBadConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
   298  			b.pool.Ban(ctx, cc, err)
   299  		}
   300  	}()
   301  
   302  	if ctx, err = b.driverConfig.Meta().Context(ctx); err != nil {
   303  		return xerrors.WithStackTrace(err)
   304  	}
   305  
   306  	if err = f(ctx, cc); err != nil {
   307  		if conn.UseWrapping(ctx) {
   308  			if credentials.IsAccessError(err) {
   309  				err = credentials.AccessError("no access", err,
   310  					credentials.WithAddress(cc.Endpoint().String()),
   311  					credentials.WithNodeID(cc.Endpoint().NodeID()),
   312  					credentials.WithCredentials(b.driverConfig.Credentials()),
   313  				)
   314  			}
   315  
   316  			return xerrors.WithStackTrace(err)
   317  		}
   318  
   319  		return err
   320  	}
   321  
   322  	return nil
   323  }
   324  
   325  func (b *Balancer) connections() *connectionsState {
   326  	return b.connectionsState.Load()
   327  }
   328  
   329  func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
   330  	onDone := trace.DriverOnBalancerChooseEndpoint(
   331  		b.driverConfig.Trace(), &ctx,
   332  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).getConn"),
   333  	)
   334  	defer func() {
   335  		if err == nil {
   336  			onDone(c.Endpoint(), nil)
   337  		} else {
   338  			onDone(nil, err)
   339  		}
   340  	}()
   341  
   342  	if err = ctx.Err(); err != nil {
   343  		return nil, xerrors.WithStackTrace(err)
   344  	}
   345  
   346  	var (
   347  		state       = b.connections()
   348  		failedCount int
   349  	)
   350  
   351  	defer func() {
   352  		if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil {
   353  			b.discoveryRepeater.Force()
   354  		}
   355  	}()
   356  
   357  	c, failedCount = state.GetConnection(ctx)
   358  	if c == nil {
   359  		return nil, xerrors.WithStackTrace(
   360  			fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount),
   361  		)
   362  	}
   363  
   364  	return c, nil
   365  }
   366  
   367  func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn {
   368  	conns := make([]conn.Conn, 0, len(endpoints))
   369  	for _, e := range endpoints {
   370  		conns = append(conns, p.Get(e))
   371  	}
   372  
   373  	return conns
   374  }