github.com/goldeneggg/goa@v1.3.1/middleware/xray/transport_test.go (about) 1 package xray 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "strings" 12 "testing" 13 14 "context" 15 ) 16 17 type mockRoundTripper struct { 18 Callback func(*http.Request) (*http.Response, error) 19 } 20 21 func (mrt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 22 return mrt.Callback(req) 23 } 24 25 func TestTransportExample(t *testing.T) { 26 var ( 27 responseBody = "good morning" 28 ) 29 server := httptest.NewServer(http.HandlerFunc( 30 func(rw http.ResponseWriter, req *http.Request) { 31 rw.WriteHeader(http.StatusOK) 32 33 rw.Write([]byte(responseBody)) 34 })) 35 defer server.Close() 36 37 conn, err := net.Dial("udp", udplisten) 38 if err != nil { 39 t.Fatalf("failed to connect to daemon - %s", err) 40 } 41 42 // 43 // Wrap http client's Transport with xray tracing 44 httpClient := &http.Client{ 45 Transport: WrapTransport(http.DefaultTransport), 46 } 47 48 // 49 // Setup context 50 parentSegment := NewSegment("hello", NewTraceID(), NewID(), conn) 51 ctx := WithSegment(context.Background(), parentSegment) 52 53 // 54 // make Request 55 req, err := http.NewRequest("GET", server.URL, nil) 56 if err != nil { 57 t.Fatalf("NewRequest returned error - %s", err) 58 } 59 // Setting context on request 60 req = req.WithContext(ctx) 61 62 js := readUDP(t, func() { 63 resp, err := httpClient.Do(req) 64 if err != nil { 65 t.Fatalf("failed to make request - %s", err) 66 } 67 if resp.StatusCode != http.StatusOK { 68 t.Errorf("HTTP Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode) 69 } 70 }) 71 72 // 73 // Verify 74 var s *Segment 75 elems := strings.Split(js, "\n") 76 if len(elems) != 2 { 77 t.Fatalf("invalid number of lines, expected 2 got %d: %v", len(elems), elems) 78 } 79 if elems[0] != udpHeader[:len(udpHeader)-1] { 80 t.Errorf("invalid header, got %s", elems[0]) 81 } 82 err = json.Unmarshal([]byte(elems[1]), &s) 83 if err != nil { 84 t.Fatal(err) 85 } 86 url, _ := url.Parse(server.URL) 87 if s.Name != url.Host { 88 t.Errorf("unexpected segment name, expected %q - got %q", url.Host, s.Name) 89 } 90 if s.ParentID != parentSegment.ID { 91 t.Errorf("unexpected ParentID, expect %q - got %q", parentSegment.ID, s.ParentID) 92 } 93 if s.HTTP.Response.ContentLength != int64(len(responseBody)) { 94 t.Errorf("unexpected ContentLength, expect %d - got %d", len(responseBody), s.HTTP.Response.ContentLength) 95 } 96 } 97 98 func TestTransportNoSegmentInContext(t *testing.T) { 99 var ( 100 url, _ = url.Parse("https://goa.design/path?query#fragment") 101 req, _ = http.NewRequest("GET", url.String(), nil) 102 rw = httptest.NewRecorder() 103 rt = &mockRoundTripper{func(*http.Request) (*http.Response, error) { 104 105 rw.WriteHeader(http.StatusOK) 106 107 return rw.Result(), nil 108 }} 109 ) 110 111 resp, err := WrapTransport(rt).RoundTrip(req) 112 if err != nil { 113 t.Errorf("Expected no error got %s", err) 114 } 115 if resp.StatusCode != http.StatusOK { 116 t.Errorf("Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode) 117 } 118 } 119 120 func TestTransport(t *testing.T) { 121 type ( 122 Tra struct { 123 TraceID, SpanID string 124 } 125 Req struct { 126 Method, Host, IP, RemoteAddr string 127 RemoteHost, UserAgent string 128 URL *url.URL 129 } 130 Res struct { 131 Status int 132 Body string 133 } 134 Seg struct { 135 Exception string 136 Error bool 137 } 138 ) 139 var ( 140 traceID = "traceID" 141 spanID = "spanID" 142 host = "goa.design" 143 method = "GET" 144 ip = "104.18.42.42" 145 remoteAddr = "104.18.43.42:443" 146 remoteHost = "104.18.43.42" 147 agent = "user agent" 148 url, _ = url.Parse("https://goa.design/path?query#fragment") 149 ) 150 cases := map[string]struct { 151 Trace Tra 152 Request Req 153 Response *Res 154 Segment Seg 155 }{ 156 "basic": { 157 Trace: Tra{traceID, spanID}, 158 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 159 Response: &Res{http.StatusOK, "test"}, 160 Segment: Seg{"", false}, 161 }, 162 "badRequest": { 163 Trace: Tra{traceID, spanID}, 164 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 165 Response: &Res{http.StatusBadRequest, "payload not valid"}, 166 Segment: Seg{"", false}, 167 }, 168 "fault": { 169 Trace: Tra{traceID, spanID}, 170 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 171 Response: &Res{http.StatusInternalServerError, ""}, 172 Segment: Seg{"", true}, 173 }, 174 "error": { 175 Trace: Tra{traceID, spanID}, 176 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 177 Response: nil, 178 Segment: Seg{"some error", true}, 179 }, 180 } 181 for k, c := range cases { 182 conn, err := net.Dial("udp", udplisten) 183 if err != nil { 184 t.Fatalf("%s: failed to connect to daemon - %s", k, err) 185 } 186 187 var ( 188 parent = NewSegment(k, c.Trace.TraceID, c.Trace.SpanID, conn) 189 req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil) 190 rw = httptest.NewRecorder() 191 rt = &mockRoundTripper{func(*http.Request) (*http.Response, error) { 192 193 if c.Segment.Exception != "" { 194 return nil, errors.New(c.Segment.Exception) 195 } 196 197 rw.WriteHeader(c.Response.Status) 198 if _, err := rw.WriteString(c.Response.Body); err != nil { 199 t.Fatalf("%s: failed to write response body - %s", k, err) 200 } 201 rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(c.Response.Body))) 202 res := rw.Result() 203 204 // Fixed in go1.8 with commit 205 // https://github.com/golang/go/commit/ea143c299040f8a270fb782c5efd3a3a5e6057a4 206 // to stay backwards compatible with go1.7, we set ContentLength manually 207 res.ContentLength = int64(len(c.Response.Body)) 208 209 return res, nil 210 }} 211 ) 212 213 req = req.WithContext(WithSegment(context.Background(), parent)) 214 215 if c.Request.UserAgent != "" { 216 req.Header.Set("User-Agent", c.Request.UserAgent) 217 } 218 if c.Request.IP != "" { 219 req.Header.Set("X-Forwarded-For", c.Request.IP) 220 } 221 if c.Request.RemoteAddr != "" { 222 req.RemoteAddr = c.Request.RemoteAddr 223 } 224 if c.Request.Host != "" { 225 req.Host = c.Request.Host 226 } 227 228 js := readUDP(t, func() { 229 resp, err := WrapTransport(rt).RoundTrip(req) 230 if c.Segment.Exception == "" && err != nil { 231 t.Errorf("%s: Expected no error got %s", k, err) 232 } 233 if c.Response != nil && resp.StatusCode != c.Response.Status { 234 t.Errorf("%s: Response Status is invalid, expected %d got %d", k, c.Response.Status, resp.StatusCode) 235 } 236 }) 237 238 var s *Segment 239 elems := strings.Split(js, "\n") 240 if len(elems) != 2 { 241 t.Fatalf("%s: invalid number of lines, expected 2 got %d: %v", k, len(elems), elems) 242 } 243 if elems[0] != udpHeader[:len(udpHeader)-1] { 244 t.Errorf("%s: invalid header, got %s", k, elems[0]) 245 } 246 err = json.Unmarshal([]byte(elems[1]), &s) 247 if err != nil { 248 t.Fatal(err) 249 } 250 251 if s.Name != host { 252 t.Errorf("%s: unexpected segment name, expected %q - got %q", k, host, s.Name) 253 } 254 if c.Trace.SpanID != s.ParentID { 255 t.Errorf("%s: unexpected ParentID, expect %q - got %q", k, c.Trace.SpanID, s.ParentID) 256 } 257 if s.Type != "subsegment" { 258 t.Errorf("%s: expected Type to be 'subsegment' but got %q", k, s.Type) 259 } 260 if s.ID == "" { 261 t.Errorf("%s: segment ID not set", k) 262 } 263 if s.TraceID != c.Trace.TraceID { 264 t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID) 265 } 266 if s.StartTime == 0 { 267 t.Errorf("%s: StartTime is 0", k) 268 } 269 if s.EndTime == 0 { 270 t.Errorf("%s: EndTime is 0", k) 271 } 272 if s.StartTime > s.EndTime { 273 t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime) 274 } 275 if s.HTTP == nil { 276 t.Fatalf("%s: HTTP field is nil", k) 277 } 278 if s.HTTP.Request == nil { 279 t.Fatalf("%s: HTTP Request field is nil", k) 280 } 281 if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP { 282 t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP) 283 } 284 if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost { 285 t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP) 286 } 287 if s.HTTP.Request.Method != c.Request.Method { 288 t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method) 289 } 290 expected := strings.Split(c.Request.URL.String(), "?")[0] 291 if s.HTTP.Request.URL != expected { 292 t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL) 293 } 294 if s.HTTP.Request.UserAgent != c.Request.UserAgent { 295 t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent) 296 } 297 if c.Response != nil && s.HTTP.Response.Status != c.Response.Status { 298 t.Errorf("%s: HTTP Response Status is invalid, expected %d got %d", k, c.Response.Status, s.HTTP.Response.Status) 299 } 300 if c.Response != nil && s.HTTP.Response.ContentLength != int64(len(c.Response.Body)) { 301 t.Errorf("%s: HTTP Response ContentLength is invalid, expected %d got %d", k, len(c.Response.Body), s.HTTP.Response.ContentLength) 302 } 303 if s.Cause == nil && c.Segment.Exception != "" { 304 t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception) 305 } 306 if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception { 307 t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message) 308 } 309 if s.Error != c.Segment.Error { 310 t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error) 311 } 312 } 313 }