go.undefinedlabs.com/scopeagent@v0.4.2/instrumentation/nethttp/client.go (about)

     1  package nethttp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptrace"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/opentracing/opentracing-go"
    14  	"github.com/opentracing/opentracing-go/ext"
    15  	"github.com/opentracing/opentracing-go/log"
    16  
    17  	scopeerrors "go.undefinedlabs.com/scopeagent/errors"
    18  	"go.undefinedlabs.com/scopeagent/instrumentation"
    19  	scopetracer "go.undefinedlabs.com/scopeagent/tracer"
    20  )
    21  
    22  type contextKey int
    23  
    24  const (
    25  	keyTracer contextKey = iota
    26  )
    27  
    28  const defaultComponentName = "net/http"
    29  
    30  // Transport wraps a RoundTripper. If a request is being traced with
    31  // Tracer, Transport will inject the current span into the headers,
    32  // and set HTTP related tags on the span.
    33  type Transport struct {
    34  	// The actual RoundTripper to use for the request. A nil
    35  	// RoundTripper defaults to http.DefaultTransport.
    36  	http.RoundTripper
    37  
    38  	// Enable payload instrumentation
    39  	PayloadInstrumentation bool
    40  
    41  	// Enable stacktrace
    42  	Stacktrace bool
    43  }
    44  
    45  type clientOptions struct {
    46  	operationName            string
    47  	componentName            string
    48  	disableClientTrace       bool
    49  	disableInjectSpanContext bool
    50  	spanObserver             func(span opentracing.Span, r *http.Request)
    51  }
    52  
    53  // ClientOption controls the behavior of TraceRequest.
    54  type ClientOption func(*clientOptions)
    55  
    56  // OperationName returns a ClientOption that sets the operation
    57  // name for the client-side span.
    58  func OperationName(operationName string) ClientOption {
    59  	return func(options *clientOptions) {
    60  		options.operationName = operationName
    61  	}
    62  }
    63  
    64  // ComponentName returns a ClientOption that sets the component
    65  // name for the client-side span.
    66  func ComponentName(componentName string) ClientOption {
    67  	return func(options *clientOptions) {
    68  		options.componentName = componentName
    69  	}
    70  }
    71  
    72  // ClientTrace returns a ClientOption that turns on or off
    73  // extra instrumentation via httptrace.WithClientTrace.
    74  func ClientTrace(enabled bool) ClientOption {
    75  	return func(options *clientOptions) {
    76  		options.disableClientTrace = !enabled
    77  	}
    78  }
    79  
    80  // InjectSpanContext returns a ClientOption that turns on or off
    81  // injection of the Span context in the request HTTP headers.
    82  // If this option is not used, the default behaviour is to
    83  // inject the span context.
    84  func InjectSpanContext(enabled bool) ClientOption {
    85  	return func(options *clientOptions) {
    86  		options.disableInjectSpanContext = !enabled
    87  	}
    88  }
    89  
    90  // ClientSpanObserver returns a ClientOption that observes the span
    91  // for the client-side span.
    92  func ClientSpanObserver(f func(span opentracing.Span, r *http.Request)) ClientOption {
    93  	return func(options *clientOptions) {
    94  		options.spanObserver = f
    95  	}
    96  }
    97  
    98  // TraceRequest adds a ClientTracer to req, tracing the request and
    99  // all requests caused due to redirects. When tracing requests this
   100  // way you must also use Transport.
   101  //
   102  // Example:
   103  //
   104  // 	func AskGoogle(ctx context.Context) error {
   105  // 		client := &http.Client{Transport: &nethttp.Transport{}}
   106  // 		req, err := http.NewRequest("GET", "http://google.com", nil)
   107  // 		if err != nil {
   108  // 			return err
   109  // 		}
   110  // 		req = req.WithContext(ctx) // extend existing trace, if any
   111  //
   112  // 		req, ht := nethttp.TraceRequest(tracer, req)
   113  // 		defer ht.Finish()
   114  //
   115  // 		res, err := client.Do(req)
   116  // 		if err != nil {
   117  // 			return err
   118  // 		}
   119  // 		res.Body.Close()
   120  // 		return nil
   121  // 	}
   122  func TraceRequest(tr opentracing.Tracer, req *http.Request, options ...ClientOption) (*http.Request, *Tracer) {
   123  	opts := &clientOptions{
   124  		spanObserver: func(_ opentracing.Span, _ *http.Request) {},
   125  	}
   126  	for _, opt := range options {
   127  		opt(opts)
   128  	}
   129  	ht := &Tracer{tr: tr, opts: opts}
   130  	ctx := req.Context()
   131  	if !opts.disableClientTrace {
   132  		ctx = httptrace.WithClientTrace(ctx, ht.clientTrace())
   133  	}
   134  	req = req.WithContext(context.WithValue(ctx, keyTracer, ht))
   135  	return req, ht
   136  }
   137  
   138  type closeTracker struct {
   139  	io.ReadCloser
   140  	sp opentracing.Span
   141  }
   142  
   143  func (c closeTracker) Close() error {
   144  	err := c.ReadCloser.Close()
   145  	c.sp.Finish()
   146  	return err
   147  }
   148  
   149  // TracerFromRequest retrieves the Tracer from the request. If the request does
   150  // not have a Tracer it will return nil.
   151  func TracerFromRequest(req *http.Request) *Tracer {
   152  	tr, ok := req.Context().Value(keyTracer).(*Tracer)
   153  	if !ok {
   154  		return nil
   155  	}
   156  	return tr
   157  }
   158  
   159  func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
   160  	// Only trace outgoing requests that are inside an active trace
   161  	parent := opentracing.SpanFromContext(req.Context())
   162  	if parent == nil {
   163  		rt := t.RoundTripper
   164  		if rt == nil {
   165  			rt = http.DefaultTransport
   166  		}
   167  		return rt.RoundTrip(req)
   168  	}
   169  	req, _ = TraceRequest(instrumentation.Tracer(), req)
   170  	return t.doRoundTrip(req)
   171  }
   172  
   173  // RoundTrip implements the RoundTripper interface.
   174  func (t *Transport) doRoundTrip(req *http.Request) (*http.Response, error) {
   175  	rt := t.RoundTripper
   176  	if rt == nil {
   177  		rt = http.DefaultTransport
   178  	}
   179  	tracer := TracerFromRequest(req)
   180  	if tracer == nil {
   181  		return rt.RoundTrip(req)
   182  	}
   183  
   184  	tracer.start(req)
   185  
   186  	if t.Stacktrace {
   187  		if span, ok := tracer.sp.(scopetracer.Span); ok {
   188  			span.UnsafeSetTag("stacktrace", scopeerrors.GetCurrentStackTrace(2))
   189  		} else {
   190  			tracer.sp.SetTag("stacktrace", scopeerrors.GetCurrentStackTrace(2))
   191  		}
   192  	}
   193  
   194  	ext.HTTPMethod.Set(tracer.sp, req.Method)
   195  	ext.HTTPUrl.Set(tracer.sp, req.URL.String())
   196  	tracer.opts.spanObserver(tracer.sp, req)
   197  
   198  	if !tracer.opts.disableInjectSpanContext {
   199  		carrier := opentracing.HTTPHeadersCarrier(req.Header)
   200  		tracer.sp.Tracer().Inject(tracer.sp.Context(), opentracing.HTTPHeaders, carrier)
   201  	}
   202  
   203  	if t.PayloadInstrumentation {
   204  		rqPayload := getRequestPayload(req, payloadBufferSize)
   205  		tracer.sp.SetTag("http.request_payload", rqPayload)
   206  	} else {
   207  		tracer.sp.SetTag("http.request_payload.unavailable", "disabled")
   208  	}
   209  
   210  	resp, err := rt.RoundTrip(req)
   211  
   212  	if t.PayloadInstrumentation {
   213  		rsPayLoad := getResponsePayload(resp, payloadBufferSize)
   214  		tracer.sp.SetTag("http.response_payload", rsPayLoad)
   215  	} else {
   216  		tracer.sp.SetTag("http.response_payload.unavailable", "disabled")
   217  	}
   218  
   219  	if err != nil {
   220  		tracer.sp.Finish()
   221  		return resp, err
   222  	}
   223  	ext.HTTPStatusCode.Set(tracer.sp, uint16(resp.StatusCode))
   224  	if resp.StatusCode >= http.StatusBadRequest {
   225  		ext.Error.Set(tracer.sp, true)
   226  	}
   227  	if req.Method == "HEAD" {
   228  		tracer.sp.Finish()
   229  	} else {
   230  		resp.Body = closeTracker{resp.Body, tracer.sp}
   231  	}
   232  	return resp, nil
   233  }
   234  
   235  // Gets the request payload
   236  func getRequestPayload(req *http.Request, bufferSize int) string {
   237  	if req == nil || req.Body == nil || req.Body == http.NoBody {
   238  		return ""
   239  	}
   240  	if req.GetBody == nil {
   241  		// GetBody is nil in server requests
   242  		nBody, payload := getBodyPayload(req.Body, bufferSize)
   243  		req.Body = nBody
   244  		return payload
   245  	}
   246  	rqBody, rqErr := req.GetBody()
   247  	if rqErr != nil {
   248  		return ""
   249  	}
   250  	rqBodyBuffer := make([]byte, bufferSize)
   251  	if ln, err := rqBody.Read(rqBodyBuffer); err == nil && ln > 0 {
   252  		if ln < bufferSize {
   253  			rqBodyBuffer = rqBodyBuffer[:ln]
   254  		}
   255  		return string(bytes.Runes(rqBodyBuffer))
   256  	}
   257  	return ""
   258  }
   259  
   260  // Gets the payload from a body
   261  func getBodyPayload(body io.ReadCloser, bufferSize int) (io.ReadCloser, string) {
   262  	if body == nil {
   263  		return body, ""
   264  	}
   265  	rsBodyBuffer := make([]byte, bufferSize)
   266  	ln, _ := body.Read(rsBodyBuffer)
   267  	if ln == 0 {
   268  		return body, ""
   269  	}
   270  	if ln < bufferSize {
   271  		rsBodyBuffer = rsBodyBuffer[:ln]
   272  	}
   273  	rsPayload := string(bytes.Runes(rsBodyBuffer))
   274  	rBody := struct {
   275  		io.Reader
   276  		io.Closer
   277  	}{
   278  		io.MultiReader(bytes.NewReader(rsBodyBuffer), body),
   279  		body,
   280  	}
   281  	return rBody, rsPayload
   282  }
   283  
   284  // Gets the response payload
   285  func getResponsePayload(resp *http.Response, bufferSize int) string {
   286  	if resp == nil || resp.Body == nil || resp.Body == http.NoBody {
   287  		return ""
   288  	}
   289  	rsBodyBuffer := make([]byte, bufferSize)
   290  	ln, _ := resp.Body.Read(rsBodyBuffer)
   291  	if ln == 0 {
   292  		return ""
   293  	}
   294  	if ln < bufferSize {
   295  		rsBodyBuffer = rsBodyBuffer[:ln]
   296  	}
   297  	rsPayload := string(bytes.Runes(rsBodyBuffer))
   298  	resp.Body = struct {
   299  		io.Reader
   300  		io.Closer
   301  	}{
   302  		io.MultiReader(bytes.NewReader(rsBodyBuffer), resp.Body),
   303  		resp.Body,
   304  	}
   305  	return rsPayload
   306  }
   307  
   308  // Tracer holds tracing details for one HTTP request.
   309  type Tracer struct {
   310  	tr   opentracing.Tracer
   311  	root opentracing.Span
   312  	sp   opentracing.Span
   313  	opts *clientOptions
   314  }
   315  
   316  func (h *Tracer) start(req *http.Request) opentracing.Span {
   317  	if h.root == nil {
   318  		parent := opentracing.SpanFromContext(req.Context())
   319  		h.root = parent
   320  	}
   321  
   322  	ctx := h.root.Context()
   323  	h.sp = h.tr.StartSpan("HTTP "+req.Method, opentracing.ChildOf(ctx))
   324  	ext.SpanKindRPCClient.Set(h.sp)
   325  
   326  	componentName := h.opts.componentName
   327  	if componentName == "" {
   328  		componentName = defaultComponentName
   329  	}
   330  	ext.Component.Set(h.sp, componentName)
   331  
   332  	return h.sp
   333  }
   334  
   335  // Finish finishes the span of the traced request.
   336  func (h *Tracer) Finish() {
   337  	if h.root != nil {
   338  		h.root.Finish()
   339  	}
   340  }
   341  
   342  // Span returns the root span of the traced request. This function
   343  // should only be called after the request has been executed.
   344  func (h *Tracer) Span() opentracing.Span {
   345  	return h.root
   346  }
   347  
   348  func (h *Tracer) clientTrace() *httptrace.ClientTrace {
   349  	return &httptrace.ClientTrace{
   350  		GetConn:              h.getConn,
   351  		GotConn:              h.gotConn,
   352  		PutIdleConn:          h.putIdleConn,
   353  		GotFirstResponseByte: h.gotFirstResponseByte,
   354  		Got100Continue:       h.got100Continue,
   355  		DNSStart:             h.dnsStart,
   356  		DNSDone:              h.dnsDone,
   357  		ConnectStart:         h.connectStart,
   358  		ConnectDone:          h.connectDone,
   359  		WroteHeaders:         h.wroteHeaders,
   360  		Wait100Continue:      h.wait100Continue,
   361  		WroteRequest:         h.wroteRequest,
   362  	}
   363  }
   364  
   365  func (h *Tracer) getConn(hostPort string) {
   366  }
   367  
   368  func (h *Tracer) gotConn(info httptrace.GotConnInfo) {
   369  	h.sp.SetTag("net/http.reused", info.Reused)
   370  	h.sp.SetTag("net/http.was_idle", info.WasIdle)
   371  }
   372  
   373  func (h *Tracer) putIdleConn(error) {
   374  }
   375  
   376  func (h *Tracer) gotFirstResponseByte() {
   377  }
   378  
   379  func (h *Tracer) got100Continue() {
   380  }
   381  
   382  func (h *Tracer) dnsStart(info httptrace.DNSStartInfo) {
   383  	ext.PeerHostname.Set(h.sp, info.Host)
   384  }
   385  
   386  func (h *Tracer) dnsDone(info httptrace.DNSDoneInfo) {
   387  }
   388  
   389  func (h *Tracer) connectStart(network, addr string) {
   390  	ext.PeerAddress.Set(h.sp, addr)
   391  	if idx := strings.IndexByte(addr, ':'); idx > -1 {
   392  		ip := net.ParseIP(addr[:idx])
   393  		if ip.Equal(ip.To4()) {
   394  			ext.PeerHostIPv4.SetString(h.sp, ip.String())
   395  		} else if ip.Equal(ip.To16()) {
   396  			ext.PeerHostIPv6.Set(h.sp, ip.String())
   397  		}
   398  		if val, err := strconv.ParseUint(addr[idx+1:], 10, 16); err == nil {
   399  			ext.PeerPort.Set(h.sp, uint16(val))
   400  		}
   401  	}
   402  }
   403  
   404  func (h *Tracer) connectDone(network, addr string, err error) {
   405  	if err != nil {
   406  		h.sp.LogFields(
   407  			log.String("message", "ConnectDone"),
   408  			log.String("network", network),
   409  			log.String("addr", addr),
   410  			log.String("event", "error"),
   411  			log.Error(err),
   412  		)
   413  	}
   414  }
   415  
   416  func (h *Tracer) wroteHeaders() {
   417  }
   418  
   419  func (h *Tracer) wait100Continue() {
   420  }
   421  
   422  func (h *Tracer) wroteRequest(info httptrace.WroteRequestInfo) {
   423  	if info.Err != nil {
   424  		h.sp.LogFields(
   425  			log.String("message", "WroteRequest"),
   426  			log.String("event", "error"),
   427  			log.Error(info.Err),
   428  		)
   429  		ext.Error.Set(h.sp, true)
   430  	}
   431  }