github.com/uber/kraken@v0.1.4/lib/middleware/middleware_test.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package middleware
    15  
    16  import (
    17  	"fmt"
    18  	"io"
    19  	"net/http"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/uber/kraken/utils/httputil"
    24  	"github.com/uber/kraken/utils/testutil"
    25  
    26  	"github.com/pressly/chi"
    27  	"github.com/stretchr/testify/require"
    28  	"github.com/uber-go/tally"
    29  )
    30  
    31  func TestScopeByEndpoint(t *testing.T) {
    32  	tests := []struct {
    33  		method           string
    34  		path             string
    35  		reqPath          string
    36  		expectedEndpoint string
    37  	}{
    38  		{"GET", "/foo/{foo}/bar/{bar}", "/foo/x/bar/y", "foo.bar"},
    39  		{"POST", "/foo/{foo}/bar/{bar}", "/foo/x/bar/y", "foo.bar"},
    40  		{"GET", "/a/b/c", "/a/b/c", "a.b.c"},
    41  		{"GET", "/", "/", ""},
    42  		{"GET", "/x/{a}/{b}/{c}", "/x/a/b/c", "x"},
    43  	}
    44  
    45  	for _, test := range tests {
    46  		t.Run(test.method+" "+test.path, func(t *testing.T) {
    47  			require := require.New(t)
    48  
    49  			stats := tally.NewTestScope("", nil)
    50  
    51  			r := chi.NewRouter()
    52  			r.HandleFunc(test.path, func(w http.ResponseWriter, r *http.Request) {
    53  				tagEndpoint(stats, r).Counter("count").Inc(1)
    54  			})
    55  			addr, stop := testutil.StartServer(r)
    56  			defer stop()
    57  
    58  			_, err := httputil.Send(test.method, fmt.Sprintf("http://%s%s", addr, test.reqPath))
    59  			require.NoError(err)
    60  
    61  			require.Equal(1, len(stats.Snapshot().Counters()))
    62  			for _, v := range stats.Snapshot().Counters() {
    63  				require.Equal("count", v.Name())
    64  				require.Equal(int64(1), v.Value())
    65  				require.Equal(map[string]string{
    66  					"endpoint": test.expectedEndpoint,
    67  					"method":   test.method,
    68  				}, v.Tags())
    69  			}
    70  		})
    71  	}
    72  }
    73  
    74  func TestLatencyTimer(t *testing.T) {
    75  	require := require.New(t)
    76  
    77  	stats := tally.NewTestScope("", nil)
    78  
    79  	r := chi.NewRouter()
    80  	r.Use(LatencyTimer(stats))
    81  	r.Get("/foo/{foo}", func(w http.ResponseWriter, r *http.Request) {
    82  		time.Sleep(200 * time.Millisecond)
    83  	})
    84  
    85  	addr, stop := testutil.StartServer(r)
    86  	defer stop()
    87  
    88  	_, err := httputil.Get(fmt.Sprintf("http://%s/foo/x", addr))
    89  	require.NoError(err)
    90  
    91  	now := time.Now()
    92  
    93  	require.Equal(1, len(stats.Snapshot().Timers()))
    94  	for _, v := range stats.Snapshot().Timers() {
    95  		require.Equal("latency", v.Name())
    96  		require.WithinDuration(now, now.Add(v.Values()[0]), 500*time.Millisecond)
    97  		require.Equal(map[string]string{
    98  			"endpoint": "foo",
    99  			"method":   "GET",
   100  		}, v.Tags())
   101  	}
   102  }
   103  
   104  func TestStatusCounter(t *testing.T) {
   105  	tests := []struct {
   106  		desc           string
   107  		handler        func(http.ResponseWriter, *http.Request)
   108  		expectedStatus string
   109  	}{
   110  		{
   111  			"empty handler counts 200",
   112  			func(http.ResponseWriter, *http.Request) {},
   113  			"200",
   114  		}, {
   115  			"writes count 200",
   116  			func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "OK") },
   117  			"200",
   118  		}, {
   119  			"write header",
   120  			func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) },
   121  			"500",
   122  		}, {
   123  			"multiple write header calls only measures first call",
   124  			func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(400); w.WriteHeader(500) },
   125  			"400",
   126  		},
   127  	}
   128  	for _, test := range tests {
   129  		t.Run(test.desc, func(t *testing.T) {
   130  			require := require.New(t)
   131  
   132  			stats := tally.NewTestScope("", nil)
   133  
   134  			r := chi.NewRouter()
   135  			r.Use(StatusCounter(stats))
   136  			r.Get("/foo/{foo}", test.handler)
   137  
   138  			addr, stop := testutil.StartServer(r)
   139  			defer stop()
   140  
   141  			for i := 0; i < 5; i++ {
   142  				_, err := http.Get(fmt.Sprintf("http://%s/foo/x", addr))
   143  				require.NoError(err)
   144  			}
   145  
   146  			require.Equal(1, len(stats.Snapshot().Counters()))
   147  			for _, v := range stats.Snapshot().Counters() {
   148  				require.Equal(test.expectedStatus, v.Name())
   149  				require.Equal(int64(5), v.Value())
   150  				require.Equal(map[string]string{
   151  					"endpoint": "foo",
   152  					"method":   "GET",
   153  				}, v.Tags())
   154  			}
   155  		})
   156  	}
   157  }