github.com/zak-blake/goa@v1.4.1/middleware/xray/middleware_test.go (about)

     1  package xray
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"regexp"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/goadesign/goa"
    18  	"github.com/goadesign/goa/middleware"
    19  )
    20  
    21  const (
    22  	// udp host:port used to run test server
    23  	udplisten = "127.0.0.1:62111"
    24  )
    25  
    26  func TestNew(t *testing.T) {
    27  	cases := map[string]struct {
    28  		Daemon  string
    29  		Success bool
    30  	}{
    31  		"ok":     {udplisten, true},
    32  		"not-ok": {"1002.0.0.0:62111", false},
    33  	}
    34  	for k, c := range cases {
    35  		m, err := New("", c.Daemon)
    36  		if err == nil && !c.Success {
    37  			t.Errorf("%s: expected failure but err is nil", k)
    38  		}
    39  		if err != nil && c.Success {
    40  			t.Errorf("%s: unexpected error %s", k, err)
    41  		}
    42  		if m == nil && c.Success {
    43  			t.Errorf("%s: middleware is nil", k)
    44  		}
    45  	}
    46  }
    47  
    48  func TestMiddleware(t *testing.T) {
    49  	type (
    50  		Tra struct {
    51  			TraceID, SpanID, ParentID string
    52  		}
    53  		Req struct {
    54  			Method, Host, IP, RemoteAddr string
    55  			RemoteHost, UserAgent        string
    56  			URL                          *url.URL
    57  		}
    58  		Res struct {
    59  			Status int
    60  		}
    61  		Seg struct {
    62  			Exception string
    63  			Error     bool
    64  		}
    65  	)
    66  	var (
    67  		traceID      = "traceID"
    68  		spanID       = "spanID"
    69  		parentID     = "parentID"
    70  		host         = "goa.design"
    71  		method       = "GET"
    72  		ip           = "104.18.42.42"
    73  		remoteAddr   = "104.18.43.42:443"
    74  		remoteNoPort = "104.18.43.42"
    75  		remoteHost   = "104.18.43.42"
    76  		agent        = "user agent"
    77  		url, _       = url.Parse("https://goa.design/path?query#fragment")
    78  	)
    79  	cases := map[string]struct {
    80  		Trace    Tra
    81  		Request  Req
    82  		Response Res
    83  		Segment  Seg
    84  	}{
    85  		"no-trace": {
    86  			Trace:    Tra{"", "", ""},
    87  			Request:  Req{"", "", "", "", "", "", nil},
    88  			Response: Res{0},
    89  			Segment:  Seg{"", false},
    90  		},
    91  		"basic": {
    92  			Trace:    Tra{traceID, spanID, ""},
    93  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
    94  			Response: Res{http.StatusOK},
    95  			Segment:  Seg{"", false},
    96  		},
    97  		"with-parent": {
    98  			Trace:    Tra{traceID, spanID, parentID},
    99  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   100  			Response: Res{http.StatusOK},
   101  			Segment:  Seg{"", false},
   102  		},
   103  		"without-ip": {
   104  			Trace:    Tra{traceID, spanID, parentID},
   105  			Request:  Req{method, host, "", remoteAddr, remoteHost, agent, url},
   106  			Response: Res{http.StatusOK},
   107  			Segment:  Seg{"", false},
   108  		},
   109  		"without-ip-remote-port": {
   110  			Trace:    Tra{traceID, spanID, parentID},
   111  			Request:  Req{method, host, "", remoteNoPort, remoteHost, agent, url},
   112  			Response: Res{http.StatusOK},
   113  			Segment:  Seg{"", false},
   114  		},
   115  		"error": {
   116  			Trace:    Tra{traceID, spanID, ""},
   117  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   118  			Response: Res{http.StatusBadRequest},
   119  			Segment:  Seg{"error", true},
   120  		},
   121  		"fault": {
   122  			Trace:    Tra{traceID, spanID, ""},
   123  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   124  			Response: Res{http.StatusInternalServerError},
   125  			Segment:  Seg{"", true},
   126  		},
   127  	}
   128  	for k, c := range cases {
   129  		m, err := New("service", udplisten)
   130  		if err != nil {
   131  			t.Fatalf("%s: failed to create middleware: %s", k, err)
   132  		}
   133  		if c.Response.Status == 0 {
   134  			continue
   135  		}
   136  
   137  		var (
   138  			req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil)
   139  			rw     = httptest.NewRecorder()
   140  			ctx    = goa.NewContext(context.Background(), rw, req, nil)
   141  			h      = func(ctx context.Context, rw http.ResponseWriter, _ *http.Request) error {
   142  				if c.Segment.Exception != "" {
   143  					ContextSegment(ctx).RecordError(errors.New(c.Segment.Exception))
   144  				}
   145  				rw.WriteHeader(c.Response.Status)
   146  				return nil
   147  			}
   148  		)
   149  
   150  		ctx = middleware.WithTrace(ctx, c.Trace.TraceID, c.Trace.SpanID, c.Trace.ParentID)
   151  		if c.Request.UserAgent != "" {
   152  			req.Header.Set("User-Agent", c.Request.UserAgent)
   153  		}
   154  		if c.Request.IP != "" {
   155  			req.Header.Set("X-Forwarded-For", c.Request.IP)
   156  		}
   157  		if c.Request.RemoteAddr != "" {
   158  			req.RemoteAddr = c.Request.RemoteAddr
   159  		}
   160  		if c.Request.Host != "" {
   161  			req.Host = c.Request.Host
   162  		}
   163  
   164  		js := readUDP(t, func() {
   165  			m(h)(ctx, goa.ContextResponse(ctx), req)
   166  		})
   167  
   168  		var s *Segment
   169  		elems := strings.Split(js, "\n")
   170  		if len(elems) != 2 {
   171  			t.Fatalf("%s: invalid number of lines, expected 2 got %d: %v", k, len(elems), elems)
   172  		}
   173  		if elems[0] != udpHeader[:len(udpHeader)-1] {
   174  			t.Errorf("%s: invalid header, got %s", k, elems[0])
   175  		}
   176  		err = json.Unmarshal([]byte(elems[1]), &s)
   177  		if err != nil {
   178  			t.Fatal(err)
   179  		}
   180  
   181  		if s.Name != "service" {
   182  			t.Errorf("%s: unexpected segment name, expected service - got %s", k, s.Name)
   183  		}
   184  		if s.Type != "" {
   185  			t.Errorf("%s: expected Type to be empty but got %s", k, s.Type)
   186  		}
   187  		if s.ID != c.Trace.SpanID {
   188  			t.Errorf("%s: unexpected segment ID, expected %s - got %s", k, c.Trace.SpanID, s.ID)
   189  		}
   190  		if s.TraceID != c.Trace.TraceID {
   191  			t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID)
   192  		}
   193  		if s.ParentID != c.Trace.ParentID {
   194  			t.Errorf("%s: unexpected parent ID, expected %s - got %s", k, c.Trace.ParentID, s.ParentID)
   195  		}
   196  		if s.StartTime == 0 {
   197  			t.Errorf("%s: StartTime is 0", k)
   198  		}
   199  		if s.EndTime == 0 {
   200  			t.Errorf("%s: EndTime is 0", k)
   201  		}
   202  		if s.StartTime > s.EndTime {
   203  			t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime)
   204  		}
   205  		if s.HTTP == nil {
   206  			t.Fatalf("%s: HTTP field is nil", k)
   207  		}
   208  		if s.HTTP.Request == nil {
   209  			t.Fatalf("%s: HTTP Request field is nil", k)
   210  		}
   211  		if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP {
   212  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP)
   213  		}
   214  		if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost {
   215  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP)
   216  		}
   217  		if s.HTTP.Request.Method != c.Request.Method {
   218  			t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method)
   219  		}
   220  		expected := strings.Split(c.Request.URL.String(), "?")[0]
   221  		if s.HTTP.Request.URL != expected {
   222  			t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL)
   223  		}
   224  		if s.HTTP.Request.UserAgent != c.Request.UserAgent {
   225  			t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent)
   226  		}
   227  		if s.Cause == nil && c.Segment.Exception != "" {
   228  			t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception)
   229  		}
   230  		if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception {
   231  			t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message)
   232  		}
   233  		if s.Error != c.Segment.Error {
   234  			t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error)
   235  		}
   236  	}
   237  }
   238  
   239  func TestNewID(t *testing.T) {
   240  	id := NewID()
   241  	if len(id) != 16 {
   242  		t.Errorf("invalid ID length, expected 16 got %d", len(id))
   243  	}
   244  	if !regexp.MustCompile("[0-9a-f]{16}").MatchString(id) {
   245  		t.Errorf("invalid ID format, should be hexadecimal, got %s", id)
   246  	}
   247  	if id == NewID() {
   248  		t.Errorf("ids not unique")
   249  	}
   250  }
   251  
   252  func TestNewTraceID(t *testing.T) {
   253  	id := NewTraceID()
   254  	if len(id) != 35 {
   255  		t.Errorf("invalid ID length, expected 35 got %d", len(id))
   256  	}
   257  	if !regexp.MustCompile("1-[0-9a-f]{8}-[0-9a-f]{16}").MatchString(id) {
   258  		t.Errorf("invalid Trace ID format, got %s", id)
   259  	}
   260  	if id == NewTraceID() {
   261  		t.Errorf("trace ids not unique")
   262  	}
   263  }
   264  
   265  func TestPeriodicallyRedialingConn(t *testing.T) {
   266  
   267  	t.Run("dial fails, returns error immediately", func(t *testing.T) {
   268  		dialErr := errors.New("dialErr")
   269  		_, err := periodicallyRedialingConn(context.Background(), time.Millisecond, func() (net.Conn, error) {
   270  			return nil, dialErr
   271  		})
   272  		if err != dialErr {
   273  			t.Fatalf("Unexpected err, got %q, expected %q", err, dialErr)
   274  		}
   275  	})
   276  	t.Run("connection gets replaced by new one", func(t *testing.T) {
   277  		var (
   278  			firstConn  = &net.UDPConn{}
   279  			secondConn = &net.UnixConn{}
   280  			callCount  = 0
   281  		)
   282  		wgCheckFirstConnection := sync.WaitGroup{}
   283  		wgCheckFirstConnection.Add(1)
   284  		wgThirdDial := sync.WaitGroup{}
   285  		wgThirdDial.Add(1)
   286  		dial := func() (net.Conn, error) {
   287  			callCount++
   288  			if callCount == 1 {
   289  				return firstConn, nil
   290  			}
   291  			wgCheckFirstConnection.Wait()
   292  			if callCount == 3 {
   293  				wgThirdDial.Done()
   294  			}
   295  			return secondConn, nil
   296  		}
   297  
   298  		ctx, cancel := context.WithCancel(context.Background())
   299  		defer cancel()
   300  		conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial)
   301  		if err != nil {
   302  			t.Fatalf("Expected nil err but got: %v", err)
   303  		}
   304  
   305  		if c := conn(); c != firstConn {
   306  			t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn)
   307  		}
   308  		wgCheckFirstConnection.Done()
   309  
   310  		// by the time the 3rd dial happens, we know conn() should be returning the second connection
   311  		wgThirdDial.Wait()
   312  
   313  		if c := conn(); c != secondConn {
   314  			t.Fatalf("Unexpected second connection: got %#v, expected %#v", c, secondConn)
   315  		}
   316  	})
   317  	t.Run("connection not replaced if dial errored", func(t *testing.T) {
   318  		var (
   319  			firstConn = &net.UDPConn{}
   320  			callCount = 0
   321  		)
   322  		wgCheckFirstConnection := sync.WaitGroup{}
   323  		wgCheckFirstConnection.Add(1)
   324  		wgThirdDial := sync.WaitGroup{}
   325  		wgThirdDial.Add(1)
   326  		dial := func() (net.Conn, error) {
   327  			callCount++
   328  			if callCount == 1 {
   329  				return firstConn, nil
   330  			}
   331  			wgCheckFirstConnection.Wait()
   332  			if callCount == 3 {
   333  				wgThirdDial.Done()
   334  			}
   335  			return nil, errors.New("dialErr")
   336  		}
   337  
   338  		ctx, cancel := context.WithCancel(context.Background())
   339  		defer cancel()
   340  		conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial)
   341  		if err != nil {
   342  			t.Fatalf("Expected nil err but got: %v", err)
   343  		}
   344  
   345  		if c := conn(); c != firstConn {
   346  			t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn)
   347  		}
   348  		wgCheckFirstConnection.Done()
   349  
   350  		// by the time the 3rd dial happens, we know the second dial was processed, and shouldn't have replaced conn()
   351  		wgThirdDial.Wait()
   352  
   353  		if c := conn(); c != firstConn {
   354  			t.Fatalf("Connection unexpectedly replaced: got %#v, expected %#v", c, firstConn)
   355  		}
   356  	})
   357  }
   358  
   359  // readUDP calls sender, reads and returns UDP messages received on udplisten.
   360  func readUDP(t *testing.T, sender func()) string {
   361  	var (
   362  		readChan = make(chan string)
   363  		msg      = make([]byte, 1024*32)
   364  	)
   365  	resAddr, err := net.ResolveUDPAddr("udp", udplisten)
   366  	if err != nil {
   367  		t.Fatal(err)
   368  	}
   369  	listener, err := net.ListenUDP("udp", resAddr)
   370  	if err != nil {
   371  		t.Fatal(err)
   372  	}
   373  
   374  	go func() {
   375  		listener.SetReadDeadline(time.Now().Add(time.Second))
   376  		n, _, _ := listener.ReadFrom(msg)
   377  		readChan <- string(msg[0:n])
   378  	}()
   379  
   380  	sender()
   381  
   382  	defer func() {
   383  		if err := listener.Close(); err != nil {
   384  			t.Fatal(err)
   385  		}
   386  	}()
   387  
   388  	return <-readChan
   389  }