github.com/xmidt-org/webpa-common@v1.11.9/semaphore/semaphore_test.go (about)

     1  package semaphore
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func ExampleMutex() {
    16  	const routineCount = 5
    17  
    18  	var (
    19  		s     = Mutex()
    20  		wg    = new(sync.WaitGroup)
    21  		value int
    22  	)
    23  
    24  	wg.Add(routineCount)
    25  	for i := 0; i < routineCount; i++ {
    26  		go func() {
    27  			defer wg.Done()
    28  			defer s.Release()
    29  			s.Acquire()
    30  			value++
    31  			fmt.Println(value)
    32  		}()
    33  	}
    34  
    35  	wg.Wait()
    36  
    37  	// Unordered output:
    38  	// 1
    39  	// 2
    40  	// 3
    41  	// 4
    42  	// 5
    43  }
    44  
    45  func ExampleInterface_AcquireWait() {
    46  	var (
    47  		s     = Mutex()
    48  		timer = time.NewTimer(100 * time.Millisecond)
    49  	)
    50  
    51  	defer timer.Stop()
    52  	s.Acquire() // force AcquireWait to block
    53  	err := s.AcquireWait(timer.C)
    54  	fmt.Println(err != nil)
    55  
    56  	// Output:
    57  	// true
    58  }
    59  
    60  func testNewInvalidCount(t *testing.T) {
    61  	for _, c := range []int{0, -1} {
    62  		t.Run(strconv.Itoa(c), func(t *testing.T) {
    63  			assert.Panics(t, func() {
    64  				New(c)
    65  			})
    66  		})
    67  	}
    68  }
    69  
    70  func testNewValidCount(t *testing.T) {
    71  	for _, c := range []int{1, 2, 5} {
    72  		t.Run(strconv.Itoa(c), func(t *testing.T) {
    73  			s := New(c)
    74  			assert.NotNil(t, s)
    75  		})
    76  	}
    77  }
    78  
    79  func TestNew(t *testing.T) {
    80  	t.Run("InvalidCount", testNewInvalidCount)
    81  	t.Run("ValidCount", testNewValidCount)
    82  }
    83  
    84  func testSemaphoreTryAcquire(t *testing.T, s Interface, totalCount int) {
    85  	assert := assert.New(t)
    86  	for i := 0; i < totalCount; i++ {
    87  		assert.True(s.TryAcquire())
    88  	}
    89  
    90  	assert.False(s.TryAcquire())
    91  	s.Release()
    92  	assert.True(s.TryAcquire())
    93  	assert.False(s.TryAcquire())
    94  }
    95  
    96  func testSemaphoreAcquire(t *testing.T, s Interface, totalCount int) {
    97  	var (
    98  		assert  = assert.New(t)
    99  		require = require.New(t)
   100  	)
   101  
   102  	// acquire all the things!
   103  	for i := 0; i < totalCount; i++ {
   104  		done := make(chan struct{})
   105  		go func() {
   106  			defer close(done)
   107  			s.Acquire()
   108  		}()
   109  
   110  		select {
   111  		case <-done:
   112  			// passing
   113  		case <-time.After(time.Second):
   114  			assert.FailNow("Acquire blocked unexpectedly")
   115  		}
   116  	}
   117  
   118  	// post condition: no point continuing if this fails
   119  	require.False(s.TryAcquire())
   120  
   121  	var (
   122  		ready    = make(chan struct{})
   123  		acquired = make(chan struct{})
   124  	)
   125  
   126  	go func() {
   127  		defer close(acquired)
   128  		close(ready)
   129  		s.Acquire() // this should now block
   130  	}()
   131  
   132  	select {
   133  	case <-ready:
   134  		// passing
   135  		require.False(s.TryAcquire())
   136  		s.Release()
   137  	case <-time.After(time.Second):
   138  		require.FailNow("Unable to spawn acquire goroutine")
   139  	}
   140  
   141  	select {
   142  	case <-acquired:
   143  		require.False(s.TryAcquire())
   144  	case <-time.After(time.Second):
   145  		require.FailNow("Acquire blocked unexpectedly")
   146  	}
   147  
   148  	s.Release()
   149  	assert.True(s.TryAcquire())
   150  }
   151  
   152  func testSemaphoreAcquireWait(t *testing.T, s Interface, totalCount int) {
   153  	var (
   154  		assert  = assert.New(t)
   155  		require = require.New(t)
   156  		timer   = make(chan time.Time)
   157  	)
   158  
   159  	// acquire all the things!
   160  	for i := 0; i < totalCount; i++ {
   161  		result := make(chan error)
   162  		go func() {
   163  			result <- s.AcquireWait(timer)
   164  		}()
   165  
   166  		select {
   167  		case err := <-result:
   168  			assert.NoError(err)
   169  		case <-time.After(time.Second):
   170  			assert.FailNow("Acquire blocked unexpectedly")
   171  		}
   172  	}
   173  
   174  	// post condition: no point continuing if this fails
   175  	require.False(s.TryAcquire())
   176  
   177  	var (
   178  		ready  = make(chan struct{})
   179  		result = make(chan error)
   180  	)
   181  
   182  	go func() {
   183  		close(ready)
   184  		result <- s.AcquireWait(timer)
   185  	}()
   186  
   187  	select {
   188  	case <-ready:
   189  		timer <- time.Time{}
   190  	case <-time.After(time.Second):
   191  		require.FailNow("Unable to spawn acquire goroutine")
   192  	}
   193  
   194  	select {
   195  	case err := <-result:
   196  		assert.Equal(ErrTimeout, err)
   197  	case <-time.After(time.Second):
   198  		require.FailNow("AcquireWait blocked unexpectedly")
   199  	}
   200  }
   201  
   202  func testSemaphoreAcquireCtx(t *testing.T, s Interface, totalCount int) {
   203  	var (
   204  		assert      = assert.New(t)
   205  		require     = require.New(t)
   206  		ctx, cancel = context.WithCancel(context.Background())
   207  	)
   208  
   209  	defer cancel()
   210  
   211  	// acquire all the things!
   212  	for i := 0; i < totalCount; i++ {
   213  		result := make(chan error)
   214  		go func() {
   215  			result <- s.AcquireCtx(ctx)
   216  		}()
   217  
   218  		select {
   219  		case err := <-result:
   220  			assert.NoError(err)
   221  		case <-time.After(time.Second):
   222  			assert.FailNow("Acquire blocked unexpectedly")
   223  		}
   224  	}
   225  
   226  	// post condition: no point continuing if this fails
   227  	require.False(s.TryAcquire())
   228  
   229  	var (
   230  		ready  = make(chan struct{})
   231  		result = make(chan error)
   232  	)
   233  
   234  	go func() {
   235  		close(ready)
   236  		result <- s.AcquireCtx(ctx)
   237  	}()
   238  
   239  	select {
   240  	case <-ready:
   241  		cancel()
   242  	case <-time.After(time.Second):
   243  		require.FailNow("Unable to spawn acquire goroutine")
   244  	}
   245  
   246  	select {
   247  	case err := <-result:
   248  		assert.Equal(ctx.Err(), err)
   249  	case <-time.After(time.Second):
   250  		require.FailNow("AcquireWait blocked unexpectedly")
   251  	}
   252  }
   253  
   254  func TestSemaphore(t *testing.T) {
   255  	for _, c := range []int{1, 2, 5} {
   256  		t.Run(fmt.Sprintf("count=%d", c), func(t *testing.T) {
   257  			t.Run("TryAcquire", func(t *testing.T) {
   258  				testSemaphoreTryAcquire(t, New(c), c)
   259  			})
   260  
   261  			t.Run("Acquire", func(t *testing.T) {
   262  				testSemaphoreAcquire(t, New(c), c)
   263  			})
   264  
   265  			t.Run("AcquireWait", func(t *testing.T) {
   266  				testSemaphoreAcquireWait(t, New(c), c)
   267  			})
   268  
   269  			t.Run("AcquireCtx", func(t *testing.T) {
   270  				testSemaphoreAcquireCtx(t, New(c), c)
   271  			})
   272  		})
   273  	}
   274  }
   275  
   276  func TestMutex(t *testing.T) {
   277  	t.Run("TryAcquire", func(t *testing.T) {
   278  		testSemaphoreTryAcquire(t, Mutex(), 1)
   279  	})
   280  
   281  	t.Run("Acquire", func(t *testing.T) {
   282  		testSemaphoreAcquire(t, Mutex(), 1)
   283  	})
   284  
   285  	t.Run("AcquireWait", func(t *testing.T) {
   286  		testSemaphoreAcquireWait(t, Mutex(), 1)
   287  	})
   288  
   289  	t.Run("AcquireCtx", func(t *testing.T) {
   290  		testSemaphoreAcquireCtx(t, Mutex(), 1)
   291  	})
   292  }