github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/limiter/batch_test.go (about)

     1  // Copyright 2021 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package limiter
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"strconv"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/Schaudge/grailbase/traverse"
    16  	"golang.org/x/time/rate"
    17  )
    18  
    19  type testBatchApi struct {
    20  	mu          sync.Mutex
    21  	usePtr      bool
    22  	maxPerBatch int
    23  	last        time.Time
    24  	perBatchIds [][]string
    25  	durs        []time.Duration
    26  	idSeenCount map[string]int
    27  }
    28  
    29  func (a *testBatchApi) MaxPerBatch() int { return a.maxPerBatch }
    30  func (a *testBatchApi) Do(results map[ID]*Result) {
    31  	a.mu.Lock()
    32  	defer a.mu.Unlock()
    33  	now := time.Now()
    34  	if a.last.IsZero() {
    35  		a.last = now
    36  	}
    37  	ids := make([]string, 0, len(results))
    38  	for k, r := range results {
    39  		var id string
    40  		if a.usePtr {
    41  			id = *k.(*string)
    42  		} else {
    43  			id = k.(string)
    44  		}
    45  		ids = append(ids, id)
    46  		idSeenCount := a.idSeenCount[id]
    47  		i, err := strconv.Atoi(id)
    48  		if err != nil {
    49  			i = -1
    50  		}
    51  		switch {
    52  		case shouldErr(i):
    53  		case i%2 == 0:
    54  			r.Set(nil, fmt.Errorf("failed_%s_count_%d", id, idSeenCount))
    55  		default:
    56  			r.Set(fmt.Sprintf("value-%s", id), nil)
    57  		}
    58  		a.idSeenCount[id] = idSeenCount + 1
    59  	}
    60  	a.perBatchIds = append(a.perBatchIds, ids)
    61  	a.durs = append(a.durs, now.Sub(a.last))
    62  	a.last = now
    63  	return
    64  }
    65  
    66  func TestSimple(t *testing.T) {
    67  	a := &testBatchApi{idSeenCount: make(map[string]int)}
    68  	l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(time.Millisecond), 1))
    69  	id := "test"
    70  	_, _ = l.Do(context.Background(), id)
    71  	if got, want := a.idSeenCount[id], 1; got != want {
    72  		t.Errorf("got %d, want %d", got, want)
    73  	}
    74  	_, _ = l.Do(context.Background(), id)
    75  	if got, want := a.idSeenCount[id], 2; got != want {
    76  		t.Errorf("got %d, want %d", got, want)
    77  	}
    78  }
    79  
    80  func TestCtxCanceled(t *testing.T) {
    81  	a := &testBatchApi{idSeenCount: make(map[string]int)}
    82  	l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(time.Second), 1))
    83  	id1, id2 := "test1", "test2"
    84  	_, _ = l.Do(context.Background(), id1)
    85  	if got, want := a.idSeenCount[id1], 1; got != want {
    86  		t.Errorf("got %d, want %d", got, want)
    87  	}
    88  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
    89  	defer cancel()
    90  	var wg sync.WaitGroup
    91  	wg.Add(1)
    92  	go func() {
    93  		defer wg.Done()
    94  		_, _ = l.Do(ctx, id1)
    95  	}()
    96  	wg.Add(1)
    97  	go func() {
    98  		defer wg.Done()
    99  		_, _ = l.Do(context.Background(), id2)
   100  	}()
   101  	wg.Wait()
   102  	if got, want := a.idSeenCount[id1], 1; got != want {
   103  		t.Errorf("got %d, want %d", got, want)
   104  	}
   105  	if got, want := a.idSeenCount[id2], 1; got != want {
   106  		t.Errorf("got %d, want %d", got, want)
   107  	}
   108  }
   109  
   110  func TestSometimesDedup(t *testing.T) {
   111  	const num = 5
   112  	a := &testBatchApi{idSeenCount: make(map[string]int)}
   113  	l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(10*time.Millisecond), num))
   114  	id := "test"
   115  	a.mu.Lock() // Locks the batch API.
   116  	var done sync.WaitGroup
   117  	done.Add(num)
   118  	for i := 0; i < num; i++ {
   119  		go func() {
   120  			defer done.Done()
   121  			_, _ = l.Do(context.Background(), id)
   122  		}()
   123  	}
   124  	var allWaiting bool
   125  	for !allWaiting {
   126  		l.mu.Lock()
   127  		r := l.results[id]
   128  		l.mu.Unlock()
   129  		if r == nil {
   130  			time.Sleep(time.Millisecond)
   131  			continue
   132  		}
   133  		r.mu.Lock()
   134  		allWaiting = r.nWaiters == num
   135  		r.mu.Unlock()
   136  	}
   137  	a.mu.Unlock() // Unlock the batch API.
   138  	done.Wait()   // Wait for all the goroutines on the same ID to complete
   139  	if got, want := a.idSeenCount[id], 1; got != want {
   140  		t.Errorf("got %d, want %d", got, want)
   141  	}
   142  }
   143  
   144  func TestNoDedup(t *testing.T) {
   145  	a := &testBatchApi{usePtr: true, idSeenCount: make(map[string]int)}
   146  	l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(10*time.Millisecond), 1))
   147  	id := "test"
   148  	a.mu.Lock() // Locks the batch API.
   149  	var started, done sync.WaitGroup
   150  	started.Add(5)
   151  	done.Add(5)
   152  	for i := 0; i < 5; i++ {
   153  		go func() {
   154  			started.Done()
   155  			id := id
   156  			_, _ = l.Do(context.Background(), &id)
   157  			done.Done()
   158  		}()
   159  	}
   160  	started.Wait() // Wait for all the goroutines on the same ID to start
   161  	a.mu.Unlock()  // Unlock the batch API.
   162  	done.Wait()    // Wait for all the goroutines on the same ID to complete
   163  	if got, want := a.idSeenCount[id], 5; got != want {
   164  		t.Errorf("got %d, want %d", got, want)
   165  	}
   166  }
   167  
   168  func TestDo(t *testing.T) {
   169  	testApi(t, &testBatchApi{idSeenCount: make(map[string]int)}, time.Second)
   170  }
   171  
   172  func TestDoWithMax5(t *testing.T) {
   173  	testApi(t, &testBatchApi{maxPerBatch: 5, idSeenCount: make(map[string]int)}, 3*time.Second)
   174  }
   175  
   176  func TestDoWithMax8(t *testing.T) {
   177  	testApi(t, &testBatchApi{maxPerBatch: 8, idSeenCount: make(map[string]int)}, 2*time.Second)
   178  }
   179  
   180  type result struct {
   181  	v   string
   182  	err error
   183  }
   184  
   185  func shouldErr(i int) bool {
   186  	return i%5 == 0 && i%2 != 0
   187  }
   188  
   189  func testApi(t *testing.T, a *testBatchApi, timeout time.Duration) {
   190  	const numIds = 100
   191  	var interval = 100 * time.Millisecond
   192  	l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(interval), 1))
   193  	var mu sync.Mutex
   194  	results := make(map[string]result)
   195  	_ = traverse.Each(numIds, func(i int) error {
   196  		time.Sleep(time.Duration(i*10) * time.Millisecond)
   197  		id := fmt.Sprintf("%d", i)
   198  		ctx, cancel := context.WithTimeout(context.Background(), timeout)
   199  		defer cancel()
   200  		v, err := l.Do(ctx, id)
   201  		mu.Lock()
   202  		r := result{err: err}
   203  		if r.err == nil {
   204  			r.v = v.(string)
   205  		}
   206  		results[id] = r
   207  		mu.Unlock()
   208  		return nil
   209  	})
   210  	for i := 0; i < numIds; i++ {
   211  		id := fmt.Sprintf("%d", i)
   212  		if got, want := a.idSeenCount[id], 1; got != want {
   213  			t.Errorf("[%v] got %d, want %d", id, got, want)
   214  		}
   215  		if shouldErr(i) {
   216  			if got, want := results[id].err, ErrNoResult; got != want {
   217  				t.Errorf("[%d] got %v, want %v", i, got, want)
   218  			}
   219  		}
   220  	}
   221  	for _, dur := range a.durs[1:] {
   222  		if got, want, diff := dur, interval, (dur - interval).Round(5*time.Millisecond); diff < 0 {
   223  			t.Errorf("got %v, want %v, diff %v", got, want, diff)
   224  		}
   225  	}
   226  	for i, batchIds := range a.perBatchIds {
   227  		if want := a.maxPerBatch; want > 0 {
   228  			if got := len(batchIds); got > want {
   229  				t.Errorf("got %v, want <=%v", got, want)
   230  			}
   231  		}
   232  		t.Logf("batch %d (after %s): %v", i, a.durs[i].Round(time.Millisecond), batchIds)
   233  	}
   234  }