sigs.k8s.io/prow@v0.0.0-20240503223140-c5e374dc7eb1/pkg/throttle/throttle_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package throttle
    18  
    19  import (
    20  	"context"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/sirupsen/logrus"
    26  )
    27  
    28  func TestThrottle(t *testing.T) {
    29  	logrus.SetLevel(logrus.DebugLevel)
    30  	t.Parallel()
    31  	testCases := []struct {
    32  		name string
    33  
    34  		setup []struct {
    35  			hourly int
    36  			burst  int
    37  			org    string
    38  		}
    39  		expectThrottling bool
    40  	}{
    41  		{
    42  			name: "global throttler",
    43  			setup: []struct {
    44  				hourly int
    45  				burst  int
    46  				org    string
    47  			}{
    48  				{1, 2, ""},
    49  			},
    50  			expectThrottling: true,
    51  		},
    52  		{
    53  			name: "our org is throttled",
    54  
    55  			setup: []struct {
    56  				hourly int
    57  				burst  int
    58  				org    string
    59  			}{
    60  				{1, 2, "org"},
    61  			},
    62  			expectThrottling: true,
    63  		},
    64  		{
    65  			name: "different org is throttled, ours is not",
    66  
    67  			setup: []struct {
    68  				hourly int
    69  				burst  int
    70  				org    string
    71  			}{
    72  				{1, 2, "something-else"},
    73  			},
    74  		},
    75  		{
    76  			name: "global throttler and throttler for our org",
    77  
    78  			setup: []struct {
    79  				hourly int
    80  				burst  int
    81  				org    string
    82  			}{
    83  				{100, 100, ""},
    84  				{1, 2, "org"},
    85  			},
    86  			expectThrottling: true,
    87  		},
    88  	}
    89  
    90  	for _, tc := range testCases {
    91  		tc := tc
    92  		t.Run(tc.name, func(t *testing.T) {
    93  			t.Parallel()
    94  			throttler := &Throttler{}
    95  			throttlerKey := throttlerGlobalKey
    96  			for _, setup := range tc.setup {
    97  				if setup.org != "" {
    98  					throttler.Throttle(setup.hourly, setup.burst, setup.org)
    99  					if setup.org == "org" {
   100  						throttlerKey = "org"
   101  					}
   102  				} else {
   103  					throttler.Throttle(setup.hourly, setup.burst)
   104  				}
   105  			}
   106  
   107  			var expectItems int
   108  			if tc.expectThrottling {
   109  				expectItems = 2
   110  			}
   111  			if n := len(throttler.throttle[throttlerKey]); n != expectItems {
   112  				t.Fatalf("Expected %d items in throttle channel, found %d", expectItems, n)
   113  			}
   114  			if n := cap(throttler.throttle[throttlerKey]); n != expectItems {
   115  				t.Fatalf("Expected throttle channel capacity of %d, found %d", expectItems, n)
   116  			}
   117  			check := func(err error) {
   118  				t.Helper()
   119  				if err != nil {
   120  					t.Errorf("Unexpected error: %v", err)
   121  				}
   122  				if tc.expectThrottling {
   123  					if len(throttler.throttle[throttlerKey]) != 1 {
   124  						t.Errorf("Expected one item in throttle channel, found %d", len(throttler.throttle[throttlerKey]))
   125  					}
   126  				} else if _, throttleChannelExists := throttler.throttle[throttlerKey]; throttleChannelExists {
   127  					t.Error("didn't expect throttling, but throttler existed")
   128  				}
   129  			}
   130  			err := throttler.Wait(context.Background(), "org")
   131  			check(err)
   132  			// The following two waits should be properly refunded.
   133  			err = throttler.Wait(context.Background(), "org")
   134  			throttler.Refund("org")
   135  			check(err)
   136  			err = throttler.Wait(context.Background(), "org")
   137  			throttler.Refund("org")
   138  			check(err)
   139  
   140  			// Check that calls are delayed while throttled.
   141  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   142  			go func() {
   143  				if err := throttler.Wait(context.Background(), "org"); err != nil {
   144  					t.Errorf("Unexpected error: %v", err)
   145  				}
   146  				if err := throttler.Wait(context.Background(), "org"); err != nil {
   147  					t.Errorf("Unexpected error: %v", err)
   148  				}
   149  				cancel()
   150  			}()
   151  			slowed := false
   152  			for ctx.Err() == nil {
   153  				// Wait for the client to get throttled
   154  				val := throttler.slow[throttlerKey]
   155  				if val == nil || atomic.LoadInt32(val) == 0 {
   156  					continue
   157  				}
   158  				// Throttled, now add to the channel
   159  				slowed = true
   160  				select {
   161  				case throttler.throttle[throttlerKey] <- time.Now(): // Add items to the channel
   162  				case <-ctx.Done():
   163  				}
   164  			}
   165  			if slowed != tc.expectThrottling {
   166  				t.Errorf("expected throttling: %t, got throttled: %t", tc.expectThrottling, slowed)
   167  			}
   168  			if err := ctx.Err(); err != context.Canceled {
   169  				t.Errorf("Expected context cancellation did not happen: %v", err)
   170  			}
   171  		})
   172  	}
   173  }