github.com/ManabuSeki/goa-v1@v1.4.3/middleware/xray/middleware_test.go (about)

     1  package xray
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"regexp"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/goadesign/goa"
    18  	"github.com/goadesign/goa/middleware"
    19  )
    20  
    21  const (
    22  	// udp host:port used to run test server
    23  	udplisten = "127.0.0.1:62111"
    24  )
    25  
    26  func TestNew(t *testing.T) {
    27  	cases := map[string]struct {
    28  		Daemon  string
    29  		Success bool
    30  	}{
    31  		"ok":     {udplisten, true},
    32  		"not-ok": {"1002.0.0.0:62111", false},
    33  	}
    34  	for k, c := range cases {
    35  		m, err := New("", c.Daemon)
    36  		if err == nil && !c.Success {
    37  			t.Errorf("%s: expected failure but err is nil", k)
    38  		}
    39  		if err != nil && c.Success {
    40  			t.Errorf("%s: unexpected error %s", k, err)
    41  		}
    42  		if m == nil && c.Success {
    43  			t.Errorf("%s: middleware is nil", k)
    44  		}
    45  	}
    46  }
    47  
    48  func TestMiddleware(t *testing.T) {
    49  	type (
    50  		Tra struct {
    51  			TraceID, SpanID, ParentID string
    52  		}
    53  		Req struct {
    54  			Method, Host, IP, RemoteAddr string
    55  			RemoteHost, UserAgent        string
    56  			URL                          *url.URL
    57  		}
    58  		Res struct {
    59  			Status int
    60  		}
    61  		Seg struct {
    62  			Exception string
    63  			Error     bool
    64  		}
    65  	)
    66  	var (
    67  		traceID      = "traceID"
    68  		spanID       = "spanID"
    69  		parentID     = "parentID"
    70  		host         = "goa.design"
    71  		method       = "GET"
    72  		ip           = "104.18.42.42"
    73  		remoteAddr   = "104.18.43.42:443"
    74  		remoteNoPort = "104.18.43.42"
    75  		remoteHost   = "104.18.43.42"
    76  		agent        = "user agent"
    77  		url, _       = url.Parse("https://goa.design/path?query#fragment")
    78  	)
    79  	cases := map[string]struct {
    80  		Trace    Tra
    81  		Request  Req
    82  		Response Res
    83  		Segment  Seg
    84  	}{
    85  		"no-trace": {
    86  			Trace:    Tra{"", "", ""},
    87  			Request:  Req{"", "", "", "", "", "", nil},
    88  			Response: Res{0},
    89  			Segment:  Seg{"", false},
    90  		},
    91  		"basic": {
    92  			Trace:    Tra{traceID, spanID, ""},
    93  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
    94  			Response: Res{http.StatusOK},
    95  			Segment:  Seg{"", false},
    96  		},
    97  		"with-parent": {
    98  			Trace:    Tra{traceID, spanID, parentID},
    99  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   100  			Response: Res{http.StatusOK},
   101  			Segment:  Seg{"", false},
   102  		},
   103  		"without-ip": {
   104  			Trace:    Tra{traceID, spanID, parentID},
   105  			Request:  Req{method, host, "", remoteAddr, remoteHost, agent, url},
   106  			Response: Res{http.StatusOK},
   107  			Segment:  Seg{"", false},
   108  		},
   109  		"without-ip-remote-port": {
   110  			Trace:    Tra{traceID, spanID, parentID},
   111  			Request:  Req{method, host, "", remoteNoPort, remoteHost, agent, url},
   112  			Response: Res{http.StatusOK},
   113  			Segment:  Seg{"", false},
   114  		},
   115  		"error": {
   116  			Trace:    Tra{traceID, spanID, ""},
   117  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   118  			Response: Res{http.StatusBadRequest},
   119  			Segment:  Seg{"error", true},
   120  		},
   121  		"fault": {
   122  			Trace:    Tra{traceID, spanID, ""},
   123  			Request:  Req{method, host, ip, remoteAddr, remoteHost, agent, url},
   124  			Response: Res{http.StatusInternalServerError},
   125  			Segment:  Seg{"", true},
   126  		},
   127  	}
   128  	for k, c := range cases {
   129  		m, err := New("service", udplisten)
   130  		if err != nil {
   131  			t.Fatalf("%s: failed to create middleware: %s", k, err)
   132  		}
   133  		if c.Response.Status == 0 {
   134  			continue
   135  		}
   136  
   137  		var (
   138  			req, _ = http.NewRequest(c.Request.Method, c.Request.URL.String(), nil)
   139  			rw     = httptest.NewRecorder()
   140  			ctx    = goa.NewContext(context.Background(), rw, req, nil)
   141  			h      = func(ctx context.Context, rw http.ResponseWriter, _ *http.Request) error {
   142  				if c.Segment.Exception != "" {
   143  					ContextSegment(ctx).RecordError(errors.New(c.Segment.Exception))
   144  				}
   145  				rw.WriteHeader(c.Response.Status)
   146  				return nil
   147  			}
   148  		)
   149  
   150  		ctx = middleware.WithTrace(ctx, c.Trace.TraceID, c.Trace.SpanID, c.Trace.ParentID)
   151  		if c.Request.UserAgent != "" {
   152  			req.Header.Set("User-Agent", c.Request.UserAgent)
   153  		}
   154  		if c.Request.IP != "" {
   155  			req.Header.Set("X-Forwarded-For", c.Request.IP)
   156  		}
   157  		if c.Request.RemoteAddr != "" {
   158  			req.RemoteAddr = c.Request.RemoteAddr
   159  		}
   160  		if c.Request.Host != "" {
   161  			req.Host = c.Request.Host
   162  		}
   163  
   164  		messages := readUDP(t, 2, func() {
   165  			m(h)(ctx, goa.ContextResponse(ctx), req)
   166  		})
   167  
   168  		// expect the first message is InProgress
   169  		s := extractSegment(t, messages[0])
   170  		if !s.InProgress {
   171  			t.Fatalf("%s: expected first segment to be InProgress but it was not", k)
   172  		}
   173  
   174  		// second message
   175  		s = extractSegment(t, messages[1])
   176  		if s.Name != "service" {
   177  			t.Errorf("%s: unexpected segment name, expected service - got %s", k, s.Name)
   178  		}
   179  		if s.Type != "" {
   180  			t.Errorf("%s: expected Type to be empty but got %s", k, s.Type)
   181  		}
   182  		if s.ID != c.Trace.SpanID {
   183  			t.Errorf("%s: unexpected segment ID, expected %s - got %s", k, c.Trace.SpanID, s.ID)
   184  		}
   185  		if s.TraceID != c.Trace.TraceID {
   186  			t.Errorf("%s: unexpected trace ID, expected %s - got %s", k, c.Trace.TraceID, s.TraceID)
   187  		}
   188  		if s.ParentID != c.Trace.ParentID {
   189  			t.Errorf("%s: unexpected parent ID, expected %s - got %s", k, c.Trace.ParentID, s.ParentID)
   190  		}
   191  		if s.StartTime == 0 {
   192  			t.Errorf("%s: StartTime is 0", k)
   193  		}
   194  		if s.EndTime == 0 {
   195  			t.Errorf("%s: EndTime is 0", k)
   196  		}
   197  		if s.StartTime > s.EndTime {
   198  			t.Errorf("%s: StartTime (%v) is after EndTime (%v)", k, s.StartTime, s.EndTime)
   199  		}
   200  		if s.HTTP == nil {
   201  			t.Fatalf("%s: HTTP field is nil", k)
   202  		}
   203  		if s.HTTP.Request == nil {
   204  			t.Fatalf("%s: HTTP Request field is nil", k)
   205  		}
   206  		if c.Request.IP != "" && s.HTTP.Request.ClientIP != c.Request.IP {
   207  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected %#v got %#v", k, c.Request.IP, s.HTTP.Request.ClientIP)
   208  		}
   209  		if c.Request.IP == "" && s.HTTP.Request.ClientIP != c.Request.RemoteHost {
   210  			t.Errorf("%s: HTTP Request ClientIP is invalid, expected host %#v got %#v", k, c.Request.RemoteHost, s.HTTP.Request.ClientIP)
   211  		}
   212  		if s.HTTP.Request.Method != c.Request.Method {
   213  			t.Errorf("%s: HTTP Request Method is invalid, expected %#v got %#v", k, c.Request.Method, s.HTTP.Request.Method)
   214  		}
   215  		expected := strings.Split(c.Request.URL.String(), "?")[0]
   216  		if s.HTTP.Request.URL != expected {
   217  			t.Errorf("%s: HTTP Request URL is invalid, expected %#v got %#v", k, expected, s.HTTP.Request.URL)
   218  		}
   219  		if s.HTTP.Request.UserAgent != c.Request.UserAgent {
   220  			t.Errorf("%s: HTTP Request UserAgent is invalid, expected %#v got %#v", k, c.Request.UserAgent, s.HTTP.Request.UserAgent)
   221  		}
   222  		if s.Cause == nil && c.Segment.Exception != "" {
   223  			t.Errorf("%s: Exception is invalid, expected %v but got nil Cause", k, c.Segment.Exception)
   224  		}
   225  		if s.Cause != nil && s.Cause.Exceptions[0].Message != c.Segment.Exception {
   226  			t.Errorf("%s: Exception is invalid, expected %v got %v", k, c.Segment.Exception, s.Cause.Exceptions[0].Message)
   227  		}
   228  		if s.Error != c.Segment.Error {
   229  			t.Errorf("%s: Error is invalid, expected %v got %v", k, c.Segment.Error, s.Error)
   230  		}
   231  	}
   232  }
   233  
   234  func TestNewID(t *testing.T) {
   235  	id := NewID()
   236  	if len(id) != 16 {
   237  		t.Errorf("invalid ID length, expected 16 got %d", len(id))
   238  	}
   239  	if !regexp.MustCompile("[0-9a-f]{16}").MatchString(id) {
   240  		t.Errorf("invalid ID format, should be hexadecimal, got %s", id)
   241  	}
   242  	if id == NewID() {
   243  		t.Errorf("ids not unique")
   244  	}
   245  }
   246  
   247  func TestNewTraceID(t *testing.T) {
   248  	id := NewTraceID()
   249  	if len(id) != 35 {
   250  		t.Errorf("invalid ID length, expected 35 got %d", len(id))
   251  	}
   252  	if !regexp.MustCompile("1-[0-9a-f]{8}-[0-9a-f]{16}").MatchString(id) {
   253  		t.Errorf("invalid Trace ID format, got %s", id)
   254  	}
   255  	if id == NewTraceID() {
   256  		t.Errorf("trace ids not unique")
   257  	}
   258  }
   259  
   260  func TestPeriodicallyRedialingConn(t *testing.T) {
   261  
   262  	t.Run("dial fails, returns error immediately", func(t *testing.T) {
   263  		dialErr := errors.New("dialErr")
   264  		_, err := periodicallyRedialingConn(context.Background(), time.Millisecond, func() (net.Conn, error) {
   265  			return nil, dialErr
   266  		})
   267  		if err != dialErr {
   268  			t.Fatalf("Unexpected err, got %q, expected %q", err, dialErr)
   269  		}
   270  	})
   271  	t.Run("connection gets replaced by new one", func(t *testing.T) {
   272  		var (
   273  			firstConn  = &net.UDPConn{}
   274  			secondConn = &net.UnixConn{}
   275  			callCount  = 0
   276  		)
   277  		wgCheckFirstConnection := sync.WaitGroup{}
   278  		wgCheckFirstConnection.Add(1)
   279  		wgThirdDial := sync.WaitGroup{}
   280  		wgThirdDial.Add(1)
   281  		dial := func() (net.Conn, error) {
   282  			callCount++
   283  			if callCount == 1 {
   284  				return firstConn, nil
   285  			}
   286  			wgCheckFirstConnection.Wait()
   287  			if callCount == 3 {
   288  				wgThirdDial.Done()
   289  			}
   290  			return secondConn, nil
   291  		}
   292  
   293  		ctx, cancel := context.WithCancel(context.Background())
   294  		defer cancel()
   295  		conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial)
   296  		if err != nil {
   297  			t.Fatalf("Expected nil err but got: %v", err)
   298  		}
   299  
   300  		if c := conn(); c != firstConn {
   301  			t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn)
   302  		}
   303  		wgCheckFirstConnection.Done()
   304  
   305  		// by the time the 3rd dial happens, we know conn() should be returning the second connection
   306  		wgThirdDial.Wait()
   307  
   308  		if c := conn(); c != secondConn {
   309  			t.Fatalf("Unexpected second connection: got %#v, expected %#v", c, secondConn)
   310  		}
   311  	})
   312  	t.Run("connection not replaced if dial errored", func(t *testing.T) {
   313  		var (
   314  			firstConn = &net.UDPConn{}
   315  			callCount = 0
   316  		)
   317  		wgCheckFirstConnection := sync.WaitGroup{}
   318  		wgCheckFirstConnection.Add(1)
   319  		wgThirdDial := sync.WaitGroup{}
   320  		wgThirdDial.Add(1)
   321  		dial := func() (net.Conn, error) {
   322  			callCount++
   323  			if callCount == 1 {
   324  				return firstConn, nil
   325  			}
   326  			wgCheckFirstConnection.Wait()
   327  			if callCount == 3 {
   328  				wgThirdDial.Done()
   329  			}
   330  			return nil, errors.New("dialErr")
   331  		}
   332  
   333  		ctx, cancel := context.WithCancel(context.Background())
   334  		defer cancel()
   335  		conn, err := periodicallyRedialingConn(ctx, time.Millisecond, dial)
   336  		if err != nil {
   337  			t.Fatalf("Expected nil err but got: %v", err)
   338  		}
   339  
   340  		if c := conn(); c != firstConn {
   341  			t.Fatalf("Unexpected first connection: got %#v, expected %#v", c, firstConn)
   342  		}
   343  		wgCheckFirstConnection.Done()
   344  
   345  		// by the time the 3rd dial happens, we know the second dial was processed, and shouldn't have replaced conn()
   346  		wgThirdDial.Wait()
   347  
   348  		if c := conn(); c != firstConn {
   349  			t.Fatalf("Connection unexpectedly replaced: got %#v, expected %#v", c, firstConn)
   350  		}
   351  	})
   352  }
   353  
   354  // readUDP calls sender, reads and returns UDP messages received on udplisten.
   355  // Verifies that exactly the expected number of messages are received.
   356  func readUDP(t *testing.T, expectedMessages int, sender func()) []string {
   357  	var (
   358  		readChan = make(chan []string)
   359  		msg      = make([]byte, 1024*32)
   360  	)
   361  	resAddr, err := net.ResolveUDPAddr("udp", udplisten)
   362  	if err != nil {
   363  		t.Fatal(err)
   364  	}
   365  	listener, err := net.ListenUDP("udp", resAddr)
   366  	if err != nil {
   367  		t.Fatal(err)
   368  	}
   369  
   370  	go func() {
   371  		listener.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
   372  		var messages []string
   373  		for {
   374  			n, _, err := listener.ReadFrom(msg)
   375  			if err != nil {
   376  				if !strings.HasSuffix(err.Error(), "i/o timeout") {
   377  					t.Errorf("expected final timeout error but got: %s", err)
   378  				}
   379  				break // we're done
   380  			}
   381  			messages = append(messages, string(msg[0:n]))
   382  		}
   383  		if len(messages) != expectedMessages {
   384  			t.Errorf("unexpected number of messages, expected %d got %d. All messages:\n%s",
   385  				expectedMessages, len(messages), strings.Join(messages, "\n"))
   386  		}
   387  		readChan <- messages
   388  	}()
   389  
   390  	sender()
   391  
   392  	defer func() {
   393  		if err := listener.Close(); err != nil {
   394  			t.Fatal(err)
   395  		}
   396  	}()
   397  
   398  	return <-readChan
   399  }
   400  
   401  // extractSegment returns the unmarshalled segment JSON from a readUDP response.
   402  func extractSegment(t *testing.T, js string) *Segment {
   403  	t.Helper()
   404  
   405  	var s *Segment
   406  	elems := strings.Split(js, "\n")
   407  	if len(elems) != 2 {
   408  		t.Fatalf("invalid number of lines, expected 2 got %d: %v", len(elems), elems)
   409  	}
   410  	if elems[0] != udpHeader[:len(udpHeader)-1] {
   411  		t.Errorf("invalid header, got %s", elems[0])
   412  	}
   413  	err := json.Unmarshal([]byte(elems[1]), &s)
   414  	if err != nil {
   415  		t.Fatal(err)
   416  	}
   417  	return s
   418  }