github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/balancer/connections_state.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  
     6  	balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
     7  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand"
    10  )
    11  
    12  type connectionsState struct {
    13  	connByNodeID map[uint32]conn.Conn
    14  
    15  	prefer   []conn.Conn
    16  	fallback []conn.Conn
    17  	all      []conn.Conn
    18  
    19  	rand xrand.Rand
    20  }
    21  
    22  func newConnectionsState(
    23  	conns []conn.Conn,
    24  	filter balancerConfig.Filter,
    25  	info balancerConfig.Info,
    26  	allowFallback bool,
    27  ) *connectionsState {
    28  	res := &connectionsState{
    29  		connByNodeID: connsToNodeIDMap(conns),
    30  		rand:         xrand.New(xrand.WithLock()),
    31  	}
    32  
    33  	res.prefer, res.fallback = sortPreferConnections(conns, filter, info, allowFallback)
    34  	if allowFallback {
    35  		res.all = conns
    36  	} else {
    37  		res.all = res.prefer
    38  	}
    39  
    40  	return res
    41  }
    42  
    43  func (s *connectionsState) PreferredCount() int {
    44  	return len(s.prefer)
    45  }
    46  
    47  func (s *connectionsState) All() (all []endpoint.Endpoint) {
    48  	if s == nil {
    49  		return nil
    50  	}
    51  
    52  	all = make([]endpoint.Endpoint, len(s.all))
    53  	for i, c := range s.all {
    54  		all[i] = c.Endpoint()
    55  	}
    56  
    57  	return all
    58  }
    59  
    60  func (s *connectionsState) GetConnection(ctx context.Context) (_ conn.Conn, failedCount int) {
    61  	if err := ctx.Err(); err != nil {
    62  		return nil, 0
    63  	}
    64  
    65  	if c := s.preferConnection(ctx); c != nil {
    66  		return c, 0
    67  	}
    68  
    69  	try := func(conns []conn.Conn) conn.Conn {
    70  		c, tryFailed := s.selectRandomConnection(conns, false)
    71  		failedCount += tryFailed
    72  
    73  		return c
    74  	}
    75  
    76  	if c := try(s.prefer); c != nil {
    77  		return c, failedCount
    78  	}
    79  
    80  	if c := try(s.fallback); c != nil {
    81  		return c, failedCount
    82  	}
    83  
    84  	c, _ := s.selectRandomConnection(s.all, true)
    85  
    86  	return c, failedCount
    87  }
    88  
    89  func (s *connectionsState) preferConnection(ctx context.Context) conn.Conn {
    90  	if nodeID, hasPreferEndpoint := endpoint.ContextNodeID(ctx); hasPreferEndpoint {
    91  		c := s.connByNodeID[nodeID]
    92  		if c != nil && isOkConnection(c, true) {
    93  			return c
    94  		}
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned bool) (c conn.Conn, failedConns int) {
   101  	connCount := len(conns)
   102  	if connCount == 0 {
   103  		// return for empty list need for prevent panic in fast path
   104  		return nil, 0
   105  	}
   106  
   107  	// fast path
   108  	if c := conns[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) {
   109  		return c, 0
   110  	}
   111  
   112  	// shuffled indexes slices need for guarantee about every connection will check
   113  	indexes := make([]int, connCount)
   114  	for index := range indexes {
   115  		indexes[index] = index
   116  	}
   117  	s.rand.Shuffle(connCount, func(i, j int) {
   118  		indexes[i], indexes[j] = indexes[j], indexes[i]
   119  	})
   120  
   121  	for _, index := range indexes {
   122  		c := conns[index]
   123  		if isOkConnection(c, allowBanned) {
   124  			return c, 0
   125  		}
   126  		failedConns++
   127  	}
   128  
   129  	return nil, failedConns
   130  }
   131  
   132  func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) {
   133  	if len(conns) == 0 {
   134  		return nil
   135  	}
   136  	nodes = make(map[uint32]conn.Conn, len(conns))
   137  	for _, c := range conns {
   138  		nodes[c.Endpoint().NodeID()] = c
   139  	}
   140  
   141  	return nodes
   142  }
   143  
   144  func sortPreferConnections(
   145  	conns []conn.Conn,
   146  	filter balancerConfig.Filter,
   147  	info balancerConfig.Info,
   148  	allowFallback bool,
   149  ) (prefer, fallback []conn.Conn) {
   150  	if filter == nil {
   151  		return conns, nil
   152  	}
   153  
   154  	prefer = make([]conn.Conn, 0, len(conns))
   155  	if allowFallback {
   156  		fallback = make([]conn.Conn, 0, len(conns))
   157  	}
   158  
   159  	for _, c := range conns {
   160  		if filter.Allow(info, c.Endpoint()) {
   161  			prefer = append(prefer, c)
   162  		} else if allowFallback {
   163  			fallback = append(fallback, c)
   164  		}
   165  	}
   166  
   167  	return prefer, fallback
   168  }
   169  
   170  func isOkConnection(c conn.Conn, bannedIsOk bool) bool {
   171  	switch c.GetState() {
   172  	case conn.Online, conn.Created, conn.Offline:
   173  		return true
   174  	case conn.Banned:
   175  		return bannedIsOk
   176  	default:
   177  		return false
   178  	}
   179  }