github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/balancer/cluster/cluster.go (about)

     1  package cluster
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
     7  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices"
    10  )
    11  
    12  type (
    13  	Cluster struct {
    14  		filter        func(e endpoint.Info) bool
    15  		allowFallback bool
    16  
    17  		index map[uint32]endpoint.Endpoint
    18  
    19  		prefer   []endpoint.Endpoint
    20  		fallback []endpoint.Endpoint
    21  		all      []endpoint.Endpoint
    22  
    23  		rand xrand.Rand
    24  	}
    25  	option func(s *Cluster)
    26  )
    27  
    28  func WithFilter(filter func(e endpoint.Info) bool) option {
    29  	return func(s *Cluster) {
    30  		s.filter = filter
    31  	}
    32  }
    33  
    34  func WithFallback(allowFallback bool) option {
    35  	return func(s *Cluster) {
    36  		s.allowFallback = allowFallback
    37  	}
    38  }
    39  
    40  func New(endpoints []endpoint.Endpoint, opts ...option) *Cluster {
    41  	s := &Cluster{
    42  		filter: func(e endpoint.Info) bool {
    43  			return true
    44  		},
    45  	}
    46  
    47  	for _, opt := range opts {
    48  		opt(s)
    49  	}
    50  
    51  	if s.rand == nil {
    52  		s.rand = xrand.New(xrand.WithLock())
    53  	}
    54  
    55  	s.prefer, s.fallback = xslices.Split(endpoints, func(e endpoint.Endpoint) bool {
    56  		return s.filter(e)
    57  	})
    58  
    59  	if s.allowFallback {
    60  		s.all = endpoints
    61  		s.index = xslices.Map(endpoints, func(e endpoint.Endpoint) uint32 { return e.NodeID() })
    62  	} else {
    63  		s.all = s.prefer
    64  		s.fallback = nil
    65  		s.index = xslices.Map(s.prefer, func(e endpoint.Endpoint) uint32 { return e.NodeID() })
    66  	}
    67  
    68  	return s
    69  }
    70  
    71  func (s *Cluster) All() (all []endpoint.Endpoint) {
    72  	if s == nil {
    73  		return nil
    74  	}
    75  
    76  	return s.all
    77  }
    78  
    79  func Without(s *Cluster, endpoints ...endpoint.Endpoint) *Cluster {
    80  	prefer := make([]endpoint.Endpoint, 0, len(s.prefer))
    81  	fallback := s.fallback
    82  	for _, endpoint := range endpoints {
    83  		for i := range s.prefer {
    84  			if s.prefer[i].Address() != endpoint.Address() {
    85  				prefer = append(prefer, s.prefer[i])
    86  			} else {
    87  				fallback = append(fallback, s.prefer[i])
    88  			}
    89  		}
    90  	}
    91  
    92  	return &Cluster{
    93  		filter:        s.filter,
    94  		allowFallback: s.allowFallback,
    95  		index:         s.index,
    96  		prefer:        prefer,
    97  		fallback:      fallback,
    98  		all:           s.all,
    99  		rand:          s.rand,
   100  	}
   101  }
   102  
   103  func (s *Cluster) Next(ctx context.Context) (endpoint.Endpoint, error) {
   104  	if s == nil {
   105  		return nil, ErrNilPtr
   106  	}
   107  
   108  	if err := ctx.Err(); err != nil {
   109  		return nil, xerrors.WithStackTrace(err)
   110  	}
   111  
   112  	if nodeID, wantEndpointByNodeID := endpoint.ContextNodeID(ctx); wantEndpointByNodeID {
   113  		e, has := s.index[nodeID]
   114  		if has {
   115  			return e, nil
   116  		}
   117  	}
   118  
   119  	if l := len(s.prefer); l > 0 {
   120  		return s.prefer[s.rand.Int(l)], nil
   121  	}
   122  
   123  	if l := len(s.fallback); l > 0 {
   124  		return s.fallback[s.rand.Int(l)], nil
   125  	}
   126  
   127  	return nil, xerrors.WithStackTrace(ErrNoEndpoints)
   128  }