github.com/xmidt-org/webpa-common@v1.11.9/middleware/concurrent_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func testConcurrentNoCancellation(t *testing.T, concurrency int) {
    15  	var (
    16  		require          = require.New(t)
    17  		assert           = assert.New(t)
    18  		expectedRequest  = "expected request"
    19  		expectedResponse = "expected response"
    20  
    21  		nextCalled = false
    22  		next       = func(ctx context.Context, value interface{}) (interface{}, error) {
    23  			nextCalled = true
    24  			assert.Equal(expectedRequest, value)
    25  			return expectedResponse, nil
    26  		}
    27  
    28  		concurrent = Concurrent(concurrency, nil)
    29  	)
    30  
    31  	require.NotNil(concurrent)
    32  	actualResponse, err := concurrent(next)(context.Background(), expectedRequest)
    33  	assert.Equal(expectedResponse, actualResponse)
    34  	assert.NoError(err)
    35  	assert.True(nextCalled)
    36  }
    37  
    38  func testConcurrentCancel(t *testing.T, concurrency int, timeoutError error) {
    39  	var (
    40  		require             = require.New(t)
    41  		assert              = assert.New(t)
    42  		expectedCtx, cancel = context.WithCancel(context.Background())
    43  		expectedResponse    = "expected response"
    44  
    45  		nextWaiting = new(sync.WaitGroup)
    46  		nextBarrier = make(chan struct{})
    47  		next        = func(ctx context.Context, value interface{}) (interface{}, error) {
    48  			wait, ok := value.(func())
    49  			if ok {
    50  				wait()
    51  			}
    52  
    53  			return expectedResponse, nil
    54  		}
    55  
    56  		concurrent = Concurrent(concurrency, timeoutError)
    57  	)
    58  
    59  	require.NotNil(concurrent)
    60  	endpoint := concurrent(next)
    61  
    62  	// spawn enough goroutines to exhaust the semaphore
    63  	nextWaiting.Add(concurrency)
    64  	for r := 0; r < concurrency; r++ {
    65  		go endpoint(expectedCtx, func() {
    66  			nextWaiting.Done()
    67  			<-nextBarrier
    68  		})
    69  	}
    70  
    71  	// wait until we know the semaphore is exhausted, then cancel
    72  	nextWaiting.Wait()
    73  	cancel()
    74  
    75  	// because the context is cancelled, subsequent calls should complete immediately
    76  	actualResponse, err := endpoint(expectedCtx, "request")
    77  	assert.Nil(actualResponse)
    78  	assert.NotNil(err)
    79  
    80  	if timeoutError != nil {
    81  		assert.Equal(timeoutError, err)
    82  	} else {
    83  		assert.Equal(context.Canceled, err)
    84  	}
    85  
    86  	close(nextBarrier)
    87  }
    88  
    89  func TestConcurrent(t *testing.T) {
    90  	t.Run("NoCancellation", func(t *testing.T) {
    91  		for _, c := range []int{1, 10, 15, 100} {
    92  			t.Run(fmt.Sprintf("Concurrency=%d", c), func(t *testing.T) {
    93  				testConcurrentNoCancellation(t, c)
    94  			})
    95  		}
    96  	})
    97  
    98  	t.Run("Cancel", func(t *testing.T) {
    99  		t.Run("NilTimeoutError", func(t *testing.T) {
   100  			for _, c := range []int{1, 10, 15, 100} {
   101  				t.Run(fmt.Sprintf("Concurrency=%d", c), func(t *testing.T) {
   102  					testConcurrentCancel(t, c, nil)
   103  				})
   104  			}
   105  		})
   106  
   107  		t.Run("WithTimeoutError", func(t *testing.T) {
   108  			for _, c := range []int{1, 10, 15, 100} {
   109  				t.Run(fmt.Sprintf("Concurrency=%d", c), func(t *testing.T) {
   110  					testConcurrentCancel(t, c, errors.New("expected timeout error"))
   111  				})
   112  			}
   113  		})
   114  	})
   115  }