github.com/ManabuSeki/goa-v1@v1.4.3/middleware/xray/transport_test.go (about)

     1  package xray
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  	"strings"
    11  	"testing"
    12  
    13  	"context"
    14  )
    15  
    16  type mockRoundTripper struct {
    17  	Callback func(*http.Request) (*http.Response, error)
    18  }
    19  
    20  func (mrt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    21  	return mrt.Callback(req)
    22  }
    23  
    24  func TestTransportExample(t *testing.T) {
    25  	var (
    26  		responseBody = "good morning"
    27  	)
    28  	server := httptest.NewServer(http.HandlerFunc(
    29  		func(rw http.ResponseWriter, req *http.Request) {
    30  			rw.WriteHeader(http.StatusOK)
    31  
    32  			rw.Write([]byte(responseBody))
    33  		}))
    34  	defer server.Close()
    35  
    36  	conn, err := net.Dial("udp", udplisten)
    37  	if err != nil {
    38  		t.Fatalf("failed to connect to daemon - %s", err)
    39  	}
    40  
    41  	//
    42  	// Wrap http client's Transport with xray tracing
    43  	httpClient := &http.Client{
    44  		Transport: WrapTransport(http.DefaultTransport),
    45  	}
    46  
    47  	//
    48  	// Setup context
    49  	parentSegment := NewSegment("hello", NewTraceID(), NewID(), conn)
    50  	ctx := WithSegment(context.Background(), parentSegment)
    51  
    52  	//
    53  	// make Request
    54  	req, err := http.NewRequest("GET", server.URL, nil)
    55  	if err != nil {
    56  		t.Fatalf("NewRequest returned error - %s", err)
    57  	}
    58  	// Setting context on request
    59  	req = req.WithContext(ctx)
    60  
    61  	messages := readUDP(t, 2, func() {
    62  		resp, err := httpClient.Do(req)
    63  		if err != nil {
    64  			t.Fatalf("failed to make request - %s", err)
    65  		}
    66  		if resp.StatusCode != http.StatusOK {
    67  			t.Errorf("HTTP Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode)
    68  		}
    69  	})
    70  
    71  	// expect the first message is InProgress
    72  	s := extractSegment(t, messages[0])
    73  	if !s.InProgress {
    74  		t.Fatalf("expected first segment to be InProgress but it was not")
    75  	}
    76  
    77  	//
    78  	// Verify
    79  	s = extractSegment(t, messages[1])
    80  	url, _ := url.Parse(server.URL)
    81  	if s.Name != url.Host {
    82  		t.Errorf("unexpected segment name, expected %q - got %q", url.Host, s.Name)
    83  	}
    84  	if s.ParentID != parentSegment.ID {
    85  		t.Errorf("unexpected ParentID, expect %q - got %q", parentSegment.ID, s.ParentID)
    86  	}
    87  	if s.HTTP.Response.ContentLength != int64(len(responseBody)) {
    88  		t.Errorf("unexpected ContentLength, expect %d - got %d", len(responseBody), s.HTTP.Response.ContentLength)
    89  	}
    90  }
    91  
    92  func TestTransportNoSegmentInContext(t *testing.T) {
    93  	var (
    94  		url, _ = url.Parse("https://goa.design/path?query#fragment")
    95  		req, _ = http.NewRequest("GET", url.String(), nil)
    96  		rw     = httptest.NewRecorder()
    97  		rt     = &mockRoundTripper{func(*http.Request) (*http.Response, error) {
    98  
    99  			rw.WriteHeader(http.StatusOK)
   100  
   101  			return rw.Result(), nil
   102  		}}
   103  	)
   104  
   105  	resp, err := WrapTransport(rt).RoundTrip(req)
   106  	if err != nil {
   107  		t.Errorf("Expected no error got %s", err)
   108  	}
   109  	if resp.StatusCode != http.StatusOK {
   110  		t.Errorf("Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode)
   111  	}
   112  }
   113  
   114  func TestTransport(t *testing.T) {
   115  	type (
   116  		Tra struct {
   117  			TraceID, SpanID string
   118  		}
   119  		Req struct {
   120  			Method, Host, IP, RemoteAddr string
   121  			RemoteHost, UserAgent        string
   122  			URL                          *url.URL
   123  		}
   124  		Res struct {
   125  			Status int
   126  			Body   string
   127  		}
   128  		Seg struct {
   129  			Exception string
   130  			Error     bool
   131  		}
   132  	)
   133  	var (
   134  		traceID    = "traceID"
   135  		spanID     = "spanID"
   136  		host       = "goa.design"
   137  		method     = "GET"
   138  		ip         = "104.18.42.42"
   139  		remoteAddr = "104.18.43.42:443"
   140  		remoteHost = "104.18.43.42"
   141  		agent      = "user agent"
   142  		url, _     = url.Parse("https://goa.design/path?query#fragment")
   143  	)
   144  	cases := map[string]struct {
   145  		Trace    Tra
   146  		Request  Req
   147  		Response *Res
   148  		Segment  Seg
   149  	}{
   150  		"basic": {
   151  			Trace:    Tra{traceID, spanID},
   152  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   153  			Response: &Res{http.StatusOK, "test"},
   154  			Segment:  Seg{"", false},
   155  		},
   156  		"badRequest": {
   157  			Trace:    Tra{traceID, spanID},
   158  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   159  			Response: &Res{http.StatusBadRequest, "payload not valid"},
   160  			Segment:  Seg{"", false},
   161  		},
   162  		"fault": {
   163  			Trace:    Tra{traceID, spanID},
   164  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   165  			Response: &Res{http.StatusInternalServerError, ""},
   166  			Segment:  Seg{"", true},
   167  		},
   168  		"error": {
   169  			Trace:    Tra{traceID, spanID},
   170  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   171  			Response: nil,
   172  			Segment:  Seg{"some error", true},
   173  		},
   174  	}
   175  	for k, c := range cases {
   176  		conn, err := net.Dial("udp", udplisten)
   177  		if err != nil {
   178  			t.Fatalf("%s: failed to connect to daemon - %s", k, err)
   179  		}
   180  
   181  		var (
   182  			parent = NewSegment(k, c.Trace.TraceID, c.Trace.SpanID, conn)
   183  			req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil)
   184  			rw     = httptest.NewRecorder()
   185  			rt     = &mockRoundTripper{func(*http.Request) (*http.Response, error) {
   186  
   187  				if c.Segment.Exception != "" {
   188  					return nil, errors.New(c.Segment.Exception)
   189  				}
   190  
   191  				rw.WriteHeader(c.Response.Status)
   192  				if _, err := rw.WriteString(c.Response.Body); err != nil {
   193  					t.Fatalf("%s: failed to write response body - %s", k, err)
   194  				}
   195  				rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(c.Response.Body)))
   196  				res := rw.Result()
   197  
   198  				// Fixed in go1.8 with commit
   199  				// https://github.com/golang/go/commit/ea143c299040f8a270fb782c5efd3a3a5e6057a4
   200  				// to stay backwards compatible with go1.7, we set ContentLength manually
   201  				res.ContentLength = int64(len(c.Response.Body))
   202  
   203  				return res, nil
   204  			}}
   205  		)
   206  
   207  		req = req.WithContext(WithSegment(context.Background(), parent))
   208  
   209  		if c.Request.UserAgent != "" {
   210  			req.Header.Set("User-Agent", c.Request.UserAgent)
   211  		}
   212  		if c.Request.IP != "" {
   213  			req.Header.Set("X-Forwarded-For", c.Request.IP)
   214  		}
   215  		if c.Request.RemoteAddr != "" {
   216  			req.RemoteAddr = c.Request.RemoteAddr
   217  		}
   218  		if c.Request.Host != "" {
   219  			req.Host = c.Request.Host
   220  		}
   221  
   222  		messages := readUDP(t, 2, func() {
   223  			resp, err := WrapTransport(rt).RoundTrip(req)
   224  			if c.Segment.Exception == "" && err != nil {
   225  				t.Errorf("%s: Expected no error got %s", k, err)
   226  			}
   227  			if c.Response != nil && resp.StatusCode != c.Response.Status {
   228  				t.Errorf("%s: Response Status is invalid, expected %d got %d", k, c.Response.Status, resp.StatusCode)
   229  			}
   230  		})
   231  		// expect the first message is InProgress
   232  		s := extractSegment(t, messages[0])
   233  		if !s.InProgress {
   234  			t.Errorf("%s: expected first segment to be InProgress but it was not", k)
   235  		}
   236  
   237  		// second message
   238  		s = extractSegment(t, messages[1])
   239  		if s.Name != host {
   240  			t.Errorf("%s: unexpected segment name, expected %q - got %q", k, host, s.Name)
   241  		}
   242  		if c.Trace.SpanID != s.ParentID {
   243  			t.Errorf("%s: unexpected ParentID, expect %q - got %q", k, c.Trace.SpanID, s.ParentID)
   244  		}
   245  		if s.Type != "subsegment" {
   246  			t.Errorf("%s: expected Type to be 'subsegment' but got %q", k, s.Type)
   247  		}
   248  		if s.ID == "" {
   249  			t.Errorf("%s: segment ID not set", k)
   250  		}
   251  		if s.TraceID != c.Trace.TraceID {
   252  			t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID)
   253  		}
   254  		if s.StartTime == 0 {
   255  			t.Errorf("%s: StartTime is 0", k)
   256  		}
   257  		if s.EndTime == 0 {
   258  			t.Errorf("%s: EndTime is 0", k)
   259  		}
   260  		if s.StartTime > s.EndTime {
   261  			t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime)
   262  		}
   263  		if s.HTTP == nil {
   264  			t.Fatalf("%s: HTTP field is nil", k)
   265  		}
   266  		if s.HTTP.Request == nil {
   267  			t.Fatalf("%s: HTTP Request field is nil", k)
   268  		}
   269  		if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP {
   270  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP)
   271  		}
   272  		if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost {
   273  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP)
   274  		}
   275  		if s.HTTP.Request.Method != c.Request.Method {
   276  			t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method)
   277  		}
   278  		expected := strings.Split(c.Request.URL.String(), "?")[0]
   279  		if s.HTTP.Request.URL != expected {
   280  			t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL)
   281  		}
   282  		if s.HTTP.Request.UserAgent != c.Request.UserAgent {
   283  			t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent)
   284  		}
   285  		if c.Response != nil && s.HTTP.Response.Status != c.Response.Status {
   286  			t.Errorf("%s: HTTP Response Status is invalid, expected %d got %d", k, c.Response.Status, s.HTTP.Response.Status)
   287  		}
   288  		if c.Response != nil && s.HTTP.Response.ContentLength != int64(len(c.Response.Body)) {
   289  			t.Errorf("%s: HTTP Response ContentLength is invalid, expected %d got %d", k, len(c.Response.Body), s.HTTP.Response.ContentLength)
   290  		}
   291  		if s.Cause == nil && c.Segment.Exception != "" {
   292  			t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception)
   293  		}
   294  		if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception {
   295  			t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message)
   296  		}
   297  		if s.Error != c.Segment.Error {
   298  			t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error)
   299  		}
   300  	}
   301  }