github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/engine/pkg/rpcutil/middleware.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package rpcutil
    15  
    16  import (
    17  	"context"
    18  	"reflect"
    19  	"strings"
    20  	"time"
    21  
    22  	perrors "github.com/pingcap/errors"
    23  	"github.com/pingcap/log"
    24  	"github.com/pingcap/tiflow/pkg/errors"
    25  	"go.uber.org/zap"
    26  	"golang.org/x/time/rate"
    27  	"google.golang.org/genproto/googleapis/rpc/errdetails"
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/codes"
    30  	"google.golang.org/grpc/status"
    31  )
    32  
    33  const metadataCauseKey = "cause"
    34  
    35  // ToGRPCError converts an error to a gRPC error.
    36  func ToGRPCError(errIn error) error {
    37  	if errIn == nil {
    38  		return nil
    39  	}
    40  	if _, ok := status.FromError(errIn); ok {
    41  		return errIn
    42  	}
    43  
    44  	var (
    45  		normalizedErr *perrors.Error
    46  		metadata      map[string]string
    47  		rfcCode       perrors.RFCErrorCode
    48  		errMsg        string
    49  	)
    50  	if errors.As(errIn, &normalizedErr) {
    51  		rfcCode = normalizedErr.RFCCode()
    52  		if cause := normalizedErr.Cause(); cause != nil {
    53  			metadata = map[string]string{
    54  				metadataCauseKey: cause.Error(),
    55  			}
    56  		}
    57  		errMsg = normalizedErr.GetMsg()
    58  	} else {
    59  		rfcCode = errors.ErrUnknown.RFCCode()
    60  		errMsg = errIn.Error()
    61  	}
    62  
    63  	code := errors.GRPCStatusCode(errIn)
    64  	st, err := status.New(code, errMsg).
    65  		WithDetails(&errdetails.ErrorInfo{
    66  			Reason:   string(rfcCode),
    67  			Metadata: metadata,
    68  		})
    69  	if err != nil {
    70  		return status.New(code, errMsg).Err()
    71  	}
    72  	return st.Err()
    73  }
    74  
    75  // FromGRPCError converts a gRPC error to a normalized error.
    76  func FromGRPCError(errIn error) error {
    77  	if errIn == nil {
    78  		return nil
    79  	}
    80  	st, ok := status.FromError(errIn)
    81  	if !ok {
    82  		return errIn
    83  	}
    84  	var errInfo *errdetails.ErrorInfo
    85  	for _, detail := range st.Details() {
    86  		if ei, ok := detail.(*errdetails.ErrorInfo); ok {
    87  			errInfo = ei
    88  			break
    89  		}
    90  	}
    91  	if errInfo == nil || errInfo.Reason == "" {
    92  		return errors.ErrUnknown.GenWithStack(st.Message())
    93  	}
    94  
    95  	normalizedErr := perrors.Normalize(st.Message(), perrors.RFCCodeText(errInfo.Reason))
    96  	if causeMsg := errInfo.Metadata[metadataCauseKey]; causeMsg != "" {
    97  		return normalizedErr.Wrap(perrors.New(causeMsg)).GenWithStackByArgs()
    98  	}
    99  	return normalizedErr.GenWithStackByArgs()
   100  }
   101  
   102  // ForwardToLeader is a gRPC middleware that forwards the request to the leader if the current node is not the leader.
   103  func ForwardToLeader[T any](fc ForwardChecker[T]) grpc.UnaryServerInterceptor {
   104  	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, _ error) {
   105  		method := extractMethod(info.FullMethod)
   106  		if fc.IsLeader() || !fc.LeaderOnly(method) {
   107  			return handler(ctx, req)
   108  		}
   109  
   110  		leaderCli, err := waitForLeader(ctx, fc)
   111  		if err != nil {
   112  			// Return gRPC error to avoid depending on NormalizeError middleware.
   113  			return nil, ToGRPCError(err)
   114  		}
   115  
   116  		fv := reflect.ValueOf(leaderCli).MethodByName(method)
   117  		if fv.IsZero() {
   118  			return nil, status.Errorf(codes.Unimplemented, "method %s not implemented", method)
   119  		}
   120  		results := fv.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(req)})
   121  		if len(results) != 2 {
   122  			log.Panic("invalid method signature", zap.String("method", method))
   123  		}
   124  		errI := results[1].Interface()
   125  		if errI != nil {
   126  			return nil, errI.(error)
   127  		}
   128  		return results[0].Interface(), nil
   129  	}
   130  }
   131  
   132  const (
   133  	waitForLeaderTimeout = 3 * time.Second
   134  	waitForLeaderTick    = 300 * time.Millisecond
   135  )
   136  
   137  func waitForLeader[T any](ctx context.Context, fc ForwardChecker[T]) (leaderCli T, _ error) {
   138  	leaderCli, err := fc.LeaderClient()
   139  	if err == nil {
   140  		return leaderCli, nil
   141  	}
   142  	if !errors.Is(err, errors.ErrMasterNoLeader) {
   143  		return leaderCli, err
   144  	}
   145  
   146  	timer := time.NewTimer(waitForLeaderTimeout)
   147  	defer timer.Stop()
   148  
   149  	ticker := time.NewTicker(waitForLeaderTick)
   150  	defer ticker.Stop()
   151  
   152  	for {
   153  		select {
   154  		case <-ctx.Done():
   155  			return leaderCli, errors.Trace(ctx.Err())
   156  		case <-ticker.C:
   157  			leaderCli, err = fc.LeaderClient()
   158  			if err == nil {
   159  				return leaderCli, nil
   160  			}
   161  			if !errors.Is(err, errors.ErrMasterNoLeader) {
   162  				return leaderCli, err
   163  			}
   164  		case <-time.After(waitForLeaderTimeout):
   165  			return leaderCli, errors.ErrMasterNoLeader.GenWithStackByArgs()
   166  		}
   167  	}
   168  }
   169  
   170  // CheckAvailable is a gRPC middleware that checks whether a method is ready to serve.
   171  func CheckAvailable(fc FeatureChecker) grpc.UnaryServerInterceptor {
   172  	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, _ error) {
   173  		method := extractMethod(info.FullMethod)
   174  		if !fc.Available(method) {
   175  			// Return gRPC error to avoid depending on NormalizeError middleware.
   176  			return nil, ToGRPCError(errors.ErrMasterNotReady.GenWithStackByArgs())
   177  		}
   178  		return handler(ctx, req)
   179  	}
   180  }
   181  
   182  // NormalizeError is a gRPC middleware that normalizes the error.
   183  func NormalizeError() grpc.UnaryServerInterceptor {
   184  	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, _ error) {
   185  		method := extractMethod(info.FullMethod)
   186  		resp, err := handler(ctx, req)
   187  		if err != nil {
   188  			errOut := ToGRPCError(err)
   189  			s, _ := status.FromError(errOut)
   190  			logger := log.L().With(zap.String("method", method), zap.Error(err), zap.Any("request", req))
   191  			switch s.Code() {
   192  			case codes.Unknown:
   193  				logger.Warn("request handled with an unknown error")
   194  			case codes.Internal:
   195  				logger.Warn("request handled with an internal error")
   196  			default:
   197  				logger.Debug("request handled with an error")
   198  			}
   199  			return nil, errOut
   200  		}
   201  
   202  		log.Debug("request handled successfully", zap.String("method", method), zap.Any("request", req), zap.Any("response", resp))
   203  		return resp, nil
   204  	}
   205  }
   206  
   207  // Logger is a gRPC middleware that logs the request and response.
   208  // allowList is a list of methods that will be logged. limiter is used to limit the log rate.
   209  func Logger(allowList []string, limiter *rate.Limiter) grpc.UnaryServerInterceptor {
   210  	allow := func(method string) bool {
   211  		for _, m := range allowList {
   212  			if m == method {
   213  				return true
   214  			}
   215  		}
   216  		return limiter.Allow()
   217  	}
   218  	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, _ error) {
   219  		method := extractMethod(info.FullMethod)
   220  		if allow(method) {
   221  			log.Info("", zap.Any("request", req), zap.String("method", method))
   222  		}
   223  		return handler(ctx, req)
   224  	}
   225  }
   226  
   227  // extract method name from full method name. fullMethod is the full RPC method string, i.e., /package.service/method.
   228  func extractMethod(fullMethod string) string {
   229  	return fullMethod[strings.LastIndexByte(fullMethod, '/')+1:]
   230  }