github.com/google/martian/v3@v3.3.3/trafficshape/bucket_test.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package trafficshape 16 17 import ( 18 "errors" 19 "runtime" 20 "sync/atomic" 21 "testing" 22 "time" 23 ) 24 25 func TestBucket(t *testing.T) { 26 t.Parallel() 27 28 b := NewBucket(10, 10*time.Millisecond) 29 defer b.Close() 30 31 if got, want := b.Capacity(), int64(10); got != want { 32 t.Fatalf("b.Capacity(): got %d, want %d", got, want) 33 } 34 35 n, err := b.Fill(func(remaining int64) (int64, error) { 36 if want := int64(10); remaining != want { 37 t.Errorf("remaining: got %d, want %d", remaining, want) 38 } 39 return 5, nil 40 }) 41 if err != nil { 42 t.Fatalf("Fill(): got %v, want no error", err) 43 } 44 if got, want := n, int64(5); got != want { 45 t.Fatalf("n: got %d, want %d", got, want) 46 } 47 48 n, err = b.Fill(func(remaining int64) (int64, error) { 49 if want := int64(5); remaining != want { 50 t.Errorf("remaining: got %d, want %d", remaining, want) 51 } 52 return 5, nil 53 }) 54 if err != nil { 55 t.Fatalf("Fill(): got %v, want no error", err) 56 } 57 if got, want := n, int64(5); got != want { 58 t.Fatalf("n: got %d, want %d", got, want) 59 } 60 n, err = b.Fill(func(remaining int64) (int64, error) { 61 t.Fatal("Fill: executed func when full, want skipped") 62 return 0, nil 63 }) 64 if err != nil { 65 t.Fatalf("Fill(): got %v, want no error", err) 66 } 67 68 // Wait for the bucket to drain. 69 for { 70 if atomic.LoadInt64(&b.fill) == 0 { 71 break 72 } 73 // Allow for a goroutine switch, required for GOMAXPROCS = 1. 74 runtime.Gosched() 75 } 76 77 wanterr := errors.New("fill function error") 78 n, err = b.Fill(func(remaining int64) (int64, error) { 79 if want := int64(10); remaining != want { 80 t.Errorf("remaining: got %d, want %d", remaining, want) 81 } 82 return 0, wanterr 83 }) 84 if err != wanterr { 85 t.Fatalf("Fill(): got %v, want %v", err, wanterr) 86 } 87 if got, want := n, int64(0); got != want { 88 t.Fatalf("n: got %d, want %d", got, want) 89 } 90 } 91 92 func TestBucketClosed(t *testing.T) { 93 t.Parallel() 94 95 b := NewBucket(0, time.Millisecond) 96 b.Close() 97 98 if _, err := b.Fill(nil); err != errFillClosedBucket { 99 t.Errorf("Fill(): got %v, want errFillClosedBucket", err) 100 } 101 if _, err := b.FillThrottle(nil); err != errFillClosedBucket { 102 t.Errorf("FillThrottle(): got %v, want errFillClosedBucket", err) 103 } 104 } 105 106 func TestBucketOverflow(t *testing.T) { 107 t.Parallel() 108 109 b := NewBucket(10, 10*time.Millisecond) 110 defer b.Close() 111 112 n, err := b.Fill(func(remaining int64) (int64, error) { 113 return 11, nil 114 }) 115 if err != nil { 116 t.Fatalf("Fill(): got %v, want no error", err) 117 } 118 119 n, err = b.Fill(func(int64) (int64, error) { 120 t.Fatal("Fill: executed func when full, want skipped") 121 return 0, nil 122 }) 123 if err != ErrBucketOverflow { 124 t.Fatalf("Fill(): got %v, want ErrBucketOverflow", err) 125 } 126 if got, want := n, int64(0); got != want { 127 t.Fatalf("n: got %d, want %d", got, want) 128 } 129 } 130 131 func TestBucketThrottle(t *testing.T) { 132 t.Parallel() 133 134 b := NewBucket(50, 50*time.Millisecond) 135 defer b.Close() 136 137 closec := make(chan struct{}) 138 errc := make(chan error, 1) 139 140 fill := func() { 141 for { 142 select { 143 case <-closec: 144 return 145 default: 146 if _, err := b.FillThrottle(func(remaining int64) (int64, error) { 147 if remaining < 10 { 148 return remaining, nil 149 } 150 return 10, nil 151 }); err != nil { 152 select { 153 case errc <- err: 154 default: 155 } 156 } 157 } 158 } 159 } 160 161 for i := 0; i < 5; i++ { 162 go fill() 163 } 164 165 time.Sleep(time.Second) 166 167 close(closec) 168 169 select { 170 case err := <-errc: 171 t.Fatalf("FillThrottle: got %v, want no error", err) 172 default: 173 } 174 } 175 176 func TestBucketFillThrottleCloseBeforeTick(t *testing.T) { 177 t.Parallel() 178 179 b := NewBucket(0, time.Minute) 180 time.AfterFunc(time.Second, func() { b.Close() }) 181 182 if _, err := b.FillThrottle(func(int64) (int64, error) { 183 t.Fatal("FillThrottle(): executed func after close, want skipped") 184 return 0, nil 185 }); err != errFillClosedBucket { 186 t.Errorf("b.FillThrottle(): got nil, want errFillClosedBucket") 187 } 188 }