github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/connections_state.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  
     6  	balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
     7  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand"
     9  )
    10  
    11  type connectionsState struct {
    12  	connByNodeID map[uint32]conn.Conn
    13  
    14  	prefer   []conn.Conn
    15  	fallback []conn.Conn
    16  	all      []conn.Conn
    17  
    18  	rand xrand.Rand
    19  }
    20  
    21  func newConnectionsState(
    22  	conns []conn.Conn,
    23  	filter balancerConfig.Filter,
    24  	info balancerConfig.Info,
    25  	allowFallback bool,
    26  ) *connectionsState {
    27  	res := &connectionsState{
    28  		connByNodeID: connsToNodeIDMap(conns),
    29  		rand:         xrand.New(xrand.WithLock()),
    30  	}
    31  
    32  	res.prefer, res.fallback = sortPreferConnections(conns, filter, info, allowFallback)
    33  	if allowFallback {
    34  		res.all = conns
    35  	} else {
    36  		res.all = res.prefer
    37  	}
    38  
    39  	return res
    40  }
    41  
    42  func (s *connectionsState) PreferredCount() int {
    43  	return len(s.prefer)
    44  }
    45  
    46  func (s *connectionsState) GetConnection(ctx context.Context) (_ conn.Conn, failedCount int) {
    47  	if err := ctx.Err(); err != nil {
    48  		return nil, 0
    49  	}
    50  
    51  	if c := s.preferConnection(ctx); c != nil {
    52  		return c, 0
    53  	}
    54  
    55  	try := func(conns []conn.Conn) conn.Conn {
    56  		c, tryFailed := s.selectRandomConnection(conns, false)
    57  		failedCount += tryFailed
    58  
    59  		return c
    60  	}
    61  
    62  	if c := try(s.prefer); c != nil {
    63  		return c, failedCount
    64  	}
    65  
    66  	if c := try(s.fallback); c != nil {
    67  		return c, failedCount
    68  	}
    69  
    70  	c, _ := s.selectRandomConnection(s.all, true)
    71  
    72  	return c, failedCount
    73  }
    74  
    75  func (s *connectionsState) preferConnection(ctx context.Context) conn.Conn {
    76  	if e, hasPreferEndpoint := ContextEndpoint(ctx); hasPreferEndpoint {
    77  		c := s.connByNodeID[e.NodeID()]
    78  		if c != nil && isOkConnection(c, true) {
    79  			return c
    80  		}
    81  	}
    82  
    83  	return nil
    84  }
    85  
    86  func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned bool) (c conn.Conn, failedConns int) {
    87  	connCount := len(conns)
    88  	if connCount == 0 {
    89  		// return for empty list need for prevent panic in fast path
    90  		return nil, 0
    91  	}
    92  
    93  	// fast path
    94  	if c := conns[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) {
    95  		return c, 0
    96  	}
    97  
    98  	// shuffled indexes slices need for guarantee about every connection will check
    99  	indexes := make([]int, connCount)
   100  	for index := range indexes {
   101  		indexes[index] = index
   102  	}
   103  	s.rand.Shuffle(connCount, func(i, j int) {
   104  		indexes[i], indexes[j] = indexes[j], indexes[i]
   105  	})
   106  
   107  	for _, index := range indexes {
   108  		c := conns[index]
   109  		if isOkConnection(c, allowBanned) {
   110  			return c, 0
   111  		}
   112  		failedConns++
   113  	}
   114  
   115  	return nil, failedConns
   116  }
   117  
   118  func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) {
   119  	if len(conns) == 0 {
   120  		return nil
   121  	}
   122  	nodes = make(map[uint32]conn.Conn, len(conns))
   123  	for _, c := range conns {
   124  		nodes[c.Endpoint().NodeID()] = c
   125  	}
   126  
   127  	return nodes
   128  }
   129  
   130  func sortPreferConnections(
   131  	conns []conn.Conn,
   132  	filter balancerConfig.Filter,
   133  	info balancerConfig.Info,
   134  	allowFallback bool,
   135  ) (prefer, fallback []conn.Conn) {
   136  	if filter == nil {
   137  		return conns, nil
   138  	}
   139  
   140  	prefer = make([]conn.Conn, 0, len(conns))
   141  	if allowFallback {
   142  		fallback = make([]conn.Conn, 0, len(conns))
   143  	}
   144  
   145  	for _, c := range conns {
   146  		if filter.Allow(info, c) {
   147  			prefer = append(prefer, c)
   148  		} else if allowFallback {
   149  			fallback = append(fallback, c)
   150  		}
   151  	}
   152  
   153  	return prefer, fallback
   154  }
   155  
   156  func isOkConnection(c conn.Conn, bannedIsOk bool) bool {
   157  	switch c.GetState() {
   158  	case conn.Online, conn.Created, conn.Offline:
   159  		return true
   160  	case conn.Banned:
   161  		return bannedIsOk
   162  	default:
   163  		return false
   164  	}
   165  }