github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/serviceregistration/checks/client_test.go (about)

     1  package checks
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/hashicorp/nomad/ci"
    15  	"github.com/hashicorp/nomad/helper/freeport"
    16  	"github.com/hashicorp/nomad/helper/testlog"
    17  	"github.com/hashicorp/nomad/nomad/mock"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  	"github.com/shoenig/test/must"
    20  	"golang.org/x/exp/maps"
    21  	"oss.indeed.com/go/libtime/libtimetest"
    22  )
    23  
    24  func splitURL(u string) (string, string) {
    25  	// get the address and port for http server
    26  	tokens := strings.Split(u, ":")
    27  	addr, port := strings.TrimPrefix(tokens[1], "//"), tokens[2]
    28  	return addr, port
    29  }
    30  
    31  func TestChecker_Do_HTTP(t *testing.T) {
    32  	ci.Parallel(t)
    33  
    34  	// an example response that will be truncated
    35  	tooLong, truncate := bigResponse()
    36  
    37  	// create an http server with various responses
    38  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    39  		switch r.URL.Path {
    40  		case "/fail":
    41  			w.WriteHeader(500)
    42  			_, _ = io.WriteString(w, "500 problem")
    43  		case "/hang":
    44  			time.Sleep(1 * time.Second)
    45  			_, _ = io.WriteString(w, "too slow")
    46  		case "/long-fail":
    47  			w.WriteHeader(500)
    48  			_, _ = io.WriteString(w, tooLong)
    49  		case "/long-not-fail":
    50  			w.WriteHeader(201)
    51  			_, _ = io.WriteString(w, tooLong)
    52  		default:
    53  			w.WriteHeader(200)
    54  			_, _ = io.WriteString(w, "200 ok")
    55  		}
    56  	}))
    57  	defer ts.Close()
    58  
    59  	// get the address and port for http server
    60  	addr, port := splitURL(ts.URL)
    61  
    62  	// create a mock clock so we can assert time is set
    63  	now := time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC)
    64  	clock := libtimetest.NewClockMock(t).NowMock.Return(now)
    65  
    66  	makeQueryContext := func() *QueryContext {
    67  		return &QueryContext{
    68  			ID:               "abc123",
    69  			CustomAddress:    addr,
    70  			ServicePortLabel: port,
    71  			Networks:         nil,
    72  			NetworkStatus:    mock.NewNetworkStatus(addr),
    73  			Ports:            nil,
    74  			Group:            "group",
    75  			Task:             "task",
    76  			Service:          "service",
    77  			Check:            "check",
    78  		}
    79  	}
    80  
    81  	makeQuery := func(
    82  		kind structs.CheckMode,
    83  		path string,
    84  	) *Query {
    85  		return &Query{
    86  			Mode:        kind,
    87  			Type:        "http",
    88  			Timeout:     100 * time.Millisecond,
    89  			AddressMode: "auto",
    90  			PortLabel:   port,
    91  			Protocol:    "http",
    92  			Path:        path,
    93  			Method:      "GET",
    94  		}
    95  	}
    96  
    97  	makeExpResult := func(
    98  		kind structs.CheckMode,
    99  		status structs.CheckStatus,
   100  		code int,
   101  		output string,
   102  	) *structs.CheckQueryResult {
   103  		return &structs.CheckQueryResult{
   104  			ID:         "abc123",
   105  			Mode:       kind,
   106  			Status:     status,
   107  			StatusCode: code,
   108  			Output:     output,
   109  			Timestamp:  now.Unix(),
   110  			Group:      "group",
   111  			Task:       "task",
   112  			Service:    "service",
   113  			Check:      "check",
   114  		}
   115  	}
   116  
   117  	cases := []struct {
   118  		name      string
   119  		qc        *QueryContext
   120  		q         *Query
   121  		expResult *structs.CheckQueryResult
   122  	}{{
   123  		name: "200 healthiness",
   124  		qc:   makeQueryContext(),
   125  		q:    makeQuery(structs.Healthiness, "/"),
   126  		expResult: makeExpResult(
   127  			structs.Healthiness,
   128  			structs.CheckSuccess,
   129  			http.StatusOK,
   130  			"nomad: http ok",
   131  		),
   132  	}, {
   133  		name: "200 readiness",
   134  		qc:   makeQueryContext(),
   135  		q:    makeQuery(structs.Readiness, "/"),
   136  		expResult: makeExpResult(
   137  			structs.Readiness,
   138  			structs.CheckSuccess,
   139  			http.StatusOK,
   140  			"nomad: http ok",
   141  		),
   142  	}, {
   143  		name: "500 healthiness",
   144  		qc:   makeQueryContext(),
   145  		q:    makeQuery(structs.Healthiness, "fail"),
   146  		expResult: makeExpResult(
   147  			structs.Healthiness,
   148  			structs.CheckFailure,
   149  			http.StatusInternalServerError,
   150  			"500 problem",
   151  		),
   152  	}, {
   153  		name: "hang",
   154  		qc:   makeQueryContext(),
   155  		q:    makeQuery(structs.Healthiness, "hang"),
   156  		expResult: makeExpResult(
   157  			structs.Healthiness,
   158  			structs.CheckFailure,
   159  			0,
   160  			fmt.Sprintf(`nomad: Get "%s/hang": context deadline exceeded`, ts.URL),
   161  		),
   162  	}, {
   163  		name: "500 truncate",
   164  		qc:   makeQueryContext(),
   165  		q:    makeQuery(structs.Healthiness, "long-fail"),
   166  		expResult: makeExpResult(
   167  			structs.Healthiness,
   168  			structs.CheckFailure,
   169  			http.StatusInternalServerError,
   170  			truncate,
   171  		),
   172  	}, {
   173  		name: "201 truncate",
   174  		qc:   makeQueryContext(),
   175  		q:    makeQuery(structs.Healthiness, "long-not-fail"),
   176  		expResult: makeExpResult(
   177  			structs.Healthiness,
   178  			structs.CheckSuccess,
   179  			http.StatusCreated,
   180  			truncate,
   181  		),
   182  	}}
   183  
   184  	for _, tc := range cases {
   185  		t.Run(tc.name, func(t *testing.T) {
   186  			logger := testlog.HCLogger(t)
   187  
   188  			c := New(logger)
   189  			c.(*checker).clock = clock
   190  
   191  			ctx := context.Background()
   192  			result := c.Do(ctx, tc.qc, tc.q)
   193  			must.Eq(t, tc.expResult, result)
   194  		})
   195  	}
   196  }
   197  
   198  // bigResponse creates a response payload larger than the maximum outputSizeLimit
   199  // as well as the same response but truncated to length of outputSizeLimit
   200  func bigResponse() (string, string) {
   201  	size := outputSizeLimit + 5
   202  	b := make([]byte, size, size)
   203  	for i := 0; i < size; i++ {
   204  		b[i] = 'a'
   205  	}
   206  	s := string(b)
   207  	return s, s[:outputSizeLimit]
   208  }
   209  
   210  func TestChecker_Do_HTTP_extras(t *testing.T) {
   211  	ci.Parallel(t)
   212  
   213  	// record the method, body, and headers of the request
   214  	var (
   215  		method  string
   216  		body    []byte
   217  		headers map[string][]string
   218  		host    string
   219  	)
   220  
   221  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   222  		method = r.Method
   223  		body, _ = io.ReadAll(r.Body)
   224  		headers = maps.Clone(r.Header)
   225  		host = r.Host
   226  		w.WriteHeader(http.StatusOK)
   227  	}))
   228  	defer ts.Close()
   229  
   230  	// get the address and port for http server
   231  	addr, port := splitURL(ts.URL)
   232  
   233  	// make headers from key-value pairs
   234  	makeHeaders := func(more ...[2]string) http.Header {
   235  		h := make(http.Header)
   236  		for _, extra := range more {
   237  			h.Set(extra[0], extra[1])
   238  		}
   239  		return h
   240  	}
   241  
   242  	encoding := [2]string{"Accept-Encoding", "gzip"}
   243  	agent := [2]string{"User-Agent", "Go-http-client/1.1"}
   244  
   245  	cases := []struct {
   246  		name    string
   247  		method  string
   248  		body    string
   249  		headers http.Header
   250  	}{
   251  		{
   252  			name:    "method GET",
   253  			method:  "GET",
   254  			headers: makeHeaders(encoding, agent),
   255  		},
   256  		{
   257  			name:    "method Get",
   258  			method:  "Get",
   259  			headers: makeHeaders(encoding, agent),
   260  		},
   261  		{
   262  			name:    "method HEAD",
   263  			method:  "HEAD",
   264  			headers: makeHeaders(agent),
   265  		},
   266  		{
   267  			name:   "extra headers",
   268  			method: "GET",
   269  			headers: makeHeaders(encoding, agent,
   270  				[2]string{"X-My-Header", "hello"},
   271  				[2]string{"Authorization", "Basic ZWxhc3RpYzpjaGFuZ2VtZQ=="},
   272  			),
   273  		},
   274  		{
   275  			name:   "host header",
   276  			method: "GET",
   277  			headers: makeHeaders(encoding, agent,
   278  				[2]string{"Host", "hello"},
   279  				[2]string{"Test-Abc", "hello"},
   280  			),
   281  		},
   282  		{
   283  			name:   "host header without normalization",
   284  			method: "GET",
   285  			body:   "",
   286  			// This is needed to prevent header normalization by http.Header.Set
   287  			headers: func() map[string][]string {
   288  				h := makeHeaders(encoding, agent, [2]string{"Test-Abc", "hello"})
   289  				h["hoST"] = []string{"heLLO"}
   290  				return h
   291  			}(),
   292  		},
   293  		{
   294  			name:    "with body",
   295  			method:  "POST",
   296  			headers: makeHeaders(encoding, agent),
   297  			body:    "some payload",
   298  		},
   299  	}
   300  
   301  	for _, tc := range cases {
   302  		qc := &QueryContext{
   303  			ID:               "abc123",
   304  			CustomAddress:    addr,
   305  			ServicePortLabel: port,
   306  			Networks:         nil,
   307  			NetworkStatus:    mock.NewNetworkStatus(addr),
   308  			Ports:            nil,
   309  			Group:            "group",
   310  			Task:             "task",
   311  			Service:          "service",
   312  			Check:            "check",
   313  		}
   314  
   315  		q := &Query{
   316  			Mode:        structs.Healthiness,
   317  			Type:        "http",
   318  			Timeout:     1 * time.Second,
   319  			AddressMode: "auto",
   320  			PortLabel:   port,
   321  			Protocol:    "http",
   322  			Path:        "/",
   323  			Method:      tc.method,
   324  			Headers:     tc.headers,
   325  			Body:        tc.body,
   326  		}
   327  
   328  		t.Run(tc.name, func(t *testing.T) {
   329  			logger := testlog.HCLogger(t)
   330  			c := New(logger)
   331  			ctx := context.Background()
   332  			result := c.Do(ctx, qc, q)
   333  			must.Eq(t, http.StatusOK, result.StatusCode,
   334  				must.Sprintf("test.URL: %s", ts.URL),
   335  				must.Sprintf("headers: %v", tc.headers),
   336  				must.Sprintf("received headers: %v", tc.headers),
   337  			)
   338  			must.Eq(t, tc.method, method)
   339  			must.Eq(t, tc.body, string(body))
   340  
   341  			hostSent := false
   342  
   343  			for key, values := range tc.headers {
   344  				if strings.EqualFold(key, "Host") && len(values) > 0 {
   345  					must.Eq(t, values[0], host)
   346  					hostSent = true
   347  					delete(tc.headers, key)
   348  
   349  				}
   350  			}
   351  			if !hostSent {
   352  				must.Eq(t, nil, tc.headers["Host"])
   353  			}
   354  
   355  			must.Eq(t, tc.headers, headers)
   356  		})
   357  	}
   358  }
   359  
   360  func TestChecker_Do_TCP(t *testing.T) {
   361  	ci.Parallel(t)
   362  
   363  	// create a mock clock so we can assert time is set
   364  	now := time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC)
   365  	clock := libtimetest.NewClockMock(t).NowMock.Return(now)
   366  
   367  	makeQueryContext := func(address string, port int) *QueryContext {
   368  		return &QueryContext{
   369  			ID:               "abc123",
   370  			CustomAddress:    address,
   371  			ServicePortLabel: fmt.Sprintf("%d", port),
   372  			Networks:         nil,
   373  			NetworkStatus:    mock.NewNetworkStatus(address),
   374  			Ports:            nil,
   375  			Group:            "group",
   376  			Task:             "task",
   377  			Service:          "service",
   378  			Check:            "check",
   379  		}
   380  	}
   381  
   382  	makeQuery := func(
   383  		kind structs.CheckMode,
   384  		port int,
   385  	) *Query {
   386  		return &Query{
   387  			Mode:        kind,
   388  			Type:        "tcp",
   389  			Timeout:     100 * time.Millisecond,
   390  			AddressMode: "auto",
   391  			PortLabel:   fmt.Sprintf("%d", port),
   392  		}
   393  	}
   394  
   395  	makeExpResult := func(
   396  		kind structs.CheckMode,
   397  		status structs.CheckStatus,
   398  		output string,
   399  	) *structs.CheckQueryResult {
   400  		return &structs.CheckQueryResult{
   401  			ID:        "abc123",
   402  			Mode:      kind,
   403  			Status:    status,
   404  			Output:    output,
   405  			Timestamp: now.Unix(),
   406  			Group:     "group",
   407  			Task:      "task",
   408  			Service:   "service",
   409  			Check:     "check",
   410  		}
   411  	}
   412  
   413  	ports := freeport.MustTake(3)
   414  	defer freeport.Return(ports)
   415  
   416  	cases := []struct {
   417  		name      string
   418  		qc        *QueryContext
   419  		q         *Query
   420  		tcpMode   string // "ok", "off", "hang"
   421  		tcpPort   int
   422  		expResult *structs.CheckQueryResult
   423  	}{{
   424  		name:    "tcp ok",
   425  		qc:      makeQueryContext("localhost", ports[0]),
   426  		q:       makeQuery(structs.Healthiness, ports[0]),
   427  		tcpMode: "ok",
   428  		tcpPort: ports[0],
   429  		expResult: makeExpResult(
   430  			structs.Healthiness,
   431  			structs.CheckSuccess,
   432  			"nomad: tcp ok",
   433  		),
   434  	}, {
   435  		name:    "tcp not listening",
   436  		qc:      makeQueryContext("127.0.0.1", ports[1]),
   437  		q:       makeQuery(structs.Healthiness, ports[1]),
   438  		tcpMode: "off",
   439  		tcpPort: ports[1],
   440  		expResult: makeExpResult(
   441  			structs.Healthiness,
   442  			structs.CheckFailure,
   443  			fmt.Sprintf("dial tcp 127.0.0.1:%d: connect: connection refused", ports[1]),
   444  		),
   445  	}, {
   446  		name:    "tcp slow accept",
   447  		qc:      makeQueryContext("localhost", ports[2]),
   448  		q:       makeQuery(structs.Healthiness, ports[2]),
   449  		tcpMode: "hang",
   450  		tcpPort: ports[2],
   451  		expResult: makeExpResult(
   452  			structs.Healthiness,
   453  			structs.CheckFailure,
   454  			"dial tcp: lookup localhost: i/o timeout",
   455  		),
   456  	}}
   457  
   458  	for _, tc := range cases {
   459  		t.Run(tc.name, func(t *testing.T) {
   460  			logger := testlog.HCLogger(t)
   461  
   462  			ctx, cancel := context.WithCancel(context.Background())
   463  			defer cancel()
   464  
   465  			c := New(logger)
   466  			c.(*checker).clock = clock
   467  
   468  			switch tc.tcpMode {
   469  			case "ok":
   470  				// simulate tcp server by listening
   471  				tcpServer(t, ctx, tc.tcpPort)
   472  			case "hang":
   473  				// simulate tcp hang by setting an already expired context
   474  				timeout, stop := context.WithDeadline(ctx, now.Add(-1*time.Second))
   475  				defer stop()
   476  				ctx = timeout
   477  			case "off":
   478  				// simulate tcp dead connection by not listening
   479  			}
   480  
   481  			result := c.Do(ctx, tc.qc, tc.q)
   482  			must.Eq(t, tc.expResult, result)
   483  		})
   484  	}
   485  }
   486  
   487  // tcpServer will start a tcp listener that accepts connections and closes them.
   488  // The caller can close the listener by cancelling ctx.
   489  func tcpServer(t *testing.T, ctx context.Context, port int) {
   490  	var lc net.ListenConfig
   491  	l, err := lc.Listen(ctx, "tcp", net.JoinHostPort(
   492  		"localhost", fmt.Sprintf("%d", port),
   493  	))
   494  	must.NoError(t, err, must.Sprint("port", port))
   495  	t.Cleanup(func() {
   496  		_ = l.Close()
   497  	})
   498  
   499  	go func() {
   500  		// caller can stop us by cancelling ctx
   501  		for {
   502  			_, acceptErr := l.Accept()
   503  			if acceptErr != nil {
   504  				return
   505  			}
   506  		}
   507  	}()
   508  }