github.com/xmidt-org/webpa-common@v1.11.9/middleware/concurrent_test.go (about) 1 package middleware 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync" 8 "testing" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12 ) 13 14 func testConcurrentNoCancellation(t *testing.T, concurrency int) { 15 var ( 16 require = require.New(t) 17 assert = assert.New(t) 18 expectedRequest = "expected request" 19 expectedResponse = "expected response" 20 21 nextCalled = false 22 next = func(ctx context.Context, value interface{}) (interface{}, error) { 23 nextCalled = true 24 assert.Equal(expectedRequest, value) 25 return expectedResponse, nil 26 } 27 28 concurrent = Concurrent(concurrency, nil) 29 ) 30 31 require.NotNil(concurrent) 32 actualResponse, err := concurrent(next)(context.Background(), expectedRequest) 33 assert.Equal(expectedResponse, actualResponse) 34 assert.NoError(err) 35 assert.True(nextCalled) 36 } 37 38 func testConcurrentCancel(t *testing.T, concurrency int, timeoutError error) { 39 var ( 40 require = require.New(t) 41 assert = assert.New(t) 42 expectedCtx, cancel = context.WithCancel(context.Background()) 43 expectedResponse = "expected response" 44 45 nextWaiting = new(sync.WaitGroup) 46 nextBarrier = make(chan struct{}) 47 next = func(ctx context.Context, value interface{}) (interface{}, error) { 48 wait, ok := value.(func()) 49 if ok { 50 wait() 51 } 52 53 return expectedResponse, nil 54 } 55 56 concurrent = Concurrent(concurrency, timeoutError) 57 ) 58 59 require.NotNil(concurrent) 60 endpoint := concurrent(next) 61 62 // spawn enough goroutines to exhaust the semaphore 63 nextWaiting.Add(concurrency) 64 for r := 0; r < concurrency; r++ { 65 go endpoint(expectedCtx, func() { 66 nextWaiting.Done() 67 <-nextBarrier 68 }) 69 } 70 71 // wait until we know the semaphore is exhausted, then cancel 72 nextWaiting.Wait() 73 cancel() 74 75 // because the context is cancelled, subsequent calls should complete immediately 76 actualResponse, err := endpoint(expectedCtx, "request") 77 assert.Nil(actualResponse) 78 assert.NotNil(err) 79 80 if timeoutError != nil { 81 assert.Equal(timeoutError, err) 82 } else { 83 assert.Equal(context.Canceled, err) 84 } 85 86 close(nextBarrier) 87 } 88 89 func TestConcurrent(t *testing.T) { 90 t.Run("NoCancellation", func(t *testing.T) { 91 for _, c := range []int{1, 10, 15, 100} { 92 t.Run(fmt.Sprintf("Concurrency=%d", c), func(t *testing.T) { 93 testConcurrentNoCancellation(t, c) 94 }) 95 } 96 }) 97 98 t.Run("Cancel", func(t *testing.T) { 99 t.Run("NilTimeoutError", func(t *testing.T) { 100 for _, c := range []int{1, 10, 15, 100} { 101 t.Run(fmt.Sprintf("Concurrency=%d", c), func(t *testing.T) { 102 testConcurrentCancel(t, c, nil) 103 }) 104 } 105 }) 106 107 t.Run("WithTimeoutError", func(t *testing.T) { 108 for _, c := range []int{1, 10, 15, 100} { 109 t.Run(fmt.Sprintf("Concurrency=%d", c), func(t *testing.T) { 110 testConcurrentCancel(t, c, errors.New("expected timeout error")) 111 }) 112 } 113 }) 114 }) 115 }