google.golang.org/grpc@v1.72.2/balancer/leastrequest/leastrequest.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package leastrequest implements a least request load balancer.
    20  package leastrequest
    21  
    22  import (
    23  	"encoding/json"
    24  	"fmt"
    25  	rand "math/rand/v2"
    26  	"sync"
    27  	"sync/atomic"
    28  
    29  	"google.golang.org/grpc/balancer"
    30  	"google.golang.org/grpc/balancer/endpointsharding"
    31  	"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
    32  	"google.golang.org/grpc/connectivity"
    33  	"google.golang.org/grpc/grpclog"
    34  	internalgrpclog "google.golang.org/grpc/internal/grpclog"
    35  	"google.golang.org/grpc/resolver"
    36  	"google.golang.org/grpc/serviceconfig"
    37  )
    38  
    39  // Name is the name of the least request balancer.
    40  const Name = "least_request_experimental"
    41  
    42  var (
    43  	// randuint32 is a global to stub out in tests.
    44  	randuint32 = rand.Uint32
    45  	logger     = grpclog.Component("least-request")
    46  )
    47  
    48  func init() {
    49  	balancer.Register(bb{})
    50  }
    51  
    52  // LBConfig is the balancer config for least_request_experimental balancer.
    53  type LBConfig struct {
    54  	serviceconfig.LoadBalancingConfig `json:"-"`
    55  
    56  	// ChoiceCount is the number of random SubConns to sample to find the one
    57  	// with the fewest outstanding requests. If unset, defaults to 2. If set to
    58  	// < 2, the config will be rejected, and if set to > 10, will become 10.
    59  	ChoiceCount uint32 `json:"choiceCount,omitempty"`
    60  }
    61  
    62  type bb struct{}
    63  
    64  func (bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
    65  	lbConfig := &LBConfig{
    66  		ChoiceCount: 2,
    67  	}
    68  	if err := json.Unmarshal(s, lbConfig); err != nil {
    69  		return nil, fmt.Errorf("least-request: unable to unmarshal LBConfig: %v", err)
    70  	}
    71  	// "If `choice_count < 2`, the config will be rejected." - A48
    72  	if lbConfig.ChoiceCount < 2 { // sweet
    73  		return nil, fmt.Errorf("least-request: lbConfig.choiceCount: %v, must be >= 2", lbConfig.ChoiceCount)
    74  	}
    75  	// "If a LeastRequestLoadBalancingConfig with a choice_count > 10 is
    76  	// received, the least_request_experimental policy will set choice_count =
    77  	// 10." - A48
    78  	if lbConfig.ChoiceCount > 10 {
    79  		lbConfig.ChoiceCount = 10
    80  	}
    81  	return lbConfig, nil
    82  }
    83  
    84  func (bb) Name() string {
    85  	return Name
    86  }
    87  
    88  func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
    89  	b := &leastRequestBalancer{
    90  		ClientConn:        cc,
    91  		endpointRPCCounts: resolver.NewEndpointMap[*atomic.Int32](),
    92  	}
    93  	b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{})
    94  	b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
    95  	b.logger.Infof("Created")
    96  	return b
    97  }
    98  
    99  type leastRequestBalancer struct {
   100  	// Embeds balancer.ClientConn because we need to intercept UpdateState
   101  	// calls from the child balancer.
   102  	balancer.ClientConn
   103  	child  balancer.Balancer
   104  	logger *internalgrpclog.PrefixLogger
   105  
   106  	mu          sync.Mutex
   107  	choiceCount uint32
   108  	// endpointRPCCounts holds RPC counts to keep track for subsequent picker
   109  	// updates.
   110  	endpointRPCCounts *resolver.EndpointMap[*atomic.Int32]
   111  }
   112  
   113  func (lrb *leastRequestBalancer) Close() {
   114  	lrb.child.Close()
   115  	lrb.endpointRPCCounts = nil
   116  }
   117  
   118  func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
   119  	lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
   120  }
   121  
   122  func (lrb *leastRequestBalancer) ResolverError(err error) {
   123  	// Will cause inline picker update from endpoint sharding.
   124  	lrb.child.ResolverError(err)
   125  }
   126  
   127  func (lrb *leastRequestBalancer) ExitIdle() {
   128  	if ei, ok := lrb.child.(balancer.ExitIdler); ok { // Should always be ok, as child is endpoint sharding.
   129  		ei.ExitIdle()
   130  	}
   131  }
   132  
   133  func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
   134  	lrCfg, ok := ccs.BalancerConfig.(*LBConfig)
   135  	if !ok {
   136  		logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig)
   137  		return balancer.ErrBadResolverState
   138  	}
   139  
   140  	lrb.mu.Lock()
   141  	lrb.choiceCount = lrCfg.ChoiceCount
   142  	lrb.mu.Unlock()
   143  	return lrb.child.UpdateClientConnState(balancer.ClientConnState{
   144  		// Enable the health listener in pickfirst children for client side health
   145  		// checks and outlier detection, if configured.
   146  		ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
   147  	})
   148  }
   149  
   150  type endpointState struct {
   151  	picker  balancer.Picker
   152  	numRPCs *atomic.Int32
   153  }
   154  
   155  func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
   156  	var readyEndpoints []endpointsharding.ChildState
   157  	for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) {
   158  		if child.State.ConnectivityState == connectivity.Ready {
   159  			readyEndpoints = append(readyEndpoints, child)
   160  		}
   161  	}
   162  
   163  	// If no ready pickers are present, simply defer to the round robin picker
   164  	// from endpoint sharding, which will round robin across the most relevant
   165  	// pick first children in the highest precedence connectivity state.
   166  	if len(readyEndpoints) == 0 {
   167  		lrb.ClientConn.UpdateState(state)
   168  		return
   169  	}
   170  
   171  	lrb.mu.Lock()
   172  	defer lrb.mu.Unlock()
   173  
   174  	if logger.V(2) {
   175  		lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints)
   176  	}
   177  
   178  	// Reconcile endpoints.
   179  	newEndpoints := resolver.NewEndpointMap[any]()
   180  	for _, child := range readyEndpoints {
   181  		newEndpoints.Set(child.Endpoint, nil)
   182  	}
   183  
   184  	// If endpoints are no longer ready, no need to count their active RPCs.
   185  	for _, endpoint := range lrb.endpointRPCCounts.Keys() {
   186  		if _, ok := newEndpoints.Get(endpoint); !ok {
   187  			lrb.endpointRPCCounts.Delete(endpoint)
   188  		}
   189  	}
   190  
   191  	// Copy refs to counters into picker.
   192  	endpointStates := make([]endpointState, 0, len(readyEndpoints))
   193  	for _, child := range readyEndpoints {
   194  		counter, ok := lrb.endpointRPCCounts.Get(child.Endpoint)
   195  		if !ok {
   196  			// Create new counts if needed.
   197  			counter = new(atomic.Int32)
   198  			lrb.endpointRPCCounts.Set(child.Endpoint, counter)
   199  		}
   200  		endpointStates = append(endpointStates, endpointState{
   201  			picker:  child.State.Picker,
   202  			numRPCs: counter,
   203  		})
   204  	}
   205  
   206  	lrb.ClientConn.UpdateState(balancer.State{
   207  		Picker: &picker{
   208  			choiceCount:    lrb.choiceCount,
   209  			endpointStates: endpointStates,
   210  		},
   211  		ConnectivityState: connectivity.Ready,
   212  	})
   213  }
   214  
   215  type picker struct {
   216  	// choiceCount is the number of random endpoints to sample for choosing the
   217  	// one with the least requests.
   218  	choiceCount    uint32
   219  	endpointStates []endpointState
   220  }
   221  
   222  func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) {
   223  	var pickedEndpointState *endpointState
   224  	var pickedEndpointNumRPCs int32
   225  	for i := 0; i < int(p.choiceCount); i++ {
   226  		index := randuint32() % uint32(len(p.endpointStates))
   227  		endpointState := p.endpointStates[index]
   228  		n := endpointState.numRPCs.Load()
   229  		if pickedEndpointState == nil || n < pickedEndpointNumRPCs {
   230  			pickedEndpointState = &endpointState
   231  			pickedEndpointNumRPCs = n
   232  		}
   233  	}
   234  	result, err := pickedEndpointState.picker.Pick(pInfo)
   235  	if err != nil {
   236  		return result, err
   237  	}
   238  	// "The counter for a subchannel should be atomically incremented by one
   239  	// after it has been successfully picked by the picker." - A48
   240  	pickedEndpointState.numRPCs.Add(1)
   241  	// "the picker should add a callback for atomically decrementing the
   242  	// subchannel counter once the RPC finishes (regardless of Status code)." -
   243  	// A48.
   244  	originalDone := result.Done
   245  	result.Done = func(info balancer.DoneInfo) {
   246  		pickedEndpointState.numRPCs.Add(-1)
   247  		if originalDone != nil {
   248  			originalDone(info)
   249  		}
   250  	}
   251  	return result, nil
   252  }