github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/retryutils/jitter_test.go (about) 1 /* 2 Copyright 2021-2022 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package retryutils 18 19 import ( 20 "fmt" 21 "runtime" 22 "testing" 23 "time" 24 25 "github.com/gravitational/trace" 26 "github.com/stretchr/testify/require" 27 ) 28 29 func TestNewJitterBadParameter(t *testing.T) { 30 t.Parallel() 31 32 for _, tc := range []struct { 33 n time.Duration 34 assertErr require.ErrorAssertionFunc 35 }{ 36 { 37 n: -1, 38 assertErr: func(t require.TestingT, err error, i ...interface{}) { 39 require.True(t, trace.IsBadParameter(err), err) 40 }, 41 }, 42 { 43 n: 0, 44 assertErr: func(t require.TestingT, err error, i ...interface{}) { 45 require.True(t, trace.IsBadParameter(err), err) 46 }, 47 }, 48 { 49 n: 1, 50 assertErr: require.NoError, 51 }, 52 { 53 n: 7, 54 assertErr: require.NoError, 55 }, 56 } { 57 t.Run(fmt.Sprintf("n=%v", tc.n), func(t *testing.T) { 58 _, err := newJitter(tc.n, nil) 59 tc.assertErr(t, err) 60 _, err = newShardedJitter(tc.n, nil) 61 tc.assertErr(t, err) 62 }) 63 } 64 } 65 66 func TestNewJitter(t *testing.T) { 67 t.Parallel() 68 69 baseDuration := time.Second 70 mockInt63nFloor := mockInt63n(func(n int64) int64 { return 0 }) 71 mockInt63nCeiling := mockInt63n(func(n int64) int64 { return n - 1 }) 72 73 for _, tc := range []struct { 74 desc string 75 n time.Duration 76 expectFloor time.Duration 77 expectCeiling time.Duration 78 }{ 79 { 80 desc: "FullJitter", 81 n: 1, 82 expectFloor: 0, 83 expectCeiling: baseDuration - 1, 84 }, 85 { 86 desc: "HalfJitter", 87 n: 2, 88 expectFloor: baseDuration / 2, 89 expectCeiling: baseDuration - 1, 90 }, 91 { 92 desc: "SeventhJitter", 93 n: 7, 94 expectFloor: baseDuration * 6 / 7, 95 expectCeiling: baseDuration - 1, 96 }, 97 } { 98 tc := tc 99 t.Run(tc.desc, func(t *testing.T) { 100 t.Parallel() 101 102 testFloorJitter, err := newJitter(tc.n, mockInt63nFloor) 103 require.NoError(t, err) 104 require.Equal(t, tc.expectFloor, testFloorJitter(baseDuration)) 105 106 testFloorJitter, err = newShardedJitter(tc.n, func() rng { return mockInt63nFloor }) 107 require.NoError(t, err) 108 require.Equal(t, tc.expectFloor, testFloorJitter(baseDuration)) 109 110 testCeilingJitter, err := newJitter(tc.n, mockInt63nCeiling) 111 require.NoError(t, err) 112 require.Equal(t, tc.expectCeiling, testCeilingJitter(baseDuration)) 113 114 testCeilingJitter, err = newShardedJitter(tc.n, func() rng { return mockInt63nCeiling }) 115 require.NoError(t, err) 116 require.Equal(t, tc.expectCeiling, testCeilingJitter(baseDuration)) 117 }) 118 } 119 } 120 121 type mockInt63n func(n int64) int64 122 123 func (m mockInt63n) Int63n(n int64) int64 { 124 return m(n) 125 } 126 127 // BenchmarkJitter is an attempt to check the effect of concurrency on the performance 128 // of a global jitter instance. I'm a bit skeptical of how "true to life" this benchmark 129 // really is, but the results would seem to indicate that >100k concurrent jitters would 130 // still all complete in <1s, which is very good for our purposes. 131 func BenchmarkSingleJitter(b *testing.B) { 132 benchmarkSharedJitter(b, NewHalfJitter()) 133 } 134 135 func BenchmarkShardedJitter(b *testing.B) { 136 benchmarkSharedJitter(b, NewShardedHalfJitter()) 137 } 138 139 func benchmarkSharedJitter(b *testing.B, jitter Jitter) { 140 benchmarkJitter(b, func() Jitter { return jitter }) 141 } 142 143 func benchmarkJitter(b *testing.B, mkjitter func() Jitter) { 144 procs := runtime.GOMAXPROCS(0) 145 for n := procs; n < 200_000; n = n * 2 { 146 b.Run(fmt.Sprintf("n%d", n), func(b *testing.B) { 147 b.SetParallelism(n / procs) 148 b.RunParallel(func(pb *testing.PB) { 149 jitter := mkjitter() 150 for pb.Next() { 151 jitter(time.Hour) 152 } 153 }) 154 }) 155 } 156 }