google.golang.org/grpc@v1.74.2/balancer/leastrequest/leastrequest.go (about) 1 /* 2 * 3 * Copyright 2023 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 // Package leastrequest implements a least request load balancer. 20 package leastrequest 21 22 import ( 23 "encoding/json" 24 "fmt" 25 rand "math/rand/v2" 26 "sync" 27 "sync/atomic" 28 29 "google.golang.org/grpc/balancer" 30 "google.golang.org/grpc/balancer/endpointsharding" 31 "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" 32 "google.golang.org/grpc/connectivity" 33 "google.golang.org/grpc/grpclog" 34 internalgrpclog "google.golang.org/grpc/internal/grpclog" 35 "google.golang.org/grpc/resolver" 36 "google.golang.org/grpc/serviceconfig" 37 ) 38 39 // Name is the name of the least request balancer. 40 const Name = "least_request_experimental" 41 42 var ( 43 // randuint32 is a global to stub out in tests. 44 randuint32 = rand.Uint32 45 logger = grpclog.Component("least-request") 46 ) 47 48 func init() { 49 balancer.Register(bb{}) 50 } 51 52 // LBConfig is the balancer config for least_request_experimental balancer. 53 type LBConfig struct { 54 serviceconfig.LoadBalancingConfig `json:"-"` 55 56 // ChoiceCount is the number of random SubConns to sample to find the one 57 // with the fewest outstanding requests. If unset, defaults to 2. If set to 58 // < 2, the config will be rejected, and if set to > 10, will become 10. 59 ChoiceCount uint32 `json:"choiceCount,omitempty"` 60 } 61 62 type bb struct{} 63 64 func (bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { 65 lbConfig := &LBConfig{ 66 ChoiceCount: 2, 67 } 68 if err := json.Unmarshal(s, lbConfig); err != nil { 69 return nil, fmt.Errorf("least-request: unable to unmarshal LBConfig: %v", err) 70 } 71 // "If `choice_count < 2`, the config will be rejected." - A48 72 if lbConfig.ChoiceCount < 2 { // sweet 73 return nil, fmt.Errorf("least-request: lbConfig.choiceCount: %v, must be >= 2", lbConfig.ChoiceCount) 74 } 75 // "If a LeastRequestLoadBalancingConfig with a choice_count > 10 is 76 // received, the least_request_experimental policy will set choice_count = 77 // 10." - A48 78 if lbConfig.ChoiceCount > 10 { 79 lbConfig.ChoiceCount = 10 80 } 81 return lbConfig, nil 82 } 83 84 func (bb) Name() string { 85 return Name 86 } 87 88 func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { 89 b := &leastRequestBalancer{ 90 ClientConn: cc, 91 endpointRPCCounts: resolver.NewEndpointMap[*atomic.Int32](), 92 } 93 b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{}) 94 b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b)) 95 b.logger.Infof("Created") 96 return b 97 } 98 99 type leastRequestBalancer struct { 100 // Embeds balancer.ClientConn because we need to intercept UpdateState 101 // calls from the child balancer. 102 balancer.ClientConn 103 child balancer.Balancer 104 logger *internalgrpclog.PrefixLogger 105 106 mu sync.Mutex 107 choiceCount uint32 108 // endpointRPCCounts holds RPC counts to keep track for subsequent picker 109 // updates. 110 endpointRPCCounts *resolver.EndpointMap[*atomic.Int32] 111 } 112 113 func (lrb *leastRequestBalancer) Close() { 114 lrb.child.Close() 115 lrb.endpointRPCCounts = nil 116 } 117 118 func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { 119 lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state) 120 } 121 122 func (lrb *leastRequestBalancer) ResolverError(err error) { 123 // Will cause inline picker update from endpoint sharding. 124 lrb.child.ResolverError(err) 125 } 126 127 func (lrb *leastRequestBalancer) ExitIdle() { 128 lrb.child.ExitIdle() 129 } 130 131 func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { 132 lrCfg, ok := ccs.BalancerConfig.(*LBConfig) 133 if !ok { 134 logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig) 135 return balancer.ErrBadResolverState 136 } 137 138 lrb.mu.Lock() 139 lrb.choiceCount = lrCfg.ChoiceCount 140 lrb.mu.Unlock() 141 return lrb.child.UpdateClientConnState(balancer.ClientConnState{ 142 // Enable the health listener in pickfirst children for client side health 143 // checks and outlier detection, if configured. 144 ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState), 145 }) 146 } 147 148 type endpointState struct { 149 picker balancer.Picker 150 numRPCs *atomic.Int32 151 } 152 153 func (lrb *leastRequestBalancer) UpdateState(state balancer.State) { 154 var readyEndpoints []endpointsharding.ChildState 155 for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) { 156 if child.State.ConnectivityState == connectivity.Ready { 157 readyEndpoints = append(readyEndpoints, child) 158 } 159 } 160 161 // If no ready pickers are present, simply defer to the round robin picker 162 // from endpoint sharding, which will round robin across the most relevant 163 // pick first children in the highest precedence connectivity state. 164 if len(readyEndpoints) == 0 { 165 lrb.ClientConn.UpdateState(state) 166 return 167 } 168 169 lrb.mu.Lock() 170 defer lrb.mu.Unlock() 171 172 if logger.V(2) { 173 lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints) 174 } 175 176 // Reconcile endpoints. 177 newEndpoints := resolver.NewEndpointMap[any]() 178 for _, child := range readyEndpoints { 179 newEndpoints.Set(child.Endpoint, nil) 180 } 181 182 // If endpoints are no longer ready, no need to count their active RPCs. 183 for _, endpoint := range lrb.endpointRPCCounts.Keys() { 184 if _, ok := newEndpoints.Get(endpoint); !ok { 185 lrb.endpointRPCCounts.Delete(endpoint) 186 } 187 } 188 189 // Copy refs to counters into picker. 190 endpointStates := make([]endpointState, 0, len(readyEndpoints)) 191 for _, child := range readyEndpoints { 192 counter, ok := lrb.endpointRPCCounts.Get(child.Endpoint) 193 if !ok { 194 // Create new counts if needed. 195 counter = new(atomic.Int32) 196 lrb.endpointRPCCounts.Set(child.Endpoint, counter) 197 } 198 endpointStates = append(endpointStates, endpointState{ 199 picker: child.State.Picker, 200 numRPCs: counter, 201 }) 202 } 203 204 lrb.ClientConn.UpdateState(balancer.State{ 205 Picker: &picker{ 206 choiceCount: lrb.choiceCount, 207 endpointStates: endpointStates, 208 }, 209 ConnectivityState: connectivity.Ready, 210 }) 211 } 212 213 type picker struct { 214 // choiceCount is the number of random endpoints to sample for choosing the 215 // one with the least requests. 216 choiceCount uint32 217 endpointStates []endpointState 218 } 219 220 func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) { 221 var pickedEndpointState *endpointState 222 var pickedEndpointNumRPCs int32 223 for i := 0; i < int(p.choiceCount); i++ { 224 index := randuint32() % uint32(len(p.endpointStates)) 225 endpointState := p.endpointStates[index] 226 n := endpointState.numRPCs.Load() 227 if pickedEndpointState == nil || n < pickedEndpointNumRPCs { 228 pickedEndpointState = &endpointState 229 pickedEndpointNumRPCs = n 230 } 231 } 232 result, err := pickedEndpointState.picker.Pick(pInfo) 233 if err != nil { 234 return result, err 235 } 236 // "The counter for a subchannel should be atomically incremented by one 237 // after it has been successfully picked by the picker." - A48 238 pickedEndpointState.numRPCs.Add(1) 239 // "the picker should add a callback for atomically decrementing the 240 // subchannel counter once the RPC finishes (regardless of Status code)." - 241 // A48. 242 originalDone := result.Done 243 result.Done = func(info balancer.DoneInfo) { 244 pickedEndpointState.numRPCs.Add(-1) 245 if originalDone != nil { 246 originalDone(info) 247 } 248 } 249 return result, nil 250 }