github.com/ManabuSeki/goa-v1@v1.4.3/middleware/xray/segment_test.go (about)

     1  package xray
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"regexp"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  
    14  	"github.com/goadesign/goa"
    15  	"github.com/pkg/errors"
    16  )
    17  
    18  func TestRecordError(t *testing.T) {
    19  	var (
    20  		errMsg       = "foo"
    21  		cause        = "cause"
    22  		inner        = "inner"
    23  		err          = errors.New(errMsg)
    24  		wrapped      = errors.Wrap(err, cause)
    25  		wrappedTwice = errors.Wrap(wrapped, inner)
    26  	)
    27  	cases := map[string]struct {
    28  		Error    error
    29  		Message  string
    30  		HasCause bool
    31  	}{
    32  		"go-error":     {err, errMsg, false},
    33  		"wrapped":      {wrapped, cause + ": " + errMsg, true},
    34  		"wrappedTwice": {wrappedTwice, inner + ": " + cause + ": " + errMsg, true},
    35  	}
    36  	for k, c := range cases {
    37  		s := Segment{Mutex: &sync.Mutex{}}
    38  		s.RecordError(c.Error)
    39  		w := s.Cause.Exceptions[0]
    40  		if w.Message != c.Message {
    41  			t.Errorf("%s: invalid message, expected %s got %s", k, c.Message, w.Message)
    42  		}
    43  		if c.HasCause && len(w.Stack) < 2 {
    44  			t.Errorf("%s: stack too small: %v", k, w.Stack)
    45  		}
    46  		if !s.Error {
    47  			t.Error("s.Error was not set to true")
    48  		}
    49  	}
    50  }
    51  
    52  func TestRecordResponse(t *testing.T) {
    53  	type Res struct {
    54  		Status int
    55  		Body   string
    56  	}
    57  
    58  	cases := map[string]struct {
    59  		Response Res
    60  		Request  *Request
    61  	}{
    62  		"with-HTTP.Request": {
    63  			Response: Res{Status: http.StatusOK, Body: "hello"},
    64  			Request:  &Request{Method: "GET"},
    65  		},
    66  		"without-HTTP.Request": {
    67  			Response: Res{Status: http.StatusOK, Body: "hello"},
    68  			Request:  nil,
    69  		},
    70  	}
    71  
    72  	for k, c := range cases {
    73  		rw := httptest.NewRecorder()
    74  		rw.WriteHeader(c.Response.Status)
    75  		if _, err := rw.WriteString(c.Response.Body); err != nil {
    76  			t.Fatalf("%s: failed to write response body - %s", k, err)
    77  		}
    78  		resp := rw.Result()
    79  		// Fixed in go1.8 with commit
    80  		// https://github.com/golang/go/commit/ea143c299040f8a270fb782c5efd3a3a5e6057a4
    81  		// to stay backwards compatible with go1.7, we set ContentLength manually
    82  		resp.ContentLength = int64(len(c.Response.Body))
    83  
    84  		s := Segment{Mutex: &sync.Mutex{}}
    85  		if c.Request != nil {
    86  			s.HTTP = &HTTP{Request: c.Request}
    87  		}
    88  
    89  		s.RecordResponse(resp)
    90  
    91  		if s.HTTP == nil {
    92  			t.Fatalf("%s: HTTP field is nil", k)
    93  		}
    94  		if s.HTTP.Response == nil {
    95  			t.Fatalf("%s: HTTP Response field is nil", k)
    96  		}
    97  		if s.HTTP.Response.Status != c.Response.Status {
    98  			t.Errorf("%s: HTTP Response Status is invalid, expected %d got %d", k, c.Response.Status, s.HTTP.Response.Status)
    99  		}
   100  		if s.HTTP.Response.ContentLength != int64(len(c.Response.Body)) {
   101  			t.Errorf("%s: HTTP Response ContentLength is invalid, expected %d got %d", k, len(c.Response.Body), s.HTTP.Response.ContentLength)
   102  		}
   103  	}
   104  
   105  }
   106  
   107  func TestRecordRequest(t *testing.T) {
   108  	var (
   109  		method     = "GET"
   110  		ip         = "104.18.42.42"
   111  		remoteAddr = "104.18.43.42:443"
   112  		remoteHost = "104.18.43.42"
   113  		userAgent  = "user agent"
   114  		reqURL, _  = url.Parse("https://goa.design/path?query#fragment")
   115  	)
   116  
   117  	type Req struct {
   118  		Method, Host, IP, RemoteAddr string
   119  		RemoteHost, UserAgent        string
   120  		URL                          *url.URL
   121  	}
   122  
   123  	cases := map[string]struct {
   124  		Request  Req
   125  		Response *Response
   126  	}{
   127  		"with-HTTP.Response": {
   128  			Request:  Req{method, reqURL.Host, ip, remoteAddr, remoteHost, userAgent, reqURL},
   129  			Response: &Response{Status: 200},
   130  		},
   131  		"without-HTTP.Response": {
   132  			Request:  Req{method, reqURL.Host, ip, remoteAddr, remoteHost, userAgent, reqURL},
   133  			Response: nil,
   134  		},
   135  	}
   136  
   137  	for k, c := range cases {
   138  		req, _ := http.NewRequest(method, c.Request.URL.String(), nil)
   139  		req.Header.Set("User-Agent", c.Request.UserAgent)
   140  		req.Header.Set("X-Forwarded-For", c.Request.IP)
   141  		req.RemoteAddr = c.Request.RemoteAddr
   142  		req.Host = c.Request.Host
   143  
   144  		s := Segment{Mutex: &sync.Mutex{}}
   145  		if c.Response != nil {
   146  			s.HTTP = &HTTP{Response: c.Response}
   147  		}
   148  
   149  		s.RecordRequest(req, "remote")
   150  
   151  		if s.Namespace != "remote" {
   152  			t.Errorf("%s: Namespace is invalid, expected %q got %q", k, "remote", s.Namespace)
   153  		}
   154  		if s.HTTP == nil {
   155  			t.Fatalf("%s: HTTP field is nil", k)
   156  		}
   157  		if s.HTTP.Request == nil {
   158  			t.Fatalf("%s: HTTP Request field is nil", k)
   159  		}
   160  		if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP {
   161  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP)
   162  		}
   163  		if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost {
   164  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP)
   165  		}
   166  		if s.HTTP.Request.Method != c.Request.Method {
   167  			t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method)
   168  		}
   169  		expected := strings.Split(c.Request.URL.String(), "?")[0]
   170  		if s.HTTP.Request.URL != expected {
   171  			t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL)
   172  		}
   173  		if s.HTTP.Request.UserAgent != c.Request.UserAgent {
   174  			t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent)
   175  		}
   176  		if c.Response != nil && (s.HTTP.Response == nil || c.Response.Status != s.HTTP.Response.Status) {
   177  			t.Errorf("%s: HTTP Response is invalid, expected %#v got %#v", k, c.Response, s.HTTP.Response)
   178  		}
   179  	}
   180  }
   181  
   182  func TestSegment_NewSubsegment(t *testing.T) {
   183  	conn, err := net.Dial("udp", udplisten)
   184  	if err != nil {
   185  		t.Fatalf("failed to connect to daemon - %s", err)
   186  	}
   187  	var (
   188  		name   = "sub"
   189  		s      = &Segment{Mutex: &sync.Mutex{}, conn: conn}
   190  		before = now()
   191  		ss     = s.NewSubsegment(name)
   192  	)
   193  	if ss.submittedInProgressSegment {
   194  		t.Errorf("subsegment submittedInProgressSegment should initially be false")
   195  	}
   196  	if ss.ID == "" {
   197  		t.Errorf("subsegment ID not initialized")
   198  	}
   199  	if !regexp.MustCompile("[0-9a-f]{16}").MatchString(ss.ID) {
   200  		t.Errorf("invalid subsegment ID, got %v", ss.ID)
   201  	}
   202  	if ss.Name != name {
   203  		t.Errorf("invalid subsegemnt name, expected %s got %s", name, ss.Name)
   204  	}
   205  	if ss.StartTime < before {
   206  		t.Errorf("invalid subsegment StartAt, expected at least %v, got %v", before, ss.StartTime)
   207  	}
   208  	if !ss.InProgress {
   209  		t.Errorf("subsegemnt not in progress")
   210  	}
   211  	if ss.Parent != s {
   212  		t.Errorf("invalid subsegment parent, expected %v, got %v", s, ss.Parent)
   213  	}
   214  }
   215  
   216  func TestSegment_SubmitInProgress(t *testing.T) {
   217  	t.Run("call twice then close -- second call is ignored", func(t *testing.T) {
   218  		conn, err := net.Dial("udp", udplisten)
   219  		if err != nil {
   220  			t.Fatalf("failed to connect to daemon - %s", err)
   221  		}
   222  
   223  		segment := NewSegment("hello", NewTraceID(), NewID(), conn)
   224  
   225  		// call SubmitInProgress() twice, then Close it
   226  		messages := readUDP(t, 2, func() {
   227  			segment.Namespace = "1"
   228  			segment.SubmitInProgress()
   229  			segment.Namespace = "2"
   230  			segment.SubmitInProgress() // should have no effect
   231  			segment.Namespace = "3"
   232  			segment.Close()
   233  		})
   234  
   235  		// verify the In-Progress segment
   236  		s := extractSegment(t, messages[0])
   237  		if !s.InProgress {
   238  			t.Errorf("expected segment to be InProgress, but it's not")
   239  		}
   240  		if s.Namespace != "1" {
   241  			t.Errorf("unexpected segment namespace, expected %q got %q", "1", s.Namespace)
   242  		}
   243  
   244  		// verify the final segment (the second In-Progress segment would not have been sent)
   245  		s = extractSegment(t, messages[1])
   246  		if s.InProgress {
   247  			t.Errorf("expected segment to not be InProgress, but it is")
   248  		}
   249  		if s.Namespace != "3" {
   250  			t.Errorf("unexpected segment namespace, expected %q got %q", "3", s.Namespace)
   251  		}
   252  	})
   253  
   254  	t.Run("calling after already Closed -- no effect", func(t *testing.T) {
   255  		conn, err := net.Dial("udp", udplisten)
   256  		if err != nil {
   257  			t.Fatalf("failed to connect to daemon - %s", err)
   258  		}
   259  
   260  		segment := NewSegment("hello", NewTraceID(), NewID(), conn)
   261  
   262  		// Close(), then call SubmitInProgress(), only expect 1 segment written
   263  		messages := readUDP(t, 1, func() {
   264  			segment.Namespace = "1"
   265  			segment.Close()
   266  			segment.Namespace = "2"
   267  			segment.SubmitInProgress() // should have no effect
   268  		})
   269  
   270  		// verify the In-Progress segment
   271  		s := extractSegment(t, messages[0])
   272  		if s.InProgress {
   273  			t.Errorf("expected segment to be closed, but it is still InProgress")
   274  		}
   275  		if s.Namespace != "1" {
   276  			t.Errorf("unexpected segment namespace, expected %q got %q", "1", s.Namespace)
   277  		}
   278  	})
   279  }
   280  
   281  // TestRace starts two goroutines and races them to call Segment's public function. In this way, when tests are run
   282  // with the -race flag, race conditions will be detected.
   283  func TestRace(t *testing.T) {
   284  	var (
   285  		rErr   = errors.New("oh no")
   286  		req, _ = http.NewRequest("GET", "https://goa.design", nil)
   287  		resp   = httptest.NewRecorder().Result()
   288  		ctx    = goa.NewContext(context.Background(), httptest.NewRecorder(), req, nil)
   289  	)
   290  
   291  	conn, err := net.Dial("udp", udplisten)
   292  	if err != nil {
   293  		t.Fatalf("failed to connect to daemon - %s", err)
   294  	}
   295  	s := NewSegment("hello", NewTraceID(), NewID(), conn)
   296  
   297  	wg := &sync.WaitGroup{}
   298  	raceFct := func() {
   299  		s.RecordRequest(req, "")
   300  		s.RecordResponse(resp)
   301  		s.RecordContextResponse(ctx)
   302  		s.RecordError(rErr)
   303  		s.SubmitInProgress()
   304  
   305  		sub := s.NewSubsegment("sub")
   306  		s.Capture("sub2", func() {})
   307  
   308  		s.AddAnnotation("k1", "v1")
   309  		s.AddInt64Annotation("k2", 2)
   310  		s.AddBoolAnnotation("k3", true)
   311  
   312  		s.AddMetadata("k1", "v1")
   313  		s.AddInt64Metadata("k2", 2)
   314  		s.AddBoolMetadata("k3", true)
   315  
   316  		sub.Close()
   317  		s.Close()
   318  
   319  		wg.Done()
   320  	}
   321  
   322  	for i := 0; i < 2; i++ {
   323  		wg.Add(1)
   324  		go raceFct()
   325  	}
   326  
   327  	wg.Wait()
   328  }