github.com/brycereitano/goa@v0.0.0-20170315073847-8ffa6c85e265/middleware/xray/middleware_test.go (about) 1 package xray 2 3 import ( 4 "encoding/json" 5 "errors" 6 "net" 7 "net/http" 8 "net/http/httptest" 9 "net/url" 10 "regexp" 11 "strings" 12 "testing" 13 "time" 14 15 "github.com/goadesign/goa" 16 "github.com/goadesign/goa/middleware" 17 18 "golang.org/x/net/context" 19 ) 20 21 const ( 22 // udp host:port used to run test server 23 udplisten = "127.0.0.1:62111" 24 ) 25 26 func TestNew(t *testing.T) { 27 cases := map[string]struct { 28 Daemon string 29 Success bool 30 }{ 31 "ok": {udplisten, true}, 32 "not-ok": {"1002.0.0.0:62111", false}, 33 } 34 for k, c := range cases { 35 m, err := New("", c.Daemon) 36 if err == nil && !c.Success { 37 t.Errorf("%s: expected failure but err is nil", k) 38 } 39 if err != nil && c.Success { 40 t.Errorf("%s: unexpected error %s", k, err) 41 } 42 if m == nil && c.Success { 43 t.Errorf("%s: middleware is nil", k) 44 } 45 } 46 } 47 48 func TestMiddleware(t *testing.T) { 49 type ( 50 Tra struct { 51 TraceID, SpanID, ParentID string 52 } 53 Req struct { 54 Method, Host, IP, RemoteAddr string 55 RemoteHost, UserAgent string 56 URL *url.URL 57 } 58 Res struct { 59 Status int 60 } 61 Seg struct { 62 Exception string 63 Error bool 64 } 65 ) 66 var ( 67 traceID = "traceID" 68 spanID = "spanID" 69 parentID = "parentID" 70 host = "goa.design" 71 method = "GET" 72 ip = "104.18.42.42" 73 remoteAddr = "104.18.43.42:443" 74 remoteNoPort = "104.18.43.42" 75 remoteHost = "104.18.43.42" 76 agent = "user agent" 77 url, _ = url.Parse("https://goa.design/path?query#fragment") 78 ) 79 cases := map[string]struct { 80 Trace Tra 81 Request Req 82 Response Res 83 Segment Seg 84 }{ 85 "no-trace": { 86 Trace: Tra{"", "", ""}, 87 Request: Req{"", "", "", "", "", "", nil}, 88 Response: Res{0}, 89 Segment: Seg{"", false}, 90 }, 91 "basic": { 92 Trace: Tra{traceID, spanID, ""}, 93 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 94 Response: Res{http.StatusOK}, 95 Segment: Seg{"", false}, 96 }, 97 "with-parent": { 98 Trace: Tra{traceID, spanID, parentID}, 99 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 100 Response: Res{http.StatusOK}, 101 Segment: Seg{"", false}, 102 }, 103 "without-ip": { 104 Trace: Tra{traceID, spanID, parentID}, 105 Request: Req{method, host, "", remoteAddr, remoteHost, agent, url}, 106 Response: Res{http.StatusOK}, 107 Segment: Seg{"", false}, 108 }, 109 "without-ip-remote-port": { 110 Trace: Tra{traceID, spanID, parentID}, 111 Request: Req{method, host, "", remoteNoPort, remoteHost, agent, url}, 112 Response: Res{http.StatusOK}, 113 Segment: Seg{"", false}, 114 }, 115 "error": { 116 Trace: Tra{traceID, spanID, ""}, 117 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 118 Response: Res{http.StatusBadRequest}, 119 Segment: Seg{"error", false}, 120 }, 121 "fault": { 122 Trace: Tra{traceID, spanID, ""}, 123 Request: Req{method, host, ip, remoteAddr, remoteHost, agent, url}, 124 Response: Res{http.StatusInternalServerError}, 125 Segment: Seg{"", true}, 126 }, 127 } 128 for k, c := range cases { 129 m, err := New("service", udplisten) 130 if err != nil { 131 t.Fatalf("%s: failed to create middleware: %s", k, err) 132 } 133 if c.Response.Status == 0 { 134 continue 135 } 136 137 var ( 138 req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil) 139 rw = httptest.NewRecorder() 140 ctx = goa.NewContext(context.Background(), rw, req, nil) 141 h = func(ctx context.Context, rw http.ResponseWriter, _ *http.Request) error { 142 if c.Segment.Exception != "" { 143 ContextSegment(ctx).RecordError(errors.New(c.Segment.Exception)) 144 } 145 rw.WriteHeader(c.Response.Status) 146 return nil 147 } 148 ) 149 ctx = middleware.WithTrace(ctx, c.Trace.TraceID, c.Trace.SpanID, c.Trace.ParentID) 150 if c.Request.UserAgent != "" { 151 req.Header.Set("User-Agent", c.Request.UserAgent) 152 } 153 if c.Request.IP != "" { 154 req.Header.Set("X-Forwarded-For", c.Request.IP) 155 } 156 if c.Request.RemoteAddr != "" { 157 req.RemoteAddr = c.Request.RemoteAddr 158 } 159 if c.Request.Host != "" { 160 req.Host = c.Request.Host 161 } 162 163 js := readUDP(t, func() { 164 m(h)(ctx, goa.ContextResponse(ctx), req) 165 }) 166 167 var s *Segment 168 elems := strings.Split(js, "\n") 169 if len(elems) != 2 { 170 t.Fatalf("%s: invalid number of lines, expected 2 got %d: %v", k, len(elems), elems) 171 } 172 if elems[0] != udpHeader[:len(udpHeader)-1] { 173 t.Errorf("%s: invalid header, got %s", k, elems[0]) 174 } 175 err = json.Unmarshal([]byte(elems[1]), &s) 176 if err != nil { 177 t.Fatal(err) 178 } 179 180 if s.Name != "service" { 181 t.Errorf("%s: unexpected segment name, expected service - got %s", k, s.Name) 182 } 183 if c.Trace.ParentID == "" && s.Type != "" { 184 t.Errorf("%s: expected Type to be empty but got %s", k, s.Type) 185 } 186 if c.Trace.ParentID != "" && s.Type != "subsegment" { 187 t.Errorf("%s: expected Type to subsegment but got %s", k, s.Type) 188 } 189 if s.ID != c.Trace.SpanID { 190 t.Errorf("%s: unexpected segment ID, expected %s - got %s", k, c.Trace.SpanID, s.ID) 191 } 192 if s.TraceID != c.Trace.TraceID { 193 t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID) 194 } 195 if s.ParentID != c.Trace.ParentID { 196 t.Errorf("%s: unexpected parent ID, expected %s - got %s", k, c.Trace.ParentID, s.ParentID) 197 } 198 if s.StartTime == 0 { 199 t.Errorf("%s: StartTime is 0", k) 200 } 201 if s.EndTime == 0 { 202 t.Errorf("%s: EndTime is 0", k) 203 } 204 if s.StartTime > s.EndTime { 205 t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime) 206 } 207 if s.HTTP == nil { 208 t.Fatalf("%s: HTTP field is nil", k) 209 } 210 if s.HTTP.Request == nil { 211 t.Fatalf("%s: HTTP Request field is nil", k) 212 } 213 if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP { 214 t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP) 215 } 216 if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost { 217 t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP) 218 } 219 if s.HTTP.Request.Method != c.Request.Method { 220 t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method) 221 } 222 expected := strings.Split(c.Request.URL.String(), "?")[0] 223 if s.HTTP.Request.URL != expected { 224 t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL) 225 } 226 if s.HTTP.Request.UserAgent != c.Request.UserAgent { 227 t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent) 228 } 229 if (s.Cause == nil && c.Segment.Exception != "") || (s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception) { 230 t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message) 231 } 232 if s.Error != c.Segment.Error { 233 t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error) 234 } 235 } 236 } 237 238 func TestNewID(t *testing.T) { 239 id := NewID() 240 if len(id) != 16 { 241 t.Errorf("invalid ID length, expected 16 got %d", len(id)) 242 } 243 if !regexp.MustCompile("[0-9a-f]{16}").MatchString(id) { 244 t.Errorf("invalid ID format, should be hexadecimal, got %s", id) 245 } 246 if id == NewID() { 247 t.Errorf("ids not unique") 248 } 249 } 250 251 func TestNewTraceID(t *testing.T) { 252 id := NewTraceID() 253 if len(id) != 35 { 254 t.Errorf("invalid ID length, expected 35 got %d", len(id)) 255 } 256 if !regexp.MustCompile("1-[0-9a-f]{8}-[0-9a-f]{16}").MatchString(id) { 257 t.Errorf("invalid Trace ID format, got %s", id) 258 } 259 if id == NewTraceID() { 260 t.Errorf("trace ids not unique") 261 } 262 } 263 264 // readUDP calls sender, reads and returns UDP messages received on udplisten. 265 func readUDP(t *testing.T, sender func()) string { 266 var ( 267 readChan = make(chan string) 268 msg = make([]byte, 1024*32) 269 ) 270 resAddr, err := net.ResolveUDPAddr("udp", udplisten) 271 if err != nil { 272 t.Fatal(err) 273 } 274 listener, err := net.ListenUDP("udp", resAddr) 275 if err != nil { 276 t.Fatal(err) 277 } 278 279 go func() { 280 listener.SetReadDeadline(time.Now().Add(time.Second)) 281 n, _, _ := listener.ReadFrom(msg) 282 readChan <- string(msg[0:n]) 283 }() 284 285 sender() 286 287 defer func() { 288 if err := listener.Close(); err != nil { 289 t.Fatal(err) 290 } 291 }() 292 293 return <-readChan 294 }