github.com/grafana/pyroscope@v1.18.0/pkg/util/nethttp/client.go (about)

     1  package nethttp
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"net/http/httptrace"
     8  
     9  	"github.com/opentracing/opentracing-go"
    10  	"github.com/opentracing/opentracing-go/ext"
    11  	"github.com/opentracing/opentracing-go/log"
    12  )
    13  
    14  type contextKey int
    15  
    16  const (
    17  	keyTracer     contextKey = iota
    18  	componentName            = "pyroscope/net/http"
    19  )
    20  
    21  type Transport struct {
    22  	// The actual RoundTripper to use for the request. A nil RoundTripper defaults to http.DefaultTransport.
    23  	http.RoundTripper
    24  }
    25  
    26  func TraceRequest(tr opentracing.Tracer, req *http.Request) *http.Request {
    27  	ht := &Tracer{tr: tr}
    28  	ctx := req.Context()
    29  	ctx = httptrace.WithClientTrace(ctx, ht.clientTrace())
    30  	req = req.WithContext(context.WithValue(ctx, keyTracer, ht))
    31  	return req
    32  }
    33  
    34  type closeTracker struct {
    35  	io.ReadCloser
    36  	sp opentracing.Span
    37  }
    38  
    39  func (c closeTracker) Close() error {
    40  	err := c.ReadCloser.Close()
    41  	c.sp.LogFields(log.String("event", "ClosedBody"))
    42  	c.sp.Finish()
    43  	return err
    44  }
    45  
    46  func TracerFromRequest(req *http.Request) *Tracer {
    47  	tr, ok := req.Context().Value(keyTracer).(*Tracer)
    48  	if !ok {
    49  		return nil
    50  	}
    51  	return tr
    52  }
    53  
    54  func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
    55  	rt := t.RoundTripper
    56  	if rt == nil {
    57  		rt = http.DefaultTransport
    58  	}
    59  	tracer := TracerFromRequest(req)
    60  	if tracer == nil {
    61  		return rt.RoundTrip(req)
    62  	}
    63  
    64  	tracer.start(req)
    65  
    66  	carrier := opentracing.HTTPHeadersCarrier(req.Header)
    67  	err := tracer.sp.Tracer().Inject(tracer.sp.Context(), opentracing.HTTPHeaders, carrier)
    68  	if err != nil {
    69  		return rt.RoundTrip(req)
    70  	}
    71  
    72  	resp, err := rt.RoundTrip(req)
    73  
    74  	if err != nil {
    75  		tracer.sp.Finish()
    76  		return resp, err
    77  	}
    78  	ext.HTTPStatusCode.Set(tracer.sp, uint16(resp.StatusCode))
    79  	if resp.StatusCode >= http.StatusInternalServerError {
    80  		ext.Error.Set(tracer.sp, true)
    81  	}
    82  	// Normally the span is finished when the response body is closed, but with streaming the initial HTTP response
    83  	// does not have a body and this never happens. We are patching this here by finishing the span early, knowing
    84  	// that this will make the span shorter than what it actually is.
    85  	if req.Method == "HEAD" || resp.ContentLength < 0 {
    86  		tracer.sp.Finish()
    87  	} else {
    88  		resp.Body = closeTracker{resp.Body, tracer.sp}
    89  	}
    90  	return resp, nil
    91  }
    92  
    93  type Tracer struct {
    94  	tr opentracing.Tracer
    95  	sp opentracing.Span
    96  }
    97  
    98  func (t *Tracer) start(req *http.Request) opentracing.Span {
    99  	ctx := opentracing.SpanFromContext(req.Context()).Context()
   100  	t.sp = t.tr.StartSpan("HTTP "+req.Method, opentracing.ChildOf(ctx))
   101  	ext.SpanKindRPCClient.Set(t.sp)
   102  	ext.Component.Set(t.sp, componentName)
   103  	ext.HTTPMethod.Set(t.sp, req.Method)
   104  	ext.HTTPUrl.Set(t.sp, req.URL.String())
   105  	return t.sp
   106  }
   107  
   108  func (t *Tracer) clientTrace() *httptrace.ClientTrace {
   109  	return &httptrace.ClientTrace{
   110  		GetConn:              t.getConn,
   111  		GotConn:              t.gotConn,
   112  		PutIdleConn:          t.putIdleConn,
   113  		GotFirstResponseByte: t.gotFirstResponseByte,
   114  		Got100Continue:       t.got100Continue,
   115  		DNSStart:             t.dnsStart,
   116  		DNSDone:              t.dnsDone,
   117  		ConnectStart:         t.connectStart,
   118  		ConnectDone:          t.connectDone,
   119  		WroteHeaders:         t.wroteHeaders,
   120  		Wait100Continue:      t.wait100Continue,
   121  		WroteRequest:         t.wroteRequest,
   122  	}
   123  }
   124  
   125  func (t *Tracer) getConn(hostPort string) {
   126  	ext.HTTPUrl.Set(t.sp, hostPort)
   127  	t.sp.LogFields(log.String("event", "GetConn"))
   128  }
   129  
   130  func (t *Tracer) gotConn(info httptrace.GotConnInfo) {
   131  	t.sp.SetTag("net/http.reused", info.Reused)
   132  	t.sp.SetTag("net/http.was_idle", info.WasIdle)
   133  	t.sp.LogFields(log.String("event", "GotConn"))
   134  }
   135  
   136  func (t *Tracer) putIdleConn(error) {
   137  	t.sp.LogFields(log.String("event", "PutIdleConn"))
   138  }
   139  
   140  func (t *Tracer) gotFirstResponseByte() {
   141  	t.sp.LogFields(log.String("event", "GotFirstResponseByte"))
   142  }
   143  
   144  func (t *Tracer) got100Continue() {
   145  	t.sp.LogFields(log.String("event", "Got100Continue"))
   146  }
   147  
   148  func (t *Tracer) dnsStart(info httptrace.DNSStartInfo) {
   149  	t.sp.LogFields(
   150  		log.String("event", "DNSStart"),
   151  		log.String("host", info.Host),
   152  	)
   153  }
   154  
   155  func (t *Tracer) dnsDone(info httptrace.DNSDoneInfo) {
   156  	fields := []log.Field{log.String("event", "DNSDone")}
   157  	for _, addr := range info.Addrs {
   158  		fields = append(fields, log.String("addr", addr.String()))
   159  	}
   160  	if info.Err != nil {
   161  		fields = append(fields, log.Error(info.Err))
   162  	}
   163  	t.sp.LogFields(fields...)
   164  }
   165  
   166  func (t *Tracer) connectStart(network, addr string) {
   167  	t.sp.LogFields(
   168  		log.String("event", "ConnectStart"),
   169  		log.String("network", network),
   170  		log.String("addr", addr),
   171  	)
   172  }
   173  
   174  func (t *Tracer) connectDone(network, addr string, err error) {
   175  	if err != nil {
   176  		t.sp.LogFields(
   177  			log.String("message", "ConnectDone"),
   178  			log.String("network", network),
   179  			log.String("addr", addr),
   180  			log.String("event", "error"),
   181  			log.Error(err),
   182  		)
   183  	} else {
   184  		t.sp.LogFields(
   185  			log.String("event", "ConnectDone"),
   186  			log.String("network", network),
   187  			log.String("addr", addr),
   188  		)
   189  	}
   190  }
   191  
   192  func (t *Tracer) wroteHeaders() {
   193  	t.sp.LogFields(log.String("event", "WroteHeaders"))
   194  }
   195  
   196  func (t *Tracer) wait100Continue() {
   197  	t.sp.LogFields(log.String("event", "Wait100Continue"))
   198  }
   199  
   200  func (t *Tracer) wroteRequest(info httptrace.WroteRequestInfo) {
   201  	if info.Err != nil {
   202  		t.sp.LogFields(
   203  			log.String("message", "WroteRequest"),
   204  			log.String("event", "error"),
   205  			log.Error(info.Err),
   206  		)
   207  		ext.Error.Set(t.sp, true)
   208  	} else {
   209  		t.sp.LogFields(log.String("event", "WroteRequest"))
   210  	}
   211  }