github.com/ManabuSeki/goa-v1@v1.4.3/middleware/xray/middleware_test.go (about) 1 package xray 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "regexp" 12 "strings" 13 "sync" 14 "testing" 15 "time" 16 17 "github.com/goadesign/goa" 18 "github.com/goadesign/goa/middleware" 19 ) 20 21 const ( 22 // udp host:port used to run test server 23 udplisten = "127.0.0.1:62111" 24 ) 25 26 func TestNew(t *testing.T) { 27 cases := map[string]struct { 28 Daemon string 29 Success bool 30 }{ 31 "ok": {udplisten, true}, 32 "not-ok": {"1002.0.0.0:62111", false}, 33 } 34 for k, c := range cases { 35 m, err := New("", c.Daemon) 36 if err == nil && !c.Success { 37 t.Errorf("%s: expected failure but err is nil", k) 38 } 39 if err != nil && c.Success { 40 t.Errorf("%s: unexpected error %s", k, err) 41 } 42 if m == nil && c.Success { 43 t.Errorf("%s: middleware is nil", k) 44 } 45 } 46 } 47 48 func TestMiddleware(t *testing.T) { 49 type ( 50 Tra struct { 51 TraceID, SpanID, ParentID string 52 } 53 Req struct { 54 Method, Host, IP, RemoteAddr string 55 RemoteHost, UserAgent string 56 URL *url.URL 57 } 58 Res struct { 59 Status int 60 } 61 Seg struct { 62 Exception string 63 Error bool 64 } 65 ) 66 var ( 67 traceID = "traceID" 68 spanID = "spanID" 69 parentID = "parentID" 70 host = "goa.design" 71 method = "GET" 72 ip = "104.18.42.42" 73 remoteAddr = "104.18.43.42:443" 74 remoteNoPort = "104.18.43.42" 75 remoteHost = "104.18.43.42" 76 agent = "user agent" 77 url, _ = url.Parse("https://goa.design/path?query#fragment") 78 ) 79 cases := map[string]struct { 80 Trace Tra 81 Request Req 82 Response Res 83 Segment Seg 84 }{ 85 "no-trace": { 86 Trace: Tra{"", "", ""}, 87 Request: Req{"", "", "", "", "", "", nil}, 88 Response: Res{0}, 89 Segment: Seg{"", false}, 90 }, 91 "basic": { 92 Trace: Tra{traceID, spanID, ""}, 93 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 94 Response: Res{http.StatusOK}, 95 Segment: Seg{"", false}, 96 }, 97 "with-parent": { 98 Trace: Tra{traceID, spanID, parentID}, 99 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 100 Response: Res{http.StatusOK}, 101 Segment: Seg{"", false}, 102 }, 103 "without-ip": { 104 Trace: Tra{traceID, spanID, parentID}, 105 Request: Req{method, host, "", remoteAddr, remoteHost, agent, url}, 106 Response: Res{http.StatusOK}, 107 Segment: Seg{"", false}, 108 }, 109 "without-ip-remote-port": { 110 Trace: Tra{traceID, spanID, parentID}, 111 Request: Req{method, host, "", remoteNoPort, remoteHost, agent, url}, 112 Response: Res{http.StatusOK}, 113 Segment: Seg{"", false}, 114 }, 115 "error": { 116 Trace: Tra{traceID, spanID, ""}, 117 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 118 Response: Res{http.StatusBadRequest}, 119 Segment: Seg{"error", true}, 120 }, 121 "fault": { 122 Trace: Tra{traceID, spanID, ""}, 123 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 124 Response: Res{http.StatusInternalServerError}, 125 Segment: Seg{"", true}, 126 }, 127 } 128 for k, c := range cases { 129 m, err := New("service", udplisten) 130 if err != nil { 131 t.Fatalf("%s: failed to create middleware: %s", k, err) 132 } 133 if c.Response.Status == 0 { 134 continue 135 } 136 137 var ( 138 req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil) 139 rw = httptest.NewRecorder() 140 ctx = goa.NewContext(context.Background(), rw, req, nil) 141 h = func(ctx context.Context, rw http.ResponseWriter, _ *http.Request) error { 142 if c.Segment.Exception != "" { 143 ContextSegment(ctx).RecordError(errors.New(c.Segment.Exception)) 144 } 145 rw.WriteHeader(c.Response.Status) 146 return nil 147 } 148 ) 149 150 ctx = middleware.WithTrace(ctx, c.Trace.TraceID, c.Trace.SpanID, c.Trace.ParentID) 151 if c.Request.UserAgent != "" { 152 req.Header.Set("User-Agent", c.Request.UserAgent) 153 } 154 if c.Request.IP != "" { 155 req.Header.Set("X-Forwarded-For", c.Request.IP) 156 } 157 if c.Request.RemoteAddr != "" { 158 req.RemoteAddr = c.Request.RemoteAddr 159 } 160 if c.Request.Host != "" { 161 req.Host = c.Request.Host 162 } 163 164 messages := readUDP(t, 2, func() { 165 m(h)(ctx, goa.ContextResponse(ctx), req) 166 }) 167 168 // expect the first message is InProgress 169 s := extractSegment(t, messages[0]) 170 if !s.InProgress { 171 t.Fatalf("%s: expected first segment to be InProgress but it was not", k) 172 } 173 174 // second message 175 s = extractSegment(t, messages[1]) 176 if s.Name != "service" { 177 t.Errorf("%s: unexpected segment name, expected service - got %s", k, s.Name) 178 } 179 if s.Type != "" { 180 t.Errorf("%s: expected Type to be empty but got %s", k, s.Type) 181 } 182 if s.ID != c.Trace.SpanID { 183 t.Errorf("%s: unexpected segment ID, expected %s - got %s", k, c.Trace.SpanID, s.ID) 184 } 185 if s.TraceID != c.Trace.TraceID { 186 t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID) 187 } 188 if s.ParentID != c.Trace.ParentID { 189 t.Errorf("%s: unexpected parent ID, expected %s - got %s", k, c.Trace.ParentID, s.ParentID) 190 } 191 if s.StartTime == 0 { 192 t.Errorf("%s: StartTime is 0", k) 193 } 194 if s.EndTime == 0 { 195 t.Errorf("%s: EndTime is 0", k) 196 } 197 if s.StartTime > s.EndTime { 198 t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime) 199 } 200 if s.HTTP == nil { 201 t.Fatalf("%s: HTTP field is nil", k) 202 } 203 if s.HTTP.Request == nil { 204 t.Fatalf("%s: HTTP Request field is nil", k) 205 } 206 if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP { 207 t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP) 208 } 209 if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost { 210 t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP) 211 } 212 if s.HTTP.Request.Method != c.Request.Method { 213 t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method) 214 } 215 expected := strings.Split(c.Request.URL.String(), "?")[0] 216 if s.HTTP.Request.URL != expected { 217 t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL) 218 } 219 if s.HTTP.Request.UserAgent != c.Request.UserAgent { 220 t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent) 221 } 222 if s.Cause == nil && c.Segment.Exception != "" { 223 t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception) 224 } 225 if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception { 226 t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message) 227 } 228 if s.Error != c.Segment.Error { 229 t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error) 230 } 231 } 232 } 233 234 func TestNewID(t *testing.T) { 235 id := NewID() 236 if len(id) != 16 { 237 t.Errorf("invalid ID length, expected 16 got %d", len(id)) 238 } 239 if !regexp.MustCompile("[0-9a-f]{16}").MatchString(id) { 240 t.Errorf("invalid ID format, should be hexadecimal, got %s", id) 241 } 242 if id == NewID() { 243 t.Errorf("ids not unique") 244 } 245 } 246 247 func TestNewTraceID(t *testing.T) { 248 id := NewTraceID() 249 if len(id) != 35 { 250 t.Errorf("invalid ID length, expected 35 got %d", len(id)) 251 } 252 if !regexp.MustCompile("1-[0-9a-f]{8}-[0-9a-f]{16}").MatchString(id) { 253 t.Errorf("invalid Trace ID format, got %s", id) 254 } 255 if id == NewTraceID() { 256 t.Errorf("trace ids not unique") 257 } 258 } 259 260 func TestPeriodicallyRedialingConn(t *testing.T) { 261 262 t.Run("dial fails, returns error immediately", func(t *testing.T) { 263 dialErr := errors.New("dialErr") 264 _, err := periodicallyRedialingConn(context.Background(), time.Millisecond, func() (net.Conn, error) { 265 return nil, dialErr 266 }) 267 if err != dialErr { 268 t.Fatalf("Unexpected err, got %q, expected %q", err, dialErr) 269 } 270 }) 271 t.Run("connection gets replaced by new one", func(t *testing.T) { 272 var ( 273 firstConn = &net.UDPConn{} 274 secondConn = &net.UnixConn{} 275 callCount = 0 276 ) 277 wgCheckFirstConnection := sync.WaitGroup{} 278 wgCheckFirstConnection.Add(1) 279 wgThirdDial := sync.WaitGroup{} 280 wgThirdDial.Add(1) 281 dial := func() (net.Conn, error) { 282 callCount++ 283 if callCount == 1 { 284 return firstConn, nil 285 } 286 wgCheckFirstConnection.Wait() 287 if callCount == 3 { 288 wgThirdDial.Done() 289 } 290 return secondConn, nil 291 } 292 293 ctx, cancel := context.WithCancel(context.Background()) 294 defer cancel() 295 conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial) 296 if err != nil { 297 t.Fatalf("Expected nil err but got: %v", err) 298 } 299 300 if c := conn(); c != firstConn { 301 t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn) 302 } 303 wgCheckFirstConnection.Done() 304 305 // by the time the 3rd dial happens, we know conn() should be returning the second connection 306 wgThirdDial.Wait() 307 308 if c := conn(); c != secondConn { 309 t.Fatalf("Unexpected second connection: got %#v, expected %#v", c, secondConn) 310 } 311 }) 312 t.Run("connection not replaced if dial errored", func(t *testing.T) { 313 var ( 314 firstConn = &net.UDPConn{} 315 callCount = 0 316 ) 317 wgCheckFirstConnection := sync.WaitGroup{} 318 wgCheckFirstConnection.Add(1) 319 wgThirdDial := sync.WaitGroup{} 320 wgThirdDial.Add(1) 321 dial := func() (net.Conn, error) { 322 callCount++ 323 if callCount == 1 { 324 return firstConn, nil 325 } 326 wgCheckFirstConnection.Wait() 327 if callCount == 3 { 328 wgThirdDial.Done() 329 } 330 return nil, errors.New("dialErr") 331 } 332 333 ctx, cancel := context.WithCancel(context.Background()) 334 defer cancel() 335 conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial) 336 if err != nil { 337 t.Fatalf("Expected nil err but got: %v", err) 338 } 339 340 if c := conn(); c != firstConn { 341 t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn) 342 } 343 wgCheckFirstConnection.Done() 344 345 // by the time the 3rd dial happens, we know the second dial was processed, and shouldn't have replaced conn() 346 wgThirdDial.Wait() 347 348 if c := conn(); c != firstConn { 349 t.Fatalf("Connection unexpectedly replaced: got %#v, expected %#v", c, firstConn) 350 } 351 }) 352 } 353 354 // readUDP calls sender, reads and returns UDP messages received on udplisten. 355 // Verifies that exactly the expected number of messages are received. 356 func readUDP(t *testing.T, expectedMessages int, sender func()) []string { 357 var ( 358 readChan = make(chan []string) 359 msg = make([]byte, 1024*32) 360 ) 361 resAddr, err := net.ResolveUDPAddr("udp", udplisten) 362 if err != nil { 363 t.Fatal(err) 364 } 365 listener, err := net.ListenUDP("udp", resAddr) 366 if err != nil { 367 t.Fatal(err) 368 } 369 370 go func() { 371 listener.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) 372 var messages []string 373 for { 374 n, _, err := listener.ReadFrom(msg) 375 if err != nil { 376 if !strings.HasSuffix(err.Error(), "i/o timeout") { 377 t.Errorf("expected final timeout error but got: %s", err) 378 } 379 break // we're done 380 } 381 messages = append(messages, string(msg[0:n])) 382 } 383 if len(messages) != expectedMessages { 384 t.Errorf("unexpected number of messages, expected %d got %d. All messages:\n%s", 385 expectedMessages, len(messages), strings.Join(messages, "\n")) 386 } 387 readChan <- messages 388 }() 389 390 sender() 391 392 defer func() { 393 if err := listener.Close(); err != nil { 394 t.Fatal(err) 395 } 396 }() 397 398 return <-readChan 399 } 400 401 // extractSegment returns the unmarshalled segment JSON from a readUDP response. 402 func extractSegment(t *testing.T, js string) *Segment { 403 t.Helper() 404 405 var s *Segment 406 elems := strings.Split(js, "\n") 407 if len(elems) != 2 { 408 t.Fatalf("invalid number of lines, expected 2 got %d: %v", len(elems), elems) 409 } 410 if elems[0] != udpHeader[:len(udpHeader)-1] { 411 t.Errorf("invalid header, got %s", elems[0]) 412 } 413 err := json.Unmarshal([]byte(elems[1]), &s) 414 if err != nil { 415 t.Fatal(err) 416 } 417 return s 418 }