github.com/zak-blake/goa@v1.4.1/middleware/tracer_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  )
    10  
    11  func TestNewTracer(t *testing.T) {
    12  	// valid sampling percentage
    13  	{
    14  		cases := map[string]struct{ Rate int }{
    15  			"zero":  {0},
    16  			"one":   {1},
    17  			"fifty": {50},
    18  			"100":   {100},
    19  		}
    20  		for k, c := range cases {
    21  			m := Tracer(c.Rate, shortID, shortID)
    22  			if m == nil {
    23  				t.Errorf("%s: Tracer return nil", k)
    24  			}
    25  			m = NewTracer(SamplingPercent(c.Rate))
    26  			if m == nil {
    27  				t.Errorf("%s: NewTracer return nil", k)
    28  			}
    29  		}
    30  	}
    31  
    32  	// valid adaptive sampler tests
    33  	{
    34  		m := NewTracer(MaxSamplingRate(2))
    35  		if m == nil {
    36  			t.Error("NewTracer return nil")
    37  		}
    38  		m = NewTracer(MaxSamplingRate(5), SampleSize(100))
    39  		if m == nil {
    40  			t.Error("NewTracer return nil")
    41  		}
    42  	}
    43  
    44  	// invalid sampling percentage
    45  	{
    46  		cases := map[string]struct{ SamplingPercentage int }{
    47  			"negative":  {-1},
    48  			"one-o-one": {101},
    49  			"maxint":    {math.MaxInt64},
    50  		}
    51  
    52  		for k, c := range cases {
    53  			func() {
    54  				defer func() {
    55  					r := recover()
    56  					if r != "sampling rate must be between 0 and 100" {
    57  						t.Errorf("%s: Tracer did *not* panic as expected: %v", k, r)
    58  					}
    59  				}()
    60  				Tracer(c.SamplingPercentage, shortID, shortID)
    61  			}()
    62  			func() {
    63  				defer func() {
    64  					r := recover()
    65  					if r != "sampling rate must be between 0 and 100" {
    66  						t.Errorf("%s: NewTracer did *not* panic as expected: %v", k, r)
    67  					}
    68  				}()
    69  				NewTracer(SamplingPercent(c.SamplingPercentage))
    70  			}()
    71  		}
    72  	}
    73  
    74  	// invalid max sampling rate
    75  	{
    76  		cases := map[string]struct{ MaxSamplingRate int }{
    77  			"negative": {-1},
    78  			"zero":     {0},
    79  		}
    80  		for k, c := range cases {
    81  			func() {
    82  				defer func() {
    83  					r := recover()
    84  					if r != "max sampling rate must be greater than 0" {
    85  						t.Errorf("%s: NewTracer did *not* panic as expected: %v", k, r)
    86  					}
    87  				}()
    88  				NewTracer(MaxSamplingRate(c.MaxSamplingRate))
    89  			}()
    90  		}
    91  	}
    92  
    93  	// invalid sample size
    94  	{
    95  		cases := map[string]struct{ SampleSize int }{
    96  			"negative": {-1},
    97  			"zero":     {0},
    98  		}
    99  		for k, c := range cases {
   100  			func() {
   101  				defer func() {
   102  					r := recover()
   103  					if r != "sample size must be greater than 0" {
   104  						t.Errorf("%s: NewTracer did *not* panic as expected: %v", k, r)
   105  					}
   106  				}()
   107  				NewTracer(SampleSize(c.SampleSize))
   108  			}()
   109  		}
   110  	}
   111  }
   112  
   113  func TestTracerMiddleware(t *testing.T) {
   114  	var (
   115  		traceID    = "testTraceID"
   116  		spanID     = "testSpanID"
   117  		newTraceID = func() string { return traceID }
   118  		newID      = func() string { return spanID }
   119  	)
   120  
   121  	cases := map[string]struct {
   122  		Rate                  int
   123  		TraceID, ParentSpanID string
   124  		// output
   125  		CtxTraceID, CtxSpanID, CtxParentID string
   126  	}{
   127  		"no-trace": {100, "", "", traceID, spanID, ""},
   128  		"trace":    {100, "trace", "", "trace", spanID, ""},
   129  		"parent":   {100, "trace", "parent", "trace", spanID, "parent"},
   130  
   131  		"zero-rate-no-trace": {0, "", "", "", "", ""},
   132  		"zero-rate-trace":    {0, "trace", "", "trace", spanID, ""},
   133  		"zero-rate-parent":   {0, "trace", "parent", "trace", spanID, "parent"},
   134  	}
   135  
   136  	for k, c := range cases {
   137  		var (
   138  			ctxTraceID, ctxSpanID, ctxParentID string
   139  
   140  			m = Tracer(c.Rate, newID, newTraceID)
   141  			h = func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
   142  				ctxTraceID = ContextTraceID(ctx)
   143  				ctxSpanID = ContextSpanID(ctx)
   144  				ctxParentID = ContextParentSpanID(ctx)
   145  				return nil
   146  			}
   147  			headers = make(http.Header)
   148  			ctx     = context.Background()
   149  		)
   150  		if c.TraceID != "" {
   151  			headers.Set(TraceIDHeader, c.TraceID)
   152  		}
   153  		if c.ParentSpanID != "" {
   154  			headers.Set(ParentSpanIDHeader, c.ParentSpanID)
   155  		}
   156  		req, _ := http.NewRequest("GET", "/", nil)
   157  		req.Header = headers
   158  
   159  		m(h)(ctx, httptest.NewRecorder(), req)
   160  
   161  		if ctxTraceID != c.CtxTraceID {
   162  			t.Errorf("%s: invalid TraceID, expected %v - got %v", k, c.CtxTraceID, ctxTraceID)
   163  		}
   164  		if ctxSpanID != c.CtxSpanID {
   165  			t.Errorf("%s: invalid SpanID, expected %v - got %v", k, c.CtxSpanID, ctxSpanID)
   166  		}
   167  		if ctxParentID != c.CtxParentID {
   168  			t.Errorf("%s: invalid ParentSpanID, expected %v - got %v", k, c.CtxParentID, ctxParentID)
   169  		}
   170  	}
   171  }