github.com/arthur-befumo/witchcraft-go-server@v1.12.0/integration/ratelimit_test.go (about)

     1  // Copyright (c) 2019 Palantir Technologies. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package integration
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"net/http"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/palantir/pkg/httpserver"
    26  	werror "github.com/palantir/witchcraft-go-error"
    27  	"github.com/palantir/witchcraft-go-server/conjure/witchcraft/api/health"
    28  	"github.com/palantir/witchcraft-go-server/status/reporter"
    29  	"github.com/palantir/witchcraft-go-server/witchcraft"
    30  	"github.com/palantir/witchcraft-go-server/witchcraft/ratelimit"
    31  	"github.com/palantir/witchcraft-go-server/witchcraft/refreshable"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  func TestNewInflightLimitMiddleware(t *testing.T) {
    37  	healthReporter := reporter.NewHealthReporter()
    38  	healthComponent, err := healthReporter.InitializeHealthComponent("INFLIGHT_MUTATING_REQUESTS")
    39  	require.NoError(t, err)
    40  	requireHealthy := func(msg string) {
    41  		require.Equal(t, string(health.HealthStateHealthy), string(healthComponent.Status()), msg)
    42  	}
    43  	requireRepairing := func(msg string) {
    44  		require.Equal(t, string(health.HealthStateRepairing), string(healthComponent.Status()), msg)
    45  	}
    46  
    47  	limiter := ratelimit.NewInFlightRequestLimitMiddleware(refreshable.NewInt(refreshable.NewDefaultRefreshable(2)), ratelimit.MatchMutating, healthComponent)
    48  
    49  	wait, closeWait := context.WithCancel(context.Background())
    50  	defer closeWait()
    51  
    52  	const totalPostRequests = 4
    53  	reqChan := make(chan struct{}, totalPostRequests)
    54  	initFn := func(ctx context.Context, info witchcraft.InitInfo) (cleanup func(), rErr error) {
    55  		info.Router.RootRouter().AddRouteHandlerMiddleware(limiter)
    56  		if err := info.Router.Get("/get", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    57  			rw.WriteHeader(http.StatusOK)
    58  		})); err != nil {
    59  			return nil, err
    60  		}
    61  		if err := info.Router.Post("/post", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    62  			reqChan <- struct{}{}
    63  			<-wait.Done()
    64  			rw.WriteHeader(http.StatusOK)
    65  		})); err != nil {
    66  			return nil, err
    67  		}
    68  
    69  		return nil, nil
    70  	}
    71  
    72  	port, err := httpserver.AvailablePort()
    73  	require.NoError(t, err)
    74  	server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, initFn, ioutil.Discard, createTestServer)
    75  	defer func() {
    76  		require.NoError(t, server.Close())
    77  	}()
    78  	defer cleanup()
    79  
    80  	client := testServerClient()
    81  
    82  	const testTimeout = time.Minute
    83  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
    84  	defer cancel()
    85  
    86  	doGet := func() *http.Response {
    87  		getURL := fmt.Sprintf("https://localhost:%d/%s/get", port, basePath)
    88  		request, err := http.NewRequest(http.MethodGet, getURL, nil)
    89  		if err != nil {
    90  			assert.NoError(t, err)
    91  		}
    92  		resp, err := client.Do(request.WithContext(ctx))
    93  		if err != nil {
    94  			assert.NoError(t, err)
    95  		}
    96  		return resp
    97  	}
    98  	doPost := func() *http.Response {
    99  		postURL := fmt.Sprintf("https://localhost:%d/%s/post", port, basePath)
   100  		request, err := http.NewRequest(http.MethodPost, postURL, nil)
   101  		if err != nil {
   102  			assert.NoError(t, err)
   103  		}
   104  		resp, err := client.Do(request.WithContext(ctx))
   105  		if err != nil {
   106  			assert.NoError(t, err)
   107  		}
   108  		return resp
   109  	}
   110  
   111  	// Fill up rate limit
   112  	resp1c := make(chan *http.Response)
   113  	requireHealthy("expected healthy before the first request")
   114  	go func() { resp1c <- doPost() }()
   115  	resp2c := make(chan *http.Response)
   116  	requireHealthy("expected healthy before the second request")
   117  	go func() { resp2c <- doPost() }()
   118  
   119  	// wait until both of the above requests have made it through
   120  	err = waitForRequests(reqChan, 2, testTimeout)
   121  	require.NoError(t, err)
   122  	requireHealthy("expected healthy before the third request")
   123  
   124  	resp3 := doPost()
   125  	require.Equal(t, http.StatusTooManyRequests, resp3.StatusCode, "expected third POST request to be rate limited")
   126  	requireRepairing("expected repairing after throttled response")
   127  
   128  	require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful")
   129  	require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful")
   130  
   131  	requireRepairing("expected unchanged health after unmatched response")
   132  
   133  	// free the waiting requests, which should return 200
   134  	closeWait()
   135  	resp1 := <-resp1c
   136  	assert.Equal(t, http.StatusOK, resp1.StatusCode, "expected blocked request 1 to return 200 when unblocked")
   137  	resp2 := <-resp2c
   138  	assert.Equal(t, http.StatusOK, resp2.StatusCode, "expected blocked request 2 to return 200 when unblocked")
   139  
   140  	// we should now be unblocked and healthy
   141  	require.Equal(t, http.StatusOK, doPost().StatusCode, "expected fourth POST request to be unblocked")
   142  	requireHealthy("expected healthy after accepted request")
   143  
   144  	select {
   145  	case err := <-serverErr:
   146  		require.NoError(t, err)
   147  	default:
   148  	}
   149  }
   150  
   151  func waitForRequests(reqChan <-chan struct{}, expected int, timeout time.Duration) error {
   152  	t := time.After(timeout)
   153  	var current int
   154  	for {
   155  		select {
   156  		case <-reqChan:
   157  			current++
   158  			if current == expected {
   159  				return nil
   160  			}
   161  		case <-t:
   162  			return werror.Error("timed out waiting for expected number of requests",
   163  				werror.SafeParam("current", current),
   164  				werror.SafeParam("expected", expected))
   165  		}
   166  	}
   167  }