github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/retryutils/jitter_test.go (about)

     1  /*
     2  Copyright 2021-2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package retryutils
    18  
    19  import (
    20  	"fmt"
    21  	"runtime"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/gravitational/trace"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  func TestNewJitterBadParameter(t *testing.T) {
    30  	t.Parallel()
    31  
    32  	for _, tc := range []struct {
    33  		n         time.Duration
    34  		assertErr require.ErrorAssertionFunc
    35  	}{
    36  		{
    37  			n: -1,
    38  			assertErr: func(t require.TestingT, err error, i ...interface{}) {
    39  				require.True(t, trace.IsBadParameter(err), err)
    40  			},
    41  		},
    42  		{
    43  			n: 0,
    44  			assertErr: func(t require.TestingT, err error, i ...interface{}) {
    45  				require.True(t, trace.IsBadParameter(err), err)
    46  			},
    47  		},
    48  		{
    49  			n:         1,
    50  			assertErr: require.NoError,
    51  		},
    52  		{
    53  			n:         7,
    54  			assertErr: require.NoError,
    55  		},
    56  	} {
    57  		t.Run(fmt.Sprintf("n=%v", tc.n), func(t *testing.T) {
    58  			_, err := newJitter(tc.n, nil)
    59  			tc.assertErr(t, err)
    60  			_, err = newShardedJitter(tc.n, nil)
    61  			tc.assertErr(t, err)
    62  		})
    63  	}
    64  }
    65  
    66  func TestNewJitter(t *testing.T) {
    67  	t.Parallel()
    68  
    69  	baseDuration := time.Second
    70  	mockInt63nFloor := mockInt63n(func(n int64) int64 { return 0 })
    71  	mockInt63nCeiling := mockInt63n(func(n int64) int64 { return n - 1 })
    72  
    73  	for _, tc := range []struct {
    74  		desc          string
    75  		n             time.Duration
    76  		expectFloor   time.Duration
    77  		expectCeiling time.Duration
    78  	}{
    79  		{
    80  			desc:          "FullJitter",
    81  			n:             1,
    82  			expectFloor:   0,
    83  			expectCeiling: baseDuration - 1,
    84  		},
    85  		{
    86  			desc:          "HalfJitter",
    87  			n:             2,
    88  			expectFloor:   baseDuration / 2,
    89  			expectCeiling: baseDuration - 1,
    90  		},
    91  		{
    92  			desc:          "SeventhJitter",
    93  			n:             7,
    94  			expectFloor:   baseDuration * 6 / 7,
    95  			expectCeiling: baseDuration - 1,
    96  		},
    97  	} {
    98  		tc := tc
    99  		t.Run(tc.desc, func(t *testing.T) {
   100  			t.Parallel()
   101  
   102  			testFloorJitter, err := newJitter(tc.n, mockInt63nFloor)
   103  			require.NoError(t, err)
   104  			require.Equal(t, tc.expectFloor, testFloorJitter(baseDuration))
   105  
   106  			testFloorJitter, err = newShardedJitter(tc.n, func() rng { return mockInt63nFloor })
   107  			require.NoError(t, err)
   108  			require.Equal(t, tc.expectFloor, testFloorJitter(baseDuration))
   109  
   110  			testCeilingJitter, err := newJitter(tc.n, mockInt63nCeiling)
   111  			require.NoError(t, err)
   112  			require.Equal(t, tc.expectCeiling, testCeilingJitter(baseDuration))
   113  
   114  			testCeilingJitter, err = newShardedJitter(tc.n, func() rng { return mockInt63nCeiling })
   115  			require.NoError(t, err)
   116  			require.Equal(t, tc.expectCeiling, testCeilingJitter(baseDuration))
   117  		})
   118  	}
   119  }
   120  
   121  type mockInt63n func(n int64) int64
   122  
   123  func (m mockInt63n) Int63n(n int64) int64 {
   124  	return m(n)
   125  }
   126  
   127  // BenchmarkJitter is an attempt to check the effect of concurrency on the performance
   128  // of a global jitter instance. I'm a bit skeptical of how "true to life" this benchmark
   129  // really is, but the results would seem to indicate that >100k concurrent jitters would
   130  // still all complete in <1s, which is very good for our purposes.
   131  func BenchmarkSingleJitter(b *testing.B) {
   132  	benchmarkSharedJitter(b, NewHalfJitter())
   133  }
   134  
   135  func BenchmarkShardedJitter(b *testing.B) {
   136  	benchmarkSharedJitter(b, NewShardedHalfJitter())
   137  }
   138  
   139  func benchmarkSharedJitter(b *testing.B, jitter Jitter) {
   140  	benchmarkJitter(b, func() Jitter { return jitter })
   141  }
   142  
   143  func benchmarkJitter(b *testing.B, mkjitter func() Jitter) {
   144  	procs := runtime.GOMAXPROCS(0)
   145  	for n := procs; n < 200_000; n = n * 2 {
   146  		b.Run(fmt.Sprintf("n%d", n), func(b *testing.B) {
   147  			b.SetParallelism(n / procs)
   148  			b.RunParallel(func(pb *testing.PB) {
   149  				jitter := mkjitter()
   150  				for pb.Next() {
   151  					jitter(time.Hour)
   152  				}
   153  			})
   154  		})
   155  	}
   156  }