github.com/slackhq/nebula@v1.9.0/timeout_test.go (about)

     1  package nebula
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/slackhq/nebula/firewall"
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  func TestNewTimerWheel(t *testing.T) {
    12  	// Make sure we get an object we expect
    13  	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
    14  	assert.Equal(t, 12, tw.wheelLen)
    15  	assert.Equal(t, 0, tw.current)
    16  	assert.Nil(t, tw.lastTick)
    17  	assert.Equal(t, time.Second*1, tw.tickDuration)
    18  	assert.Equal(t, time.Second*10, tw.wheelDuration)
    19  	assert.Len(t, tw.wheel, 12)
    20  
    21  	// Assert the math is correct
    22  	tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10)
    23  	assert.Equal(t, 5, tw.wheelLen)
    24  
    25  	tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10)
    26  	assert.Equal(t, 7, tw.wheelLen)
    27  
    28  	// Test empty purge of non nil items
    29  	i, ok := tw.Purge()
    30  	assert.Equal(t, firewall.Packet{}, i)
    31  	assert.False(t, ok)
    32  
    33  	// Test empty purges of nil items
    34  	tw2 := NewTimerWheel[*int](time.Second, time.Second*10)
    35  	i2, ok := tw2.Purge()
    36  	assert.Nil(t, i2)
    37  	assert.False(t, ok)
    38  
    39  }
    40  
    41  func TestTimerWheel_findWheel(t *testing.T) {
    42  	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
    43  	assert.Len(t, tw.wheel, 12)
    44  
    45  	// Current + tick + 1 since we don't know how far into current we are
    46  	assert.Equal(t, 2, tw.findWheel(time.Second*1))
    47  
    48  	// Scale up to min duration
    49  	assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
    50  
    51  	// Make sure we hit that last index
    52  	assert.Equal(t, 11, tw.findWheel(time.Second*10))
    53  
    54  	// Scale down to max duration
    55  	assert.Equal(t, 11, tw.findWheel(time.Second*11))
    56  
    57  	tw.current = 1
    58  	// Make sure we account for the current position properly
    59  	assert.Equal(t, 3, tw.findWheel(time.Second*1))
    60  	assert.Equal(t, 0, tw.findWheel(time.Second*10))
    61  }
    62  
    63  func TestTimerWheel_Add(t *testing.T) {
    64  	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
    65  
    66  	fp1 := firewall.Packet{}
    67  	tw.Add(fp1, time.Second*1)
    68  
    69  	// Make sure we set head and tail properly
    70  	assert.NotNil(t, tw.wheel[2])
    71  	assert.Equal(t, fp1, tw.wheel[2].Head.Item)
    72  	assert.Nil(t, tw.wheel[2].Head.Next)
    73  	assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
    74  	assert.Nil(t, tw.wheel[2].Tail.Next)
    75  
    76  	// Make sure we only modify head
    77  	fp2 := firewall.Packet{}
    78  	tw.Add(fp2, time.Second*1)
    79  	assert.Equal(t, fp2, tw.wheel[2].Head.Item)
    80  	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
    81  	assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
    82  	assert.Nil(t, tw.wheel[2].Tail.Next)
    83  
    84  	// Make sure we use free'd items first
    85  	tw.itemCache = &TimeoutItem[firewall.Packet]{}
    86  	tw.itemsCached = 1
    87  	tw.Add(fp2, time.Second*1)
    88  	assert.Nil(t, tw.itemCache)
    89  	assert.Equal(t, 0, tw.itemsCached)
    90  
    91  	// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
    92  	for min := time.Duration(1); min < 100; min++ {
    93  		for max := min; max < 100; max++ {
    94  			tw = NewTimerWheel[firewall.Packet](min, max)
    95  
    96  			for current := 0; current < tw.wheelLen; current++ {
    97  				tw.current = current
    98  				for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
    99  					tick := tw.findWheel(timeout)
   100  					if tick >= tw.wheelLen {
   101  						t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
   102  					}
   103  				}
   104  			}
   105  		}
   106  	}
   107  }
   108  
   109  func TestTimerWheel_Purge(t *testing.T) {
   110  	// First advance should set the lastTick and do nothing else
   111  	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
   112  	assert.Nil(t, tw.lastTick)
   113  	tw.Advance(time.Now())
   114  	assert.NotNil(t, tw.lastTick)
   115  	assert.Equal(t, 0, tw.current)
   116  
   117  	fps := []firewall.Packet{
   118  		{LocalIP: 1},
   119  		{LocalIP: 2},
   120  		{LocalIP: 3},
   121  		{LocalIP: 4},
   122  	}
   123  
   124  	tw.Add(fps[0], time.Second*1)
   125  	tw.Add(fps[1], time.Second*1)
   126  	tw.Add(fps[2], time.Second*2)
   127  	tw.Add(fps[3], time.Second*2)
   128  
   129  	ta := time.Now().Add(time.Second * 3)
   130  	lastTick := *tw.lastTick
   131  	tw.Advance(ta)
   132  	assert.Equal(t, 3, tw.current)
   133  	assert.True(t, tw.lastTick.After(lastTick))
   134  
   135  	// Make sure we get all 4 packets back
   136  	for i := 0; i < 4; i++ {
   137  		p, has := tw.Purge()
   138  		assert.True(t, has)
   139  		assert.Equal(t, fps[i], p)
   140  	}
   141  
   142  	// Make sure there aren't any leftover
   143  	_, ok := tw.Purge()
   144  	assert.False(t, ok)
   145  	assert.Nil(t, tw.expired.Head)
   146  	assert.Nil(t, tw.expired.Tail)
   147  
   148  	// Make sure we cached the free'd items
   149  	assert.Equal(t, 4, tw.itemsCached)
   150  	ci := tw.itemCache
   151  	for i := 0; i < 4; i++ {
   152  		assert.NotNil(t, ci)
   153  		ci = ci.Next
   154  	}
   155  	assert.Nil(t, ci)
   156  
   157  	// Let's make sure we roll over properly
   158  	ta = ta.Add(time.Second * 5)
   159  	tw.Advance(ta)
   160  	assert.Equal(t, 8, tw.current)
   161  
   162  	ta = ta.Add(time.Second * 2)
   163  	tw.Advance(ta)
   164  	assert.Equal(t, 10, tw.current)
   165  
   166  	ta = ta.Add(time.Second * 1)
   167  	tw.Advance(ta)
   168  	assert.Equal(t, 11, tw.current)
   169  
   170  	ta = ta.Add(time.Second * 1)
   171  	tw.Advance(ta)
   172  	assert.Equal(t, 0, tw.current)
   173  }