github.com/splucs/witchcraft-go-server@v1.7.0/integration/ratelimit_test.go (about)

     1  // Copyright (c) 2019 Palantir Technologies. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package integration
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"net/http"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/palantir/pkg/httpserver"
    26  	"github.com/palantir/witchcraft-go-server/conjure/witchcraft/api/health"
    27  	"github.com/palantir/witchcraft-go-server/status/reporter"
    28  	"github.com/palantir/witchcraft-go-server/witchcraft"
    29  	"github.com/palantir/witchcraft-go-server/witchcraft/ratelimit"
    30  	"github.com/palantir/witchcraft-go-server/witchcraft/refreshable"
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  )
    34  
    35  func TestNewInflightLimitMiddleware(t *testing.T) {
    36  	healthReporter := reporter.NewHealthReporter()
    37  	healthComponent, err := healthReporter.InitializeHealthComponent("INFLIGHT_MUTATING_REQUESTS")
    38  	require.NoError(t, err)
    39  	requireHealthy := func(msg string) {
    40  		require.Equal(t, string(health.HealthStateHealthy), string(healthComponent.Status()), msg)
    41  	}
    42  	requireRepairing := func(msg string) {
    43  		require.Equal(t, string(health.HealthStateRepairing), string(healthComponent.Status()), msg)
    44  	}
    45  
    46  	limiter := ratelimit.NewInFlightRequestLimitMiddleware(refreshable.NewInt(refreshable.NewDefaultRefreshable(2)), ratelimit.MatchMutating, healthComponent)
    47  
    48  	wait, closeWait := context.WithCancel(context.Background())
    49  	defer closeWait()
    50  
    51  	initFn := func(ctx context.Context, info witchcraft.InitInfo) (cleanup func(), rErr error) {
    52  		info.Router.RootRouter().AddRouteHandlerMiddleware(limiter)
    53  		if err := info.Router.Get("/get", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    54  			rw.WriteHeader(http.StatusOK)
    55  		})); err != nil {
    56  			return nil, err
    57  		}
    58  		if err := info.Router.Post("/post", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
    59  			<-wait.Done()
    60  			rw.WriteHeader(http.StatusOK)
    61  		})); err != nil {
    62  			return nil, err
    63  		}
    64  
    65  		return nil, nil
    66  	}
    67  
    68  	port, err := httpserver.AvailablePort()
    69  	require.NoError(t, err)
    70  	server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, initFn, ioutil.Discard, createTestServer)
    71  	defer func() {
    72  		require.NoError(t, server.Close())
    73  	}()
    74  	defer cleanup()
    75  
    76  	client := testServerClient()
    77  
    78  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
    79  	defer cancel()
    80  
    81  	doGet := func() *http.Response {
    82  		getURL := fmt.Sprintf("https://localhost:%d/%s/get", port, basePath)
    83  		request, err := http.NewRequest(http.MethodGet, getURL, nil)
    84  		if err != nil {
    85  			assert.NoError(t, err)
    86  		}
    87  		resp, err := client.Do(request.WithContext(ctx))
    88  		if err != nil {
    89  			assert.NoError(t, err)
    90  		}
    91  		return resp
    92  	}
    93  	doPost := func() *http.Response {
    94  		postURL := fmt.Sprintf("https://localhost:%d/%s/post", port, basePath)
    95  		request, err := http.NewRequest(http.MethodPost, postURL, nil)
    96  		if err != nil {
    97  			assert.NoError(t, err)
    98  		}
    99  		resp, err := client.Do(request.WithContext(ctx))
   100  		if err != nil {
   101  			assert.NoError(t, err)
   102  		}
   103  		return resp
   104  	}
   105  
   106  	// Fill up rate limit
   107  	resp1c := make(chan *http.Response)
   108  	requireHealthy("expected healthy before the first request")
   109  	go func() { resp1c <- doPost() }()
   110  	resp2c := make(chan *http.Response)
   111  	requireHealthy("expected healthy before the second request")
   112  	go func() { resp2c <- doPost() }()
   113  
   114  	time.Sleep(time.Millisecond) // let things settle
   115  	requireHealthy("expected healthy before the third request")
   116  
   117  	resp3 := doPost()
   118  	require.Equal(t, http.StatusTooManyRequests, resp3.StatusCode, "expected third POST request to be rate limited")
   119  	requireRepairing("expected repairing after throttled response")
   120  
   121  	require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful")
   122  	require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful")
   123  
   124  	requireRepairing("expected unchanged health after unmatched response")
   125  
   126  	// free the waiting requests, which should return 200
   127  	closeWait()
   128  	resp1 := <-resp1c
   129  	assert.Equal(t, http.StatusOK, resp1.StatusCode, "expected blocked request 1 to return 200 when unblocked")
   130  	resp2 := <-resp2c
   131  	assert.Equal(t, http.StatusOK, resp2.StatusCode, "expected blocked request 2 to return 200 when unblocked")
   132  
   133  	// we should now be unblocked and healthy
   134  	require.Equal(t, http.StatusOK, doPost().StatusCode, "expected fourth POST request to be unblocked")
   135  	requireHealthy("expected healthy after accepted request")
   136  
   137  	select {
   138  	case err := <-serverErr:
   139  		require.NoError(t, err)
   140  	default:
   141  	}
   142  }