google.golang.org/grpc@v1.74.2/balancer/leastrequest/leastrequest.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package leastrequest implements a least request load balancer.
    20  package leastrequest
    21  
    22  import (
    23  	"encoding/json"
    24  	"fmt"
    25  	rand "math/rand/v2"
    26  	"sync"
    27  	"sync/atomic"
    28  
    29  	"google.golang.org/grpc/balancer"
    30  	"google.golang.org/grpc/balancer/endpointsharding"
    31  	"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
    32  	"google.golang.org/grpc/connectivity"
    33  	"google.golang.org/grpc/grpclog"
    34  	internalgrpclog "google.golang.org/grpc/internal/grpclog"
    35  	"google.golang.org/grpc/resolver"
    36  	"google.golang.org/grpc/serviceconfig"
    37  )
    38  
    39  // Name is the name of the least request balancer.
    40  const Name = "least_request_experimental"
    41  
    42  var (
    43  	// randuint32 is a global to stub out in tests.
    44  	randuint32 = rand.Uint32
    45  	logger     = grpclog.Component("least-request")
    46  )
    47  
    48  func init() {
    49  	balancer.Register(bb{})
    50  }
    51  
    52  // LBConfig is the balancer config for least_request_experimental balancer.
    53  type LBConfig struct {
    54  	serviceconfig.LoadBalancingConfig `json:"-"`
    55  
    56  	// ChoiceCount is the number of random SubConns to sample to find the one
    57  	// with the fewest outstanding requests. If unset, defaults to 2. If set to
    58  	// < 2, the config will be rejected, and if set to > 10, will become 10.
    59  	ChoiceCount uint32 `json:"choiceCount,omitempty"`
    60  }
    61  
    62  type bb struct{}
    63  
    64  func (bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
    65  	lbConfig := &LBConfig{
    66  		ChoiceCount: 2,
    67  	}
    68  	if err := json.Unmarshal(s, lbConfig); err != nil {
    69  		return nil, fmt.Errorf("least-request: unable to unmarshal LBConfig: %v", err)
    70  	}
    71  	// "If `choice_count < 2`, the config will be rejected." - A48
    72  	if lbConfig.ChoiceCount < 2 { // sweet
    73  		return nil, fmt.Errorf("least-request: lbConfig.choiceCount: %v, must be >= 2", lbConfig.ChoiceCount)
    74  	}
    75  	// "If a LeastRequestLoadBalancingConfig with a choice_count > 10 is
    76  	// received, the least_request_experimental policy will set choice_count =
    77  	// 10." - A48
    78  	if lbConfig.ChoiceCount > 10 {
    79  		lbConfig.ChoiceCount = 10
    80  	}
    81  	return lbConfig, nil
    82  }
    83  
    84  func (bb) Name() string {
    85  	return Name
    86  }
    87  
    88  func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
    89  	b := &leastRequestBalancer{
    90  		ClientConn:        cc,
    91  		endpointRPCCounts: resolver.NewEndpointMap[*atomic.Int32](),
    92  	}
    93  	b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{})
    94  	b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
    95  	b.logger.Infof("Created")
    96  	return b
    97  }
    98  
    99  type leastRequestBalancer struct {
   100  	// Embeds balancer.ClientConn because we need to intercept UpdateState
   101  	// calls from the child balancer.
   102  	balancer.ClientConn
   103  	child  balancer.Balancer
   104  	logger *internalgrpclog.PrefixLogger
   105  
   106  	mu          sync.Mutex
   107  	choiceCount uint32
   108  	// endpointRPCCounts holds RPC counts to keep track for subsequent picker
   109  	// updates.
   110  	endpointRPCCounts *resolver.EndpointMap[*atomic.Int32]
   111  }
   112  
   113  func (lrb *leastRequestBalancer) Close() {
   114  	lrb.child.Close()
   115  	lrb.endpointRPCCounts = nil
   116  }
   117  
   118  func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
   119  	lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
   120  }
   121  
   122  func (lrb *leastRequestBalancer) ResolverError(err error) {
   123  	// Will cause inline picker update from endpoint sharding.
   124  	lrb.child.ResolverError(err)
   125  }
   126  
   127  func (lrb *leastRequestBalancer) ExitIdle() {
   128  	lrb.child.ExitIdle()
   129  }
   130  
   131  func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
   132  	lrCfg, ok := ccs.BalancerConfig.(*LBConfig)
   133  	if !ok {
   134  		logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig)
   135  		return balancer.ErrBadResolverState
   136  	}
   137  
   138  	lrb.mu.Lock()
   139  	lrb.choiceCount = lrCfg.ChoiceCount
   140  	lrb.mu.Unlock()
   141  	return lrb.child.UpdateClientConnState(balancer.ClientConnState{
   142  		// Enable the health listener in pickfirst children for client side health
   143  		// checks and outlier detection, if configured.
   144  		ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
   145  	})
   146  }
   147  
   148  type endpointState struct {
   149  	picker  balancer.Picker
   150  	numRPCs *atomic.Int32
   151  }
   152  
   153  func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
   154  	var readyEndpoints []endpointsharding.ChildState
   155  	for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) {
   156  		if child.State.ConnectivityState == connectivity.Ready {
   157  			readyEndpoints = append(readyEndpoints, child)
   158  		}
   159  	}
   160  
   161  	// If no ready pickers are present, simply defer to the round robin picker
   162  	// from endpoint sharding, which will round robin across the most relevant
   163  	// pick first children in the highest precedence connectivity state.
   164  	if len(readyEndpoints) == 0 {
   165  		lrb.ClientConn.UpdateState(state)
   166  		return
   167  	}
   168  
   169  	lrb.mu.Lock()
   170  	defer lrb.mu.Unlock()
   171  
   172  	if logger.V(2) {
   173  		lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints)
   174  	}
   175  
   176  	// Reconcile endpoints.
   177  	newEndpoints := resolver.NewEndpointMap[any]()
   178  	for _, child := range readyEndpoints {
   179  		newEndpoints.Set(child.Endpoint, nil)
   180  	}
   181  
   182  	// If endpoints are no longer ready, no need to count their active RPCs.
   183  	for _, endpoint := range lrb.endpointRPCCounts.Keys() {
   184  		if _, ok := newEndpoints.Get(endpoint); !ok {
   185  			lrb.endpointRPCCounts.Delete(endpoint)
   186  		}
   187  	}
   188  
   189  	// Copy refs to counters into picker.
   190  	endpointStates := make([]endpointState, 0, len(readyEndpoints))
   191  	for _, child := range readyEndpoints {
   192  		counter, ok := lrb.endpointRPCCounts.Get(child.Endpoint)
   193  		if !ok {
   194  			// Create new counts if needed.
   195  			counter = new(atomic.Int32)
   196  			lrb.endpointRPCCounts.Set(child.Endpoint, counter)
   197  		}
   198  		endpointStates = append(endpointStates, endpointState{
   199  			picker:  child.State.Picker,
   200  			numRPCs: counter,
   201  		})
   202  	}
   203  
   204  	lrb.ClientConn.UpdateState(balancer.State{
   205  		Picker: &picker{
   206  			choiceCount:    lrb.choiceCount,
   207  			endpointStates: endpointStates,
   208  		},
   209  		ConnectivityState: connectivity.Ready,
   210  	})
   211  }
   212  
   213  type picker struct {
   214  	// choiceCount is the number of random endpoints to sample for choosing the
   215  	// one with the least requests.
   216  	choiceCount    uint32
   217  	endpointStates []endpointState
   218  }
   219  
   220  func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) {
   221  	var pickedEndpointState *endpointState
   222  	var pickedEndpointNumRPCs int32
   223  	for i := 0; i < int(p.choiceCount); i++ {
   224  		index := randuint32() % uint32(len(p.endpointStates))
   225  		endpointState := p.endpointStates[index]
   226  		n := endpointState.numRPCs.Load()
   227  		if pickedEndpointState == nil || n < pickedEndpointNumRPCs {
   228  			pickedEndpointState = &endpointState
   229  			pickedEndpointNumRPCs = n
   230  		}
   231  	}
   232  	result, err := pickedEndpointState.picker.Pick(pInfo)
   233  	if err != nil {
   234  		return result, err
   235  	}
   236  	// "The counter for a subchannel should be atomically incremented by one
   237  	// after it has been successfully picked by the picker." - A48
   238  	pickedEndpointState.numRPCs.Add(1)
   239  	// "the picker should add a callback for atomically decrementing the
   240  	// subchannel counter once the RPC finishes (regardless of Status code)." -
   241  	// A48.
   242  	originalDone := result.Done
   243  	result.Done = func(info balancer.DoneInfo) {
   244  		pickedEndpointState.numRPCs.Add(-1)
   245  		if originalDone != nil {
   246  			originalDone(info)
   247  		}
   248  	}
   249  	return result, nil
   250  }