github.com/goldeneggg/goa@v1.3.1/middleware/xray/transport_test.go (about)

     1  package xray
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"strings"
    12  	"testing"
    13  
    14  	"context"
    15  )
    16  
    17  type mockRoundTripper struct {
    18  	Callback func(*http.Request) (*http.Response, error)
    19  }
    20  
    21  func (mrt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    22  	return mrt.Callback(req)
    23  }
    24  
    25  func TestTransportExample(t *testing.T) {
    26  	var (
    27  		responseBody = "good morning"
    28  	)
    29  	server := httptest.NewServer(http.HandlerFunc(
    30  		func(rw http.ResponseWriter, req *http.Request) {
    31  			rw.WriteHeader(http.StatusOK)
    32  
    33  			rw.Write([]byte(responseBody))
    34  		}))
    35  	defer server.Close()
    36  
    37  	conn, err := net.Dial("udp", udplisten)
    38  	if err != nil {
    39  		t.Fatalf("failed to connect to daemon - %s", err)
    40  	}
    41  
    42  	//
    43  	// Wrap http client's Transport with xray tracing
    44  	httpClient := &http.Client{
    45  		Transport: WrapTransport(http.DefaultTransport),
    46  	}
    47  
    48  	//
    49  	// Setup context
    50  	parentSegment := NewSegment("hello", NewTraceID(), NewID(), conn)
    51  	ctx := WithSegment(context.Background(), parentSegment)
    52  
    53  	//
    54  	// make Request
    55  	req, err := http.NewRequest("GET", server.URL, nil)
    56  	if err != nil {
    57  		t.Fatalf("NewRequest returned error - %s", err)
    58  	}
    59  	// Setting context on request
    60  	req = req.WithContext(ctx)
    61  
    62  	js := readUDP(t, func() {
    63  		resp, err := httpClient.Do(req)
    64  		if err != nil {
    65  			t.Fatalf("failed to make request - %s", err)
    66  		}
    67  		if resp.StatusCode != http.StatusOK {
    68  			t.Errorf("HTTP Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode)
    69  		}
    70  	})
    71  
    72  	//
    73  	// Verify
    74  	var s *Segment
    75  	elems := strings.Split(js, "\n")
    76  	if len(elems) != 2 {
    77  		t.Fatalf("invalid number of lines, expected 2 got %d: %v", len(elems), elems)
    78  	}
    79  	if elems[0] != udpHeader[:len(udpHeader)-1] {
    80  		t.Errorf("invalid header, got %s", elems[0])
    81  	}
    82  	err = json.Unmarshal([]byte(elems[1]), &s)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	url, _ := url.Parse(server.URL)
    87  	if s.Name != url.Host {
    88  		t.Errorf("unexpected segment name, expected %q - got %q", url.Host, s.Name)
    89  	}
    90  	if s.ParentID != parentSegment.ID {
    91  		t.Errorf("unexpected ParentID, expect %q - got %q", parentSegment.ID, s.ParentID)
    92  	}
    93  	if s.HTTP.Response.ContentLength != int64(len(responseBody)) {
    94  		t.Errorf("unexpected ContentLength, expect %d - got %d", len(responseBody), s.HTTP.Response.ContentLength)
    95  	}
    96  }
    97  
    98  func TestTransportNoSegmentInContext(t *testing.T) {
    99  	var (
   100  		url, _ = url.Parse("https://goa.design/path?query#fragment")
   101  		req, _ = http.NewRequest("GET", url.String(), nil)
   102  		rw     = httptest.NewRecorder()
   103  		rt     = &mockRoundTripper{func(*http.Request) (*http.Response, error) {
   104  
   105  			rw.WriteHeader(http.StatusOK)
   106  
   107  			return rw.Result(), nil
   108  		}}
   109  	)
   110  
   111  	resp, err := WrapTransport(rt).RoundTrip(req)
   112  	if err != nil {
   113  		t.Errorf("Expected no error got %s", err)
   114  	}
   115  	if resp.StatusCode != http.StatusOK {
   116  		t.Errorf("Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode)
   117  	}
   118  }
   119  
   120  func TestTransport(t *testing.T) {
   121  	type (
   122  		Tra struct {
   123  			TraceID, SpanID string
   124  		}
   125  		Req struct {
   126  			Method, Host, IP, RemoteAddr string
   127  			RemoteHost, UserAgent        string
   128  			URL                          *url.URL
   129  		}
   130  		Res struct {
   131  			Status int
   132  			Body   string
   133  		}
   134  		Seg struct {
   135  			Exception string
   136  			Error     bool
   137  		}
   138  	)
   139  	var (
   140  		traceID    = "traceID"
   141  		spanID     = "spanID"
   142  		host       = "goa.design"
   143  		method     = "GET"
   144  		ip         = "104.18.42.42"
   145  		remoteAddr = "104.18.43.42:443"
   146  		remoteHost = "104.18.43.42"
   147  		agent      = "user agent"
   148  		url, _     = url.Parse("https://goa.design/path?query#fragment")
   149  	)
   150  	cases := map[string]struct {
   151  		Trace    Tra
   152  		Request  Req
   153  		Response *Res
   154  		Segment  Seg
   155  	}{
   156  		"basic": {
   157  			Trace:    Tra{traceID, spanID},
   158  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   159  			Response: &Res{http.StatusOK, "test"},
   160  			Segment:  Seg{"", false},
   161  		},
   162  		"badRequest": {
   163  			Trace:    Tra{traceID, spanID},
   164  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   165  			Response: &Res{http.StatusBadRequest, "payload not valid"},
   166  			Segment:  Seg{"", false},
   167  		},
   168  		"fault": {
   169  			Trace:    Tra{traceID, spanID},
   170  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   171  			Response: &Res{http.StatusInternalServerError, ""},
   172  			Segment:  Seg{"", true},
   173  		},
   174  		"error": {
   175  			Trace:    Tra{traceID, spanID},
   176  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   177  			Response: nil,
   178  			Segment:  Seg{"some error", true},
   179  		},
   180  	}
   181  	for k, c := range cases {
   182  		conn, err := net.Dial("udp", udplisten)
   183  		if err != nil {
   184  			t.Fatalf("%s: failed to connect to daemon - %s", k, err)
   185  		}
   186  
   187  		var (
   188  			parent = NewSegment(k, c.Trace.TraceID, c.Trace.SpanID, conn)
   189  			req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil)
   190  			rw     = httptest.NewRecorder()
   191  			rt     = &mockRoundTripper{func(*http.Request) (*http.Response, error) {
   192  
   193  				if c.Segment.Exception != "" {
   194  					return nil, errors.New(c.Segment.Exception)
   195  				}
   196  
   197  				rw.WriteHeader(c.Response.Status)
   198  				if _, err := rw.WriteString(c.Response.Body); err != nil {
   199  					t.Fatalf("%s: failed to write response body - %s", k, err)
   200  				}
   201  				rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(c.Response.Body)))
   202  				res := rw.Result()
   203  
   204  				// Fixed in go1.8 with commit
   205  				// https://github.com/golang/go/commit/ea143c299040f8a270fb782c5efd3a3a5e6057a4
   206  				// to stay backwards compatible with go1.7, we set ContentLength manually
   207  				res.ContentLength = int64(len(c.Response.Body))
   208  
   209  				return res, nil
   210  			}}
   211  		)
   212  
   213  		req = req.WithContext(WithSegment(context.Background(), parent))
   214  
   215  		if c.Request.UserAgent != "" {
   216  			req.Header.Set("User-Agent", c.Request.UserAgent)
   217  		}
   218  		if c.Request.IP != "" {
   219  			req.Header.Set("X-Forwarded-For", c.Request.IP)
   220  		}
   221  		if c.Request.RemoteAddr != "" {
   222  			req.RemoteAddr = c.Request.RemoteAddr
   223  		}
   224  		if c.Request.Host != "" {
   225  			req.Host = c.Request.Host
   226  		}
   227  
   228  		js := readUDP(t, func() {
   229  			resp, err := WrapTransport(rt).RoundTrip(req)
   230  			if c.Segment.Exception == "" && err != nil {
   231  				t.Errorf("%s: Expected no error got %s", k, err)
   232  			}
   233  			if c.Response != nil && resp.StatusCode != c.Response.Status {
   234  				t.Errorf("%s: Response Status is invalid, expected %d got %d", k, c.Response.Status, resp.StatusCode)
   235  			}
   236  		})
   237  
   238  		var s *Segment
   239  		elems := strings.Split(js, "\n")
   240  		if len(elems) != 2 {
   241  			t.Fatalf("%s: invalid number of lines, expected 2 got %d: %v", k, len(elems), elems)
   242  		}
   243  		if elems[0] != udpHeader[:len(udpHeader)-1] {
   244  			t.Errorf("%s: invalid header, got %s", k, elems[0])
   245  		}
   246  		err = json.Unmarshal([]byte(elems[1]), &s)
   247  		if err != nil {
   248  			t.Fatal(err)
   249  		}
   250  
   251  		if s.Name != host {
   252  			t.Errorf("%s: unexpected segment name, expected %q - got %q", k, host, s.Name)
   253  		}
   254  		if c.Trace.SpanID != s.ParentID {
   255  			t.Errorf("%s: unexpected ParentID, expect %q - got %q", k, c.Trace.SpanID, s.ParentID)
   256  		}
   257  		if s.Type != "subsegment" {
   258  			t.Errorf("%s: expected Type to be 'subsegment' but got %q", k, s.Type)
   259  		}
   260  		if s.ID == "" {
   261  			t.Errorf("%s: segment ID not set", k)
   262  		}
   263  		if s.TraceID != c.Trace.TraceID {
   264  			t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID)
   265  		}
   266  		if s.StartTime == 0 {
   267  			t.Errorf("%s: StartTime is 0", k)
   268  		}
   269  		if s.EndTime == 0 {
   270  			t.Errorf("%s: EndTime is 0", k)
   271  		}
   272  		if s.StartTime > s.EndTime {
   273  			t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime)
   274  		}
   275  		if s.HTTP == nil {
   276  			t.Fatalf("%s: HTTP field is nil", k)
   277  		}
   278  		if s.HTTP.Request == nil {
   279  			t.Fatalf("%s: HTTP Request field is nil", k)
   280  		}
   281  		if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP {
   282  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP)
   283  		}
   284  		if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost {
   285  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP)
   286  		}
   287  		if s.HTTP.Request.Method != c.Request.Method {
   288  			t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method)
   289  		}
   290  		expected := strings.Split(c.Request.URL.String(), "?")[0]
   291  		if s.HTTP.Request.URL != expected {
   292  			t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL)
   293  		}
   294  		if s.HTTP.Request.UserAgent != c.Request.UserAgent {
   295  			t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent)
   296  		}
   297  		if c.Response != nil && s.HTTP.Response.Status != c.Response.Status {
   298  			t.Errorf("%s: HTTP Response Status is invalid, expected %d got %d", k, c.Response.Status, s.HTTP.Response.Status)
   299  		}
   300  		if c.Response != nil && s.HTTP.Response.ContentLength != int64(len(c.Response.Body)) {
   301  			t.Errorf("%s: HTTP Response ContentLength is invalid, expected %d got %d", k, len(c.Response.Body), s.HTTP.Response.ContentLength)
   302  		}
   303  		if s.Cause == nil && c.Segment.Exception != "" {
   304  			t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception)
   305  		}
   306  		if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception {
   307  			t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message)
   308  		}
   309  		if s.Error != c.Segment.Error {
   310  			t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error)
   311  		}
   312  	}
   313  }