github.com/slackhq/nebula@v1.9.0/timeout_test.go (about) 1 package nebula 2 3 import ( 4 "testing" 5 "time" 6 7 "github.com/slackhq/nebula/firewall" 8 "github.com/stretchr/testify/assert" 9 ) 10 11 func TestNewTimerWheel(t *testing.T) { 12 // Make sure we get an object we expect 13 tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) 14 assert.Equal(t, 12, tw.wheelLen) 15 assert.Equal(t, 0, tw.current) 16 assert.Nil(t, tw.lastTick) 17 assert.Equal(t, time.Second*1, tw.tickDuration) 18 assert.Equal(t, time.Second*10, tw.wheelDuration) 19 assert.Len(t, tw.wheel, 12) 20 21 // Assert the math is correct 22 tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10) 23 assert.Equal(t, 5, tw.wheelLen) 24 25 tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10) 26 assert.Equal(t, 7, tw.wheelLen) 27 28 // Test empty purge of non nil items 29 i, ok := tw.Purge() 30 assert.Equal(t, firewall.Packet{}, i) 31 assert.False(t, ok) 32 33 // Test empty purges of nil items 34 tw2 := NewTimerWheel[*int](time.Second, time.Second*10) 35 i2, ok := tw2.Purge() 36 assert.Nil(t, i2) 37 assert.False(t, ok) 38 39 } 40 41 func TestTimerWheel_findWheel(t *testing.T) { 42 tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) 43 assert.Len(t, tw.wheel, 12) 44 45 // Current + tick + 1 since we don't know how far into current we are 46 assert.Equal(t, 2, tw.findWheel(time.Second*1)) 47 48 // Scale up to min duration 49 assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) 50 51 // Make sure we hit that last index 52 assert.Equal(t, 11, tw.findWheel(time.Second*10)) 53 54 // Scale down to max duration 55 assert.Equal(t, 11, tw.findWheel(time.Second*11)) 56 57 tw.current = 1 58 // Make sure we account for the current position properly 59 assert.Equal(t, 3, tw.findWheel(time.Second*1)) 60 assert.Equal(t, 0, tw.findWheel(time.Second*10)) 61 } 62 63 func TestTimerWheel_Add(t *testing.T) { 64 tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) 65 66 fp1 := firewall.Packet{} 67 tw.Add(fp1, time.Second*1) 68 69 // Make sure we set head and tail properly 70 assert.NotNil(t, tw.wheel[2]) 71 assert.Equal(t, fp1, tw.wheel[2].Head.Item) 72 assert.Nil(t, tw.wheel[2].Head.Next) 73 assert.Equal(t, fp1, tw.wheel[2].Tail.Item) 74 assert.Nil(t, tw.wheel[2].Tail.Next) 75 76 // Make sure we only modify head 77 fp2 := firewall.Packet{} 78 tw.Add(fp2, time.Second*1) 79 assert.Equal(t, fp2, tw.wheel[2].Head.Item) 80 assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) 81 assert.Equal(t, fp1, tw.wheel[2].Tail.Item) 82 assert.Nil(t, tw.wheel[2].Tail.Next) 83 84 // Make sure we use free'd items first 85 tw.itemCache = &TimeoutItem[firewall.Packet]{} 86 tw.itemsCached = 1 87 tw.Add(fp2, time.Second*1) 88 assert.Nil(t, tw.itemCache) 89 assert.Equal(t, 0, tw.itemsCached) 90 91 // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel 92 for min := time.Duration(1); min < 100; min++ { 93 for max := min; max < 100; max++ { 94 tw = NewTimerWheel[firewall.Packet](min, max) 95 96 for current := 0; current < tw.wheelLen; current++ { 97 tw.current = current 98 for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ { 99 tick := tw.findWheel(timeout) 100 if tick >= tw.wheelLen { 101 t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick) 102 } 103 } 104 } 105 } 106 } 107 } 108 109 func TestTimerWheel_Purge(t *testing.T) { 110 // First advance should set the lastTick and do nothing else 111 tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) 112 assert.Nil(t, tw.lastTick) 113 tw.Advance(time.Now()) 114 assert.NotNil(t, tw.lastTick) 115 assert.Equal(t, 0, tw.current) 116 117 fps := []firewall.Packet{ 118 {LocalIP: 1}, 119 {LocalIP: 2}, 120 {LocalIP: 3}, 121 {LocalIP: 4}, 122 } 123 124 tw.Add(fps[0], time.Second*1) 125 tw.Add(fps[1], time.Second*1) 126 tw.Add(fps[2], time.Second*2) 127 tw.Add(fps[3], time.Second*2) 128 129 ta := time.Now().Add(time.Second * 3) 130 lastTick := *tw.lastTick 131 tw.Advance(ta) 132 assert.Equal(t, 3, tw.current) 133 assert.True(t, tw.lastTick.After(lastTick)) 134 135 // Make sure we get all 4 packets back 136 for i := 0; i < 4; i++ { 137 p, has := tw.Purge() 138 assert.True(t, has) 139 assert.Equal(t, fps[i], p) 140 } 141 142 // Make sure there aren't any leftover 143 _, ok := tw.Purge() 144 assert.False(t, ok) 145 assert.Nil(t, tw.expired.Head) 146 assert.Nil(t, tw.expired.Tail) 147 148 // Make sure we cached the free'd items 149 assert.Equal(t, 4, tw.itemsCached) 150 ci := tw.itemCache 151 for i := 0; i < 4; i++ { 152 assert.NotNil(t, ci) 153 ci = ci.Next 154 } 155 assert.Nil(t, ci) 156 157 // Let's make sure we roll over properly 158 ta = ta.Add(time.Second * 5) 159 tw.Advance(ta) 160 assert.Equal(t, 8, tw.current) 161 162 ta = ta.Add(time.Second * 2) 163 tw.Advance(ta) 164 assert.Equal(t, 10, tw.current) 165 166 ta = ta.Add(time.Second * 1) 167 tw.Advance(ta) 168 assert.Equal(t, 11, tw.current) 169 170 ta = ta.Add(time.Second * 1) 171 tw.Advance(ta) 172 assert.Equal(t, 0, tw.current) 173 }