github.com/xmidt-org/webpa-common@v1.11.9/semaphore/closeable_test.go (about)

     1  package semaphore
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func testNewCloseableInvalidCount(t *testing.T) {
    15  	for _, c := range []int{0, -1} {
    16  		t.Run(strconv.Itoa(c), func(t *testing.T) {
    17  			assert.Panics(t, func() {
    18  				NewCloseable(c)
    19  			})
    20  		})
    21  	}
    22  }
    23  
    24  func testNewCloseableValidCount(t *testing.T) {
    25  	for _, c := range []int{1, 2, 5} {
    26  		t.Run(strconv.Itoa(c), func(t *testing.T) {
    27  			s := NewCloseable(c)
    28  			assert.NotNil(t, s)
    29  		})
    30  	}
    31  }
    32  
    33  func TestNewCloseable(t *testing.T) {
    34  	t.Run("InvalidCount", testNewCloseableInvalidCount)
    35  	t.Run("ValidCount", testNewCloseableValidCount)
    36  }
    37  
    38  func testCloseableTryAcquire(t *testing.T, cs Closeable, totalCount int) {
    39  	assert := assert.New(t)
    40  	for i := 0; i < totalCount; i++ {
    41  		assert.True(cs.TryAcquire())
    42  	}
    43  
    44  	assert.False(cs.TryAcquire())
    45  	assert.NoError(cs.Release())
    46  	assert.True(cs.TryAcquire())
    47  	assert.False(cs.TryAcquire())
    48  
    49  	assert.NoError(cs.Release())
    50  	assert.NoError(cs.Close())
    51  	assert.False(cs.TryAcquire())
    52  	assert.Equal(ErrClosed, cs.Close())
    53  	assert.Equal(ErrClosed, cs.Release())
    54  }
    55  
    56  func testCloseableAcquireSuccess(t *testing.T, cs Closeable, totalCount int) {
    57  	var (
    58  		assert  = assert.New(t)
    59  		require = require.New(t)
    60  	)
    61  
    62  	// acquire all the things!
    63  	for i := 0; i < totalCount; i++ {
    64  		done := make(chan struct{})
    65  		go func() {
    66  			defer close(done)
    67  			cs.Acquire()
    68  		}()
    69  
    70  		select {
    71  		case <-done:
    72  			// passing
    73  		case <-time.After(time.Second):
    74  			assert.FailNow("Acquire blocked unexpectedly")
    75  		}
    76  	}
    77  
    78  	// post condition: no point continuing if this fails
    79  	require.False(cs.TryAcquire())
    80  
    81  	var (
    82  		ready    = make(chan struct{})
    83  		acquired = make(chan struct{})
    84  	)
    85  
    86  	go func() {
    87  		defer close(acquired)
    88  		close(ready)
    89  		cs.Acquire() // this should now block
    90  	}()
    91  
    92  	select {
    93  	case <-ready:
    94  		// passing
    95  		require.False(cs.TryAcquire())
    96  		cs.Release()
    97  	case <-time.After(time.Second):
    98  		require.FailNow("Unable to spawn acquire goroutine")
    99  	}
   100  
   101  	select {
   102  	case <-acquired:
   103  		require.False(cs.TryAcquire())
   104  	case <-time.After(time.Second):
   105  		require.FailNow("Acquire blocked unexpectedly")
   106  	}
   107  
   108  	assert.NoError(cs.Release())
   109  	assert.True(cs.TryAcquire())
   110  	assert.NoError(cs.Release())
   111  }
   112  
   113  func testCloseableAcquireClose(t *testing.T, cs Closeable, totalCount int) {
   114  	var (
   115  		assert  = assert.New(t)
   116  		require = require.New(t)
   117  
   118  		acquiredAll = make(chan struct{})
   119  		results     = make(chan error, totalCount)
   120  		closeWait   = make(chan struct{})
   121  	)
   122  
   123  	defer cs.Close()
   124  
   125  	go func() {
   126  		defer close(acquiredAll)
   127  		for i := 0; i < totalCount; i++ {
   128  			assert.NoError(cs.Acquire())
   129  		}
   130  	}()
   131  
   132  	select {
   133  	case <-acquiredAll:
   134  		// passing
   135  	case <-time.After(5 * time.Second):
   136  		assert.FailNow("Unable to acquire all resources")
   137  	}
   138  
   139  	// block multiple routines waiting to acquire the semaphore
   140  	for i := 0; i < totalCount; i++ {
   141  		ready := make(chan struct{})
   142  		go func() {
   143  			close(ready)
   144  			results <- cs.Acquire()
   145  		}()
   146  
   147  		select {
   148  		case <-ready:
   149  			// passing
   150  		case <-time.After(time.Second):
   151  			assert.FailNow("Failed to spawn Acquire goroutine")
   152  		}
   153  	}
   154  
   155  	go func() {
   156  		defer close(closeWait)
   157  		<-cs.Closed()
   158  	}()
   159  
   160  	// post condition: no point continuing if this fails
   161  	require.False(cs.TryAcquire())
   162  
   163  	assert.NoError(cs.Close())
   164  	for i := 0; i < totalCount; i++ {
   165  		select {
   166  		case err := <-results:
   167  			assert.Equal(ErrClosed, err)
   168  		case <-time.After(5 * time.Second):
   169  			assert.FailNow("Acquire blocked unexpectedly")
   170  		}
   171  	}
   172  
   173  	select {
   174  	case <-closeWait:
   175  		assert.False(cs.TryAcquire())
   176  		assert.Equal(ErrClosed, cs.Close())
   177  		assert.Equal(ErrClosed, cs.Acquire())
   178  		assert.Equal(ErrClosed, cs.Release())
   179  
   180  	case <-time.After(5 * time.Second):
   181  		assert.FailNow("Closed channel did not get signaled")
   182  	}
   183  }
   184  
   185  func testCloseableAcquireWaitSuccess(t *testing.T, cs Closeable, totalCount int) {
   186  	var (
   187  		assert  = assert.New(t)
   188  		require = require.New(t)
   189  		timer   = make(chan time.Time)
   190  	)
   191  
   192  	// acquire all the things!
   193  	for i := 0; i < totalCount; i++ {
   194  		result := make(chan error)
   195  		go func() {
   196  			result <- cs.AcquireWait(timer)
   197  		}()
   198  
   199  		select {
   200  		case err := <-result:
   201  			assert.NoError(err)
   202  		case <-time.After(time.Second):
   203  			assert.FailNow("Acquire blocked unexpectedly")
   204  		}
   205  	}
   206  
   207  	defer cs.Close()
   208  
   209  	// post condition: no point continuing if this fails
   210  	require.False(cs.TryAcquire())
   211  
   212  	var (
   213  		ready  = make(chan struct{})
   214  		result = make(chan error)
   215  	)
   216  
   217  	go func() {
   218  		close(ready)
   219  		result <- cs.AcquireWait(timer)
   220  	}()
   221  
   222  	select {
   223  	case <-ready:
   224  		timer <- time.Time{}
   225  	case <-time.After(time.Second):
   226  		require.FailNow("Unable to spawn acquire goroutine")
   227  	}
   228  
   229  	select {
   230  	case err := <-result:
   231  		assert.Equal(ErrTimeout, err)
   232  	case <-time.After(time.Second):
   233  		require.FailNow("AcquireWait blocked unexpectedly")
   234  	}
   235  }
   236  
   237  func testCloseableAcquireWaitClose(t *testing.T, cs Closeable, totalCount int) {
   238  	var (
   239  		assert  = assert.New(t)
   240  		require = require.New(t)
   241  		timer   = make(chan time.Time)
   242  
   243  		acquiredAll = make(chan struct{})
   244  		results     = make(chan error, totalCount)
   245  		closeWait   = make(chan struct{})
   246  	)
   247  
   248  	defer cs.Close()
   249  
   250  	go func() {
   251  		defer close(acquiredAll)
   252  		for i := 0; i < totalCount; i++ {
   253  			assert.NoError(cs.Acquire())
   254  		}
   255  	}()
   256  
   257  	select {
   258  	case <-acquiredAll:
   259  		// passing
   260  	case <-time.After(5 * time.Second):
   261  		assert.FailNow("Unable to acquire all resources")
   262  	}
   263  
   264  	// acquire all the things!
   265  	for i := 0; i < totalCount; i++ {
   266  		ready := make(chan struct{})
   267  		go func() {
   268  			close(ready)
   269  			results <- cs.AcquireWait(timer)
   270  		}()
   271  
   272  		select {
   273  		case <-ready:
   274  			// passing
   275  		case <-time.After(5 * time.Second):
   276  			assert.FailNow("Failed to spawn AcquireWait goroutine")
   277  		}
   278  	}
   279  
   280  	// post condition: no point continuing if this fails
   281  	require.False(cs.TryAcquire())
   282  
   283  	go func() {
   284  		defer close(closeWait)
   285  		<-cs.Closed()
   286  	}()
   287  
   288  	assert.NoError(cs.Close())
   289  	for i := 0; i < totalCount; i++ {
   290  		select {
   291  		case err := <-results:
   292  			assert.Equal(ErrClosed, err)
   293  		case <-time.After(5 * time.Second):
   294  			assert.FailNow("AcquireWait blocked unexpectedly")
   295  		}
   296  	}
   297  
   298  	select {
   299  	case <-closeWait:
   300  		assert.False(cs.TryAcquire())
   301  		assert.Equal(ErrClosed, cs.Close())
   302  		assert.Equal(ErrClosed, cs.Acquire())
   303  		assert.Equal(ErrClosed, cs.Release())
   304  
   305  	case <-time.After(5 * time.Second):
   306  		assert.FailNow("Closed channel did not get signaled")
   307  	}
   308  }
   309  
   310  func testCloseableAcquireCtxSuccess(t *testing.T, cs Closeable, totalCount int) {
   311  	var (
   312  		assert      = assert.New(t)
   313  		require     = require.New(t)
   314  		ctx, cancel = context.WithCancel(context.Background())
   315  	)
   316  
   317  	defer cancel()
   318  
   319  	// acquire all the things!
   320  	for i := 0; i < totalCount; i++ {
   321  		result := make(chan error)
   322  		go func() {
   323  			result <- cs.AcquireCtx(ctx)
   324  		}()
   325  
   326  		select {
   327  		case err := <-result:
   328  			assert.NoError(err)
   329  		case <-time.After(time.Second):
   330  			assert.FailNow("Acquire blocked unexpectedly")
   331  		}
   332  	}
   333  
   334  	// post condition: no point continuing if this fails
   335  	require.False(cs.TryAcquire())
   336  
   337  	var (
   338  		ready  = make(chan struct{})
   339  		result = make(chan error)
   340  	)
   341  
   342  	go func() {
   343  		close(ready)
   344  		result <- cs.AcquireCtx(ctx)
   345  	}()
   346  
   347  	select {
   348  	case <-ready:
   349  		cancel()
   350  	case <-time.After(time.Second):
   351  		require.FailNow("Unable to spawn acquire goroutine")
   352  	}
   353  
   354  	select {
   355  	case err := <-result:
   356  		assert.Equal(ctx.Err(), err)
   357  	case <-time.After(time.Second):
   358  		require.FailNow("AcquireWait blocked unexpectedly")
   359  	}
   360  }
   361  
   362  func testCloseableAcquireCtxClose(t *testing.T, cs Closeable, totalCount int) {
   363  	var (
   364  		assert      = assert.New(t)
   365  		require     = require.New(t)
   366  		ctx, cancel = context.WithCancel(context.Background())
   367  
   368  		acquiredAll = make(chan struct{})
   369  		results     = make(chan error, totalCount)
   370  		closeWait   = make(chan struct{})
   371  	)
   372  
   373  	defer cancel()
   374  
   375  	go func() {
   376  		defer close(acquiredAll)
   377  		for i := 0; i < totalCount; i++ {
   378  			assert.NoError(cs.Acquire())
   379  		}
   380  	}()
   381  
   382  	select {
   383  	case <-acquiredAll:
   384  		// passing
   385  	case <-time.After(5 * time.Second):
   386  		assert.FailNow("Unable to acquire all resources")
   387  	}
   388  
   389  	// acquire all the things!
   390  	for i := 0; i < totalCount; i++ {
   391  		ready := make(chan struct{})
   392  		go func() {
   393  			close(ready)
   394  			results <- cs.AcquireCtx(ctx)
   395  		}()
   396  
   397  		select {
   398  		case <-ready:
   399  			// passing
   400  		case <-time.After(5 * time.Second):
   401  			assert.FailNow("Could not spawn AcquireCtx goroutine")
   402  		}
   403  	}
   404  
   405  	// post condition: no point continuing if this fails
   406  	require.False(cs.TryAcquire())
   407  
   408  	go func() {
   409  		defer close(closeWait)
   410  		<-cs.Closed()
   411  	}()
   412  
   413  	assert.NoError(cs.Close())
   414  	for i := 0; i < totalCount; i++ {
   415  		select {
   416  		case err := <-results:
   417  			assert.Equal(ErrClosed, err)
   418  		case <-time.After(5 * time.Second):
   419  			assert.FailNow("AcquireCtx blocked unexpectedly")
   420  		}
   421  	}
   422  
   423  	select {
   424  	case <-closeWait:
   425  		assert.False(cs.TryAcquire())
   426  		assert.Equal(ErrClosed, cs.Close())
   427  		assert.Equal(ErrClosed, cs.Acquire())
   428  		assert.Equal(ErrClosed, cs.Release())
   429  
   430  	case <-time.After(5 * time.Second):
   431  		assert.FailNow("Closed channel did not get signaled")
   432  	}
   433  }
   434  
   435  func TestCloseable(t *testing.T) {
   436  	for _, c := range []int{1, 2, 5} {
   437  		t.Run(fmt.Sprintf("count=%d", c), func(t *testing.T) {
   438  			t.Run("TryAcquire", func(t *testing.T) {
   439  				testCloseableTryAcquire(t, NewCloseable(c), c)
   440  			})
   441  
   442  			t.Run("Acquire", func(t *testing.T) {
   443  				t.Run("Success", func(t *testing.T) {
   444  					testCloseableAcquireSuccess(t, NewCloseable(c), c)
   445  				})
   446  
   447  				t.Run("Close", func(t *testing.T) {
   448  					testCloseableAcquireClose(t, NewCloseable(c), c)
   449  				})
   450  			})
   451  
   452  			t.Run("AcquireWait", func(t *testing.T) {
   453  				t.Run("Success", func(t *testing.T) {
   454  					testCloseableAcquireWaitSuccess(t, NewCloseable(c), c)
   455  				})
   456  
   457  				t.Run("Close", func(t *testing.T) {
   458  					testCloseableAcquireWaitClose(t, NewCloseable(c), c)
   459  				})
   460  			})
   461  
   462  			t.Run("AcquireCtx", func(t *testing.T) {
   463  				t.Run("Success", func(t *testing.T) {
   464  					testCloseableAcquireCtxSuccess(t, NewCloseable(c), c)
   465  				})
   466  
   467  				t.Run("Close", func(t *testing.T) {
   468  					testCloseableAcquireCtxClose(t, NewCloseable(c), c)
   469  				})
   470  			})
   471  		})
   472  	}
   473  }
   474  
   475  func TestCloseableMutex(t *testing.T) {
   476  	t.Run("TryAcquire", func(t *testing.T) {
   477  		testCloseableTryAcquire(t, CloseableMutex(), 1)
   478  	})
   479  
   480  	t.Run("Acquire", func(t *testing.T) {
   481  		t.Run("Success", func(t *testing.T) {
   482  			testCloseableAcquireSuccess(t, CloseableMutex(), 1)
   483  		})
   484  
   485  		t.Run("Close", func(t *testing.T) {
   486  			testCloseableAcquireClose(t, CloseableMutex(), 1)
   487  		})
   488  	})
   489  
   490  	t.Run("AcquireWait", func(t *testing.T) {
   491  		t.Run("Success", func(t *testing.T) {
   492  			testCloseableAcquireWaitSuccess(t, CloseableMutex(), 1)
   493  		})
   494  
   495  		t.Run("Close", func(t *testing.T) {
   496  			testCloseableAcquireWaitClose(t, CloseableMutex(), 1)
   497  		})
   498  	})
   499  
   500  	t.Run("AcquireCtx", func(t *testing.T) {
   501  		t.Run("Success", func(t *testing.T) {
   502  			testCloseableAcquireCtxSuccess(t, CloseableMutex(), 1)
   503  		})
   504  
   505  		t.Run("Close", func(t *testing.T) {
   506  			testCloseableAcquireCtxClose(t, CloseableMutex(), 1)
   507  		})
   508  	})
   509  }