trpc.group/trpc-go/trpc-go@v1.0.3/internal/ring/ring_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package ring_test
    15  
    16  import (
    17  	"math"
    18  	"runtime"
    19  	"sync"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"trpc.group/trpc-go/trpc-go/internal/ring"
    24  )
    25  
    26  var (
    27  	defaultMsg = []byte("Hello World!!")
    28  	defaultCap = uint32(3)
    29  )
    30  
    31  func TestNew(t *testing.T) {
    32  	r := ring.New[[]byte](0)
    33  	assert.Equal(t, uint32(1), r.Cap())
    34  
    35  	r = ring.New[[]byte](1)
    36  	assert.Equal(t, uint32(1), r.Cap())
    37  
    38  	r = ring.New[[]byte](math.MaxUint32)
    39  	assert.Equal(t, uint32(1), r.Cap())
    40  
    41  	r = ring.New[[]byte](2)
    42  	assert.Equal(t, uint32(1), r.Cap())
    43  
    44  	r = ring.New[[]byte](3)
    45  	assert.Equal(t, uint32(3), r.Cap())
    46  	assert.Equal(t, "Ring: Cap=3, Head=0, Tail=0, Size=0\n", r.String())
    47  
    48  	r = ring.New[[]byte](4)
    49  	assert.Equal(t, uint32(3), r.Cap())
    50  
    51  	r = ring.New[[]byte](7)
    52  	assert.Equal(t, uint32(7), r.Cap())
    53  }
    54  
    55  func TestPutGet(t *testing.T) {
    56  	r := ring.New[[]byte](defaultCap)
    57  	assert.NotNil(t, r)
    58  	assert.Equal(t, defaultCap, r.Cap())
    59  
    60  	// normal Put.
    61  	err := r.Put(defaultMsg)
    62  	assert.Nil(t, err)
    63  
    64  	// normal Get.
    65  	val, left := r.Get()
    66  	assert.NotEmpty(t, val)
    67  	assert.Equal(t, defaultMsg, val)
    68  	assert.Equal(t, uint32(0), left)
    69  
    70  	// Get an empty ring.
    71  	assert.Equal(t, true, r.IsEmpty())
    72  	val, left = r.Get()
    73  	assert.Empty(t, val)
    74  	assert.Equal(t, uint32(0), left)
    75  
    76  	// Check a full queue.
    77  	for i := uint32(0); i < r.Cap(); i++ {
    78  		err := r.Put(defaultMsg)
    79  		assert.Nil(t, err)
    80  	}
    81  	assert.Equal(t, true, r.IsFull())
    82  	assert.Equal(t, r.Cap(), r.Size())
    83  
    84  	// insert into a full queue.
    85  	err = r.Put(defaultMsg)
    86  	assert.Equal(t, ring.ErrQueueFull, err)
    87  }
    88  
    89  func TestPutGetOrder(t *testing.T) {
    90  	r := ring.New[uint32](defaultCap)
    91  	for i := uint32(0); i < r.Cap(); i++ {
    92  		err := r.Put(i)
    93  		assert.Nil(t, err)
    94  	}
    95  
    96  	for i := uint32(0); i < r.Cap(); i++ {
    97  		val, _ := r.Get()
    98  		assert.Equal(t, i, val)
    99  	}
   100  }
   101  
   102  func TestGetsWithFull(t *testing.T) {
   103  	r := ring.New[[]byte](defaultCap)
   104  	assert.NotNil(t, r)
   105  
   106  	for i := uint32(0); i < r.Cap(); i++ {
   107  		err := r.Put(defaultMsg)
   108  		assert.Nil(t, err)
   109  	}
   110  	values := make([][]byte, 0, r.Cap())
   111  	count, left := r.Gets(&values)
   112  	assert.Equal(t, r.Cap(), count)
   113  	assert.Equal(t, uint32(0), left)
   114  	assert.Equal(t, uint32(len(values)), defaultCap)
   115  
   116  	for _, x := range values {
   117  		assert.Equal(t, x, defaultMsg)
   118  	}
   119  	// the Get queue is empty, execute Gets.
   120  	assert.Equal(t, true, r.IsEmpty())
   121  	count, _ = r.Gets(&values)
   122  	assert.Equal(t, uint32(0), count)
   123  }
   124  
   125  func TestGetsWithAskedSize(t *testing.T) {
   126  	r := ring.New[[]byte](defaultCap)
   127  	assert.NotNil(t, r)
   128  
   129  	for i := uint32(0); i < r.Cap(); i++ {
   130  		err := r.Put(defaultMsg)
   131  		assert.Nil(t, err)
   132  	}
   133  	values := make([][]byte, 0, r.Cap()-1)
   134  	count, left := r.Gets(&values)
   135  	assert.Equal(t, r.Cap()-1, count)
   136  	assert.Equal(t, uint32(1), left)
   137  }
   138  
   139  func TestConcurrentGetPut(t *testing.T) {
   140  	r := ring.New[[]byte](1024)
   141  	cpus := runtime.NumCPU()
   142  
   143  	// starts send goroutine, every goroutine sends N packages.
   144  	wg := &sync.WaitGroup{}
   145  	for i := 0; i < cpus; i++ {
   146  		wg.Add(1)
   147  		go startPutMsgs(r, wg, 10000)
   148  	}
   149  	// starts receive goroutine, every goroutine receives N packages.
   150  	for i := 0; i < cpus; i++ {
   151  		wg.Add(1)
   152  		go startGetMsgs(r, wg, 10000)
   153  	}
   154  	wg.Wait()
   155  	assert.Equal(t, true, r.IsEmpty())
   156  }
   157  
   158  func BenchmarkTestChannel(b *testing.B) {
   159  	ch := make(chan interface{}, 1024)
   160  	cpus := runtime.NumCPU()
   161  	wg := &sync.WaitGroup{}
   162  	b.SetBytes(1)
   163  	b.ReportAllocs()
   164  	b.ResetTimer()
   165  	for i := 0; i < cpus; i++ {
   166  		wg.Add(1)
   167  		go func() {
   168  			for i := 0; i < b.N; i++ {
   169  				ch <- defaultMsg
   170  			}
   171  			wg.Done()
   172  		}()
   173  	}
   174  	for i := 0; i < cpus; i++ {
   175  		wg.Add(1)
   176  		go func() {
   177  			for i := 0; i < b.N; i++ {
   178  				<-ch
   179  			}
   180  			wg.Done()
   181  		}()
   182  	}
   183  	wg.Wait()
   184  }
   185  
   186  func BenchmarkTestRingBuffer(b *testing.B) {
   187  	r := ring.New[[]byte](1024)
   188  	cpus := runtime.NumCPU()
   189  
   190  	wg := &sync.WaitGroup{}
   191  	b.SetBytes(1)
   192  	b.ReportAllocs()
   193  	b.ResetTimer()
   194  	for i := 0; i < cpus; i++ {
   195  		wg.Add(1)
   196  		go startGetMsgs(r, wg, b.N)
   197  	}
   198  	for i := 0; i < cpus; i++ {
   199  		wg.Add(1)
   200  		go startPutMsgs(r, wg, b.N)
   201  	}
   202  	wg.Wait()
   203  }
   204  
   205  func startPutMsgs(r *ring.Ring[[]byte], wg *sync.WaitGroup, num int) {
   206  	for {
   207  		if num <= 0 {
   208  			break
   209  		}
   210  		err := r.Put(defaultMsg)
   211  		if err == nil {
   212  			num = num - 1
   213  		}
   214  	}
   215  	wg.Done()
   216  }
   217  
   218  func startGetMsgs(r *ring.Ring[[]byte], wg *sync.WaitGroup, num int) {
   219  	for {
   220  		if num <= 0 {
   221  			break
   222  		}
   223  		val, _ := r.Get()
   224  		if val != nil {
   225  			num = num - 1
   226  		}
   227  	}
   228  	wg.Done()
   229  }