github.com/goldeneggg/goa@v1.3.1/middleware/xray/segment_test.go (about)

     1  package xray
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"regexp"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  
    14  	"github.com/goadesign/goa"
    15  	"github.com/pkg/errors"
    16  )
    17  
    18  func TestRecordError(t *testing.T) {
    19  	var (
    20  		errMsg       = "foo"
    21  		cause        = "cause"
    22  		inner        = "inner"
    23  		err          = errors.New(errMsg)
    24  		wrapped      = errors.Wrap(err, cause)
    25  		wrappedTwice = errors.Wrap(wrapped, inner)
    26  	)
    27  	cases := map[string]struct {
    28  		Error    error
    29  		Message  string
    30  		HasCause bool
    31  	}{
    32  		"go-error":     {err, errMsg, false},
    33  		"wrapped":      {wrapped, cause + ": " + errMsg, true},
    34  		"wrappedTwice": {wrappedTwice, inner + ": " + cause + ": " + errMsg, true},
    35  	}
    36  	for k, c := range cases {
    37  		s := Segment{Mutex: &sync.Mutex{}}
    38  		s.RecordError(c.Error)
    39  		w := s.Cause.Exceptions[0]
    40  		if w.Message != c.Message {
    41  			t.Errorf("%s: invalid message, expected %s got %s", k, c.Message, w.Message)
    42  		}
    43  		if c.HasCause && len(w.Stack) < 2 {
    44  			t.Errorf("%s: stack too small: %v", k, w.Stack)
    45  		}
    46  		if !s.Error {
    47  			t.Error("s.Error was not set to true")
    48  		}
    49  	}
    50  }
    51  
    52  func TestRecordResponse(t *testing.T) {
    53  	type Res struct {
    54  		Status int
    55  		Body   string
    56  	}
    57  
    58  	cases := map[string]struct {
    59  		Response Res
    60  		Request  *Request
    61  	}{
    62  		"with-HTTP.Request": {
    63  			Response: Res{Status: http.StatusOK, Body: "hello"},
    64  			Request:  &Request{Method: "GET"},
    65  		},
    66  		"without-HTTP.Request": {
    67  			Response: Res{Status: http.StatusOK, Body: "hello"},
    68  			Request:  nil,
    69  		},
    70  	}
    71  
    72  	for k, c := range cases {
    73  		rw := httptest.NewRecorder()
    74  		rw.WriteHeader(c.Response.Status)
    75  		if _, err := rw.WriteString(c.Response.Body); err != nil {
    76  			t.Fatalf("%s: failed to write response body - %s", k, err)
    77  		}
    78  		resp := rw.Result()
    79  		// Fixed in go1.8 with commit
    80  		// https://github.com/golang/go/commit/ea143c299040f8a270fb782c5efd3a3a5e6057a4
    81  		// to stay backwards compatible with go1.7, we set ContentLength manually
    82  		resp.ContentLength = int64(len(c.Response.Body))
    83  
    84  		s := Segment{Mutex: &sync.Mutex{}}
    85  		if c.Request != nil {
    86  			s.HTTP = &HTTP{Request: c.Request}
    87  		}
    88  
    89  		s.RecordResponse(resp)
    90  
    91  		if s.HTTP == nil {
    92  			t.Fatalf("%s: HTTP field is nil", k)
    93  		}
    94  		if s.HTTP.Response == nil {
    95  			t.Fatalf("%s: HTTP Response field is nil", k)
    96  		}
    97  		if s.HTTP.Response.Status != c.Response.Status {
    98  			t.Errorf("%s: HTTP Response Status is invalid, expected %d got %d", k, c.Response.Status, s.HTTP.Response.Status)
    99  		}
   100  		if s.HTTP.Response.ContentLength != int64(len(c.Response.Body)) {
   101  			t.Errorf("%s: HTTP Response ContentLength is invalid, expected %d got %d", k, len(c.Response.Body), s.HTTP.Response.ContentLength)
   102  		}
   103  	}
   104  
   105  }
   106  
   107  func TestRecordRequest(t *testing.T) {
   108  	var (
   109  		method     = "GET"
   110  		ip         = "104.18.42.42"
   111  		remoteAddr = "104.18.43.42:443"
   112  		remoteHost = "104.18.43.42"
   113  		userAgent  = "user agent"
   114  		reqURL, _  = url.Parse("https://goa.design/path?query#fragment")
   115  	)
   116  
   117  	type Req struct {
   118  		Method, Host, IP, RemoteAddr string
   119  		RemoteHost, UserAgent        string
   120  		URL                          *url.URL
   121  	}
   122  
   123  	cases := map[string]struct {
   124  		Request  Req
   125  		Response *Response
   126  	}{
   127  		"with-HTTP.Response": {
   128  			Request:  Req{method, reqURL.Host, ip, remoteAddr, remoteHost, userAgent, reqURL},
   129  			Response: &Response{Status: 200},
   130  		},
   131  		"without-HTTP.Response": {
   132  			Request:  Req{method, reqURL.Host, ip, remoteAddr, remoteHost, userAgent, reqURL},
   133  			Response: nil,
   134  		},
   135  	}
   136  
   137  	for k, c := range cases {
   138  		req, _ := http.NewRequest(method, c.Request.URL.String(), nil)
   139  		req.Header.Set("User-Agent", c.Request.UserAgent)
   140  		req.Header.Set("X-Forwarded-For", c.Request.IP)
   141  		req.RemoteAddr = c.Request.RemoteAddr
   142  		req.Host = c.Request.Host
   143  
   144  		s := Segment{Mutex: &sync.Mutex{}}
   145  		if c.Response != nil {
   146  			s.HTTP = &HTTP{Response: c.Response}
   147  		}
   148  
   149  		s.RecordRequest(req, "remote")
   150  
   151  		if s.Namespace != "remote" {
   152  			t.Errorf("%s: Namespace is invalid, expected %q got %q", k, "remote", s.Namespace)
   153  		}
   154  		if s.HTTP == nil {
   155  			t.Fatalf("%s: HTTP field is nil", k)
   156  		}
   157  		if s.HTTP.Request == nil {
   158  			t.Fatalf("%s: HTTP Request field is nil", k)
   159  		}
   160  		if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP {
   161  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP)
   162  		}
   163  		if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost {
   164  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP)
   165  		}
   166  		if s.HTTP.Request.Method != c.Request.Method {
   167  			t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method)
   168  		}
   169  		expected := strings.Split(c.Request.URL.String(), "?")[0]
   170  		if s.HTTP.Request.URL != expected {
   171  			t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL)
   172  		}
   173  		if s.HTTP.Request.UserAgent != c.Request.UserAgent {
   174  			t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent)
   175  		}
   176  		if c.Response != nil && (s.HTTP.Response == nil || c.Response.Status != s.HTTP.Response.Status) {
   177  			t.Errorf("%s: HTTP Response is invalid, expected %#v got %#v", k, c.Response, s.HTTP.Response)
   178  		}
   179  	}
   180  }
   181  
   182  func TestNewSubsegment(t *testing.T) {
   183  	var (
   184  		name   = "sub"
   185  		s      = &Segment{Mutex: &sync.Mutex{}}
   186  		before = now()
   187  		ss     = s.NewSubsegment(name)
   188  	)
   189  	if s.counter != 1 {
   190  		t.Errorf("counter not incremented after call to Subsegment")
   191  	}
   192  	if len(s.Subsegments) != 1 {
   193  		t.Fatalf("invalid count of subsegments, expected 1 got %d", len(s.Subsegments))
   194  	}
   195  	if s.Subsegments[0] != ss {
   196  		t.Errorf("invalid subsegments element, expected %v - got %v", name, s.Subsegments[0])
   197  	}
   198  	if ss.ID == "" {
   199  		t.Errorf("subsegment ID not initialized")
   200  	}
   201  	if !regexp.MustCompile("[0-9a-f]{16}").MatchString(ss.ID) {
   202  		t.Errorf("invalid subsegment ID, got %v", ss.ID)
   203  	}
   204  	if ss.Name != name {
   205  		t.Errorf("invalid subsegemnt name, expected %s got %s", name, ss.Name)
   206  	}
   207  	if ss.StartTime < before {
   208  		t.Errorf("invalid subsegment StartAt, expected at least %v, got %v", before, ss.StartTime)
   209  	}
   210  	if !ss.InProgress {
   211  		t.Errorf("subsegemnt not in progress")
   212  	}
   213  	if ss.Parent != s {
   214  		t.Errorf("invalid subsegment parent, expected %v, got %v", s, ss.Parent)
   215  	}
   216  }
   217  
   218  // TestRace starts two goroutines and races them to call Segment's public function. In this way, when tests are run
   219  // with the -race flag, race conditions will be detected.
   220  func TestRace(t *testing.T) {
   221  	var (
   222  		rErr   = errors.New("oh no")
   223  		req, _ = http.NewRequest("GET", "https://goa.design", nil)
   224  		resp   = httptest.NewRecorder().Result()
   225  		ctx    = goa.NewContext(context.Background(), httptest.NewRecorder(), req, nil)
   226  	)
   227  
   228  	conn, err := net.Dial("udp", udplisten)
   229  	if err != nil {
   230  		t.Fatalf("failed to connect to daemon - %s", err)
   231  	}
   232  	s := NewSegment("hello", NewTraceID(), NewID(), conn)
   233  
   234  	wg := &sync.WaitGroup{}
   235  	raceFct := func() {
   236  		s.RecordRequest(req, "")
   237  		s.RecordResponse(resp)
   238  		s.RecordContextResponse(ctx)
   239  		s.RecordError(rErr)
   240  
   241  		sub := s.NewSubsegment("sub")
   242  		s.Capture("sub2", func() {})
   243  
   244  		s.AddAnnotation("k1", "v1")
   245  		s.AddInt64Annotation("k2", 2)
   246  		s.AddBoolAnnotation("k3", true)
   247  
   248  		s.AddMetadata("k1", "v1")
   249  		s.AddInt64Metadata("k2", 2)
   250  		s.AddBoolMetadata("k3", true)
   251  
   252  		sub.Close()
   253  		s.Close()
   254  
   255  		wg.Done()
   256  	}
   257  
   258  	for i := 0; i < 2; i++ {
   259  		wg.Add(1)
   260  		go raceFct()
   261  	}
   262  
   263  	wg.Wait()
   264  }