github.com/splucs/witchcraft-go-server@v1.7.0/integration/ratelimit_test.go (about) 1 // Copyright (c) 2019 Palantir Technologies. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package integration 16 17 import ( 18 "context" 19 "fmt" 20 "io/ioutil" 21 "net/http" 22 "testing" 23 "time" 24 25 "github.com/palantir/pkg/httpserver" 26 "github.com/palantir/witchcraft-go-server/conjure/witchcraft/api/health" 27 "github.com/palantir/witchcraft-go-server/status/reporter" 28 "github.com/palantir/witchcraft-go-server/witchcraft" 29 "github.com/palantir/witchcraft-go-server/witchcraft/ratelimit" 30 "github.com/palantir/witchcraft-go-server/witchcraft/refreshable" 31 "github.com/stretchr/testify/assert" 32 "github.com/stretchr/testify/require" 33 ) 34 35 func TestNewInflightLimitMiddleware(t *testing.T) { 36 healthReporter := reporter.NewHealthReporter() 37 healthComponent, err := healthReporter.InitializeHealthComponent("INFLIGHT_MUTATING_REQUESTS") 38 require.NoError(t, err) 39 requireHealthy := func(msg string) { 40 require.Equal(t, string(health.HealthStateHealthy), string(healthComponent.Status()), msg) 41 } 42 requireRepairing := func(msg string) { 43 require.Equal(t, string(health.HealthStateRepairing), string(healthComponent.Status()), msg) 44 } 45 46 limiter := ratelimit.NewInFlightRequestLimitMiddleware(refreshable.NewInt(refreshable.NewDefaultRefreshable(2)), ratelimit.MatchMutating, healthComponent) 47 48 wait, closeWait := context.WithCancel(context.Background()) 49 defer closeWait() 50 51 initFn := func(ctx context.Context, info witchcraft.InitInfo) (cleanup func(), rErr error) { 52 info.Router.RootRouter().AddRouteHandlerMiddleware(limiter) 53 if err := info.Router.Get("/get", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 54 rw.WriteHeader(http.StatusOK) 55 })); err != nil { 56 return nil, err 57 } 58 if err := info.Router.Post("/post", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 59 <-wait.Done() 60 rw.WriteHeader(http.StatusOK) 61 })); err != nil { 62 return nil, err 63 } 64 65 return nil, nil 66 } 67 68 port, err := httpserver.AvailablePort() 69 require.NoError(t, err) 70 server, serverErr, cleanup := createAndRunCustomTestServer(t, port, port, initFn, ioutil.Discard, createTestServer) 71 defer func() { 72 require.NoError(t, server.Close()) 73 }() 74 defer cleanup() 75 76 client := testServerClient() 77 78 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 79 defer cancel() 80 81 doGet := func() *http.Response { 82 getURL := fmt.Sprintf("https://localhost:%d/%s/get", port, basePath) 83 request, err := http.NewRequest(http.MethodGet, getURL, nil) 84 if err != nil { 85 assert.NoError(t, err) 86 } 87 resp, err := client.Do(request.WithContext(ctx)) 88 if err != nil { 89 assert.NoError(t, err) 90 } 91 return resp 92 } 93 doPost := func() *http.Response { 94 postURL := fmt.Sprintf("https://localhost:%d/%s/post", port, basePath) 95 request, err := http.NewRequest(http.MethodPost, postURL, nil) 96 if err != nil { 97 assert.NoError(t, err) 98 } 99 resp, err := client.Do(request.WithContext(ctx)) 100 if err != nil { 101 assert.NoError(t, err) 102 } 103 return resp 104 } 105 106 // Fill up rate limit 107 resp1c := make(chan *http.Response) 108 requireHealthy("expected healthy before the first request") 109 go func() { resp1c <- doPost() }() 110 resp2c := make(chan *http.Response) 111 requireHealthy("expected healthy before the second request") 112 go func() { resp2c <- doPost() }() 113 114 time.Sleep(time.Millisecond) // let things settle 115 requireHealthy("expected healthy before the third request") 116 117 resp3 := doPost() 118 require.Equal(t, http.StatusTooManyRequests, resp3.StatusCode, "expected third POST request to be rate limited") 119 requireRepairing("expected repairing after throttled response") 120 121 require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful") 122 require.Equal(t, http.StatusOK, doGet().StatusCode, "expected get request to be successful") 123 124 requireRepairing("expected unchanged health after unmatched response") 125 126 // free the waiting requests, which should return 200 127 closeWait() 128 resp1 := <-resp1c 129 assert.Equal(t, http.StatusOK, resp1.StatusCode, "expected blocked request 1 to return 200 when unblocked") 130 resp2 := <-resp2c 131 assert.Equal(t, http.StatusOK, resp2.StatusCode, "expected blocked request 2 to return 200 when unblocked") 132 133 // we should now be unblocked and healthy 134 require.Equal(t, http.StatusOK, doPost().StatusCode, "expected fourth POST request to be unblocked") 135 requireHealthy("expected healthy after accepted request") 136 137 select { 138 case err := <-serverErr: 139 require.NoError(t, err) 140 default: 141 } 142 }