github.com/mier85/go-sensor@v1.30.1-0.20220920111756-9bf41b3bc7e0/instrumentation_http.go (about)

     1  // (c) Copyright IBM Corp. 2021
     2  // (c) Copyright Instana Inc. 2020
     3  
     4  package instana
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"mime/multipart"
    10  	"net"
    11  	"net/http"
    12  	"net/textproto"
    13  	"net/url"
    14  
    15  	ot "github.com/opentracing/opentracing-go"
    16  	"github.com/opentracing/opentracing-go/ext"
    17  	otlog "github.com/opentracing/opentracing-go/log"
    18  )
    19  
    20  // TracingHandlerFunc is an HTTP middleware that captures the tracing data and ensures
    21  // trace context propagation via OpenTracing headers. The pathTemplate parameter, when provided,
    22  // will be added to the span as a template string used to match the route containing variables, regular
    23  // expressions, etc.
    24  //
    25  // The wrapped handler will also propagate the W3C trace context (https://www.w3.org/TR/trace-context/)
    26  // if found in request.
    27  func TracingHandlerFunc(sensor *Sensor, pathTemplate string, handler http.HandlerFunc) http.HandlerFunc {
    28  	return TracingNamedHandlerFunc(sensor, "", pathTemplate, handler)
    29  }
    30  
    31  // TracingNamedHandlerFunc is an HTTP middleware that similarly to instana.TracingHandlerFunc() captures the tracing data,
    32  // while allowing to provide a unique route indetifier to be associated with each request.
    33  func TracingNamedHandlerFunc(sensor *Sensor, routeID, pathTemplate string, handler http.HandlerFunc) http.HandlerFunc {
    34  	return func(w http.ResponseWriter, req *http.Request) {
    35  		ctx := req.Context()
    36  
    37  		opts := initSpanOptions(req, routeID)
    38  
    39  		tracer := sensor.Tracer()
    40  		if ps, ok := SpanFromContext(req.Context()); ok {
    41  			tracer = ps.Tracer()
    42  			opts = append(opts, ot.ChildOf(ps.Context()))
    43  		}
    44  
    45  		opts = append(opts, extractStartSpanOptionsFromHeaders(tracer, req, sensor)...)
    46  
    47  		if req.Header.Get(FieldSynthetic) == "1" {
    48  			opts = append(opts, syntheticCall())
    49  		}
    50  
    51  		if pathTemplate != "" && req.URL.Path != pathTemplate {
    52  			opts = append(opts, ot.Tag{Key: "http.path_tpl", Value: pathTemplate})
    53  		}
    54  
    55  		span := tracer.StartSpan("g.http", opts...)
    56  		defer span.Finish()
    57  
    58  		var collectableHTTPHeaders []string
    59  		if t, ok := tracer.(Tracer); ok {
    60  			opts := t.Options()
    61  			collectableHTTPHeaders = opts.CollectableHTTPHeaders
    62  
    63  			params := collectHTTPParams(req, opts.Secrets)
    64  			if len(params) > 0 {
    65  				span.SetTag("http.params", params.Encode())
    66  			}
    67  		}
    68  
    69  		collectedHeaders := make(map[string]string)
    70  		// make sure collected headers are sent in case of panic/error
    71  		defer func() {
    72  			if len(collectedHeaders) > 0 {
    73  				span.SetTag("http.header", collectedHeaders)
    74  			}
    75  		}()
    76  
    77  		collectRequestHeaders(req, collectableHTTPHeaders, collectedHeaders)
    78  
    79  		defer func() {
    80  			// Be sure to capture any kind of panic/error
    81  			if err := recover(); err != nil {
    82  				if e, ok := err.(error); ok {
    83  					span.SetTag("http.error", e.Error())
    84  					span.LogFields(otlog.Error(e))
    85  				} else {
    86  					span.SetTag("http.error", err)
    87  					span.LogFields(otlog.Object("error", err))
    88  				}
    89  
    90  				span.SetTag(string(ext.HTTPStatusCode), http.StatusInternalServerError)
    91  
    92  				// re-throw the panic
    93  				panic(err)
    94  			}
    95  		}()
    96  
    97  		wrapped := wrapResponseWriter(w)
    98  		tracer.Inject(span.Context(), ot.HTTPHeaders, ot.HTTPHeadersCarrier(wrapped.Header()))
    99  
   100  		handler(wrapped, req.WithContext(ContextWithSpan(ctx, span)))
   101  
   102  		collectResponseHeaders(wrapped, collectableHTTPHeaders, collectedHeaders)
   103  		processResponseStatus(wrapped, span)
   104  	}
   105  }
   106  
   107  func initSpanOptions(req *http.Request, routeID string) []ot.StartSpanOption {
   108  	opts := []ot.StartSpanOption{
   109  		ext.SpanKindRPCServer,
   110  		ot.Tags{
   111  			"http.host":     req.Host,
   112  			"http.method":   req.Method,
   113  			"http.protocol": req.URL.Scheme,
   114  			"http.path":     req.URL.Path,
   115  			"http.route_id": routeID,
   116  		},
   117  	}
   118  	return opts
   119  }
   120  
   121  func processResponseStatus(response wrappedResponseWriter, span ot.Span) {
   122  	if response.Status() > 0 {
   123  		if response.Status() >= http.StatusInternalServerError {
   124  			statusText := http.StatusText(response.Status())
   125  
   126  			span.SetTag("http.error", statusText)
   127  			span.LogFields(otlog.Object("error", statusText))
   128  		}
   129  
   130  		span.SetTag("http.status", response.Status())
   131  	}
   132  }
   133  
   134  func collectResponseHeaders(response wrappedResponseWriter, collectableHTTPHeaders []string, collectedHeaders map[string]string) {
   135  	for _, h := range collectableHTTPHeaders {
   136  		if v := response.Header().Get(h); v != "" {
   137  			collectedHeaders[h] = v
   138  		}
   139  	}
   140  }
   141  
   142  func collectRequestHeaders(req *http.Request, collectableHTTPHeaders []string, collectedHeaders map[string]string) {
   143  	for _, h := range collectableHTTPHeaders {
   144  		if v := req.Header.Get(h); v != "" {
   145  			collectedHeaders[h] = v
   146  		}
   147  	}
   148  }
   149  
   150  func extractStartSpanOptionsFromHeaders(tracer ot.Tracer, req *http.Request, sensor *Sensor) []ot.StartSpanOption {
   151  	var opts []ot.StartSpanOption
   152  	wireContext, err := tracer.Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(req.Header))
   153  	switch err {
   154  	case nil:
   155  		opts = append(opts, ext.RPCServerOption(wireContext))
   156  	case ot.ErrSpanContextNotFound:
   157  		sensor.Logger().Debug("no span context provided with ", req.Method, " ", req.URL.Path)
   158  	case ot.ErrUnsupportedFormat:
   159  		sensor.Logger().Info("unsupported span context format provided with ", req.Method, " ", req.URL.Path)
   160  	default:
   161  		sensor.Logger().Warn("failed to extract span context from the request:", err)
   162  	}
   163  	return opts
   164  }
   165  
   166  // RoundTripper wraps an existing http.RoundTripper and injects the tracing headers into the outgoing request.
   167  // If the original RoundTripper is nil, the http.DefaultTransport will be used.
   168  func RoundTripper(sensor *Sensor, original http.RoundTripper) http.RoundTripper {
   169  	return tracingRoundTripper(func(req *http.Request) (*http.Response, error) {
   170  		if original == nil {
   171  			original = http.DefaultTransport
   172  		}
   173  
   174  		ctx := req.Context()
   175  		parentSpan, ok := SpanFromContext(ctx)
   176  		if !ok {
   177  			// don't trace the exit call if there was no entry span provided
   178  			return original.RoundTrip(req)
   179  		}
   180  
   181  		sanitizedURL := cloneURL(req.URL)
   182  		sanitizedURL.RawQuery = ""
   183  		sanitizedURL.User = nil
   184  
   185  		span := sensor.Tracer().StartSpan("http",
   186  			ext.SpanKindRPCClient,
   187  			ot.ChildOf(parentSpan.Context()),
   188  			ot.Tags{
   189  				"http.url":    sanitizedURL.String(),
   190  				"http.method": req.Method,
   191  			})
   192  		defer span.Finish()
   193  
   194  		// clone the request since the RoundTrip should not modify the original one
   195  		req = cloneRequest(ContextWithSpan(ctx, span), req)
   196  		sensor.Tracer().Inject(span.Context(), ot.HTTPHeaders, ot.HTTPHeadersCarrier(req.Header))
   197  
   198  		var collectableHTTPHeaders []string
   199  		if t, ok := sensor.Tracer().(Tracer); ok {
   200  			opts := t.Options()
   201  			collectableHTTPHeaders = opts.CollectableHTTPHeaders
   202  
   203  			params := collectHTTPParams(req, opts.Secrets)
   204  			if len(params) > 0 {
   205  				span.SetTag("http.params", params.Encode())
   206  			}
   207  		}
   208  
   209  		collectedHeaders := make(map[string]string)
   210  		// make sure collected headers are sent in case of panic/error
   211  		defer func() {
   212  			if len(collectedHeaders) > 0 {
   213  				span.SetTag("http.header", collectedHeaders)
   214  			}
   215  		}()
   216  
   217  		// collect request headers
   218  		for _, h := range collectableHTTPHeaders {
   219  			if v := req.Header.Get(h); v != "" {
   220  				collectedHeaders[h] = v
   221  			}
   222  		}
   223  
   224  		resp, err := original.RoundTrip(req)
   225  		if err != nil {
   226  			span.SetTag("http.error", err.Error())
   227  			span.LogFields(otlog.Error(err))
   228  			return resp, err
   229  		}
   230  
   231  		// collect response headers
   232  		for _, h := range collectableHTTPHeaders {
   233  			if v := resp.Header.Get(h); v != "" {
   234  				collectedHeaders[h] = v
   235  			}
   236  		}
   237  
   238  		span.SetTag(string(ext.HTTPStatusCode), resp.StatusCode)
   239  
   240  		return resp, err
   241  	})
   242  }
   243  
   244  type wrappedResponseWriter interface {
   245  	http.ResponseWriter
   246  	Status() int
   247  }
   248  
   249  func wrapResponseWriter(w http.ResponseWriter) wrappedResponseWriter {
   250  	if _, ok := w.(http.Hijacker); ok {
   251  		return &statusCodeRecorderHTTP10{
   252  			ResponseWriter: w,
   253  		}
   254  	}
   255  
   256  	return &statusCodeRecorder{
   257  		ResponseWriter: w,
   258  	}
   259  }
   260  
   261  // statusCodeRecorder is a wrapper over http.ResponseWriter to spy the returned status code
   262  type statusCodeRecorder struct {
   263  	http.ResponseWriter
   264  	status int
   265  }
   266  
   267  func (rec *statusCodeRecorder) SetStatus(status int) {
   268  	rec.status = status
   269  }
   270  
   271  func (rec *statusCodeRecorder) WriteHeader(status int) {
   272  	rec.SetStatus(status)
   273  	rec.ResponseWriter.WriteHeader(status)
   274  }
   275  
   276  func (rec *statusCodeRecorder) Write(b []byte) (int, error) {
   277  	if rec.status == 0 {
   278  		rec.SetStatus(http.StatusOK)
   279  	}
   280  
   281  	return rec.ResponseWriter.Write(b)
   282  }
   283  
   284  func (rec *statusCodeRecorder) Status() int {
   285  	return rec.status
   286  }
   287  
   288  // statusCodeRecorderHTTP10 is a wrapper over http.ResponseWriter similar to statusCodeRecorder, but
   289  // also implementing http.Hijaker
   290  type statusCodeRecorderHTTP10 = statusCodeRecorder
   291  
   292  func (rec *statusCodeRecorderHTTP10) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   293  	return rec.ResponseWriter.(http.Hijacker).Hijack()
   294  }
   295  
   296  type tracingRoundTripper func(*http.Request) (*http.Response, error)
   297  
   298  func (rt tracingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   299  	return rt(req)
   300  }
   301  
   302  func collectHTTPParams(req *http.Request, matcher Matcher) url.Values {
   303  	params := cloneURLValues(req.URL.Query())
   304  
   305  	for k := range params {
   306  		if matcher.Match(k) {
   307  			params[k] = []string{"<redacted>"}
   308  		}
   309  	}
   310  
   311  	return params
   312  }
   313  
   314  // The following code is ported from $GOROOT/src/net/http/clone.go with minor changes
   315  // for compatibility with Go versions prior to 1.13
   316  //
   317  // Copyright 2019 The Go Authors. All rights reserved.
   318  // Use of this source code is governed by a BSD-style
   319  // license that can be found in the LICENSE file.
   320  
   321  func cloneRequest(ctx context.Context, r *http.Request) *http.Request {
   322  	r2 := new(http.Request)
   323  	*r2 = *r
   324  	r2 = r2.WithContext(ctx)
   325  
   326  	r2.URL = cloneURL(r.URL)
   327  	if r.Header != nil {
   328  		r2.Header = cloneHeader(r.Header)
   329  	}
   330  
   331  	if r.Trailer != nil {
   332  		r2.Trailer = cloneHeader(r.Trailer)
   333  	}
   334  
   335  	if s := r.TransferEncoding; s != nil {
   336  		s2 := make([]string, len(s))
   337  		copy(s2, s)
   338  		r2.TransferEncoding = s
   339  	}
   340  
   341  	r2.Form = cloneURLValues(r.Form)
   342  	r2.PostForm = cloneURLValues(r.PostForm)
   343  	r2.MultipartForm = cloneMultipartForm(r.MultipartForm)
   344  
   345  	return r2
   346  }
   347  
   348  func cloneURLValues(v url.Values) url.Values {
   349  	if v == nil {
   350  		return nil
   351  	}
   352  
   353  	// http.Header and url.Values have the same representation, so temporarily
   354  	// treat it like http.Header, which does have a clone:
   355  
   356  	return url.Values(cloneHeader(http.Header(v)))
   357  }
   358  
   359  func cloneURL(u *url.URL) *url.URL {
   360  	if u == nil {
   361  		return nil
   362  	}
   363  
   364  	u2 := new(url.URL)
   365  	*u2 = *u
   366  
   367  	if u.User != nil {
   368  		u2.User = new(url.Userinfo)
   369  		*u2.User = *u.User
   370  	}
   371  
   372  	return u2
   373  }
   374  
   375  func cloneMultipartForm(f *multipart.Form) *multipart.Form {
   376  	if f == nil {
   377  		return nil
   378  	}
   379  
   380  	f2 := &multipart.Form{
   381  		Value: (map[string][]string)(cloneHeader(http.Header(f.Value))),
   382  	}
   383  
   384  	if f.File != nil {
   385  		m := make(map[string][]*multipart.FileHeader)
   386  		for k, vv := range f.File {
   387  			vv2 := make([]*multipart.FileHeader, len(vv))
   388  			for i, v := range vv {
   389  				vv2[i] = cloneMultipartFileHeader(v)
   390  			}
   391  			m[k] = vv2
   392  
   393  		}
   394  
   395  		f2.File = m
   396  	}
   397  
   398  	return f2
   399  }
   400  
   401  func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader {
   402  	if fh == nil {
   403  		return nil
   404  	}
   405  
   406  	fh2 := new(multipart.FileHeader)
   407  	*fh2 = *fh
   408  
   409  	fh2.Header = textproto.MIMEHeader(cloneHeader(http.Header(fh.Header)))
   410  
   411  	return fh2
   412  }
   413  
   414  // The following code is ported from $GOROOT/src/net/http/header.go with minor changes
   415  // for compatibility with Go versions prior to 1.13
   416  //
   417  // Copyright 2019 The Go Authors. All rights reserved.
   418  // Use of this source code is governed by a BSD-style
   419  // license that can be found in the LICENSE file.
   420  
   421  func cloneHeader(h http.Header) http.Header {
   422  	if h == nil {
   423  		return nil
   424  	}
   425  
   426  	// Find total number of values.
   427  	nv := 0
   428  	for _, vv := range h {
   429  		nv += len(vv)
   430  	}
   431  	sv := make([]string, nv) // shared backing array for headers' values
   432  	h2 := make(http.Header, len(h))
   433  	for k, vv := range h {
   434  		n := copy(sv, vv)
   435  		h2[k] = sv[:n:n]
   436  		sv = sv[n:]
   437  	}
   438  	return h2
   439  }