github.com/XiaoMi/Gaea@v1.2.5/util/time_wheel_test.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"strconv"
    19  	"sync/atomic"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/stretchr/testify/assert"
    24  )
    25  
    26  type A struct {
    27  	a            int
    28  	b            string
    29  	isCallbacked int32
    30  }
    31  
    32  func (a *A) callback() {
    33  	atomic.StoreInt32(&a.isCallbacked, 1)
    34  }
    35  
    36  func (a *A) getCallbackValue() int32 {
    37  	return atomic.LoadInt32(&a.isCallbacked)
    38  }
    39  
    40  func newTimeWheel() *TimeWheel {
    41  	tw, err := NewTimeWheel(time.Second, 3600)
    42  	if err != nil {
    43  		panic(err)
    44  	}
    45  	tw.Start()
    46  	return tw
    47  }
    48  
    49  func TestNewTimeWheel(t *testing.T) {
    50  	tests := []struct {
    51  		name      string
    52  		tick      time.Duration
    53  		bucketNum int
    54  		hasErr    bool
    55  	}{
    56  		{tick: time.Second, bucketNum: 0, hasErr: true},
    57  		{tick: time.Millisecond, bucketNum: 1, hasErr: true},
    58  		{tick: time.Second, bucketNum: 1, hasErr: false},
    59  	}
    60  	for _, test := range tests {
    61  		t.Run(test.name, func(t *testing.T) {
    62  			_, err := NewTimeWheel(test.tick, test.bucketNum)
    63  			assert.Equal(t, test.hasErr, err != nil)
    64  		})
    65  	}
    66  }
    67  
    68  func TestAdd(t *testing.T) {
    69  	tw := newTimeWheel()
    70  	a := &A{}
    71  	err := tw.Add(time.Second*1, "test", a.callback)
    72  	assert.NoError(t, err)
    73  
    74  	time.Sleep(time.Millisecond * 500)
    75  	assert.Equal(t, int32(0), a.getCallbackValue())
    76  	time.Sleep(time.Second * 2)
    77  	assert.Equal(t, int32(1), a.getCallbackValue())
    78  	tw.Stop()
    79  }
    80  
    81  func TestAddMultipleTimes(t *testing.T) {
    82  	a := &A{}
    83  	tw := newTimeWheel()
    84  	for i := 0; i < 4; i++ {
    85  		err := tw.Add(time.Second, "test", a.callback)
    86  		assert.NoError(t, err)
    87  		time.Sleep(time.Millisecond * 500)
    88  		t.Logf("current: %d", i)
    89  		assert.Equal(t, int32(0), a.getCallbackValue())
    90  	}
    91  
    92  	time.Sleep(time.Second * 2)
    93  	assert.Equal(t, int32(1), a.getCallbackValue())
    94  	tw.Stop()
    95  }
    96  
    97  func TestRemove(t *testing.T) {
    98  	a := &A{a: 10, b: "test"}
    99  	tw := newTimeWheel()
   100  	err := tw.Add(time.Second*1, a, a.callback)
   101  	assert.NoError(t, err)
   102  
   103  	time.Sleep(time.Millisecond * 500)
   104  	assert.Equal(t, int32(0), a.getCallbackValue())
   105  	err = tw.Remove(a)
   106  	assert.NoError(t, err)
   107  	time.Sleep(time.Second * 2)
   108  	assert.Equal(t, int32(0), a.getCallbackValue())
   109  	tw.Stop()
   110  }
   111  
   112  func BenchmarkAdd(b *testing.B) {
   113  	a := &A{}
   114  	tw := newTimeWheel()
   115  	for i := 0; i < b.N; i++ {
   116  		key := "test" + strconv.Itoa(i)
   117  		err := tw.Add(time.Second, key, a.callback)
   118  		if err != nil {
   119  			b.Fatalf("benchmark Add failed, %v", err)
   120  		}
   121  	}
   122  }