github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/limiter/limiter_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package limiter
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"math/rand"
    11  	"sync"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/Schaudge/grailbase/traverse"
    17  )
    18  
    19  func TestLimiter(t *testing.T) {
    20  	l := New()
    21  	l.Release(10)
    22  
    23  	if err := l.Acquire(context.Background(), 5); err != nil {
    24  		t.Fatal(err)
    25  	}
    26  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    27  	defer cancel()
    28  	if want, got := context.DeadlineExceeded, l.Acquire(ctx, 10); got != want {
    29  		t.Fatalf("got %v, want %v", got, want)
    30  	}
    31  	l.Release(5)
    32  	if err := l.Acquire(context.Background(), 10); err != nil {
    33  		t.Fatal(err)
    34  	}
    35  }
    36  
    37  func TestLimiterConcurrently(t *testing.T) {
    38  	const (
    39  		N = 1000
    40  		T = 100
    41  	)
    42  	var pending int32
    43  	l := New()
    44  	l.Release(T)
    45  	var begin sync.WaitGroup
    46  	begin.Add(N)
    47  	err := traverse.Each(N, func(i int) error {
    48  		begin.Done()
    49  		begin.Wait()
    50  		n := rand.Intn(T) + 1
    51  		if err := l.Acquire(context.Background(), n); err != nil {
    52  			return err
    53  		}
    54  		if m := atomic.AddInt32(&pending, int32(n)); m > T {
    55  			return fmt.Errorf("too many tokens: %d > %d", m, T)
    56  		}
    57  		atomic.AddInt32(&pending, -int32(n))
    58  		l.Release(n)
    59  		return nil
    60  	})
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  }