github.com/goldeneggg/goa@v1.3.1/middleware/tracer.go (about)

     1  package middleware
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  
     7  	"github.com/goadesign/goa"
     8  	"github.com/goadesign/goa/client"
     9  )
    10  
    11  var (
    12  	// TraceIDHeader is the name of the HTTP request header containing the
    13  	// current TraceID if any.
    14  	TraceIDHeader = "TraceID"
    15  
    16  	// ParentSpanIDHeader is the name of the HTTP request header containing
    17  	// the parent span ID if any.
    18  	ParentSpanIDHeader = "ParentSpanID"
    19  )
    20  
    21  type (
    22  	// IDFunc is a function that produces span and trace IDs for cosumption by
    23  	// tracing systems such as Zipkin or AWS X-Ray.
    24  	IDFunc func() string
    25  
    26  	// TracerOption is a constructor option that makes it possible to customize
    27  	// the middleware.
    28  	TracerOption func(*tracerOptions) *tracerOptions
    29  
    30  	// tracerOptions is the struct storing all the options.
    31  	tracerOptions struct {
    32  		traceIDFunc     IDFunc
    33  		spanIDFunc      IDFunc
    34  		samplingPercent int
    35  		maxSamplingRate int
    36  		sampleSize      int
    37  	}
    38  
    39  	// tracedDoer is a goa client Doer that inserts the tracing headers for
    40  	// each request it makes.
    41  	tracedDoer struct {
    42  		client.Doer
    43  	}
    44  )
    45  
    46  // TraceIDFunc is a constructor option that overrides the function used to
    47  // compute trace IDs.
    48  func TraceIDFunc(f IDFunc) TracerOption {
    49  	return func(o *tracerOptions) *tracerOptions {
    50  		if f == nil {
    51  			panic("trace ID function cannot be nil")
    52  		}
    53  		o.traceIDFunc = f
    54  		return o
    55  	}
    56  }
    57  
    58  // SpanIDFunc is a constructor option that overrides the function used to
    59  // compute span IDs.
    60  func SpanIDFunc(f IDFunc) TracerOption {
    61  	return func(o *tracerOptions) *tracerOptions {
    62  		if f == nil {
    63  			panic("span ID function cannot be nil")
    64  		}
    65  		o.spanIDFunc = f
    66  		return o
    67  	}
    68  }
    69  
    70  // SamplingPercent sets the tracing sampling rate as a percentage value.
    71  // It panics if p is less than 0 or more than 100.
    72  // SamplingPercent and MaxSamplingRate are mutually exclusive.
    73  func SamplingPercent(p int) TracerOption {
    74  	if p < 0 || p > 100 {
    75  		panic("sampling rate must be between 0 and 100")
    76  	}
    77  	return func(o *tracerOptions) *tracerOptions {
    78  		o.samplingPercent = p
    79  		return o
    80  	}
    81  }
    82  
    83  // MaxSamplingRate sets a target sampling rate in requests per second. Setting a
    84  // max sampling rate causes the middleware to adjust the sampling percent
    85  // dynamically.
    86  // SamplingPercent and MaxSamplingRate are mutually exclusive.
    87  func MaxSamplingRate(r int) TracerOption {
    88  	if r <= 0 {
    89  		panic("max sampling rate must be greater than 0")
    90  	}
    91  	return func(o *tracerOptions) *tracerOptions {
    92  		o.maxSamplingRate = r
    93  		return o
    94  	}
    95  }
    96  
    97  // SampleSize sets the number of requests between two adjustments of the sampling
    98  // rate when MaxSamplingRate is set. Defaults to 1,000.
    99  func SampleSize(s int) TracerOption {
   100  	if s <= 0 {
   101  		panic("sample size must be greater than 0")
   102  	}
   103  	return func(o *tracerOptions) *tracerOptions {
   104  		o.sampleSize = s
   105  		return o
   106  	}
   107  }
   108  
   109  // NewTracer returns a trace middleware that initializes the trace information
   110  // in the request context. The information can be retrieved using any of the
   111  // ContextXXX functions.
   112  //
   113  // samplingPercent must be a value between 0 and 100. It represents the percentage
   114  // of requests that should be traced. If the incoming request has a Trace ID
   115  // header then the sampling rate is disregarded and the tracing is enabled.
   116  //
   117  // spanIDFunc and traceIDFunc are the functions used to create Span and Trace
   118  // IDs respectively. This is configurable so that the created IDs are compatible
   119  // with the various backend tracing systems. The xray package provides
   120  // implementations that produce AWS X-Ray compatible IDs.
   121  func NewTracer(opts ...TracerOption) goa.Middleware {
   122  	o := &tracerOptions{
   123  		traceIDFunc:     shortID,
   124  		spanIDFunc:      shortID,
   125  		samplingPercent: 100,
   126  		sampleSize:      1000, // only applies if maxSamplingRate is set
   127  	}
   128  	for _, opt := range opts {
   129  		o = opt(o)
   130  	}
   131  	var sampler Sampler
   132  	if o.maxSamplingRate > 0 {
   133  		sampler = NewAdaptiveSampler(o.maxSamplingRate, o.sampleSize)
   134  	} else {
   135  		sampler = NewFixedSampler(o.samplingPercent)
   136  	}
   137  	return func(h goa.Handler) goa.Handler {
   138  		return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
   139  			// insert a new trace ID only if not already being traced.
   140  			traceID := req.Header.Get(TraceIDHeader)
   141  			if traceID == "" {
   142  				// insert tracing only within sample.
   143  				if sampler.Sample() {
   144  					traceID = o.traceIDFunc()
   145  				} else {
   146  					return h(ctx, rw, req)
   147  				}
   148  			}
   149  
   150  			// insert IDs into context to enable tracing.
   151  			spanID := o.spanIDFunc()
   152  			parentID := req.Header.Get(ParentSpanIDHeader)
   153  			ctx = WithTrace(ctx, traceID, spanID, parentID)
   154  			return h(ctx, rw, req)
   155  		}
   156  	}
   157  }
   158  
   159  // Tracer is deprecated in favor of NewTracer.
   160  func Tracer(sampleRate int, spanIDFunc, traceIDFunc IDFunc) goa.Middleware {
   161  	return NewTracer(SamplingPercent(sampleRate), SpanIDFunc(spanIDFunc), TraceIDFunc(traceIDFunc))
   162  }
   163  
   164  // TraceDoer wraps a goa client Doer and sets the trace headers so that the
   165  // downstream service may properly retrieve the parent span ID and trace ID.
   166  func TraceDoer(doer client.Doer) client.Doer {
   167  	return &tracedDoer{doer}
   168  }
   169  
   170  // ContextTraceID returns the trace ID extracted from the given context if any,
   171  // the empty string otherwise.
   172  func ContextTraceID(ctx context.Context) string {
   173  	if t := ctx.Value(traceKey); t != nil {
   174  		return t.(string)
   175  	}
   176  	return ""
   177  }
   178  
   179  // ContextSpanID returns the span ID extracted from the given context if any,
   180  // the empty string otherwise.
   181  func ContextSpanID(ctx context.Context) string {
   182  	if s := ctx.Value(spanKey); s != nil {
   183  		return s.(string)
   184  	}
   185  	return ""
   186  }
   187  
   188  // ContextParentSpanID returns the parent span ID extracted from the given
   189  // context if any, the empty string otherwise.
   190  func ContextParentSpanID(ctx context.Context) string {
   191  	if p := ctx.Value(parentSpanKey); p != nil {
   192  		return p.(string)
   193  	}
   194  	return ""
   195  }
   196  
   197  // WithTrace returns a context containing the given trace, span and parent span
   198  // IDs.
   199  func WithTrace(ctx context.Context, traceID, spanID, parentID string) context.Context {
   200  	if parentID != "" {
   201  		ctx = context.WithValue(ctx, parentSpanKey, parentID)
   202  	}
   203  	ctx = context.WithValue(ctx, traceKey, traceID)
   204  	ctx = context.WithValue(ctx, spanKey, spanID)
   205  	return ctx
   206  }
   207  
   208  // Do adds the tracing headers to the requests before making it.
   209  func (d *tracedDoer) Do(ctx context.Context, req *http.Request) (*http.Response, error) {
   210  	var (
   211  		traceID = ContextTraceID(ctx)
   212  		spanID  = ContextSpanID(ctx)
   213  	)
   214  	if traceID != "" {
   215  		req.Header.Set(TraceIDHeader, traceID)
   216  		req.Header.Set(ParentSpanIDHeader, spanID)
   217  	}
   218  
   219  	return d.Doer.Do(ctx, req)
   220  }