github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfssync/semaphore_test.go (about)

     1  // Copyright 2017 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package kbfssync
     6  
     7  import (
     8  	"context"
     9  	"math"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/pkg/errors"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  var testTimeout = 10 * time.Second
    18  
    19  type acquireCall struct {
    20  	n     int64
    21  	count int64
    22  	err   error
    23  }
    24  
    25  func callAcquire(ctx context.Context, s *Semaphore, n int64) acquireCall {
    26  	count, err := s.Acquire(ctx, n)
    27  	return acquireCall{n, count, err}
    28  }
    29  
    30  // requireNoCall checks that there is nothing to read from
    31  // callCh. This is a racy check since it doesn't distinguish between
    32  // the goroutine with the call not having run yet, and the goroutine
    33  // with the call having run but being blocked on the semaphore.
    34  func requireNoCall(t *testing.T, callCh <-chan acquireCall) {
    35  	select {
    36  	case call := <-callCh:
    37  		t.Fatalf("Unexpected call: %+v", call)
    38  	default:
    39  	}
    40  }
    41  
    42  // TestSimple tests that Acquire and Release work in a simple
    43  // two-goroutine scenario.
    44  func TestSimple(t *testing.T) {
    45  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
    46  	defer cancel()
    47  
    48  	var n int64 = 10
    49  
    50  	s := NewSemaphore()
    51  	require.Equal(t, int64(0), s.Count())
    52  
    53  	callCh := make(chan acquireCall, 1)
    54  	go func() {
    55  		callCh <- callAcquire(ctx, s, n)
    56  	}()
    57  
    58  	requireNoCall(t, callCh)
    59  
    60  	count := s.Release(n - 1)
    61  	require.Equal(t, n-1, count)
    62  	require.Equal(t, n-1, s.Count())
    63  
    64  	requireNoCall(t, callCh)
    65  
    66  	count = s.Release(1)
    67  	require.Equal(t, n, count)
    68  
    69  	select {
    70  	case call := <-callCh:
    71  		require.Equal(t, acquireCall{n, 0, nil}, call)
    72  	case <-ctx.Done():
    73  		t.Fatal(ctx.Err())
    74  	}
    75  
    76  	require.Equal(t, int64(0), s.Count())
    77  }
    78  
    79  // TestForceAcquire tests that ForceAcquire works in a simple two-goroutine
    80  // scenario.
    81  func TestForceAcquire(t *testing.T) {
    82  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
    83  	defer cancel()
    84  
    85  	var n int64 = 10
    86  
    87  	s := NewSemaphore()
    88  	require.Equal(t, int64(0), s.Count())
    89  
    90  	callCh := make(chan acquireCall, 1)
    91  	go func() {
    92  		callCh <- callAcquire(ctx, s, n)
    93  	}()
    94  
    95  	requireNoCall(t, callCh)
    96  
    97  	count := s.Release(n - 1)
    98  	require.Equal(t, n-1, count)
    99  	require.Equal(t, n-1, s.Count())
   100  
   101  	requireNoCall(t, callCh)
   102  
   103  	count = s.ForceAcquire(n)
   104  	require.Equal(t, int64(-1), count)
   105  	require.Equal(t, int64(-1), s.Count())
   106  
   107  	count = s.Release(n + 1)
   108  	require.Equal(t, n, count)
   109  
   110  	select {
   111  	case call := <-callCh:
   112  		require.Equal(t, acquireCall{n, 0, nil}, call)
   113  	case <-ctx.Done():
   114  		t.Fatal(ctx.Err())
   115  	}
   116  
   117  	require.Equal(t, int64(0), s.Count())
   118  }
   119  
   120  // TestCancel tests that cancelling the context passed into Acquire
   121  // causes it to return an error.
   122  func TestCancel(t *testing.T) {
   123  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   124  	defer cancel()
   125  
   126  	ctx2, cancel2 := context.WithCancel(ctx)
   127  	defer cancel2()
   128  
   129  	var n int64 = 10
   130  
   131  	s := NewSemaphore()
   132  	require.Equal(t, int64(0), s.Count())
   133  
   134  	// Do this before spawning the goroutine, so that
   135  	// callAcquire() will always return a count of n-1.
   136  	count := s.Release(n - 1)
   137  	require.Equal(t, n-1, count)
   138  	require.Equal(t, n-1, s.Count())
   139  
   140  	callCh := make(chan acquireCall, 1)
   141  	go func() {
   142  		callCh <- callAcquire(ctx2, s, n)
   143  	}()
   144  
   145  	requireNoCall(t, callCh)
   146  
   147  	cancel2()
   148  	require.Equal(t, n-1, s.Count())
   149  
   150  	select {
   151  	case call := <-callCh:
   152  		call.err = errors.Cause(call.err)
   153  		require.Equal(t, acquireCall{n, n - 1, context.Canceled}, call)
   154  	case <-ctx.Done():
   155  		t.Fatal(ctx.Err())
   156  	}
   157  
   158  	require.Equal(t, n-1, s.Count())
   159  }
   160  
   161  // TestSerialRelease tests that Release(1) causes exactly one waiting
   162  // Acquire(1) to wake up at a time.
   163  func TestSerialRelease(t *testing.T) {
   164  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   165  	defer cancel()
   166  
   167  	acquirerCount := 100
   168  
   169  	s := NewSemaphore()
   170  	acquireCount := 0
   171  	callCh := make(chan acquireCall, acquirerCount)
   172  	for i := 0; i < acquirerCount; i++ {
   173  		go func() {
   174  			call := callAcquire(ctx, s, 1)
   175  			acquireCount++
   176  			callCh <- call
   177  		}()
   178  	}
   179  
   180  	for i := 0; i < acquirerCount; i++ {
   181  		requireNoCall(t, callCh)
   182  
   183  		count := s.Release(1)
   184  		require.Equal(t, int64(1), count)
   185  
   186  		select {
   187  		case call := <-callCh:
   188  			require.Equal(t, acquireCall{1, 0, nil}, call)
   189  		case <-ctx.Done():
   190  			t.Fatal(ctx.Err())
   191  		}
   192  
   193  		requireNoCall(t, callCh)
   194  
   195  		require.Equal(t, int64(0), s.Count())
   196  	}
   197  
   198  	// acquireCount should have been incremented race-free.
   199  	require.Equal(t, acquirerCount, acquireCount)
   200  }
   201  
   202  // TestAcquireDifferentSizes tests the scenario where there are
   203  // multiple acquirers for different sizes, and we release each size in
   204  // increasing order.
   205  func TestAcquireDifferentSizes(t *testing.T) {
   206  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   207  	defer cancel()
   208  
   209  	acquirerCount := 10
   210  
   211  	s := NewSemaphore()
   212  	acquireCount := 0
   213  	callCh := make(chan acquireCall, acquirerCount)
   214  	for i := 0; i < acquirerCount; i++ {
   215  		go func(i int) {
   216  			call := callAcquire(ctx, s, int64(i+1))
   217  			acquireCount++
   218  			callCh <- call
   219  		}(i)
   220  	}
   221  
   222  	for i := 0; i < acquirerCount; i++ {
   223  		requireNoCall(t, callCh)
   224  
   225  		if i == 0 {
   226  			require.Equal(t, int64(0), s.Count())
   227  		} else {
   228  			count := s.Release(int64(i))
   229  			require.Equal(t, int64(i), count)
   230  			require.Equal(t, int64(i), s.Count())
   231  		}
   232  
   233  		requireNoCall(t, callCh)
   234  
   235  		count := s.Release(1)
   236  		require.Equal(t, int64(i+1), count)
   237  
   238  		select {
   239  		case call := <-callCh:
   240  			require.Equal(t, acquireCall{int64(i + 1), 0, nil}, call)
   241  		case <-ctx.Done():
   242  			t.Fatalf("err=%+v, i=%d", ctx.Err(), i)
   243  		}
   244  
   245  		requireNoCall(t, callCh)
   246  
   247  		require.Equal(t, int64(0), s.Count())
   248  	}
   249  
   250  	// acquireCount should have been incremented race-free.
   251  	require.Equal(t, acquirerCount, acquireCount)
   252  }
   253  
   254  func TestAcquirePanic(t *testing.T) {
   255  	s := NewSemaphore()
   256  	ctx := context.Background()
   257  	require.Panics(t, func() {
   258  		_, _ = s.Acquire(ctx, 0)
   259  	})
   260  	require.Panics(t, func() {
   261  		_, _ = s.Acquire(ctx, -1)
   262  	})
   263  }
   264  
   265  func TestForceAcquirePanic(t *testing.T) {
   266  	s := NewSemaphore()
   267  	require.Panics(t, func() {
   268  		s.ForceAcquire(0)
   269  	})
   270  	require.Panics(t, func() {
   271  		s.ForceAcquire(-1)
   272  	})
   273  	s.ForceAcquire(2)
   274  	require.Panics(t, func() {
   275  		s.ForceAcquire(math.MaxInt64)
   276  	})
   277  }
   278  
   279  func TestReleasePanic(t *testing.T) {
   280  	s := NewSemaphore()
   281  	require.Panics(t, func() {
   282  		s.Release(0)
   283  	})
   284  	require.Panics(t, func() {
   285  		s.Release(-1)
   286  	})
   287  	s.Release(1)
   288  	require.Panics(t, func() {
   289  		s.Release(math.MaxInt64)
   290  	})
   291  }