github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/rate_limit_interceptor.go (about)

     1  // Licensed to the LF AI & Data foundation under one
     2  // or more contributor license agreements. See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership. The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License. You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"time"
    22  
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/codes"
    25  	"google.golang.org/grpc/status"
    26  
    27  	grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
    28  	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
    29  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    30  )
    31  
    32  // ref: https://github.com/grpc-ecosystem/go-grpc-middleware
    33  
    34  type ctxKey int
    35  
    36  const (
    37  	RetryOnRateLimit ctxKey = iota
    38  )
    39  
    40  // RetryOnRateLimitInterceptor returns a new retrying unary client interceptor.
    41  func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor {
    42  	return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    43  		if maxRetry == 0 {
    44  			return invoker(parentCtx, method, req, reply, cc, opts...)
    45  		}
    46  		var lastErr error
    47  		for attempt := uint(0); attempt < maxRetry; attempt++ {
    48  			_, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc)
    49  			if err != nil {
    50  				return err
    51  			}
    52  			lastErr = invoker(parentCtx, method, req, reply, cc, opts...)
    53  			rspStatus := getResultStatus(reply)
    54  			if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit {
    55  				continue
    56  			}
    57  			return lastErr
    58  		}
    59  		return lastErr
    60  	}
    61  }
    62  
    63  func retryOnRateLimit(ctx context.Context) bool {
    64  	retry, ok := ctx.Value(RetryOnRateLimit).(bool)
    65  	if !ok {
    66  		return true // default true
    67  	}
    68  	return retry
    69  }
    70  
    71  // getResultStatus returns status of response.
    72  func getResultStatus(reply interface{}) *commonpb.Status {
    73  	switch r := reply.(type) {
    74  	case *commonpb.Status:
    75  		return r
    76  	case *milvuspb.MutationResult:
    77  		return r.GetStatus()
    78  	case *milvuspb.BoolResponse:
    79  		return r.GetStatus()
    80  	case *milvuspb.SearchResults:
    81  		return r.GetStatus()
    82  	case *milvuspb.QueryResults:
    83  		return r.GetStatus()
    84  	case *milvuspb.FlushResponse:
    85  		return r.GetStatus()
    86  	default:
    87  		return nil
    88  	}
    89  }
    90  
    91  func contextErrToGrpcErr(err error) error {
    92  	switch err {
    93  	case context.DeadlineExceeded:
    94  		return status.Error(codes.DeadlineExceeded, err.Error())
    95  	case context.Canceled:
    96  		return status.Error(codes.Canceled, err.Error())
    97  	default:
    98  		return status.Error(codes.Unknown, err.Error())
    99  	}
   100  }
   101  
   102  func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) {
   103  	var waitTime time.Duration
   104  	if attempt > 0 {
   105  		waitTime = backoffFunc(parentCtx, attempt)
   106  	}
   107  	if waitTime > 0 {
   108  		if waitTime > maxBackoff {
   109  			waitTime = maxBackoff
   110  		}
   111  		timer := time.NewTimer(waitTime)
   112  		select {
   113  		case <-parentCtx.Done():
   114  			timer.Stop()
   115  			return waitTime, contextErrToGrpcErr(parentCtx.Err())
   116  		case <-timer.C:
   117  		}
   118  	}
   119  	return waitTime, nil
   120  }