github.com/uber/kraken@v0.1.4/lib/middleware/middleware.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package middleware
    15  
    16  import (
    17  	"net/http"
    18  	"strconv"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/pressly/chi"
    23  	"github.com/uber-go/tally"
    24  )
    25  
    26  // tagEndpoint tags stats by endpoint path and method, ignoring any path variables.
    27  // For example, "/foo/{foo}/bar/{bar}" is tagged with endpoint "foo.bar"
    28  //
    29  // Note: tagEndpoint should always be called AFTER the "next" handler serves,
    30  // such that chi can populate proper route context with the path.
    31  //
    32  // Wrong:
    33  //
    34  //     tagEndpoint(stats, r).Counter("n").Inc(1)
    35  //     next.ServeHTTP(w, r)
    36  //
    37  // Right:
    38  //
    39  //     next.ServeHTTP(w, r)
    40  //     tagEndpoint(stats, r).Counter("n").Inc(1)
    41  //
    42  func tagEndpoint(stats tally.Scope, r *http.Request) tally.Scope {
    43  	ctx := chi.RouteContext(r.Context())
    44  	var staticParts []string
    45  	for _, part := range strings.Split(ctx.RoutePattern(), "/") {
    46  		if len(part) == 0 || isPathVariable(part) {
    47  			continue
    48  		}
    49  		staticParts = append(staticParts, part)
    50  	}
    51  	return stats.Tagged(map[string]string{
    52  		"endpoint": strings.Join(staticParts, "."),
    53  		"method":   strings.ToUpper(r.Method),
    54  	})
    55  }
    56  
    57  // isPathVariable returns true if s is a path variable, e.g. "{foo}".
    58  func isPathVariable(s string) bool {
    59  	return len(s) >= 2 && s[0] == '{' && s[len(s)-1] == '}'
    60  }
    61  
    62  // LatencyTimer measures endpoint latencies.
    63  func LatencyTimer(stats tally.Scope) func(next http.Handler) http.Handler {
    64  	return func(next http.Handler) http.Handler {
    65  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    66  			start := time.Now()
    67  			next.ServeHTTP(w, r)
    68  			tagEndpoint(stats, r).Timer("latency").Record(time.Since(start))
    69  		})
    70  	}
    71  }
    72  
    73  type recordStatusWriter struct {
    74  	http.ResponseWriter
    75  	wroteHeader bool
    76  	code        int
    77  }
    78  
    79  func (w *recordStatusWriter) WriteHeader(code int) {
    80  	if !w.wroteHeader {
    81  		w.code = code
    82  		w.wroteHeader = true
    83  		w.ResponseWriter.WriteHeader(code)
    84  	}
    85  }
    86  
    87  func (w *recordStatusWriter) Write(b []byte) (int, error) {
    88  	w.WriteHeader(http.StatusOK)
    89  	return w.ResponseWriter.Write(b)
    90  }
    91  
    92  // StatusCounter measures endpoint status count.
    93  func StatusCounter(stats tally.Scope) func(next http.Handler) http.Handler {
    94  	return func(next http.Handler) http.Handler {
    95  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    96  			recordw := &recordStatusWriter{w, false, http.StatusOK}
    97  			next.ServeHTTP(recordw, r)
    98  			tagEndpoint(stats, r).Counter(strconv.Itoa(recordw.code)).Inc(1)
    99  		})
   100  	}
   101  }