github.com/google/martian/v3@v3.3.3/trafficshape/bucket_test.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package trafficshape
    16  
    17  import (
    18  	"errors"
    19  	"runtime"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  )
    24  
    25  func TestBucket(t *testing.T) {
    26  	t.Parallel()
    27  
    28  	b := NewBucket(10, 10*time.Millisecond)
    29  	defer b.Close()
    30  
    31  	if got, want := b.Capacity(), int64(10); got != want {
    32  		t.Fatalf("b.Capacity(): got %d, want %d", got, want)
    33  	}
    34  
    35  	n, err := b.Fill(func(remaining int64) (int64, error) {
    36  		if want := int64(10); remaining != want {
    37  			t.Errorf("remaining: got %d, want %d", remaining, want)
    38  		}
    39  		return 5, nil
    40  	})
    41  	if err != nil {
    42  		t.Fatalf("Fill(): got %v, want no error", err)
    43  	}
    44  	if got, want := n, int64(5); got != want {
    45  		t.Fatalf("n: got %d, want %d", got, want)
    46  	}
    47  
    48  	n, err = b.Fill(func(remaining int64) (int64, error) {
    49  		if want := int64(5); remaining != want {
    50  			t.Errorf("remaining: got %d, want %d", remaining, want)
    51  		}
    52  		return 5, nil
    53  	})
    54  	if err != nil {
    55  		t.Fatalf("Fill(): got %v, want no error", err)
    56  	}
    57  	if got, want := n, int64(5); got != want {
    58  		t.Fatalf("n: got %d, want %d", got, want)
    59  	}
    60  	n, err = b.Fill(func(remaining int64) (int64, error) {
    61  		t.Fatal("Fill: executed func when full, want skipped")
    62  		return 0, nil
    63  	})
    64  	if err != nil {
    65  		t.Fatalf("Fill(): got %v, want no error", err)
    66  	}
    67  
    68  	// Wait for the bucket to drain.
    69  	for {
    70  		if atomic.LoadInt64(&b.fill) == 0 {
    71  			break
    72  		}
    73  		// Allow for a goroutine switch, required for GOMAXPROCS = 1.
    74  		runtime.Gosched()
    75  	}
    76  
    77  	wanterr := errors.New("fill function error")
    78  	n, err = b.Fill(func(remaining int64) (int64, error) {
    79  		if want := int64(10); remaining != want {
    80  			t.Errorf("remaining: got %d, want %d", remaining, want)
    81  		}
    82  		return 0, wanterr
    83  	})
    84  	if err != wanterr {
    85  		t.Fatalf("Fill(): got %v, want %v", err, wanterr)
    86  	}
    87  	if got, want := n, int64(0); got != want {
    88  		t.Fatalf("n: got %d, want %d", got, want)
    89  	}
    90  }
    91  
    92  func TestBucketClosed(t *testing.T) {
    93  	t.Parallel()
    94  
    95  	b := NewBucket(0, time.Millisecond)
    96  	b.Close()
    97  
    98  	if _, err := b.Fill(nil); err != errFillClosedBucket {
    99  		t.Errorf("Fill(): got %v, want errFillClosedBucket", err)
   100  	}
   101  	if _, err := b.FillThrottle(nil); err != errFillClosedBucket {
   102  		t.Errorf("FillThrottle(): got %v, want errFillClosedBucket", err)
   103  	}
   104  }
   105  
   106  func TestBucketOverflow(t *testing.T) {
   107  	t.Parallel()
   108  
   109  	b := NewBucket(10, 10*time.Millisecond)
   110  	defer b.Close()
   111  
   112  	n, err := b.Fill(func(remaining int64) (int64, error) {
   113  		return 11, nil
   114  	})
   115  	if err != nil {
   116  		t.Fatalf("Fill(): got %v, want no error", err)
   117  	}
   118  
   119  	n, err = b.Fill(func(int64) (int64, error) {
   120  		t.Fatal("Fill: executed func when full, want skipped")
   121  		return 0, nil
   122  	})
   123  	if err != ErrBucketOverflow {
   124  		t.Fatalf("Fill(): got %v, want ErrBucketOverflow", err)
   125  	}
   126  	if got, want := n, int64(0); got != want {
   127  		t.Fatalf("n: got %d, want %d", got, want)
   128  	}
   129  }
   130  
   131  func TestBucketThrottle(t *testing.T) {
   132  	t.Parallel()
   133  
   134  	b := NewBucket(50, 50*time.Millisecond)
   135  	defer b.Close()
   136  
   137  	closec := make(chan struct{})
   138  	errc := make(chan error, 1)
   139  
   140  	fill := func() {
   141  		for {
   142  			select {
   143  			case <-closec:
   144  				return
   145  			default:
   146  				if _, err := b.FillThrottle(func(remaining int64) (int64, error) {
   147  					if remaining < 10 {
   148  						return remaining, nil
   149  					}
   150  					return 10, nil
   151  				}); err != nil {
   152  					select {
   153  					case errc <- err:
   154  					default:
   155  					}
   156  				}
   157  			}
   158  		}
   159  	}
   160  
   161  	for i := 0; i < 5; i++ {
   162  		go fill()
   163  	}
   164  
   165  	time.Sleep(time.Second)
   166  
   167  	close(closec)
   168  
   169  	select {
   170  	case err := <-errc:
   171  		t.Fatalf("FillThrottle: got %v, want no error", err)
   172  	default:
   173  	}
   174  }
   175  
   176  func TestBucketFillThrottleCloseBeforeTick(t *testing.T) {
   177  	t.Parallel()
   178  
   179  	b := NewBucket(0, time.Minute)
   180  	time.AfterFunc(time.Second, func() { b.Close() })
   181  
   182  	if _, err := b.FillThrottle(func(int64) (int64, error) {
   183  		t.Fatal("FillThrottle(): executed func after close, want skipped")
   184  		return 0, nil
   185  	}); err != errFillClosedBucket {
   186  		t.Errorf("b.FillThrottle(): got nil, want errFillClosedBucket")
   187  	}
   188  }