github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/graph/check_test.go (about)

     1  package graph
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func TestAsyncDispatch(t *testing.T) {
    14  	testCases := []struct {
    15  		numRequests      uint16
    16  		concurrencyLimit uint16
    17  	}{
    18  		{1, 1},
    19  		{10, 1},
    20  		{1, 50},
    21  		{50, 50},
    22  		{1000, 10},
    23  	}
    24  
    25  	for _, tc := range testCases {
    26  		tc := tc
    27  		t.Run(fmt.Sprintf("%d/%d", tc.numRequests, tc.concurrencyLimit), func(t *testing.T) {
    28  			require := require.New(t)
    29  
    30  			ctx := context.Background()
    31  
    32  			l := &sync.Mutex{}
    33  			letFinish := sync.NewCond(l)
    34  			var dispatchedCount uint16
    35  			var completedCount uint16
    36  
    37  			reqs := make([]int, 0, tc.numRequests)
    38  
    39  			for i := 0; i < int(tc.numRequests); i++ {
    40  				reqs = append(reqs, i)
    41  			}
    42  
    43  			channel := make(chan CheckResult, tc.numRequests)
    44  
    45  			dispatchAllAsync(ctx, currentRequestContext{}, reqs,
    46  				func(ctx context.Context, crc currentRequestContext, child int) CheckResult {
    47  					l.Lock()
    48  					defer l.Unlock()
    49  					dispatchedCount++
    50  					letFinish.Wait()
    51  					completedCount++
    52  					return noMembers()
    53  				}, channel, tc.concurrencyLimit)
    54  
    55  			require.Eventually(func() bool {
    56  				l.Lock()
    57  				defer l.Unlock()
    58  
    59  				return (tc.numRequests >= tc.concurrencyLimit && dispatchedCount == tc.concurrencyLimit) ||
    60  					(tc.numRequests < tc.concurrencyLimit && dispatchedCount == tc.numRequests)
    61  			}, 1*time.Second, 1*time.Millisecond)
    62  
    63  			l.Lock()
    64  			require.Equal(uint16(0), completedCount)
    65  			l.Unlock()
    66  
    67  			letFinish.Signal()
    68  
    69  			require.Eventually(func() bool {
    70  				l.Lock()
    71  				defer l.Unlock()
    72  
    73  				return completedCount == 1
    74  			}, 10*time.Millisecond, 10*time.Microsecond)
    75  
    76  			require.Eventually(func() bool {
    77  				l.Lock()
    78  				defer l.Unlock()
    79  
    80  				letFinish.Broadcast()
    81  				return tc.numRequests == dispatchedCount && tc.numRequests == completedCount
    82  			}, 1*time.Second, 1*time.Millisecond)
    83  		})
    84  	}
    85  }