go.undefinedlabs.com/scopeagent@v0.4.2/instrumentation/nethttp/server.go (about)

     1  package nethttp
     2  
     3  import (
     4  	"bytes"
     5  	"net"
     6  	"net/http"
     7  	"net/url"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/opentracing/opentracing-go"
    12  	"github.com/opentracing/opentracing-go/ext"
    13  
    14  	"go.undefinedlabs.com/scopeagent/env"
    15  	"go.undefinedlabs.com/scopeagent/errors"
    16  	"go.undefinedlabs.com/scopeagent/instrumentation"
    17  )
    18  
    19  type mwOptions struct {
    20  	opNameFunc             func(r *http.Request) string
    21  	spanFilter             func(r *http.Request) bool
    22  	spanObserver           func(span opentracing.Span, r *http.Request)
    23  	urlTagFunc             func(u *url.URL) string
    24  	componentName          string
    25  	payloadInstrumentation bool
    26  }
    27  
    28  // MWOption controls the behavior of the Middleware.
    29  type MWOption func(*mwOptions)
    30  
    31  // OperationNameFunc returns a MWOption that uses given function f
    32  // to generate operation name for each server-side span.
    33  func OperationNameFunc(f func(r *http.Request) string) MWOption {
    34  	return func(options *mwOptions) {
    35  		options.opNameFunc = f
    36  	}
    37  }
    38  
    39  // MWComponentName returns a MWOption that sets the component name
    40  // for the server-side span.
    41  func MWComponentName(componentName string) MWOption {
    42  	return func(options *mwOptions) {
    43  		options.componentName = componentName
    44  	}
    45  }
    46  
    47  // MWSpanFilter returns a MWOption that filters requests from creating a span
    48  // for the server-side span.
    49  // Span won't be created if it returns false.
    50  func MWSpanFilter(f func(r *http.Request) bool) MWOption {
    51  	return func(options *mwOptions) {
    52  		options.spanFilter = f
    53  	}
    54  }
    55  
    56  // MWSpanObserver returns a MWOption that observe the span
    57  // for the server-side span.
    58  func MWSpanObserver(f func(span opentracing.Span, r *http.Request)) MWOption {
    59  	return func(options *mwOptions) {
    60  		options.spanObserver = f
    61  	}
    62  }
    63  
    64  // MWURLTagFunc returns a MWOption that uses given function f
    65  // to set the span's http.url tag. Can be used to change the default
    66  // http.url tag, eg to redact sensitive information.
    67  func MWURLTagFunc(f func(u *url.URL) string) MWOption {
    68  	return func(options *mwOptions) {
    69  		options.urlTagFunc = f
    70  	}
    71  }
    72  
    73  // Enable payload instrumentation
    74  func MWPayloadInstrumentation() MWOption {
    75  	return func(options *mwOptions) {
    76  		options.payloadInstrumentation = true
    77  	}
    78  }
    79  
    80  // Middleware wraps an http.Handler and traces incoming requests.
    81  // Additionally, it adds the span to the request's context.
    82  //
    83  // By default, the operation name of the spans is set to "HTTP {method}".
    84  // This can be overriden with options.
    85  //
    86  // Example:
    87  // 	 http.ListenAndServe("localhost:80", nethttp.Middleware(tracer, http.DefaultServeMux))
    88  //
    89  // The options allow fine tuning the behavior of the middleware.
    90  //
    91  // Example:
    92  //   mw := nethttp.Middleware(
    93  //      tracer,
    94  //      http.DefaultServeMux,
    95  //      nethttp.OperationNameFunc(func(r *http.Request) string {
    96  //	        return "HTTP " + r.Method + ":/api/customers"
    97  //      }),
    98  //      nethttp.MWSpanObserver(func(sp opentracing.Span, r *http.Request) {
    99  //			sp.SetTag("http.uri", r.URL.EscapedPath())
   100  //		}),
   101  //   )
   102  func middleware(tr opentracing.Tracer, h http.Handler, options ...MWOption) http.Handler {
   103  	return middlewareFunc(tr, h.ServeHTTP, options...)
   104  }
   105  
   106  // MiddlewareFunc wraps an http.HandlerFunc and traces incoming requests.
   107  // It behaves identically to the Middleware function above.
   108  //
   109  // Example:
   110  //   http.ListenAndServe("localhost:80", nethttp.MiddlewareFunc(tracer, MyHandler))
   111  func middlewareFunc(tr opentracing.Tracer, h http.HandlerFunc, options ...MWOption) http.HandlerFunc {
   112  	opts := mwOptions{
   113  		opNameFunc: func(r *http.Request) string {
   114  			return "HTTP " + r.Method
   115  		},
   116  		spanFilter:   func(r *http.Request) bool { return true },
   117  		spanObserver: func(span opentracing.Span, r *http.Request) {},
   118  		urlTagFunc: func(u *url.URL) string {
   119  			return u.String()
   120  		},
   121  	}
   122  	for _, opt := range options {
   123  		opt(&opts)
   124  	}
   125  	opts.payloadInstrumentation = opts.payloadInstrumentation || env.ScopeInstrumentationHttpPayloads.Value
   126  	fn := func(w http.ResponseWriter, r *http.Request) {
   127  		if !opts.spanFilter(r) {
   128  			h(w, r)
   129  			return
   130  		}
   131  		ctx, _ := tr.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
   132  		sp := tr.StartSpan(opts.opNameFunc(r), ext.RPCServerOption(ctx))
   133  		ext.HTTPMethod.Set(sp, r.Method)
   134  		ext.HTTPUrl.Set(sp, opts.urlTagFunc(r.URL))
   135  		opts.spanObserver(sp, r)
   136  
   137  		// set component name, use "net/http" if caller does not specify
   138  		componentName := opts.componentName
   139  		if componentName == "" {
   140  			componentName = defaultComponentName
   141  		}
   142  		ext.Component.Set(sp, componentName)
   143  
   144  		ext.PeerAddress.Set(sp, r.RemoteAddr)
   145  		ext.PeerHostIPv4.SetString(sp, r.RemoteAddr)
   146  		if idx := strings.IndexByte(r.RemoteAddr, ':'); idx > -1 {
   147  			ip := net.ParseIP(r.RemoteAddr[:idx])
   148  			if ip.Equal(ip.To4()) {
   149  				ext.PeerHostIPv4.SetString(sp, ip.String())
   150  			} else if ip.Equal(ip.To16()) {
   151  				ext.PeerHostIPv6.Set(sp, ip.String())
   152  			}
   153  			if val, err := strconv.ParseUint(r.RemoteAddr[idx+1:], 10, 16); err == nil {
   154  				ext.PeerPort.Set(sp, uint16(val))
   155  			}
   156  		}
   157  
   158  		rtracker := &responseTracker{ResponseWriter: w}
   159  		rtracker.payloadInstrumentation = opts.payloadInstrumentation
   160  		r = r.WithContext(opentracing.ContextWithSpan(r.Context(), sp))
   161  
   162  		defer func() {
   163  			ext.HTTPStatusCode.Set(sp, uint16(rtracker.status))
   164  			if rtracker.status >= http.StatusBadRequest || !rtracker.wroteheader {
   165  				ext.Error.Set(sp, true)
   166  			}
   167  
   168  			if rtracker.payloadInstrumentation {
   169  				rqPayload := getRequestPayload(r, payloadBufferSize)
   170  				sp.SetTag("http.request_payload", rqPayload)
   171  			} else {
   172  				sp.SetTag("http.request_payload.unavailable", "disabled")
   173  			}
   174  
   175  			if rtracker.payloadInstrumentation {
   176  				rsRunes := bytes.Runes(rtracker.payloadBuffer)
   177  				rsPayload := string(rsRunes)
   178  				sp.SetTag("http.response_payload", rsPayload)
   179  			} else {
   180  				sp.SetTag("http.response_payload.unavailable", "disabled")
   181  			}
   182  
   183  			if r := recover(); r != nil {
   184  				errors.WriteExceptionEvent(sp, r, 1)
   185  				sp.Finish()
   186  				panic(r)
   187  			}
   188  
   189  			sp.Finish()
   190  		}()
   191  
   192  		h(rtracker.wrappedResponseWriter(), r)
   193  	}
   194  	return http.HandlerFunc(fn)
   195  }
   196  
   197  func Middleware(h http.Handler, options ...MWOption) http.Handler {
   198  	if h == nil {
   199  		h = http.DefaultServeMux
   200  	}
   201  	return MiddlewareFunc(h.ServeHTTP, options...)
   202  }
   203  
   204  func MiddlewareFunc(h http.HandlerFunc, options ...MWOption) http.Handler {
   205  	// Only trace requests that are part of a test trace
   206  	options = append(options, MWSpanFilter(func(r *http.Request) bool {
   207  		ctx, err := instrumentation.Tracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
   208  		if err != nil {
   209  			return false
   210  		}
   211  		inTest := false
   212  		ctx.ForeachBaggageItem(func(k, v string) bool {
   213  			if k == "trace.kind" && v == "test" {
   214  				inTest = true
   215  				return false
   216  			}
   217  			return true
   218  		})
   219  		return inTest
   220  	}))
   221  	return middlewareFunc(instrumentation.Tracer(), h, options...)
   222  }