github.com/brycereitano/goa@v0.0.0-20170315073847-8ffa6c85e265/middleware/xray/middleware.go (about)

     1  package xray
     2  
     3  import (
     4  	"crypto/rand"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/goadesign/goa"
    13  	"github.com/goadesign/goa/middleware"
    14  	"golang.org/x/net/context"
    15  )
    16  
    17  const (
    18  	// segKey is the key used to store the segments in the context.
    19  	segKey key = iota + 1
    20  )
    21  
    22  // New returns a middleware that sends AWS X-Ray segments to the daemon running
    23  // at the given address.
    24  //
    25  // service is the name of the service reported to X-Ray. daemon is the hostname
    26  // (including port) of the X-Ray daemon collecting the segments.
    27  //
    28  // The middleware works by extracting the trace information from the context
    29  // using the tracing middleware package. The tracing middleware must be mounted
    30  // first on the service.
    31  //
    32  // The middleware stores the request segment in the context. Use ContextSegment
    33  // to retrieve it. User code can further configure the segment for example to set
    34  // a service version or record an error.
    35  //
    36  // User code may create child segments using the Segment NewSubsegment method
    37  // for tracing requests to external services. Such segments should be closed via
    38  // the Close method once the request completes. The middleware takes care of
    39  // closing the top level segment. Typical usage:
    40  //
    41  //     segment := xray.ContextSegment(ctx)
    42  //     sub := segment.NewSubsegment("external-service")
    43  //     defer sub.Close()
    44  //     err := client.MakeRequest()
    45  //     if err != nil {
    46  //         sub.Error = xray.Wrap(err)
    47  //     }
    48  //     return
    49  //
    50  func New(service, daemon string) (goa.Middleware, error) {
    51  	c, err := net.Dial("udp", daemon)
    52  	if err != nil {
    53  		return nil, fmt.Errorf("xray: failed to connect to daemon - %s", err)
    54  	}
    55  	return func(h goa.Handler) goa.Handler {
    56  		return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    57  			var (
    58  				err     error
    59  				traceID = middleware.ContextTraceID(ctx)
    60  			)
    61  			if traceID == "" {
    62  				// No tracing
    63  				return h(ctx, rw, req)
    64  			}
    65  
    66  			s := newSegment(ctx, traceID, service, req, c)
    67  			ctx = WithSegment(ctx, s)
    68  
    69  			defer func() {
    70  				go record(ctx, s, err)
    71  			}()
    72  
    73  			err = h(ctx, rw, req)
    74  
    75  			return err
    76  		}
    77  	}, nil
    78  }
    79  
    80  // NewID is a span ID creation algorithm which produces values that are
    81  // compatible with AWS X-Ray.
    82  func NewID() string {
    83  	b := make([]byte, 8)
    84  	rand.Read(b)
    85  	return fmt.Sprintf("%x", b)
    86  }
    87  
    88  // NewTraceID is a trace ID creation algorithm which produces values that are
    89  // compatible with AWS X-Ray.
    90  func NewTraceID() string {
    91  	b := make([]byte, 12)
    92  	rand.Read(b)
    93  	return fmt.Sprintf("%d-%x-%s", 1, time.Now().Unix(), fmt.Sprintf("%x", b))
    94  }
    95  
    96  // WithSegment creates a context containing the given segment. Use ContextSegment
    97  // to retrieve it.
    98  func WithSegment(ctx context.Context, s *Segment) context.Context {
    99  	return context.WithValue(ctx, segKey, s)
   100  }
   101  
   102  // ContextSegment extracts the segment set in the context with WithSegment.
   103  func ContextSegment(ctx context.Context) *Segment {
   104  	if s := ctx.Value(segKey); s != nil {
   105  		return s.(*Segment)
   106  	}
   107  	return nil
   108  }
   109  
   110  // newSegment creates a new segment for the incoming request.
   111  func newSegment(ctx context.Context, traceID, name string, req *http.Request, c net.Conn) *Segment {
   112  	var (
   113  		spanID   = middleware.ContextSpanID(ctx)
   114  		parentID = middleware.ContextParentSpanID(ctx)
   115  		h        = &HTTP{Request: requestData(req)}
   116  	)
   117  
   118  	s := NewSegment(name, traceID, spanID, c)
   119  	s.HTTP = h
   120  
   121  	if parentID != "" {
   122  		s.ParentID = parentID
   123  		s.Type = "subsegment"
   124  	}
   125  
   126  	return s
   127  }
   128  
   129  // record finalizes and sends the segment to the X-Ray daemon.
   130  func record(ctx context.Context, s *Segment, err error) {
   131  	resp := goa.ContextResponse(ctx)
   132  	if resp != nil {
   133  		s.Lock()
   134  		switch {
   135  		case resp.Status == 429:
   136  			s.Throttle = true
   137  		case resp.Status >= 500:
   138  			s.Error = true
   139  		}
   140  		s.HTTP.Response = &Response{resp.Status, resp.Length}
   141  		s.Unlock()
   142  	}
   143  	if err != nil {
   144  		fault := false
   145  		if gerr, ok := err.(goa.ServiceError); ok {
   146  			fault = gerr.ResponseStatus() < http.StatusInternalServerError &&
   147  				gerr.ResponseStatus() != http.StatusTooManyRequests
   148  		}
   149  		s.Fault = fault
   150  		s.RecordError(err)
   151  	}
   152  	s.Close()
   153  }
   154  
   155  // requestData creates a Request from a http.Request.
   156  func requestData(req *http.Request) *Request {
   157  	var (
   158  		scheme = "http"
   159  		host   = req.Host
   160  	)
   161  	if len(req.URL.Scheme) > 0 {
   162  		scheme = req.URL.Scheme
   163  	}
   164  	if len(req.URL.Host) > 0 {
   165  		host = req.URL.Host
   166  	}
   167  	return &Request{
   168  		Method:    req.Method,
   169  		URL:       fmt.Sprintf("%s://%s%s", scheme, host, req.URL.Path),
   170  		ClientIP:  getIP(req),
   171  		UserAgent: req.UserAgent(),
   172  	}
   173  }
   174  
   175  // responseData creates a Response from a http.Response.
   176  func responseData(resp *http.Response) *Response {
   177  	var ln int
   178  	if lh := resp.Header.Get("Content-Length"); lh != "" {
   179  		ln, _ = strconv.Atoi(lh)
   180  	}
   181  
   182  	return &Response{
   183  		Status:        resp.StatusCode,
   184  		ContentLength: ln,
   185  	}
   186  }
   187  
   188  // getIP implements a heuristic that returns an origin IP address for a request.
   189  func getIP(req *http.Request) string {
   190  	for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} {
   191  		for _, ip := range strings.Split(req.Header.Get(h), ",") {
   192  			if len(ip) == 0 {
   193  				continue
   194  			}
   195  			realIP := net.ParseIP(strings.Replace(ip, " ", "", -1))
   196  			return realIP.String()
   197  		}
   198  	}
   199  
   200  	// not found in header
   201  	host, _, err := net.SplitHostPort(req.RemoteAddr)
   202  	if err != nil {
   203  		return req.RemoteAddr
   204  	}
   205  	return host
   206  }
   207  
   208  // now returns the current time as a float appropriate for X-Ray processing.
   209  func now() float64 {
   210  	return float64(time.Now().Truncate(time.Millisecond).UnixNano()) / 1e9
   211  }