
     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     5  package admit
     7  import (
     8  	"context"
     9  	"expvar"
    10  	"fmt"
    11  	"math/rand"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    17  	""
    18  	""
    19  )
    21  func checkState(t *testing.T, p Policy, limit, used int) {
    22  	t.Helper()
    23  	var gotl, gotu int
    24  	switch c := p.(type) {
    25  	case *controller:
    26  		gotl = c.limit
    27  		gotu = c.used
    28  	case *aimd:
    29  		gotl = c.limit
    30  		gotu = c.used
    31  	}
    32  	if gotu != used {
    33  		t.Errorf("c.used: got %d, want %d", gotu, used)
    34  	}
    35  	if gotl != limit {
    36  		t.Errorf("c.limit: got %d, want %d", gotl, limit)
    37  	}
    38  }
    40  func checkVars(t *testing.T, key, max, used string) {
    41  	t.Helper()
    42  	if want, got := max, admitLimit.Get(key).String(); got != want {
    43  		t.Errorf("admitLimit got %s, want %s", got, want)
    44  	}
    45  	if want, got := used, admitUsed.Get(key).String(); got != want {
    46  		t.Errorf("admitUsed got %s, want %s", got, want)
    47  	}
    48  }
    50  func getKeys(m *expvar.Map) map[string]bool {
    51  	keys := make(map[string]bool)
    52  	m.Do(func(kv expvar.KeyValue) {
    53  		keys[kv.Key] = true
    54  	})
    55  	return keys
    56  }
    58  func TestController(t *testing.T) {
    59  	c := newController(10, 15)
    60  	// use up 5.
    61  	if err := c.Acquire(context.Background(), 5); err != nil {
    62  		t.Fatal(err)
    63  	}
    64  	checkState(t, c, 10, 5)
    65  	// can go upto 6.
    66  	if err := c.Acquire(context.Background(), 6); err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	// release and report capacity error.
    70  	c.Release(5, false)
    71  	checkState(t, c, 10, 6)
    72  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
    73  	// 6 still in use and limit should now be 10, so can't acquire 6.
    74  	if want, got := context.DeadlineExceeded, c.Acquire(ctx, 6); got != want {
    75  		t.Fatalf("got %v, want %v", got, want)
    76  	}
    77  	cancel()
    78  	if want, got := 0, getKeys(admitLimit); len(got) != want {
    79  		t.Fatalf("admitLimit got %v, want len %d", got, want)
    80  	}
    81  	if want, got := 0, getKeys(admitUsed); len(got) != want {
    82  		t.Fatalf("admitUsed got %v, want len %d", got, want)
    83  	}
    84  	EnableVarExport(c, "test")
    85  	c.Release(6, true)
    86  	checkState(t, c, 10, 0)
    87  	checkVars(t, "test", "10", "0")
    88  	// max is still 9, but since none are used, should accommodate larger request.
    89  	if err := c.Acquire(context.Background(), 18); err != nil {
    90  		t.Fatal(err)
    91  	}
    92  	checkState(t, c, 10, 18)
    93  	checkVars(t, "test", "10", "18")
    94  	c.Release(17, true)
    95  	checkState(t, c, 15, 1)
    96  	checkVars(t, "test", "15", "1")
    97  	ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
    98  	// 1 still in use and max is 15, so shouldn't accommodate larger request.
    99  	if want, got := context.DeadlineExceeded, c.Acquire(ctx, 18); got != want {
   100  		t.Fatalf("got %v, want %v", got, want)
   101  	}
   102  	cancel()
   103  	checkState(t, c, 15, 1)
   104  	checkVars(t, "test", "15", "1")
   105  	c.Release(1, true)
   106  	checkState(t, c, 15, 0)
   107  	checkVars(t, "test", "15", "0")
   108  }
   110  func TestControllerConcurrently(t *testing.T) {
   111  	testPolicy(t, ControllerWithRetry(100, 1000, nil))
   112  }
   114  func TestAIMD(t *testing.T) {
   115  	c := newAimd(10, 0.2)
   116  	// use up 5.
   117  	if err := c.Acquire(context.Background(), 5); err != nil {
   118  		t.Fatal(err)
   119  	}
   120  	checkState(t, c, 10, 5)
   121  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
   122  	// 5 in use and limit should still be 10, so can't acquire 6.
   123  	if want, got := context.DeadlineExceeded, c.Acquire(ctx, 6); got != want {
   124  		t.Fatalf("got %v, want %v", got, want)
   125  	}
   126  	cancel()
   127  	// release and report capacity error.
   128  	EnableVarExport(c, "aimd")
   129  	c.Release(5, true)
   130  	checkState(t, c, 10, 0)
   131  	checkVars(t, "aimd", "10", "0")
   133  	for i := 0; i < 10; i++ {
   134  		if err := c.Acquire(context.Background(), 1); err != nil {
   135  			t.Fatal(err)
   136  		}
   137  	}
   138  	checkState(t, c, 10, 10)
   139  	checkVars(t, "aimd", "10", "10")
   140  	for i := 1; i <= 5; i++ {
   141  		c.Release(i, true)
   142  		if err := c.Acquire(context.Background(), i+1); err != nil {
   143  			t.Fatal(err)
   144  		}
   145  	}
   146  	checkState(t, c, 15, 15)
   147  	checkVars(t, "aimd", "15", "15")
   149  	c.Release(1, false)
   150  	checkState(t, c, 12, 14)
   151  	checkVars(t, "aimd", "12", "14")
   153  	c.Release(1, false)
   154  	checkState(t, c, 10, 13)
   155  	checkVars(t, "aimd", "10", "13")
   157  	ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
   158  	// 13 still in use and limit should now be 10, so can't acquire 1.
   159  	if want, got := context.DeadlineExceeded, c.Acquire(ctx, 1); got != want {
   160  		t.Fatalf("got %v, want %v", got, want)
   161  	}
   162  	cancel()
   163  }
   165  func TestAIMDConcurrently(t *testing.T) {
   166  	testPolicy(t, AIMDWithRetry(100, 0.25, nil))
   167  }
   169  func testPolicy(t *testing.T, p Policy) {
   170  	const (
   171  		N = 100
   172  		T = 100
   173  	)
   174  	var pending int32
   175  	var begin sync.WaitGroup
   176  	begin.Add(N)
   177  	err := traverse.Each(N, func(i int) error {
   178  		begin.Done()
   179  		n := rand.Intn(T/10) + 1
   180  		if err := p.Acquire(context.Background(), n); err != nil {
   181  			return err
   182  		}
   183  		if m := atomic.AddInt32(&pending, int32(n)); m > T {
   184  			return fmt.Errorf("too many tokens: %d > %d", m, T)
   185  		}
   186  		atomic.AddInt32(&pending, -int32(n))
   187  		p.Release(n, (i > 10 && i < 20) || (i > 70 && i < 80))
   188  		return nil
   189  	})
   190  	if err != nil {
   191  		t.Fatal(err)
   192  	}
   193  }
   195  func TestDo(t *testing.T) {
   196  	c := newController(100, 10000)
   197  	// Must satisfy even 150 tokens since none are used.
   198  	if err := Do(context.Background(), c, 150, func() (bool, error) { return true, nil }); err != nil {
   199  		t.Fatal(err)
   200  	}
   201  	checkState(t, c, 150, 0)
   202  	// controller has 150 tokens, use 10 and report capacity error
   203  	if want, got := error(nil), Do(context.Background(), c, 10, func() (bool, error) { return false, nil }); got != want {
   204  		t.Fatalf("got %v, want %v", got, want)
   205  	}
   206  	checkState(t, c, 135, 0)
   207  	// controller has 135 tokens, use up 35...
   208  	c.Acquire(context.Background(), 35)
   209  	checkState(t, c, 135, 35)
   210  	// can go upto 1.1*135 = 148, so should timeout for 114.
   211  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   212  	defer cancel()
   213  	if want, got := context.DeadlineExceeded, Do(ctx, c, 114, func() (bool, error) { return true, nil }); got != want {
   214  		t.Fatalf("got %v, want %v", got, want)
   215  	}
   216  	checkState(t, c, 135, 35)
   217  	// can go upto 1.1*135 = 148, so should timeout for 113.
   218  	if want, got := error(nil), Do(context.Background(), c, 113, func() (bool, error) { return true, nil }); got != want {
   219  		t.Fatalf("got %v, want %v", got, want)
   220  	}
   221  	checkState(t, c, 148, 35)
   222  	// can go upto 1.1*148 = 162, so should go upto 127.
   223  	if err := Do(context.Background(), c, 127, func() (bool, error) { return true, nil }); err != nil {
   224  		t.Fatal(err)
   225  	}
   226  }
   228  func TestRetry(t *testing.T) {
   229  	const (
   230  		N = 1000
   231  	)
   232  	c := ControllerWithRetry(200, 1000, retry.MaxRetries(retry.Backoff(100*time.Millisecond, time.Minute, 1.5), 5))
   233  	var begin sync.WaitGroup
   234  	begin.Add(N)
   235  	err := traverse.Each(N, func(i int) error {
   236  		begin.Done()
   237  		begin.Wait()
   238  		randFunc := func() (CapacityStatus, error) {
   239  			// Out of every three requests, one will (5% of the time) report over capacity with a need to retry,
   240  			// and another (also 5% of the time) will report over capacity with no need to retry.
   241  			switch i % 3 {
   242  			case 0:
   243  				time.Sleep(time.Millisecond * time.Duration(20+rand.Intn(50)))
   244  				if rand.Intn(100) < 5 { // 5% of the time.
   245  					return OverNeedRetry, nil
   246  				}
   247  			case 1:
   248  				time.Sleep(time.Millisecond * time.Duration(20+rand.Intn(50)))
   249  				if rand.Intn(100) < 5 { // 5% of the time.
   250  					return OverNoRetry, nil
   251  				}
   252  			}
   253  			time.Sleep(time.Millisecond * time.Duration(5+rand.Intn(20)))
   254  			return Within, nil
   255  		}
   256  		n := rand.Intn(20) + 1
   257  		return Retry(context.Background(), c, n, randFunc)
   258  	})
   259  	if err != nil {
   260  		t.Fatal(err)
   261  	}
   262  }