github.com/epsagon/epsagon-go@v1.39.0/wrappers/net/http/client.go (about)

     1  package epsagonhttp
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"log"
     9  	"net/http"
    10  	"net/url"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/epsagon/epsagon-go/epsagon"
    15  	"github.com/epsagon/epsagon-go/protocol"
    16  	"github.com/epsagon/epsagon-go/tracer"
    17  	"github.com/google/uuid"
    18  )
    19  
    20  const EPSAGON_TRACEID_HEADER_KEY = "epsagon-trace-id"
    21  const EPSAGON_DOMAIN = "epsagon.com"
    22  const APPSYNC_API_SUBDOMAIN = ".appsync-api."
    23  const AMAZON_REQUEST_ID = "x-amzn-requestid"
    24  const API_GATEWAY_RESOURCE_TYPE = "api_gateway"
    25  
    26  type ValidationFunction func(string, string) bool
    27  
    28  var hasSuffix ValidationFunction = strings.HasSuffix
    29  var contains ValidationFunction = strings.Contains
    30  
    31  var blacklistURLs = map[*ValidationFunction][]string{
    32  	&hasSuffix: {
    33  		EPSAGON_DOMAIN,
    34  		".amazonaws.com",
    35  	},
    36  	&contains: {
    37  		"accounts.google.com",
    38  		"documents.azure.com",
    39  		"169.254.170.2", // AWS Task Metadata Endpoint
    40  	},
    41  }
    42  var whitelistURLs = map[*ValidationFunction][]string{
    43  	&contains: {
    44  		".execute-api.",
    45  		".elb.amazonaws.com",
    46  		APPSYNC_API_SUBDOMAIN,
    47  	},
    48  }
    49  
    50  // ClientWrapper is Epsagon's wrapper for http.Client
    51  type ClientWrapper struct {
    52  	http.Client
    53  
    54  	// MetadataOnly flag overriding the configuration
    55  	MetadataOnly bool
    56  	tracer       tracer.Tracer
    57  }
    58  
    59  // Wrap wraps an http.Client to Epsagon's ClientWrapper
    60  func Wrap(c http.Client, args ...context.Context) ClientWrapper {
    61  	currentTracer := epsagon.ExtractTracer(args)
    62  	return ClientWrapper{c, false, currentTracer}
    63  }
    64  
    65  func (c *ClientWrapper) getMetadataOnly() bool {
    66  	return c.MetadataOnly || c.tracer.GetConfig().MetadataOnly
    67  }
    68  
    69  // TracingTransport is the RoundTripper implementation that traces HTTP calls
    70  type TracingTransport struct {
    71  	// MetadataOnly flag overriding the configuration
    72  	MetadataOnly bool
    73  	tracer       tracer.Tracer
    74  	transport    http.RoundTripper
    75  }
    76  
    77  func NewTracingTransport(args ...context.Context) *TracingTransport {
    78  	return NewWrappedTracingTransport(http.DefaultTransport, args...)
    79  }
    80  
    81  func NewWrappedTracingTransport(rt http.RoundTripper, args ...context.Context) *TracingTransport {
    82  	currentTracer := epsagon.ExtractTracer(args)
    83  	return &TracingTransport{
    84  		tracer:    currentTracer,
    85  		transport: rt,
    86  	}
    87  }
    88  
    89  // RoundTrip implements the RoundTripper interface to trace HTTP calls
    90  func (t *TracingTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
    91  	// reference to the tracer
    92  	tr := t.tracer
    93  	// if the TracingTransport is created before the global tracer is created it will be nil
    94  	if tr == nil {
    95  		tr = epsagon.ExtractTracer(nil)
    96  		if tr != nil && tr.GetConfig().Debug {
    97  			log.Println("EPSAGON DEBUG: defaulting to global tracer in RoundTrip")
    98  		}
    99  	}
   100  
   101  	called := false
   102  	defer func() {
   103  		if !called {
   104  			resp, err = t.transport.RoundTrip(req)
   105  		}
   106  	}()
   107  	defer epsagon.GeneralEpsagonRecover("net.http.RoundTripper", "RoundTrip", t.tracer)
   108  	startTime := tracer.GetTimestamp()
   109  	reqHeaders, reqBody := "", ""
   110  	if !t.getMetadataOnly(tr) {
   111  		reqHeaders, reqBody = epsagon.ExtractRequestData(req)
   112  	}
   113  	if !isBlacklistedURL(req.URL) {
   114  		req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()}
   115  	}
   116  
   117  	resp, err = t.transport.RoundTrip(req)
   118  
   119  	called = true
   120  	event := postSuperCall(startTime, req.URL.String(), req.Method, resp, err, t.getMetadataOnly(tr))
   121  	t.addDataToEvent(reqHeaders, reqBody, req, event, tr)
   122  	tr.AddEvent(event)
   123  	return
   124  
   125  }
   126  
   127  func (t *TracingTransport) getMetadataOnly(tr tracer.Tracer) bool {
   128  	return t.MetadataOnly || tr.GetConfig().MetadataOnly
   129  }
   130  
   131  func (t *TracingTransport) addDataToEvent(reqHeaders, reqBody string, req *http.Request, event *protocol.Event, tr tracer.Tracer) {
   132  	if req != nil {
   133  		addTraceIdToEvent(req, event)
   134  	}
   135  	if !t.getMetadataOnly(tr) {
   136  		event.Resource.Metadata["request_headers"] = reqHeaders
   137  		event.Resource.Metadata["request_body"] = reqBody
   138  	}
   139  }
   140  
   141  func isBlacklistedURL(parsedUrl *url.URL) bool {
   142  	hostname := parsedUrl.Hostname()
   143  	for method, urls := range whitelistURLs {
   144  		for _, whitelistUrl := range urls {
   145  			if (*method)(hostname, whitelistUrl) {
   146  				return false
   147  			}
   148  		}
   149  	}
   150  	for method, urls := range blacklistURLs {
   151  		for _, blacklistUrl := range urls {
   152  			if (*method)(hostname, blacklistUrl) {
   153  				return true
   154  			}
   155  		}
   156  	}
   157  	return false
   158  }
   159  
   160  func shouldAddHeaderByURL(rawUrl string) bool {
   161  	parsedURL, err := url.Parse(rawUrl)
   162  	if err != nil {
   163  		return false
   164  	}
   165  	return !isBlacklistedURL(parsedURL)
   166  }
   167  
   168  func generateRandomUUID() string {
   169  	uuid, err := uuid.NewRandom()
   170  	if err != nil {
   171  		panic("failed to generate random UUID")
   172  	}
   173  	return strings.ReplaceAll(uuid.String(), "-", "")
   174  }
   175  
   176  func generateEpsagonTraceID() string {
   177  	traceID := generateRandomUUID()
   178  	spanID := generateRandomUUID()[:16]
   179  	parentSpanID := generateRandomUUID()[:16]
   180  	return fmt.Sprintf("%s:%s:%s:1", traceID, spanID, parentSpanID)
   181  }
   182  
   183  func addTraceIdToEvent(req *http.Request, event *protocol.Event) {
   184  	traceIDs, ok := req.Header[EPSAGON_TRACEID_HEADER_KEY]
   185  	if ok && len(traceIDs) > 0 {
   186  		traceID := traceIDs[0]
   187  		event.Resource.Metadata[tracer.EpsagonHTTPTraceIDKey] = traceID
   188  	}
   189  }
   190  
   191  // update event data according to given response headers
   192  // adds amazon request ID, if returned in response headers
   193  // used for traces HTTP correlation (appsync / api gateway targets)
   194  func updateByResponseHeaders(resp *http.Response, resource *protocol.Resource) {
   195  	var amzRequestIDs []string
   196  	for headerKey, headerValues := range resp.Header {
   197  		if strings.ToLower(headerKey) == AMAZON_REQUEST_ID {
   198  			amzRequestIDs = headerValues
   199  			break
   200  		}
   201  	}
   202  	if len(amzRequestIDs) > 0 {
   203  		amzRequestID := amzRequestIDs[0]
   204  		if !strings.Contains(resp.Request.URL.Hostname(), APPSYNC_API_SUBDOMAIN) {
   205  			// api gateway
   206  			resource.Metadata[tracer.AwsServiceKey] = API_GATEWAY_RESOURCE_TYPE
   207  		}
   208  		resource.Metadata[tracer.EpsagonRequestTraceIDKey] = amzRequestID
   209  	}
   210  }
   211  
   212  func (c *ClientWrapper) addDataToEvent(req *http.Request, resp *http.Response, event *protocol.Event) {
   213  	if req != nil {
   214  		addTraceIdToEvent(req, event)
   215  	}
   216  	if resp != nil {
   217  		if reqTraceID := resp.Header.Get("Apigw-Requestid"); reqTraceID != "" {
   218  			event.Resource.Metadata["request_trace_id"] = reqTraceID
   219  		}
   220  		if !c.getMetadataOnly() {
   221  			updateRequestData(resp.Request, event.Resource.Metadata)
   222  		}
   223  	}
   224  }
   225  
   226  // Do wraps http.Client's Do
   227  func (c *ClientWrapper) Do(req *http.Request) (resp *http.Response, err error) {
   228  	called := false
   229  	defer func() {
   230  		if !called {
   231  			resp, err = c.Client.Do(req)
   232  		}
   233  	}()
   234  	defer epsagon.GeneralEpsagonRecover("net.http.Client", "Client.Do", c.tracer)
   235  	startTime := tracer.GetTimestamp()
   236  	if !isBlacklistedURL(req.URL) {
   237  		req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()}
   238  	}
   239  	resp, err = c.Client.Do(req)
   240  	called = true
   241  	event := postSuperCall(startTime, req.URL.String(), req.Method, resp, err, c.getMetadataOnly())
   242  	c.addDataToEvent(req, resp, event)
   243  	c.tracer.AddEvent(event)
   244  	return
   245  }
   246  
   247  // Get wraps http.Client.Get
   248  func (c *ClientWrapper) Get(rawUrl string) (resp *http.Response, err error) {
   249  	called := false
   250  	defer func() {
   251  		if !called {
   252  			resp, err = c.Client.Get(rawUrl)
   253  		}
   254  	}()
   255  	defer epsagon.GeneralEpsagonRecover("net.http.Client", "Client.Get", c.tracer)
   256  	startTime := tracer.GetTimestamp()
   257  	req, err := http.NewRequest(http.MethodGet, rawUrl, nil)
   258  	if err != nil || !shouldAddHeaderByURL(rawUrl) {
   259  		// err might be nil if rawUrl is invalid. Then, wrapping without any HTTP trace correlation
   260  		resp, err = c.Client.Get(rawUrl)
   261  	} else {
   262  		req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()}
   263  		resp, err = c.Client.Do(req)
   264  	}
   265  	called = true
   266  	event := postSuperCall(startTime, rawUrl, http.MethodGet, resp, err, c.getMetadataOnly())
   267  	c.addDataToEvent(req, resp, event)
   268  	c.tracer.AddEvent(event)
   269  	return
   270  }
   271  
   272  // Post wraps http.Client.Post
   273  func (c *ClientWrapper) Post(
   274  	rawUrl string, contentType string, body io.Reader) (resp *http.Response, err error) {
   275  	called := false
   276  	defer func() {
   277  		if !called {
   278  			resp, err = c.Client.Post(rawUrl, contentType, body)
   279  		}
   280  	}()
   281  	defer epsagon.GeneralEpsagonRecover("net.http.Client", "Client.Post", c.tracer)
   282  	startTime := tracer.GetTimestamp()
   283  	req, err := http.NewRequest(http.MethodPost, rawUrl, body)
   284  	if err != nil || !shouldAddHeaderByURL(rawUrl) {
   285  		// err might be nil if rawUrl is invalid. Then, wrapping without any HTTP trace correlation
   286  		resp, err = c.Client.Post(rawUrl, contentType, body)
   287  	} else {
   288  		req.Header.Set("Content-Type", contentType)
   289  		req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()}
   290  		resp, err = c.Client.Do(req)
   291  	}
   292  	called = true
   293  	event := postSuperCall(startTime, rawUrl, http.MethodPost, resp, err, c.getMetadataOnly())
   294  	c.addDataToEvent(req, resp, event)
   295  	c.tracer.AddEvent(event)
   296  	return
   297  }
   298  
   299  // PostForm wraps http.Client.PostForm
   300  func (c *ClientWrapper) PostForm(
   301  	rawUrl string, data url.Values) (resp *http.Response, err error) {
   302  	called := false
   303  	defer func() {
   304  		if !called {
   305  			resp, err = c.Client.PostForm(rawUrl, data)
   306  		}
   307  	}()
   308  	defer epsagon.GeneralEpsagonRecover("net.http.Client", "Client.PostForm", c.tracer)
   309  	startTime := tracer.GetTimestamp()
   310  	req, err := http.NewRequest(http.MethodPost, rawUrl, strings.NewReader(data.Encode()))
   311  	if err != nil || !shouldAddHeaderByURL(rawUrl) {
   312  		// err might be nil if rawUrl is invalid. Then, wrapping without any HTTP trace correlation
   313  		resp, err = c.Client.PostForm(rawUrl, data)
   314  	} else {
   315  		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   316  		req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()}
   317  		resp, err = c.Client.Do(req)
   318  	}
   319  	called = true
   320  	event := postSuperCall(startTime, rawUrl, http.MethodPost, resp, err, c.getMetadataOnly())
   321  	c.addDataToEvent(req, resp, event)
   322  	c.tracer.AddEvent(event)
   323  	return
   324  }
   325  
   326  // Head wraps http.Client.Head
   327  func (c *ClientWrapper) Head(rawUrl string) (resp *http.Response, err error) {
   328  	called := false
   329  	defer func() {
   330  		if !called {
   331  			resp, err = c.Client.Head(rawUrl)
   332  		}
   333  	}()
   334  	defer epsagon.GeneralEpsagonRecover("net.http.Client", "Client.Head", c.tracer)
   335  	startTime := tracer.GetTimestamp()
   336  	req, err := http.NewRequest(http.MethodHead, rawUrl, nil)
   337  	if err != nil || !shouldAddHeaderByURL(rawUrl) {
   338  		// err might be nil if rawUrl is invalid. Then, wrapping without any HTTP trace correlation
   339  		resp, err = c.Client.Head(rawUrl)
   340  	} else {
   341  		req.Header[EPSAGON_TRACEID_HEADER_KEY] = []string{generateEpsagonTraceID()}
   342  		resp, err = c.Client.Do(req)
   343  	}
   344  	called = true
   345  	event := postSuperCall(startTime, rawUrl, http.MethodHead, resp, err, c.getMetadataOnly())
   346  	c.addDataToEvent(req, resp, event)
   347  	c.tracer.AddEvent(event)
   348  	return
   349  }
   350  
   351  func postSuperCall(
   352  	startTime float64,
   353  	url string,
   354  	method string,
   355  	resp *http.Response,
   356  	err error,
   357  	metadataOnly bool) *protocol.Event {
   358  
   359  	endTime := tracer.GetTimestamp()
   360  	duration := endTime - startTime
   361  
   362  	event := createHTTPEvent(url, method, err)
   363  	event.StartTime = startTime
   364  	event.Duration = duration
   365  	if resp != nil {
   366  		updateResponseData(resp, event.Resource, metadataOnly)
   367  	}
   368  	return event
   369  }
   370  
   371  func createHTTPEvent(url, method string, err error) *protocol.Event {
   372  	errorcode := protocol.ErrorCode_OK
   373  	if err != nil {
   374  		errorcode = protocol.ErrorCode_ERROR
   375  	}
   376  	return &protocol.Event{
   377  		Id:        "http.Client-" + uuid.New().String(),
   378  		Origin:    "http.Client",
   379  		ErrorCode: errorcode,
   380  		Resource: &protocol.Resource{
   381  			Name:      url,
   382  			Type:      "http",
   383  			Operation: method,
   384  			Metadata:  map[string]string{},
   385  		},
   386  	}
   387  }
   388  
   389  func updateResponseData(resp *http.Response, resource *protocol.Resource, metadataOnly bool) {
   390  	resource.Metadata["status_code"] = strconv.Itoa(resp.StatusCode)
   391  	updateByResponseHeaders(resp, resource)
   392  	if metadataOnly {
   393  		return
   394  	}
   395  	headers, err := epsagon.FormatHeaders(resp.Header)
   396  	if err == nil {
   397  		resource.Metadata["response_headers"] = headers
   398  	}
   399  	body, err := ioutil.ReadAll(resp.Body)
   400  	resp.Body.Close()
   401  	if err == nil {
   402  		// truncates response body to the first 64KB
   403  		if len(body) > epsagon.MaxMetadataSize {
   404  			resource.Metadata["response_body"] = string(body[0:epsagon.MaxMetadataSize])
   405  		} else {
   406  			resource.Metadata["response_body"] = string(body)
   407  		}
   408  	}
   409  	resp.Body = epsagon.NewReadCloser(body, err)
   410  }
   411  
   412  func updateRequestData(req *http.Request, metadata map[string]string) {
   413  	headers, err := epsagon.FormatHeaders(req.Header)
   414  	if err == nil {
   415  		metadata["request_headers"] = headers
   416  	}
   417  	if req.Body == nil {
   418  		return
   419  	}
   420  	bodyReader, err := req.GetBody()
   421  	if err == nil {
   422  		bodyBytes, err := ioutil.ReadAll(bodyReader)
   423  		if err == nil {
   424  			// truncates request body to the first 64KB
   425  			if len(bodyBytes) > epsagon.MaxMetadataSize {
   426  				bodyBytes = bodyBytes[0:epsagon.MaxMetadataSize]
   427  			}
   428  			metadata["request_body"] = string(bodyBytes)
   429  		}
   430  	}
   431  }