github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/engine/pkg/rpcutil/middleware_test.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package rpcutil_test
    15  
    16  import (
    17  	"context"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/golang/mock/gomock"
    23  	perrors "github.com/pingcap/errors"
    24  	"github.com/pingcap/tiflow/engine/pkg/rpcutil"
    25  	"github.com/pingcap/tiflow/engine/pkg/rpcutil/mock"
    26  	"github.com/pingcap/tiflow/pkg/errors"
    27  	"github.com/stretchr/testify/require"
    28  	"google.golang.org/genproto/googleapis/rpc/errdetails"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/codes"
    31  	"google.golang.org/grpc/status"
    32  )
    33  
    34  func TestToGRPCError(t *testing.T) {
    35  	t.Parallel()
    36  
    37  	// nil
    38  	require.NoError(t, rpcutil.ToGRPCError(nil))
    39  
    40  	// already a gRPC error
    41  	err := status.New(codes.NotFound, "not found").Err()
    42  	require.Equal(t, err, rpcutil.ToGRPCError(err))
    43  
    44  	// unknown error
    45  	err = errors.New("unknown error")
    46  	gerr := rpcutil.ToGRPCError(err)
    47  	require.Equal(t, codes.Unknown, status.Code(gerr))
    48  	st, ok := status.FromError(gerr)
    49  	require.True(t, ok)
    50  	require.Equal(t, err.Error(), st.Message())
    51  	require.Len(t, st.Details(), 1)
    52  	errInfo := st.Details()[0].(*errdetails.ErrorInfo)
    53  	require.Equal(t, errors.ErrUnknown.RFCCode(), perrors.RFCErrorCode(errInfo.Reason))
    54  
    55  	// job not found
    56  	err = errors.ErrJobNotFound.GenWithStackByArgs("job-1")
    57  	gerr = rpcutil.ToGRPCError(err)
    58  	require.Equal(t, codes.NotFound, status.Code(gerr))
    59  	st, ok = status.FromError(gerr)
    60  	require.True(t, ok)
    61  	require.Equal(t, "job job-1 is not found", st.Message())
    62  	require.Len(t, st.Details(), 1)
    63  	errInfo = st.Details()[0].(*errdetails.ErrorInfo)
    64  	require.Equal(t, errors.ErrJobNotFound.RFCCode(), perrors.RFCErrorCode(errInfo.Reason))
    65  
    66  	// create worker terminated
    67  	err = errors.ErrCreateWorkerTerminate.Wrap(perrors.New("invalid config")).GenWithStackByArgs()
    68  	gerr = rpcutil.ToGRPCError(err)
    69  	st, ok = status.FromError(gerr)
    70  	require.True(t, ok)
    71  	require.Equal(t, "create worker is terminated", st.Message())
    72  	require.Len(t, st.Details(), 1)
    73  	errInfo = st.Details()[0].(*errdetails.ErrorInfo)
    74  	require.Equal(t, errors.ErrCreateWorkerTerminate.RFCCode(), perrors.RFCErrorCode(errInfo.Reason))
    75  	require.Equal(t, "invalid config", errInfo.Metadata["cause"])
    76  }
    77  
    78  func TestFromGRPCError(t *testing.T) {
    79  	t.Parallel()
    80  
    81  	// nil
    82  	require.NoError(t, rpcutil.FromGRPCError(nil))
    83  
    84  	// not a gRPC error
    85  	err := errors.New("unknown error")
    86  	require.Equal(t, err, rpcutil.FromGRPCError(err))
    87  
    88  	// gRPC error
    89  	srvErr := errors.ErrJobNotFound.GenWithStackByArgs("job-1")
    90  	clientErr := rpcutil.FromGRPCError(rpcutil.ToGRPCError(srvErr))
    91  	require.True(t, errors.Is(clientErr, errors.ErrJobNotFound))
    92  	require.Equal(t, srvErr.Error(), clientErr.Error())
    93  
    94  	// create worker terminated
    95  	srvErr = errors.ErrCreateWorkerTerminate.Wrap(perrors.New("invalid config")).GenWithStackByArgs()
    96  	clientErr = rpcutil.FromGRPCError(rpcutil.ToGRPCError(srvErr))
    97  	require.True(t, errors.Is(clientErr, errors.ErrCreateWorkerTerminate))
    98  	require.Equal(t, srvErr.Error(), clientErr.Error())
    99  	cause := errors.Cause(clientErr)
   100  	require.Error(t, cause)
   101  	require.Equal(t, "invalid config", cause.Error())
   102  }
   103  
   104  type leaderClient struct {
   105  	heartbeat func(ctx context.Context, req any) (any, error)
   106  	status    func(ctx context.Context, req any) (any, error)
   107  }
   108  
   109  func (lc *leaderClient) Heartbeat(ctx context.Context, req any) (any, error) {
   110  	return lc.heartbeat(ctx, req)
   111  }
   112  
   113  func (lc *leaderClient) Status(ctx context.Context, req any) (any, error) {
   114  	return lc.status(ctx, req)
   115  }
   116  
   117  func TestForwardToLeader(t *testing.T) {
   118  	t.Parallel()
   119  
   120  	fc := mock.NewMockForwardChecker[*leaderClient](gomock.NewController(t))
   121  	mw := rpcutil.ForwardToLeader[*leaderClient](fc)
   122  
   123  	var (
   124  		local   atomic.Bool
   125  		forward atomic.Bool
   126  	)
   127  	fc.EXPECT().LeaderOnly("Heartbeat").AnyTimes().Return(true)
   128  	fc.EXPECT().LeaderOnly("Status").AnyTimes().Return(false)
   129  
   130  	handler := func(ctx context.Context, req any) (any, error) {
   131  		local.Store(true)
   132  		return nil, nil
   133  	}
   134  
   135  	// Current node is leader.
   136  	fc.EXPECT().IsLeader().Times(1).Return(true)
   137  	_, err := mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "Heartbeat"}, handler)
   138  	require.NoError(t, err)
   139  	require.True(t, local.Load())
   140  	require.False(t, forward.Load())
   141  
   142  	// Method is not leader only.
   143  	fc.EXPECT().IsLeader().Times(1).Return(false)
   144  	local.Store(false)
   145  	forward.Store(false)
   146  	_, err = mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "Status"}, handler)
   147  	require.NoError(t, err)
   148  	require.True(t, local.Load())
   149  	require.False(t, forward.Load())
   150  
   151  	// Forward to leader.
   152  	lc := &leaderClient{
   153  		heartbeat: func(ctx context.Context, req any) (any, error) {
   154  			forward.Store(true)
   155  			return nil, nil
   156  		},
   157  		status: func(ctx context.Context, req any) (any, error) {
   158  			forward.Store(true)
   159  			return nil, nil
   160  		},
   161  	}
   162  	fc.EXPECT().IsLeader().Times(1).Return(false)
   163  	fc.EXPECT().LeaderClient().Times(1).Return(lc, nil)
   164  	local.Store(false)
   165  	forward.Store(false)
   166  	_, err = mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "Heartbeat"}, handler)
   167  	require.NoError(t, err)
   168  	require.False(t, local.Load())
   169  	require.True(t, forward.Load())
   170  
   171  	// Wait for leader.
   172  	const leaderDelay = time.Millisecond * 500
   173  	start := time.Now()
   174  	fc.EXPECT().IsLeader().Times(1).Return(false)
   175  	fc.EXPECT().LeaderClient().AnyTimes().DoAndReturn(func() (*leaderClient, error) {
   176  		if time.Since(start) < leaderDelay {
   177  			return nil, errors.ErrMasterNoLeader.GenWithStackByArgs()
   178  		}
   179  		return lc, nil
   180  	})
   181  	local.Store(false)
   182  	forward.Store(false)
   183  	_, err = mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "Heartbeat"}, handler)
   184  	require.NoError(t, err)
   185  	require.False(t, local.Load())
   186  	require.True(t, forward.Load())
   187  }
   188  
   189  func TestCheckAvailable(t *testing.T) {
   190  	t.Parallel()
   191  
   192  	fc := mock.NewMockFeatureChecker(gomock.NewController(t))
   193  	mw := rpcutil.CheckAvailable(fc)
   194  
   195  	var handled atomic.Bool
   196  	handler := func(ctx context.Context, req any) (any, error) {
   197  		handled.Store(true)
   198  		return nil, nil
   199  	}
   200  
   201  	fc.EXPECT().Available("Heartbeat").Times(1).Return(false)
   202  	_, err := mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "Heartbeat"}, handler)
   203  	err = rpcutil.FromGRPCError(err)
   204  	require.True(t, errors.Is(err, errors.ErrMasterNotReady))
   205  	require.False(t, handled.Load())
   206  
   207  	fc.EXPECT().Available("Heartbeat").Times(1).Return(true)
   208  	_, err = mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "Heartbeat"}, handler)
   209  	require.NoError(t, err)
   210  	require.True(t, handled.Load())
   211  }
   212  
   213  func TestNormalizeError(t *testing.T) {
   214  	t.Parallel()
   215  
   216  	mw := rpcutil.NormalizeError()
   217  	handler := func(ctx context.Context, req any) (any, error) {
   218  		return nil, errors.ErrJobNotFound.GenWithStackByArgs("job-1")
   219  	}
   220  	_, err := mw(context.Background(), "req", &grpc.UnaryServerInfo{FullMethod: "GetJob"}, handler)
   221  	gerr, ok := status.FromError(err)
   222  	require.True(t, ok)
   223  	require.Equal(t, codes.NotFound, gerr.Code())
   224  	clientErr := rpcutil.FromGRPCError(err)
   225  	require.True(t, errors.Is(clientErr, errors.ErrJobNotFound))
   226  }