github.com/cloudwego/kitex@v0.9.0/pkg/retry/retryer_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package retry
    18  
    19  import (
    20  	"context"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/cloudwego/kitex/internal/test"
    27  	"github.com/cloudwego/kitex/pkg/discovery"
    28  	"github.com/cloudwego/kitex/pkg/kerrors"
    29  	"github.com/cloudwego/kitex/pkg/remote"
    30  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    31  	"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
    32  	"github.com/cloudwego/kitex/pkg/stats"
    33  )
    34  
    35  var (
    36  	remoteTagKey   = "k"
    37  	remoteTagValue = "v"
    38  	remoteTags     = map[string]string{remoteTagKey: remoteTagValue}
    39  )
    40  
    41  // test new retry container
    42  func TestNewRetryContainer(t *testing.T) {
    43  	rc := NewRetryContainerWithCB(nil, nil)
    44  	rc.NotifyPolicyChange(method, Policy{
    45  		Enable:        true,
    46  		BackupPolicy:  NewBackupPolicy(10),
    47  		FailurePolicy: NewFailurePolicy(),
    48  	})
    49  	r := rc.getRetryer(genRPCInfo())
    50  	_, ok := r.(*failureRetryer)
    51  	test.Assert(t, ok)
    52  
    53  	rc.NotifyPolicyChange(method, Policy{
    54  		Enable:        false,
    55  		BackupPolicy:  NewBackupPolicy(10),
    56  		FailurePolicy: NewFailurePolicy(),
    57  	})
    58  	_, ok = r.(*failureRetryer)
    59  	test.Assert(t, ok)
    60  	_, allow := r.AllowRetry(context.Background())
    61  	test.Assert(t, !allow)
    62  
    63  	rc.NotifyPolicyChange(method, Policy{
    64  		Enable:        true,
    65  		BackupPolicy:  NewBackupPolicy(10),
    66  		FailurePolicy: NewFailurePolicy(),
    67  	})
    68  	r = rc.getRetryer(genRPCInfo())
    69  	_, ok = r.(*failureRetryer)
    70  	test.Assert(t, ok)
    71  	_, allow = r.AllowRetry(context.Background())
    72  	test.Assert(t, allow)
    73  
    74  	rc.NotifyPolicyChange(method, Policy{
    75  		Enable:       true,
    76  		Type:         1,
    77  		BackupPolicy: NewBackupPolicy(20),
    78  	})
    79  	r = rc.getRetryer(genRPCInfo())
    80  	_, ok = r.(*backupRetryer)
    81  	test.Assert(t, ok)
    82  	_, allow = r.AllowRetry(context.Background())
    83  	test.Assert(t, allow)
    84  
    85  	rc.NotifyPolicyChange(method, Policy{
    86  		Enable:       false,
    87  		Type:         1,
    88  		BackupPolicy: NewBackupPolicy(20),
    89  	})
    90  	r = rc.getRetryer(genRPCInfo())
    91  	_, ok = r.(*backupRetryer)
    92  	test.Assert(t, ok)
    93  	_, allow = r.AllowRetry(context.Background())
    94  	test.Assert(t, !allow)
    95  
    96  	// backupPolicy is nil
    97  	rc = NewRetryContainer()
    98  	rc.NotifyPolicyChange(method, Policy{
    99  		Enable: true,
   100  		Type:   1,
   101  	})
   102  	msg := "new retryer[test-Backup] failed, err=newBackupRetryer failed, err=BackupPolicy is nil or retry type not match, cannot do update in backupRetryer, at "
   103  	test.Assert(t, rc.msg[:len(msg)] == msg)
   104  
   105  	// backupPolicy config invalid
   106  	rc.NotifyPolicyChange(method, Policy{
   107  		Enable: true,
   108  		Type:   1,
   109  		BackupPolicy: &BackupPolicy{
   110  			RetryDelayMS: 0,
   111  		},
   112  	})
   113  	msg = "new retryer[test-Backup] failed, err=newBackupRetryer failed, err=invalid backup request delay duration or retryTimes, at "
   114  	test.Assert(t, rc.msg[:len(msg)] == msg)
   115  
   116  	// backupPolicy cBPolicy config invalid
   117  	rc.NotifyPolicyChange(method, Policy{
   118  		Enable: true,
   119  		Type:   1,
   120  		BackupPolicy: &BackupPolicy{
   121  			RetryDelayMS: 100,
   122  			StopPolicy: StopPolicy{
   123  				CBPolicy: CBPolicy{
   124  					ErrorRate: 0.4,
   125  				},
   126  			},
   127  		},
   128  	})
   129  	msg = "new retryer[test-Backup] at "
   130  	test.Assert(t, rc.msg[:len(msg)] == msg)
   131  
   132  	// failurePolicy config invalid
   133  	rc = NewRetryContainer()
   134  	rc.NotifyPolicyChange(method, Policy{
   135  		Enable: true,
   136  		Type:   0,
   137  		FailurePolicy: &FailurePolicy{
   138  			StopPolicy: StopPolicy{
   139  				MaxRetryTimes: 6,
   140  			},
   141  		},
   142  	})
   143  	msg = "new retryer[test-Failure] failed, err=newfailureRetryer failed, err=invalid failure MaxRetryTimes[6], at "
   144  	test.Assert(t, rc.msg[:len(msg)] == msg)
   145  
   146  	// failurePolicy cBPolicy config invalid
   147  	rc = NewRetryContainer()
   148  	rc.NotifyPolicyChange(method, Policy{
   149  		Enable: true,
   150  		Type:   0,
   151  		FailurePolicy: &FailurePolicy{
   152  			StopPolicy: StopPolicy{
   153  				MaxRetryTimes: 5,
   154  				CBPolicy: CBPolicy{
   155  					ErrorRate: 0.4,
   156  				},
   157  			},
   158  		},
   159  	})
   160  	msg = "new retryer[test-Failure] at "
   161  	test.Assert(t, rc.msg[:len(msg)] == msg)
   162  
   163  	// failurePolicy backOffPolicy fixedBackOffType cfg is nil
   164  	rc = NewRetryContainerWithCB(nil, nil)
   165  	rc.NotifyPolicyChange(method, Policy{
   166  		Enable: true,
   167  		Type:   0,
   168  		FailurePolicy: &FailurePolicy{
   169  			StopPolicy: StopPolicy{
   170  				MaxRetryTimes: 5,
   171  			},
   172  			BackOffPolicy: &BackOffPolicy{
   173  				BackOffType: FixedBackOffType,
   174  			},
   175  		},
   176  	})
   177  	msg = "new retryer[test-Failure] at "
   178  	test.Assert(t, rc.msg[:len(msg)] == msg)
   179  
   180  	// failurePolicy backOffPolicy fixedBackOffType cfg invalid
   181  	rc = NewRetryContainerWithCB(nil, nil)
   182  	rc.NotifyPolicyChange(method, Policy{
   183  		Enable: true,
   184  		Type:   0,
   185  		FailurePolicy: &FailurePolicy{
   186  			StopPolicy: StopPolicy{
   187  				MaxRetryTimes: 5,
   188  			},
   189  			BackOffPolicy: &BackOffPolicy{
   190  				BackOffType: FixedBackOffType,
   191  				CfgItems: map[BackOffCfgKey]float64{
   192  					FixMSBackOffCfgKey: 0,
   193  				},
   194  			},
   195  		},
   196  	})
   197  	msg = "new retryer[test-Failure] at "
   198  	test.Assert(t, rc.msg[:len(msg)] == msg)
   199  
   200  	// failurePolicy backOffPolicy randomBackOffType cfg is nil
   201  	rc = NewRetryContainerWithCB(nil, nil)
   202  	rc.NotifyPolicyChange(method, Policy{
   203  		Enable: true,
   204  		Type:   0,
   205  		FailurePolicy: &FailurePolicy{
   206  			StopPolicy: StopPolicy{
   207  				MaxRetryTimes: 5,
   208  			},
   209  			BackOffPolicy: &BackOffPolicy{
   210  				BackOffType: RandomBackOffType,
   211  			},
   212  		},
   213  	})
   214  	msg = "new retryer[test-Failure] at "
   215  	test.Assert(t, rc.msg[:len(msg)] == msg)
   216  
   217  	// failurePolicy backOffPolicy randomBackOffType cfg invalid
   218  	rc = NewRetryContainerWithCB(nil, nil)
   219  	rc.NotifyPolicyChange(method, Policy{
   220  		Enable: true,
   221  		Type:   0,
   222  		FailurePolicy: &FailurePolicy{
   223  			StopPolicy: StopPolicy{
   224  				MaxRetryTimes: 5,
   225  			},
   226  			BackOffPolicy: &BackOffPolicy{
   227  				BackOffType: RandomBackOffType,
   228  				CfgItems: map[BackOffCfgKey]float64{
   229  					MinMSBackOffCfgKey: 20,
   230  					MaxMSBackOffCfgKey: 10,
   231  				},
   232  			},
   233  		},
   234  	})
   235  	msg = "new retryer[test-Failure] at "
   236  	test.Assert(t, rc.msg[:len(msg)] == msg)
   237  
   238  	// failurePolicy backOffPolicy randomBackOffType normal
   239  	rc = NewRetryContainerWithCB(nil, nil)
   240  	rc.NotifyPolicyChange(method, Policy{
   241  		Enable: true,
   242  		Type:   0,
   243  		FailurePolicy: &FailurePolicy{
   244  			StopPolicy: StopPolicy{
   245  				MaxRetryTimes: 5,
   246  			},
   247  			BackOffPolicy: &BackOffPolicy{
   248  				BackOffType: RandomBackOffType,
   249  				CfgItems: map[BackOffCfgKey]float64{
   250  					MinMSBackOffCfgKey: 10,
   251  					MaxMSBackOffCfgKey: 20,
   252  				},
   253  			},
   254  		},
   255  	})
   256  	msg = "new retryer[test-Failure] at "
   257  	test.Assert(t, rc.msg[:len(msg)] == msg)
   258  
   259  	// test init invalid case
   260  	err := rc.Init(nil, nil)
   261  	test.Assert(t, err == nil, err)
   262  	err = rc.Init(map[string]Policy{Wildcard: {
   263  		Enable: true,
   264  		Type:   1,
   265  		BackupPolicy: &BackupPolicy{
   266  			RetryDelayMS: 0,
   267  		},
   268  	}}, nil)
   269  	test.Assert(t, err != nil, err)
   270  
   271  	rc.DeletePolicy(method)
   272  	r = rc.getRetryer(genRPCInfo())
   273  	test.Assert(t, r == nil)
   274  }
   275  
   276  // test container dump
   277  func TestContainer_Dump(t *testing.T) {
   278  	// test backupPolicy dump
   279  	rc := NewRetryContainerWithCB(nil, nil)
   280  	methodPolicies := map[string]Policy{
   281  		method: {
   282  			Enable:       true,
   283  			Type:         1,
   284  			BackupPolicy: NewBackupPolicy(20),
   285  		},
   286  	}
   287  	rc.InitWithPolicies(methodPolicies)
   288  	err := rc.Init(map[string]Policy{Wildcard: {
   289  		Enable:       true,
   290  		Type:         1,
   291  		BackupPolicy: NewBackupPolicy(20),
   292  	}}, nil)
   293  	test.Assert(t, err == nil, err)
   294  	rcDump, ok := rc.Dump().(map[string]interface{})
   295  	test.Assert(t, ok)
   296  	hasCodeCfg, err := jsoni.MarshalToString(rcDump["has_code_cfg"])
   297  	test.Assert(t, err == nil, err)
   298  	test.Assert(t, hasCodeCfg == "true", hasCodeCfg)
   299  	testStr, err := jsoni.MarshalToString(rcDump["test"])
   300  	msg := `{"backupRequest":{"retry_delay_ms":20,"stop_policy":{"max_retry_times":1,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"retry_same_node":false},"enable":true}`
   301  	test.Assert(t, err == nil, err)
   302  	test.Assert(t, testStr == msg)
   303  
   304  	// test failurePolicy dump
   305  	rc = NewRetryContainerWithCB(nil, nil)
   306  	methodPolicies = map[string]Policy{
   307  		method: {
   308  			Enable:        true,
   309  			Type:          FailureType,
   310  			FailurePolicy: NewFailurePolicy(),
   311  		},
   312  	}
   313  	rc.InitWithPolicies(methodPolicies)
   314  	err = rc.Init(map[string]Policy{Wildcard: {
   315  		Enable:        true,
   316  		Type:          FailureType,
   317  		FailurePolicy: NewFailurePolicy(),
   318  	}}, nil)
   319  	test.Assert(t, err == nil, err)
   320  	rcDump, ok = rc.Dump().(map[string]interface{})
   321  	test.Assert(t, ok)
   322  	hasCodeCfg, err = jsoni.MarshalToString(rcDump["has_code_cfg"])
   323  	test.Assert(t, err == nil, err)
   324  	test.Assert(t, hasCodeCfg == "true")
   325  	testStr, err = jsoni.MarshalToString(rcDump["test"])
   326  	msg = `{"enable":true,"failure_retry":{"stop_policy":{"max_retry_times":2,"max_duration_ms":0,"disable_chain_stop":false,"ddl_stop":false,"cb_policy":{"error_rate":0.1}},"backoff_policy":{"backoff_type":"none"},"retry_same_node":false,"extra":""},"specified_result_retry":{"error_retry":false,"resp_retry":false}}`
   327  	test.Assert(t, err == nil, err)
   328  	test.Assert(t, testStr == msg, testStr)
   329  }
   330  
   331  // test FailurePolicy call
   332  func TestFailurePolicyCall(t *testing.T) {
   333  	// call while rpc timeout
   334  	ctx := context.Background()
   335  	rc := NewRetryContainer()
   336  	failurePolicy := NewFailurePolicy()
   337  	failurePolicy.BackOffPolicy.BackOffType = FixedBackOffType
   338  	failurePolicy.BackOffPolicy.CfgItems = map[BackOffCfgKey]float64{
   339  		FixMSBackOffCfgKey: 100.0,
   340  	}
   341  	failurePolicy.StopPolicy.MaxDurationMS = 100
   342  	err := rc.Init(map[string]Policy{Wildcard: {
   343  		Enable:        true,
   344  		Type:          0,
   345  		FailurePolicy: failurePolicy,
   346  	}}, nil)
   347  	test.Assert(t, err == nil, err)
   348  	ri := genRPCInfo()
   349  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   350  	_, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   351  		return ri, nil, kerrors.ErrRPCTimeout
   352  	}, ri, nil)
   353  	test.Assert(t, err != nil, err)
   354  	test.Assert(t, !ok)
   355  
   356  	// call normal
   357  	failurePolicy.StopPolicy.MaxDurationMS = 0
   358  	err = rc.Init(map[string]Policy{Wildcard: {
   359  		Enable:        true,
   360  		Type:          0,
   361  		FailurePolicy: failurePolicy,
   362  	}}, nil)
   363  	test.Assert(t, err == nil, err)
   364  	_, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   365  		return ri, nil, nil
   366  	}, ri, nil)
   367  	test.Assert(t, err == nil, err)
   368  	test.Assert(t, ok)
   369  }
   370  
   371  // test retry with one time policy
   372  func TestRetryWithOneTimePolicy(t *testing.T) {
   373  	// call while rpc timeout and exceed MaxDurationMS cause BackOffPolicy is wait fix 100ms, it is invalid config
   374  	failurePolicy := NewFailurePolicy()
   375  	failurePolicy.BackOffPolicy.BackOffType = FixedBackOffType
   376  	failurePolicy.BackOffPolicy.CfgItems = map[BackOffCfgKey]float64{
   377  		FixMSBackOffCfgKey: 100.0,
   378  	}
   379  	failurePolicy.StopPolicy.MaxDurationMS = 100
   380  	p := Policy{
   381  		Enable:        true,
   382  		Type:          0,
   383  		FailurePolicy: failurePolicy,
   384  	}
   385  	ri := genRPCInfo()
   386  	ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   387  	_, ok, err := NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   388  		return ri, nil, kerrors.ErrRPCTimeout
   389  	}, ri, nil)
   390  	test.Assert(t, err != nil, err)
   391  	test.Assert(t, !ok)
   392  
   393  	// call no MaxDurationMS limit, the retry will success
   394  	failurePolicy.StopPolicy.MaxDurationMS = 0
   395  	var callTimes int32
   396  	ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo())
   397  	_, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   398  		if atomic.LoadInt32(&callTimes) == 0 {
   399  			atomic.AddInt32(&callTimes, 1)
   400  			return ri, nil, kerrors.ErrRPCTimeout
   401  		}
   402  		return ri, nil, nil
   403  	}, ri, nil)
   404  	test.Assert(t, err == nil, err)
   405  	test.Assert(t, !ok)
   406  
   407  	// call backup request
   408  	p = BuildBackupRequest(NewBackupPolicy(10))
   409  	ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), genRPCInfo())
   410  	callTimes = 0
   411  	_, ok, err = NewRetryContainer().WithRetryIfNeeded(ctx, &p, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   412  		if atomic.LoadInt32(&callTimes) == 0 || atomic.LoadInt32(&callTimes) == 1 {
   413  			atomic.AddInt32(&callTimes, 1)
   414  			time.Sleep(time.Millisecond * 100)
   415  		}
   416  		return ri, nil, nil
   417  	}, ri, nil)
   418  	test.Assert(t, err == nil, err)
   419  	test.Assert(t, !ok)
   420  	test.Assert(t, atomic.LoadInt32(&callTimes) == 2)
   421  }
   422  
   423  // test specified error to retry
   424  func TestSpecifiedErrorRetry(t *testing.T) {
   425  	retryWithTransError := func(callTimes int32) RPCCallFunc {
   426  		// fails for the first call if callTimes is initialized to 0
   427  		return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   428  			newVal := atomic.AddInt32(&callTimes, 1)
   429  			if newVal == 1 {
   430  				return genRPCInfo(), nil, remote.NewTransErrorWithMsg(1000, "mock")
   431  			} else {
   432  				return genRPCInfoWithRemoteTag(remoteTags), nil, nil
   433  			}
   434  		}
   435  	}
   436  	ri := genRPCInfo()
   437  	ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   438  
   439  	// case1: specified method retry with error
   440  	t.Run("case1", func(t *testing.T) {
   441  		rc := NewRetryContainer()
   442  		shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool {
   443  			if ri.To().Method() == method {
   444  				if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 {
   445  					return true
   446  				}
   447  			}
   448  			return false
   449  		}}
   450  		err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry)
   451  		test.Assert(t, err == nil, err)
   452  		ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil)
   453  		test.Assert(t, err == nil, err)
   454  		test.Assert(t, !ok)
   455  		v, ok := ri.To().Tag(remoteTagKey)
   456  		test.Assert(t, ok)
   457  		test.Assert(t, v == remoteTagValue)
   458  	})
   459  
   460  	// case2: specified method retry with error, but use backup request config cannot be effective
   461  	t.Run("case2", func(t *testing.T) {
   462  		shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool {
   463  			if ri.To().Method() == method {
   464  				if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 {
   465  					return true
   466  				}
   467  			}
   468  			return false
   469  		}}
   470  		rc := NewRetryContainer()
   471  		err := rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry)
   472  		test.Assert(t, err == nil, err)
   473  		ri = genRPCInfo()
   474  		_, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil)
   475  		test.Assert(t, err != nil, err)
   476  		test.Assert(t, !ok)
   477  	})
   478  
   479  	// case3: specified method retry with error, but method not match
   480  	t.Run("case3", func(t *testing.T) {
   481  		shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool {
   482  			if ri.To().Method() != method {
   483  				if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 {
   484  					return true
   485  				}
   486  			}
   487  			return false
   488  		}}
   489  		rc := NewRetryContainer()
   490  		err := rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry)
   491  		test.Assert(t, err == nil, err)
   492  		ri = genRPCInfo()
   493  		ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithTransError(0), ri, nil)
   494  		test.Assert(t, err != nil)
   495  		test.Assert(t, !ok)
   496  		_, ok = ri.To().Tag(remoteTagKey)
   497  		test.Assert(t, !ok)
   498  	})
   499  
   500  	// case4: all error retry
   501  	t.Run("case4", func(t *testing.T) {
   502  		rc := NewRetryContainer()
   503  		p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry()))
   504  		ri = genRPCInfo()
   505  		ri, ok, err := rc.WithRetryIfNeeded(ctx, &p, retryWithTransError(0), ri, nil)
   506  		test.Assert(t, err == nil, err)
   507  		test.Assert(t, !ok)
   508  		v, ok := ri.To().Tag(remoteTagKey)
   509  		test.Assert(t, ok)
   510  		test.Assert(t, v == remoteTagValue)
   511  	})
   512  }
   513  
   514  // test specified resp to retry
   515  func TestSpecifiedRespRetry(t *testing.T) {
   516  	retryResult := &mockResult{}
   517  	retryResp := mockResp{
   518  		code: 500,
   519  		msg:  "retry",
   520  	}
   521  	noRetryResp := mockResp{
   522  		code: 0,
   523  		msg:  "noretry",
   524  	}
   525  	var callTimes int32
   526  	retryWithResp := func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   527  		newVal := atomic.AddInt32(&callTimes, 1)
   528  		if newVal == 1 {
   529  			retryResult.SetResult(retryResp)
   530  			return genRPCInfo(), retryResult, nil
   531  		} else {
   532  			retryResult.SetResult(noRetryResp)
   533  			return genRPCInfoWithRemoteTag(remoteTags), retryResult, nil
   534  		}
   535  	}
   536  	ctx := context.Background()
   537  	ri := genRPCInfo()
   538  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   539  	rc := NewRetryContainer()
   540  	// case1: specified method retry with resp
   541  	shouldResultRetry := &ShouldResultRetry{RespRetry: func(resp interface{}, ri rpcinfo.RPCInfo) bool {
   542  		if ri.To().Method() == method {
   543  			if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp {
   544  				return true
   545  			}
   546  		}
   547  		return false
   548  	}}
   549  	err := rc.Init(map[string]Policy{Wildcard: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry)
   550  	test.Assert(t, err == nil, err)
   551  	ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil)
   552  	test.Assert(t, err == nil, err)
   553  	test.Assert(t, retryResult.GetResult() == noRetryResp, retryResult)
   554  	test.Assert(t, !ok)
   555  	v, ok := ri.To().Tag(remoteTagKey)
   556  	test.Assert(t, ok)
   557  	test.Assert(t, v == remoteTagValue)
   558  
   559  	// case2 specified method retry with resp, but use backup request config cannot be effective
   560  	atomic.StoreInt32(&callTimes, 0)
   561  	rc = NewRetryContainer()
   562  	err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(100))}, shouldResultRetry)
   563  	test.Assert(t, err == nil, err)
   564  	ri = genRPCInfo()
   565  	ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   566  	_, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil)
   567  	test.Assert(t, err == nil, err)
   568  	test.Assert(t, retryResult.GetResult() == retryResp, retryResp)
   569  	test.Assert(t, !ok)
   570  
   571  	// case3: specified method retry with resp, but method not match
   572  	atomic.StoreInt32(&callTimes, 0)
   573  	shouldResultRetry = &ShouldResultRetry{RespRetry: func(resp interface{}, ri rpcinfo.RPCInfo) bool {
   574  		if ri.To().Method() != method {
   575  			if r, ok := resp.(*mockResult); ok && r.GetResult() == retryResp {
   576  				return true
   577  			}
   578  		}
   579  		return false
   580  	}}
   581  	rc = NewRetryContainer()
   582  	err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry)
   583  	test.Assert(t, err == nil, err)
   584  	ri = genRPCInfo()
   585  	ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   586  	ri, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, retryWithResp, ri, nil)
   587  	test.Assert(t, err == nil, err)
   588  	test.Assert(t, retryResult.GetResult() == retryResp, retryResult)
   589  	test.Assert(t, ok)
   590  	_, ok = ri.To().Tag(remoteTagKey)
   591  	test.Assert(t, !ok)
   592  }
   593  
   594  // test different method use different retry policy
   595  func TestDifferentMethodConfig(t *testing.T) {
   596  	var callTimes int32
   597  	methodRetryer := make(map[string]Retryer)
   598  	var lock sync.Mutex
   599  	rpcCall := func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   600  		atomic.AddInt32(&callTimes, 1)
   601  		ri := rpcinfo.GetRPCInfo(ctx)
   602  		if ri.To().Method() == method {
   603  			lock.Lock()
   604  			methodRetryer[ri.To().Method()] = r
   605  			lock.Unlock()
   606  			// specified method mock return error and do error retry
   607  			if atomic.LoadInt32(&callTimes) == 1 {
   608  				return genRPCInfo(), nil, remote.NewTransErrorWithMsg(1000, "mock")
   609  			}
   610  		} else {
   611  			lock.Lock()
   612  			methodRetryer[ri.To().Method()] = r
   613  			lock.Unlock()
   614  			if atomic.LoadInt32(&callTimes) == 1 {
   615  				// other method do backup request
   616  				time.Sleep(20 * time.Millisecond)
   617  				return genRPCInfo(), nil, kerrors.ErrRPCTimeout
   618  			}
   619  		}
   620  		return genRPCInfo(), nil, nil
   621  	}
   622  	rc := NewRetryContainer()
   623  	shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool {
   624  		if ri.To().Method() == method {
   625  			if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 {
   626  				return true
   627  			}
   628  		}
   629  		return false
   630  	}}
   631  	err := rc.Init(map[string]Policy{
   632  		method:   BuildFailurePolicy(NewFailurePolicy()),
   633  		Wildcard: BuildBackupRequest(NewBackupPolicy(10)),
   634  	}, shouldResultRetry)
   635  	test.Assert(t, err == nil, err)
   636  
   637  	// case1: test method do error retry
   638  	ri := genRPCInfo()
   639  	ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   640  	_, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, rpcCall, ri, nil)
   641  	test.Assert(t, err == nil, err)
   642  	test.Assert(t, !ok)
   643  	lock.Lock()
   644  	_, ok = methodRetryer[method].(*failureRetryer)
   645  	lock.Unlock()
   646  	test.Assert(t, ok)
   647  
   648  	// case2: other method do backup request
   649  	method2 := "method2"
   650  	to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Method: method2}, method2).ImmutableView()
   651  	ri = rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method2), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
   652  	ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   653  	_, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, rpcCall, ri, nil)
   654  	test.Assert(t, err == nil, err)
   655  	test.Assert(t, !ok)
   656  	lock.Lock()
   657  	_, ok = methodRetryer[method2].(*backupRetryer)
   658  	lock.Unlock()
   659  	test.Assert(t, ok)
   660  }
   661  
   662  func TestResultRetryWithPolicyChange(t *testing.T) {
   663  	rc := NewRetryContainer()
   664  	shouldResultRetry := &ShouldResultRetry{ErrorRetry: func(err error, ri rpcinfo.RPCInfo) bool {
   665  		if ri.To().Method() == method {
   666  			if te, ok := err.(*remote.TransError); ok && te.TypeID() == 1000 {
   667  				return true
   668  			}
   669  		}
   670  		return false
   671  	}}
   672  	err := rc.Init(nil, shouldResultRetry)
   673  	test.Assert(t, err == nil, err)
   674  
   675  	// case 1: first time trigger NotifyPolicyChange, the `initRetryer` will be executed, check if the ShouldResultRetry is not nil
   676  	rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy()))
   677  	r := rc.getRetryer(genRPCInfo())
   678  	fr, ok := r.(*failureRetryer)
   679  	test.Assert(t, ok)
   680  	test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry)
   681  
   682  	// case 2: second time trigger NotifyPolicyChange, the `UpdatePolicy` will be executed, check if the ShouldResultRetry is not nil
   683  	rc.NotifyPolicyChange(Wildcard, BuildFailurePolicy(NewFailurePolicy()))
   684  	r = rc.getRetryer(genRPCInfo())
   685  	fr, ok = r.(*failureRetryer)
   686  	test.Assert(t, ok)
   687  	test.Assert(t, fr.policy.ShouldResultRetry == shouldResultRetry)
   688  }
   689  
   690  // test BackupPolicy call while rpcTime > delayTime
   691  func TestBackupPolicyCall(t *testing.T) {
   692  	ctx := context.Background()
   693  	rc := NewRetryContainer()
   694  	err := rc.Init(map[string]Policy{Wildcard: {
   695  		Enable: true,
   696  		Type:   1,
   697  		BackupPolicy: &BackupPolicy{
   698  			RetryDelayMS: 30,
   699  			StopPolicy: StopPolicy{
   700  				MaxRetryTimes:    2,
   701  				DisableChainStop: false,
   702  				CBPolicy: CBPolicy{
   703  					ErrorRate: defaultCBErrRate,
   704  				},
   705  			},
   706  		},
   707  	}}, nil)
   708  	test.Assert(t, err == nil, err)
   709  
   710  	callTimes := int32(0)
   711  	firstRI := genRPCInfo()
   712  	secondRI := genRPCInfoWithRemoteTag(remoteTags)
   713  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, firstRI)
   714  	ri, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   715  		atomic.AddInt32(&callTimes, 1)
   716  		if atomic.LoadInt32(&callTimes) == 1 {
   717  			// mock timeout for the first request and get the response of the backup request.
   718  			time.Sleep(time.Millisecond * 50)
   719  			return firstRI, nil, nil
   720  		}
   721  		return secondRI, nil, nil
   722  	}, firstRI, nil)
   723  	test.Assert(t, err == nil, err)
   724  	test.Assert(t, atomic.LoadInt32(&callTimes) == 2)
   725  	test.Assert(t, !ok)
   726  	v, ok := ri.To().Tag(remoteTagKey)
   727  	test.Assert(t, ok)
   728  	test.Assert(t, v == remoteTagValue)
   729  }
   730  
   731  // test policy noRetry call
   732  func TestPolicyNoRetryCall(t *testing.T) {
   733  	ctx := context.Background()
   734  	rc := NewRetryContainer()
   735  
   736  	// case 1(default): no retry policy
   737  	// no retry policy, call success
   738  	var callTimes int32
   739  	ri := genRPCInfo()
   740  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   741  	_, ok, err := rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   742  		atomic.AddInt32(&callTimes, 1)
   743  		return ri, nil, nil
   744  	}, ri, nil)
   745  	test.Assert(t, err == nil, err)
   746  	test.Assert(t, atomic.LoadInt32(&callTimes) == 1)
   747  	test.Assert(t, ok)
   748  
   749  	// no retry policy, call rpcTimeout
   750  	atomic.StoreInt32(&callTimes, 0)
   751  	ri = genRPCInfo()
   752  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   753  	_, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   754  		atomic.AddInt32(&callTimes, 1)
   755  		if atomic.LoadInt32(&callTimes) == 1 {
   756  			return ri, nil, kerrors.ErrRPCTimeout
   757  		}
   758  		return ri, nil, nil
   759  	}, ri, nil)
   760  	test.Assert(t, kerrors.IsTimeoutError(err), err)
   761  	test.Assert(t, atomic.LoadInt32(&callTimes) == 1)
   762  	test.Assert(t, !ok)
   763  
   764  	// case 2: setup retry policy, but not satisfy retry condition eg: circuit, retry times == 0, chain stop, ddl
   765  	// failurePolicy DDLStop rpcTimeOut
   766  	failurePolicy := NewFailurePolicy()
   767  	failurePolicy.WithDDLStop()
   768  	RegisterDDLStop(func(ctx context.Context, policy StopPolicy) (bool, string) {
   769  		return true, "TestDDLStop"
   770  	})
   771  	err = rc.Init(map[string]Policy{Wildcard: {
   772  		Enable:        true,
   773  		Type:          0,
   774  		FailurePolicy: failurePolicy,
   775  	}}, nil)
   776  	test.Assert(t, err == nil, err)
   777  	atomic.StoreInt32(&callTimes, 0)
   778  	ri = genRPCInfo()
   779  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   780  	_, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   781  		atomic.AddInt32(&callTimes, 1)
   782  		if atomic.LoadInt32(&callTimes) == 1 {
   783  			return ri, nil, kerrors.ErrRPCTimeout
   784  		}
   785  		return ri, nil, nil
   786  	}, ri, nil)
   787  	test.Assert(t, kerrors.IsTimeoutError(err), err)
   788  	test.Assert(t, atomic.LoadInt32(&callTimes) == 1)
   789  	test.Assert(t, !ok)
   790  
   791  	// backupPolicy MaxRetryTimes = 0
   792  	backupPolicy := NewBackupPolicy(20)
   793  	backupPolicy.StopPolicy.MaxRetryTimes = 0
   794  	err = rc.Init(map[string]Policy{Wildcard: {
   795  		Enable:       true,
   796  		Type:         1,
   797  		BackupPolicy: backupPolicy,
   798  	}}, nil)
   799  	test.Assert(t, err == nil, err)
   800  	atomic.StoreInt32(&callTimes, 0)
   801  	ri = genRPCInfo()
   802  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   803  	_, ok, err = rc.WithRetryIfNeeded(ctx, &Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   804  		atomic.AddInt32(&callTimes, 1)
   805  		time.Sleep(time.Millisecond * 100)
   806  		return ri, nil, nil
   807  	}, ri, nil)
   808  	test.Assert(t, err == nil, err)
   809  	test.Assert(t, atomic.LoadInt32(&callTimes) == 1)
   810  	test.Assert(t, ok)
   811  }
   812  
   813  func retryCall(callTimes *int32, firstRI rpcinfo.RPCInfo, backup bool) RPCCallFunc {
   814  	// prevRI represents a value of rpcinfo.RPCInfo type.
   815  	var prevRI atomic.Value
   816  	return func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) {
   817  		currCallTimes := int(atomic.AddInt32(callTimes, 1))
   818  		cRI := firstRI
   819  		if currCallTimes > 1 {
   820  			cRI = genRPCInfoWithFirstStats(firstRI)
   821  			cRI.Stats().Record(ctx, stats.RPCFinish, stats.StatusInfo, "")
   822  			remoteInfo := remoteinfo.AsRemoteInfo(cRI.To())
   823  			remoteInfo.SetInstance(discovery.NewInstance("tcp", "10.20.30.40:8888", 10, nil))
   824  			if prevRI.Load() == nil {
   825  				prevRI.Store(firstRI)
   826  			}
   827  			r.Prepare(ctx, prevRI.Load().(rpcinfo.RPCInfo), cRI)
   828  			prevRI.Store(cRI)
   829  			return cRI, nil, nil
   830  		} else {
   831  			if backup {
   832  				time.Sleep(20 * time.Millisecond)
   833  				return cRI, nil, nil
   834  			} else {
   835  				return cRI, nil, kerrors.ErrRPCTimeout
   836  			}
   837  		}
   838  	}
   839  }
   840  
   841  func TestFailureRetryWithRPCInfo(t *testing.T) {
   842  	// failure retry
   843  	ctx := context.Background()
   844  	rc := NewRetryContainer()
   845  
   846  	ri := genRPCInfo()
   847  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   848  	rpcinfo.Record(ctx, ri, stats.RPCStart, nil)
   849  
   850  	// call with retry policy
   851  	var callTimes int32
   852  	policy := BuildFailurePolicy(NewFailurePolicy())
   853  	ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, false), ri, nil)
   854  	test.Assert(t, err == nil, err)
   855  	test.Assert(t, !ok)
   856  	test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo)
   857  	test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo)
   858  	test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888")
   859  	test.Assert(t, atomic.LoadInt32(&callTimes) == 2)
   860  }
   861  
   862  func TestBackupRetryWithRPCInfo(t *testing.T) {
   863  	// backup retry
   864  	ctx := context.Background()
   865  	rc := NewRetryContainer()
   866  
   867  	ri := genRPCInfo()
   868  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
   869  	rpcinfo.Record(ctx, ri, stats.RPCStart, nil)
   870  
   871  	// call with retry policy
   872  	var callTimes int32
   873  	policy := BuildBackupRequest(NewBackupPolicy(10))
   874  	ri, ok, err := rc.WithRetryIfNeeded(ctx, &policy, retryCall(&callTimes, ri, true), ri, nil)
   875  	test.Assert(t, err == nil, err)
   876  	test.Assert(t, !ok)
   877  	test.Assert(t, ri.Stats().GetEvent(stats.RPCStart).Status() == stats.StatusInfo)
   878  	test.Assert(t, ri.Stats().GetEvent(stats.RPCFinish).Status() == stats.StatusInfo)
   879  	test.Assert(t, ri.To().Address().String() == "10.20.30.40:8888")
   880  	test.Assert(t, atomic.LoadInt32(&callTimes) == 2)
   881  }
   882  
   883  type mockResult struct {
   884  	result mockResp
   885  	sync.RWMutex
   886  }
   887  
   888  type mockResp struct {
   889  	code int
   890  	msg  string
   891  }
   892  
   893  func (r *mockResult) GetResult() interface{} {
   894  	r.RLock()
   895  	defer r.RUnlock()
   896  	return r.result
   897  }
   898  
   899  func (r *mockResult) SetResult(ret mockResp) {
   900  	r.Lock()
   901  	defer r.Unlock()
   902  	r.result = ret
   903  }
   904  
   905  func TestNewRetryContainerWithOptions(t *testing.T) {
   906  	t.Run("no_option", func(t *testing.T) {
   907  		rc := NewRetryContainer()
   908  		test.Assertf(t, rc.cbContainer.cbSuite != nil, "cb_suite nil")
   909  		test.Assertf(t, rc.cbContainer.cbStat == true, "cb_stat not true")
   910  	})
   911  
   912  	t.Run("percentage_limit", func(t *testing.T) {
   913  		rc := NewRetryContainer(WithContainerEnablePercentageLimit())
   914  		test.Assertf(t, rc.cbContainer.enablePercentageLimit == true, "percentage_limit not true")
   915  	})
   916  
   917  	t.Run("percentage_limit&&cbOptions", func(t *testing.T) {
   918  		cbSuite := newCBSuite()
   919  		rc := NewRetryContainer(
   920  			WithContainerEnablePercentageLimit(),
   921  			WithContainerCBSuite(cbSuite),
   922  			WithContainerCBControl(cbSuite.ServiceControl()),
   923  			WithContainerCBPanel(cbSuite.ServicePanel()),
   924  		)
   925  		test.Assertf(t, rc.cbContainer.enablePercentageLimit == true, "percentage_limit not true")
   926  		test.Assertf(t, rc.cbContainer.cbSuite != cbSuite, "cbSuite not ignored")
   927  		test.Assertf(t, rc.cbContainer.cbCtl != cbSuite.ServiceControl(), "cbCtl not ignored")
   928  		test.Assertf(t, rc.cbContainer.cbPanel != cbSuite.ServicePanel(), "cbPanel not ignored")
   929  	})
   930  
   931  	t.Run("cb_stat", func(t *testing.T) {
   932  		rc := NewRetryContainer(WithContainerCBStat())
   933  		test.Assertf(t, rc.cbContainer.cbStat == true, "cb_stat not true")
   934  	})
   935  
   936  	t.Run("cb_suite", func(t *testing.T) {
   937  		cbs := newCBSuite()
   938  		rc := NewRetryContainer(WithContainerCBSuite(cbs))
   939  		test.Assert(t, rc.cbContainer.cbSuite == cbs, "cbSuite expected %p, got %p", cbs, rc.cbContainer.cbSuite)
   940  	})
   941  
   942  	t.Run("cb_control&cb_panel", func(t *testing.T) {
   943  		cbs := newCBSuite()
   944  		rc := NewRetryContainer(
   945  			WithContainerCBControl(cbs.ServiceControl()),
   946  			WithContainerCBPanel(cbs.ServicePanel()))
   947  		test.Assert(t, rc.cbContainer.cbCtl == cbs.ServiceControl(), "cbControl not match")
   948  		test.Assert(t, rc.cbContainer.cbPanel == cbs.ServicePanel(), "cbPanel not match")
   949  	})
   950  }
   951  
   952  func TestNewRetryContainerWithCBStat(t *testing.T) {
   953  	cbs := newCBSuite()
   954  	rc := NewRetryContainerWithCBStat(cbs.ServiceControl(), cbs.ServicePanel())
   955  	test.Assert(t, rc.cbContainer.cbCtl == cbs.ServiceControl(), "cbControl not match")
   956  	test.Assert(t, rc.cbContainer.cbPanel == cbs.ServicePanel(), "cbPanel not match")
   957  	test.Assertf(t, rc.cbContainer.cbStat == true, "cb_stat not true")
   958  	rc.Close()
   959  }
   960  
   961  func TestNewRetryContainerWithCB(t *testing.T) {
   962  	cbs := newCBSuite()
   963  	rc := NewRetryContainerWithCB(cbs.ServiceControl(), cbs.ServicePanel())
   964  	test.Assert(t, rc.cbContainer.cbCtl == cbs.ServiceControl(), "cbControl not match")
   965  	test.Assert(t, rc.cbContainer.cbPanel == cbs.ServicePanel(), "cbPanel not match")
   966  	test.Assertf(t, rc.cbContainer.cbStat == false, "cb_stat not false")
   967  	rc.Close()
   968  }