github.com/uber/kraken@v0.1.4/lib/middleware/middleware_test.go (about) 1 // Copyright (c) 2016-2019 Uber Technologies, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 package middleware 15 16 import ( 17 "fmt" 18 "io" 19 "net/http" 20 "testing" 21 "time" 22 23 "github.com/uber/kraken/utils/httputil" 24 "github.com/uber/kraken/utils/testutil" 25 26 "github.com/pressly/chi" 27 "github.com/stretchr/testify/require" 28 "github.com/uber-go/tally" 29 ) 30 31 func TestScopeByEndpoint(t *testing.T) { 32 tests := []struct { 33 method string 34 path string 35 reqPath string 36 expectedEndpoint string 37 }{ 38 {"GET", "/foo/{foo}/bar/{bar}", "/foo/x/bar/y", "foo.bar"}, 39 {"POST", "/foo/{foo}/bar/{bar}", "/foo/x/bar/y", "foo.bar"}, 40 {"GET", "/a/b/c", "/a/b/c", "a.b.c"}, 41 {"GET", "/", "/", ""}, 42 {"GET", "/x/{a}/{b}/{c}", "/x/a/b/c", "x"}, 43 } 44 45 for _, test := range tests { 46 t.Run(test.method+" "+test.path, func(t *testing.T) { 47 require := require.New(t) 48 49 stats := tally.NewTestScope("", nil) 50 51 r := chi.NewRouter() 52 r.HandleFunc(test.path, func(w http.ResponseWriter, r *http.Request) { 53 tagEndpoint(stats, r).Counter("count").Inc(1) 54 }) 55 addr, stop := testutil.StartServer(r) 56 defer stop() 57 58 _, err := httputil.Send(test.method, fmt.Sprintf("http://%s%s", addr, test.reqPath)) 59 require.NoError(err) 60 61 require.Equal(1, len(stats.Snapshot().Counters())) 62 for _, v := range stats.Snapshot().Counters() { 63 require.Equal("count", v.Name()) 64 require.Equal(int64(1), v.Value()) 65 require.Equal(map[string]string{ 66 "endpoint": test.expectedEndpoint, 67 "method": test.method, 68 }, v.Tags()) 69 } 70 }) 71 } 72 } 73 74 func TestLatencyTimer(t *testing.T) { 75 require := require.New(t) 76 77 stats := tally.NewTestScope("", nil) 78 79 r := chi.NewRouter() 80 r.Use(LatencyTimer(stats)) 81 r.Get("/foo/{foo}", func(w http.ResponseWriter, r *http.Request) { 82 time.Sleep(200 * time.Millisecond) 83 }) 84 85 addr, stop := testutil.StartServer(r) 86 defer stop() 87 88 _, err := httputil.Get(fmt.Sprintf("http://%s/foo/x", addr)) 89 require.NoError(err) 90 91 now := time.Now() 92 93 require.Equal(1, len(stats.Snapshot().Timers())) 94 for _, v := range stats.Snapshot().Timers() { 95 require.Equal("latency", v.Name()) 96 require.WithinDuration(now, now.Add(v.Values()[0]), 500*time.Millisecond) 97 require.Equal(map[string]string{ 98 "endpoint": "foo", 99 "method": "GET", 100 }, v.Tags()) 101 } 102 } 103 104 func TestStatusCounter(t *testing.T) { 105 tests := []struct { 106 desc string 107 handler func(http.ResponseWriter, *http.Request) 108 expectedStatus string 109 }{ 110 { 111 "empty handler counts 200", 112 func(http.ResponseWriter, *http.Request) {}, 113 "200", 114 }, { 115 "writes count 200", 116 func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "OK") }, 117 "200", 118 }, { 119 "write header", 120 func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) }, 121 "500", 122 }, { 123 "multiple write header calls only measures first call", 124 func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(400); w.WriteHeader(500) }, 125 "400", 126 }, 127 } 128 for _, test := range tests { 129 t.Run(test.desc, func(t *testing.T) { 130 require := require.New(t) 131 132 stats := tally.NewTestScope("", nil) 133 134 r := chi.NewRouter() 135 r.Use(StatusCounter(stats)) 136 r.Get("/foo/{foo}", test.handler) 137 138 addr, stop := testutil.StartServer(r) 139 defer stop() 140 141 for i := 0; i < 5; i++ { 142 _, err := http.Get(fmt.Sprintf("http://%s/foo/x", addr)) 143 require.NoError(err) 144 } 145 146 require.Equal(1, len(stats.Snapshot().Counters())) 147 for _, v := range stats.Snapshot().Counters() { 148 require.Equal(test.expectedStatus, v.Name()) 149 require.Equal(int64(5), v.Value()) 150 require.Equal(map[string]string{ 151 "endpoint": "foo", 152 "method": "GET", 153 }, v.Tags()) 154 } 155 }) 156 } 157 }