github.com/palantir/witchcraft-go-server/v2@v2.76.0/integration/ratelimit_test.go (about) 1 // Copyright (c) 2019 Palantir Technologies. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package integration 16 17 import ( 18 "context" 19 "fmt" 20 "io/ioutil" 21 "net/http" 22 "testing" 23 "time" 24 25 "github.com/palantir/pkg/httpserver" 26 "github.com/palantir/pkg/refreshable" 27 werror "github.com/palantir/witchcraft-go-error" 28 "github.com/palantir/witchcraft-go-health/conjure/witchcraft/api/health" 29 "github.com/palantir/witchcraft-go-health/reporter" 30 "github.com/palantir/witchcraft-go-server/v2/witchcraft" 31 "github.com/palantir/witchcraft-go-server/v2/witchcraft/ratelimit" 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 ) 35 36 func TestNewInflightLimitMiddleware(t *testing.T) { 37 healthReporter := reporter.NewHealthReporter() 38 healthComponent, err := healthReporter.InitializeHealthComponent("INFLIGHT_MUTATING_REQUESTS") 39 require.NoError(t, err) 40 requireHealthy := func(msg string) { 41 require.Equal(t, health.HealthState_HEALTHY, healthComponent.Status(), msg) 42 } 43 requireRepairing := func(msg string) { 44 require.Equal(t, health.HealthState_REPAIRING, healthComponent.Status(), msg) 45 } 46 47 limiter := ratelimit.NewInFlightRequestLimitMiddleware(refreshable.NewInt(refreshable.NewDefaultRefreshable(2)), ratelimit.MatchMutating, healthComponent) 48 49 wait, closeWait := context.WithCancel(context.Background()) 50 defer closeWait() 51 52 const totalPostRequests = 4 53 reqChan := make(chan struct{}, totalPostRequests) 54 initFn := func(ctx context.Context, info witchcraft.InitInfo) (cleanup func(), rErr error) { 55 info.Router.RootRouter().AddRouteHandlerMiddleware(limiter) 56 if err := info.Router.Get("/get", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 57 rw.WriteHeader(http.StatusOK) 58 })); err != nil { 59 return nil, err 60 } 61 if err := info.Router.Post("/post", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 62 reqChan <- struct{}{} 63 <-wait.Done() 64 rw.WriteHeader(http.StatusOK) 65 })); err != nil { 66 return nil, err 67 } 68 69 return nil, nil 70 } 71 72 port, err := httpserver.AvailablePort() 73 require.NoError(t, err) 74 server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, initFn, ioutil.Discard, createTestServer) 75 defer func() { 76 require.NoError(t, server.Close()) 77 }() 78 defer cleanup() 79 80 client := testServerClient() 81 82 const testTimeout = time.Minute 83 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 84 defer cancel() 85 86 doGet := func() *http.Response { 87 getURL := fmt.Sprintf("https://localhost:%d/%s/get", port, basePath) 88 request, err := http.NewRequest(http.MethodGet, getURL, nil) 89 if err != nil { 90 assert.NoError(t, err) 91 } 92 resp, err := client.Do(request.WithContext(ctx)) 93 if err != nil { 94 assert.NoError(t, err) 95 } 96 return resp 97 } 98 doPost := func() *http.Response { 99 postURL := fmt.Sprintf("https://localhost:%d/%s/post", port, basePath) 100 request, err := http.NewRequest(http.MethodPost, postURL, nil) 101 if err != nil { 102 assert.NoError(t, err) 103 } 104 resp, err := client.Do(request.WithContext(ctx)) 105 if err != nil { 106 assert.NoError(t, err) 107 } 108 return resp 109 } 110 111 // Fill up rate limit 112 resp1c := make(chan *http.Response) 113 requireHealthy("expected healthy before the first request") 114 go func() { resp1c <- doPost() }() 115 resp2c := make(chan *http.Response) 116 requireHealthy("expected healthy before the second request") 117 go func() { resp2c <- doPost() }() 118 119 // wait until both of the above requests have made it through 120 err = waitForRequests(reqChan, 2, testTimeout) 121 require.NoError(t, err) 122 requireHealthy("expected healthy before the third request") 123 124 resp3 := doPost() 125 require.Equal(t, http.StatusTooManyRequests, resp3.StatusCode, "expected third POST request to be rate limited") 126 requireRepairing("expected repairing after throttled response") 127 128 require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful") 129 require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful") 130 131 requireRepairing("expected unchanged health after unmatched response") 132 133 // free the waiting requests, which should return 200 134 closeWait() 135 resp1 := <-resp1c 136 assert.Equal(t, http.StatusOK, resp1.StatusCode, "expected blocked request 1 to return 200 when unblocked") 137 resp2 := <-resp2c 138 assert.Equal(t, http.StatusOK, resp2.StatusCode, "expected blocked request 2 to return 200 when unblocked") 139 140 // we should now be unblocked and healthy 141 require.Equal(t, http.StatusOK, doPost().StatusCode, "expected fourth POST request to be unblocked") 142 requireHealthy("expected healthy after accepted request") 143 144 select { 145 case err := <-serverErr: 146 require.NoError(t, err) 147 default: 148 } 149 } 150 151 func waitForRequests(reqChan <-chan struct{}, expected int, timeout time.Duration) error { 152 t := time.After(timeout) 153 var current int 154 for { 155 select { 156 case <-reqChan: 157 current++ 158 if current == expected { 159 return nil 160 } 161 case <-t: 162 return werror.Error("timed out waiting for expected number of requests", 163 werror.SafeParam("current", current), 164 werror.SafeParam("expected", expected)) 165 } 166 } 167 }