github.com/ManabuSeki/goa-v1@v1.4.3/middleware/xray/transport_test.go (about) 1 package xray 2 3 import ( 4 "errors" 5 "fmt" 6 "net" 7 "net/http" 8 "net/http/httptest" 9 "net/url" 10 "strings" 11 "testing" 12 13 "context" 14 ) 15 16 type mockRoundTripper struct { 17 Callback func(*http.Request) (*http.Response, error) 18 } 19 20 func (mrt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 21 return mrt.Callback(req) 22 } 23 24 func TestTransportExample(t *testing.T) { 25 var ( 26 responseBody = "good morning" 27 ) 28 server := httptest.NewServer(http.HandlerFunc( 29 func(rw http.ResponseWriter, req *http.Request) { 30 rw.WriteHeader(http.StatusOK) 31 32 rw.Write([]byte(responseBody)) 33 })) 34 defer server.Close() 35 36 conn, err := net.Dial("udp", udplisten) 37 if err != nil { 38 t.Fatalf("failed to connect to daemon - %s", err) 39 } 40 41 // 42 // Wrap http client's Transport with xray tracing 43 httpClient := &http.Client{ 44 Transport: WrapTransport(http.DefaultTransport), 45 } 46 47 // 48 // Setup context 49 parentSegment := NewSegment("hello", NewTraceID(), NewID(), conn) 50 ctx := WithSegment(context.Background(), parentSegment) 51 52 // 53 // make Request 54 req, err := http.NewRequest("GET", server.URL, nil) 55 if err != nil { 56 t.Fatalf("NewRequest returned error - %s", err) 57 } 58 // Setting context on request 59 req = req.WithContext(ctx) 60 61 messages := readUDP(t, 2, func() { 62 resp, err := httpClient.Do(req) 63 if err != nil { 64 t.Fatalf("failed to make request - %s", err) 65 } 66 if resp.StatusCode != http.StatusOK { 67 t.Errorf("HTTP Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode) 68 } 69 }) 70 71 // expect the first message is InProgress 72 s := extractSegment(t, messages[0]) 73 if !s.InProgress { 74 t.Fatalf("expected first segment to be InProgress but it was not") 75 } 76 77 // 78 // Verify 79 s = extractSegment(t, messages[1]) 80 url, _ := url.Parse(server.URL) 81 if s.Name != url.Host { 82 t.Errorf("unexpected segment name, expected %q - got %q", url.Host, s.Name) 83 } 84 if s.ParentID != parentSegment.ID { 85 t.Errorf("unexpected ParentID, expect %q - got %q", parentSegment.ID, s.ParentID) 86 } 87 if s.HTTP.Response.ContentLength != int64(len(responseBody)) { 88 t.Errorf("unexpected ContentLength, expect %d - got %d", len(responseBody), s.HTTP.Response.ContentLength) 89 } 90 } 91 92 func TestTransportNoSegmentInContext(t *testing.T) { 93 var ( 94 url, _ = url.Parse("https://goa.design/path?query#fragment") 95 req, _ = http.NewRequest("GET", url.String(), nil) 96 rw = httptest.NewRecorder() 97 rt = &mockRoundTripper{func(*http.Request) (*http.Response, error) { 98 99 rw.WriteHeader(http.StatusOK) 100 101 return rw.Result(), nil 102 }} 103 ) 104 105 resp, err := WrapTransport(rt).RoundTrip(req) 106 if err != nil { 107 t.Errorf("Expected no error got %s", err) 108 } 109 if resp.StatusCode != http.StatusOK { 110 t.Errorf("Response Status is invalid, expected %d got %d", http.StatusOK, resp.StatusCode) 111 } 112 } 113 114 func TestTransport(t *testing.T) { 115 type ( 116 Tra struct { 117 TraceID, SpanID string 118 } 119 Req struct { 120 Method, Host, IP, RemoteAddr string 121 RemoteHost, UserAgent string 122 URL *url.URL 123 } 124 Res struct { 125 Status int 126 Body string 127 } 128 Seg struct { 129 Exception string 130 Error bool 131 } 132 ) 133 var ( 134 traceID = "traceID" 135 spanID = "spanID" 136 host = "goa.design" 137 method = "GET" 138 ip = "104.18.42.42" 139 remoteAddr = "104.18.43.42:443" 140 remoteHost = "104.18.43.42" 141 agent = "user agent" 142 url, _ = url.Parse("https://goa.design/path?query#fragment") 143 ) 144 cases := map[string]struct { 145 Trace Tra 146 Request Req 147 Response *Res 148 Segment Seg 149 }{ 150 "basic": { 151 Trace: Tra{traceID, spanID}, 152 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 153 Response: &Res{http.StatusOK, "test"}, 154 Segment: Seg{"", false}, 155 }, 156 "badRequest": { 157 Trace: Tra{traceID, spanID}, 158 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 159 Response: &Res{http.StatusBadRequest, "payload not valid"}, 160 Segment: Seg{"", false}, 161 }, 162 "fault": { 163 Trace: Tra{traceID, spanID}, 164 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 165 Response: &Res{http.StatusInternalServerError, ""}, 166 Segment: Seg{"", true}, 167 }, 168 "error": { 169 Trace: Tra{traceID, spanID}, 170 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 171 Response: nil, 172 Segment: Seg{"some error", true}, 173 }, 174 } 175 for k, c := range cases { 176 conn, err := net.Dial("udp", udplisten) 177 if err != nil { 178 t.Fatalf("%s: failed to connect to daemon - %s", k, err) 179 } 180 181 var ( 182 parent = NewSegment(k, c.Trace.TraceID, c.Trace.SpanID, conn) 183 req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil) 184 rw = httptest.NewRecorder() 185 rt = &mockRoundTripper{func(*http.Request) (*http.Response, error) { 186 187 if c.Segment.Exception != "" { 188 return nil, errors.New(c.Segment.Exception) 189 } 190 191 rw.WriteHeader(c.Response.Status) 192 if _, err := rw.WriteString(c.Response.Body); err != nil { 193 t.Fatalf("%s: failed to write response body - %s", k, err) 194 } 195 rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(c.Response.Body))) 196 res := rw.Result() 197 198 // Fixed in go1.8 with commit 199 // https://github.com/golang/go/commit/ea143c299040f8a270fb782c5efd3a3a5e6057a4 200 // to stay backwards compatible with go1.7, we set ContentLength manually 201 res.ContentLength = int64(len(c.Response.Body)) 202 203 return res, nil 204 }} 205 ) 206 207 req = req.WithContext(WithSegment(context.Background(), parent)) 208 209 if c.Request.UserAgent != "" { 210 req.Header.Set("User-Agent", c.Request.UserAgent) 211 } 212 if c.Request.IP != "" { 213 req.Header.Set("X-Forwarded-For", c.Request.IP) 214 } 215 if c.Request.RemoteAddr != "" { 216 req.RemoteAddr = c.Request.RemoteAddr 217 } 218 if c.Request.Host != "" { 219 req.Host = c.Request.Host 220 } 221 222 messages := readUDP(t, 2, func() { 223 resp, err := WrapTransport(rt).RoundTrip(req) 224 if c.Segment.Exception == "" && err != nil { 225 t.Errorf("%s: Expected no error got %s", k, err) 226 } 227 if c.Response != nil && resp.StatusCode != c.Response.Status { 228 t.Errorf("%s: Response Status is invalid, expected %d got %d", k, c.Response.Status, resp.StatusCode) 229 } 230 }) 231 // expect the first message is InProgress 232 s := extractSegment(t, messages[0]) 233 if !s.InProgress { 234 t.Errorf("%s: expected first segment to be InProgress but it was not", k) 235 } 236 237 // second message 238 s = extractSegment(t, messages[1]) 239 if s.Name != host { 240 t.Errorf("%s: unexpected segment name, expected %q - got %q", k, host, s.Name) 241 } 242 if c.Trace.SpanID != s.ParentID { 243 t.Errorf("%s: unexpected ParentID, expect %q - got %q", k, c.Trace.SpanID, s.ParentID) 244 } 245 if s.Type != "subsegment" { 246 t.Errorf("%s: expected Type to be 'subsegment' but got %q", k, s.Type) 247 } 248 if s.ID == "" { 249 t.Errorf("%s: segment ID not set", k) 250 } 251 if s.TraceID != c.Trace.TraceID { 252 t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID) 253 } 254 if s.StartTime == 0 { 255 t.Errorf("%s: StartTime is 0", k) 256 } 257 if s.EndTime == 0 { 258 t.Errorf("%s: EndTime is 0", k) 259 } 260 if s.StartTime > s.EndTime { 261 t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime) 262 } 263 if s.HTTP == nil { 264 t.Fatalf("%s: HTTP field is nil", k) 265 } 266 if s.HTTP.Request == nil { 267 t.Fatalf("%s: HTTP Request field is nil", k) 268 } 269 if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP { 270 t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP) 271 } 272 if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost { 273 t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP) 274 } 275 if s.HTTP.Request.Method != c.Request.Method { 276 t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method) 277 } 278 expected := strings.Split(c.Request.URL.String(), "?")[0] 279 if s.HTTP.Request.URL != expected { 280 t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL) 281 } 282 if s.HTTP.Request.UserAgent != c.Request.UserAgent { 283 t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent) 284 } 285 if c.Response != nil && s.HTTP.Response.Status != c.Response.Status { 286 t.Errorf("%s: HTTP Response Status is invalid, expected %d got %d", k, c.Response.Status, s.HTTP.Response.Status) 287 } 288 if c.Response != nil && s.HTTP.Response.ContentLength != int64(len(c.Response.Body)) { 289 t.Errorf("%s: HTTP Response ContentLength is invalid, expected %d got %d", k, len(c.Response.Body), s.HTTP.Response.ContentLength) 290 } 291 if s.Cause == nil && c.Segment.Exception != "" { 292 t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception) 293 } 294 if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception { 295 t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message) 296 } 297 if s.Error != c.Segment.Error { 298 t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error) 299 } 300 } 301 }