github.com/zak-blake/goa@v1.4.1/middleware/xray/middleware_test.go (about) 1 package xray 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "regexp" 12 "strings" 13 "sync" 14 "testing" 15 "time" 16 17 "github.com/goadesign/goa" 18 "github.com/goadesign/goa/middleware" 19 ) 20 21 const ( 22 // udp host:port used to run test server 23 udplisten = "127.0.0.1:62111" 24 ) 25 26 func TestNew(t *testing.T) { 27 cases := map[string]struct { 28 Daemon string 29 Success bool 30 }{ 31 "ok": {udplisten, true}, 32 "not-ok": {"1002.0.0.0:62111", false}, 33 } 34 for k, c := range cases { 35 m, err := New("", c.Daemon) 36 if err == nil && !c.Success { 37 t.Errorf("%s: expected failure but err is nil", k) 38 } 39 if err != nil && c.Success { 40 t.Errorf("%s: unexpected error %s", k, err) 41 } 42 if m == nil && c.Success { 43 t.Errorf("%s: middleware is nil", k) 44 } 45 } 46 } 47 48 func TestMiddleware(t *testing.T) { 49 type ( 50 Tra struct { 51 TraceID, SpanID, ParentID string 52 } 53 Req struct { 54 Method, Host, IP, RemoteAddr string 55 RemoteHost, UserAgent string 56 URL *url.URL 57 } 58 Res struct { 59 Status int 60 } 61 Seg struct { 62 Exception string 63 Error bool 64 } 65 ) 66 var ( 67 traceID = "traceID" 68 spanID = "spanID" 69 parentID = "parentID" 70 host = "goa.design" 71 method = "GET" 72 ip = "104.18.42.42" 73 remoteAddr = "104.18.43.42:443" 74 remoteNoPort = "104.18.43.42" 75 remoteHost = "104.18.43.42" 76 agent = "user agent" 77 url, _ = url.Parse("https://goa.design/path?query#fragment") 78 ) 79 cases := map[string]struct { 80 Trace Tra 81 Request Req 82 Response Res 83 Segment Seg 84 }{ 85 "no-trace": { 86 Trace: Tra{"", "", ""}, 87 Request: Req{"", "", "", "", "", "", nil}, 88 Response: Res{0}, 89 Segment: Seg{"", false}, 90 }, 91 "basic": { 92 Trace: Tra{traceID, spanID, ""}, 93 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 94 Response: Res{http.StatusOK}, 95 Segment: Seg{"", false}, 96 }, 97 "with-parent": { 98 Trace: Tra{traceID, spanID, parentID}, 99 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 100 Response: Res{http.StatusOK}, 101 Segment: Seg{"", false}, 102 }, 103 "without-ip": { 104 Trace: Tra{traceID, spanID, parentID}, 105 Request: Req{method, host, "", remoteAddr, remoteHost, agent, url}, 106 Response: Res{http.StatusOK}, 107 Segment: Seg{"", false}, 108 }, 109 "without-ip-remote-port": { 110 Trace: Tra{traceID, spanID, parentID}, 111 Request: Req{method, host, "", remoteNoPort, remoteHost, agent, url}, 112 Response: Res{http.StatusOK}, 113 Segment: Seg{"", false}, 114 }, 115 "error": { 116 Trace: Tra{traceID, spanID, ""}, 117 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 118 Response: Res{http.StatusBadRequest}, 119 Segment: Seg{"error", true}, 120 }, 121 "fault": { 122 Trace: Tra{traceID, spanID, ""}, 123 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 124 Response: Res{http.StatusInternalServerError}, 125 Segment: Seg{"", true}, 126 }, 127 } 128 for k, c := range cases { 129 m, err := New("service", udplisten) 130 if err != nil { 131 t.Fatalf("%s: failed to create middleware: %s", k, err) 132 } 133 if c.Response.Status == 0 { 134 continue 135 } 136 137 var ( 138 req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil) 139 rw = httptest.NewRecorder() 140 ctx = goa.NewContext(context.Background(), rw, req, nil) 141 h = func(ctx context.Context, rw http.ResponseWriter, _ *http.Request) error { 142 if c.Segment.Exception != "" { 143 ContextSegment(ctx).RecordError(errors.New(c.Segment.Exception)) 144 } 145 rw.WriteHeader(c.Response.Status) 146 return nil 147 } 148 ) 149 150 ctx = middleware.WithTrace(ctx, c.Trace.TraceID, c.Trace.SpanID, c.Trace.ParentID) 151 if c.Request.UserAgent != "" { 152 req.Header.Set("User-Agent", c.Request.UserAgent) 153 } 154 if c.Request.IP != "" { 155 req.Header.Set("X-Forwarded-For", c.Request.IP) 156 } 157 if c.Request.RemoteAddr != "" { 158 req.RemoteAddr = c.Request.RemoteAddr 159 } 160 if c.Request.Host != "" { 161 req.Host = c.Request.Host 162 } 163 164 js := readUDP(t, func() { 165 m(h)(ctx, goa.ContextResponse(ctx), req) 166 }) 167 168 var s *Segment 169 elems := strings.Split(js, "\n") 170 if len(elems) != 2 { 171 t.Fatalf("%s: invalid number of lines, expected 2 got %d: %v", k, len(elems), elems) 172 } 173 if elems[0] != udpHeader[:len(udpHeader)-1] { 174 t.Errorf("%s: invalid header, got %s", k, elems[0]) 175 } 176 err = json.Unmarshal([]byte(elems[1]), &s) 177 if err != nil { 178 t.Fatal(err) 179 } 180 181 if s.Name != "service" { 182 t.Errorf("%s: unexpected segment name, expected service - got %s", k, s.Name) 183 } 184 if s.Type != "" { 185 t.Errorf("%s: expected Type to be empty but got %s", k, s.Type) 186 } 187 if s.ID != c.Trace.SpanID { 188 t.Errorf("%s: unexpected segment ID, expected %s - got %s", k, c.Trace.SpanID, s.ID) 189 } 190 if s.TraceID != c.Trace.TraceID { 191 t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID) 192 } 193 if s.ParentID != c.Trace.ParentID { 194 t.Errorf("%s: unexpected parent ID, expected %s - got %s", k, c.Trace.ParentID, s.ParentID) 195 } 196 if s.StartTime == 0 { 197 t.Errorf("%s: StartTime is 0", k) 198 } 199 if s.EndTime == 0 { 200 t.Errorf("%s: EndTime is 0", k) 201 } 202 if s.StartTime > s.EndTime { 203 t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime) 204 } 205 if s.HTTP == nil { 206 t.Fatalf("%s: HTTP field is nil", k) 207 } 208 if s.HTTP.Request == nil { 209 t.Fatalf("%s: HTTP Request field is nil", k) 210 } 211 if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP { 212 t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP) 213 } 214 if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost { 215 t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP) 216 } 217 if s.HTTP.Request.Method != c.Request.Method { 218 t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method) 219 } 220 expected := strings.Split(c.Request.URL.String(), "?")[0] 221 if s.HTTP.Request.URL != expected { 222 t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL) 223 } 224 if s.HTTP.Request.UserAgent != c.Request.UserAgent { 225 t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent) 226 } 227 if s.Cause == nil && c.Segment.Exception != "" { 228 t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception) 229 } 230 if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception { 231 t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message) 232 } 233 if s.Error != c.Segment.Error { 234 t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error) 235 } 236 } 237 } 238 239 func TestNewID(t *testing.T) { 240 id := NewID() 241 if len(id) != 16 { 242 t.Errorf("invalid ID length, expected 16 got %d", len(id)) 243 } 244 if !regexp.MustCompile("[0-9a-f]{16}").MatchString(id) { 245 t.Errorf("invalid ID format, should be hexadecimal, got %s", id) 246 } 247 if id == NewID() { 248 t.Errorf("ids not unique") 249 } 250 } 251 252 func TestNewTraceID(t *testing.T) { 253 id := NewTraceID() 254 if len(id) != 35 { 255 t.Errorf("invalid ID length, expected 35 got %d", len(id)) 256 } 257 if !regexp.MustCompile("1-[0-9a-f]{8}-[0-9a-f]{16}").MatchString(id) { 258 t.Errorf("invalid Trace ID format, got %s", id) 259 } 260 if id == NewTraceID() { 261 t.Errorf("trace ids not unique") 262 } 263 } 264 265 func TestPeriodicallyRedialingConn(t *testing.T) { 266 267 t.Run("dial fails, returns error immediately", func(t *testing.T) { 268 dialErr := errors.New("dialErr") 269 _, err := periodicallyRedialingConn(context.Background(), time.Millisecond, func() (net.Conn, error) { 270 return nil, dialErr 271 }) 272 if err != dialErr { 273 t.Fatalf("Unexpected err, got %q, expected %q", err, dialErr) 274 } 275 }) 276 t.Run("connection gets replaced by new one", func(t *testing.T) { 277 var ( 278 firstConn = &net.UDPConn{} 279 secondConn = &net.UnixConn{} 280 callCount = 0 281 ) 282 wgCheckFirstConnection := sync.WaitGroup{} 283 wgCheckFirstConnection.Add(1) 284 wgThirdDial := sync.WaitGroup{} 285 wgThirdDial.Add(1) 286 dial := func() (net.Conn, error) { 287 callCount++ 288 if callCount == 1 { 289 return firstConn, nil 290 } 291 wgCheckFirstConnection.Wait() 292 if callCount == 3 { 293 wgThirdDial.Done() 294 } 295 return secondConn, nil 296 } 297 298 ctx, cancel := context.WithCancel(context.Background()) 299 defer cancel() 300 conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial) 301 if err != nil { 302 t.Fatalf("Expected nil err but got: %v", err) 303 } 304 305 if c := conn(); c != firstConn { 306 t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn) 307 } 308 wgCheckFirstConnection.Done() 309 310 // by the time the 3rd dial happens, we know conn() should be returning the second connection 311 wgThirdDial.Wait() 312 313 if c := conn(); c != secondConn { 314 t.Fatalf("Unexpected second connection: got %#v, expected %#v", c, secondConn) 315 } 316 }) 317 t.Run("connection not replaced if dial errored", func(t *testing.T) { 318 var ( 319 firstConn = &net.UDPConn{} 320 callCount = 0 321 ) 322 wgCheckFirstConnection := sync.WaitGroup{} 323 wgCheckFirstConnection.Add(1) 324 wgThirdDial := sync.WaitGroup{} 325 wgThirdDial.Add(1) 326 dial := func() (net.Conn, error) { 327 callCount++ 328 if callCount == 1 { 329 return firstConn, nil 330 } 331 wgCheckFirstConnection.Wait() 332 if callCount == 3 { 333 wgThirdDial.Done() 334 } 335 return nil, errors.New("dialErr") 336 } 337 338 ctx, cancel := context.WithCancel(context.Background()) 339 defer cancel() 340 conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial) 341 if err != nil { 342 t.Fatalf("Expected nil err but got: %v", err) 343 } 344 345 if c := conn(); c != firstConn { 346 t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn) 347 } 348 wgCheckFirstConnection.Done() 349 350 // by the time the 3rd dial happens, we know the second dial was processed, and shouldn't have replaced conn() 351 wgThirdDial.Wait() 352 353 if c := conn(); c != firstConn { 354 t.Fatalf("Connection unexpectedly replaced: got %#v, expected %#v", c, firstConn) 355 } 356 }) 357 } 358 359 // readUDP calls sender, reads and returns UDP messages received on udplisten. 360 func readUDP(t *testing.T, sender func()) string { 361 var ( 362 readChan = make(chan string) 363 msg = make([]byte, 1024*32) 364 ) 365 resAddr, err := net.ResolveUDPAddr("udp", udplisten) 366 if err != nil { 367 t.Fatal(err) 368 } 369 listener, err := net.ListenUDP("udp", resAddr) 370 if err != nil { 371 t.Fatal(err) 372 } 373 374 go func() { 375 listener.SetReadDeadline(time.Now().Add(time.Second)) 376 n, _, _ := listener.ReadFrom(msg) 377 readChan <- string(msg[0:n]) 378 }() 379 380 sender() 381 382 defer func() { 383 if err := listener.Close(); err != nil { 384 t.Fatal(err) 385 } 386 }() 387 388 return <-readChan 389 }