github.com/rohankumardubey/nomad@v0.11.8/command/agent/http_test.go (about)

     1  package agent
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/hashicorp/go-msgpack/codec"
    19  	"github.com/hashicorp/nomad/api"
    20  	"github.com/hashicorp/nomad/helper"
    21  	"github.com/hashicorp/nomad/helper/testlog"
    22  	"github.com/hashicorp/nomad/nomad/mock"
    23  	"github.com/hashicorp/nomad/nomad/structs"
    24  	"github.com/hashicorp/nomad/nomad/structs/config"
    25  	"github.com/hashicorp/nomad/testutil"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  // makeHTTPServer returns a test server whose logs will be written to
    31  // the passed writer. If the writer is nil, the logs are written to stderr.
    32  func makeHTTPServer(t testing.TB, cb func(c *Config)) *TestAgent {
    33  	return NewTestAgent(t, t.Name(), cb)
    34  }
    35  
    36  func BenchmarkHTTPRequests(b *testing.B) {
    37  	s := makeHTTPServer(b, func(c *Config) {
    38  		c.Client.Enabled = false
    39  	})
    40  	defer s.Shutdown()
    41  
    42  	job := mock.Job()
    43  	var allocs []*structs.Allocation
    44  	count := 1000
    45  	for i := 0; i < count; i++ {
    46  		alloc := mock.Alloc()
    47  		alloc.Job = job
    48  		alloc.JobID = job.ID
    49  		alloc.Name = fmt.Sprintf("my-job.web[%d]", i)
    50  		allocs = append(allocs, alloc)
    51  	}
    52  
    53  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
    54  		return allocs[:count], nil
    55  	}
    56  	b.ResetTimer()
    57  
    58  	b.RunParallel(func(pb *testing.PB) {
    59  		for pb.Next() {
    60  			resp := httptest.NewRecorder()
    61  			req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
    62  			s.Server.wrap(handler)(resp, req)
    63  		}
    64  	})
    65  }
    66  
    67  // TestRootFallthrough tests rootFallthrough handler to
    68  // verify redirect and 404 behavior
    69  func TestRootFallthrough(t *testing.T) {
    70  	t.Parallel()
    71  
    72  	cases := []struct {
    73  		desc         string
    74  		path         string
    75  		expectedPath string
    76  		expectedCode int
    77  	}{
    78  		{
    79  			desc:         "unknown endpoint 404s",
    80  			path:         "/v1/unknown/endpoint",
    81  			expectedCode: 404,
    82  		},
    83  		{
    84  			desc:         "root path redirects to ui",
    85  			path:         "/",
    86  			expectedPath: "/ui/",
    87  			expectedCode: 307,
    88  		},
    89  	}
    90  
    91  	s := makeHTTPServer(t, nil)
    92  	defer s.Shutdown()
    93  
    94  	// setup a client that doesn't follow redirects
    95  	client := &http.Client{
    96  		CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
    97  			return http.ErrUseLastResponse
    98  		},
    99  	}
   100  
   101  	for _, tc := range cases {
   102  		t.Run(tc.desc, func(t *testing.T) {
   103  
   104  			reqURL := fmt.Sprintf("http://%s%s", s.Agent.config.AdvertiseAddrs.HTTP, tc.path)
   105  
   106  			resp, err := client.Get(reqURL)
   107  			require.NoError(t, err)
   108  			require.Equal(t, tc.expectedCode, resp.StatusCode)
   109  
   110  			if tc.expectedPath != "" {
   111  				loc, err := resp.Location()
   112  				require.NoError(t, err)
   113  				require.Equal(t, tc.expectedPath, loc.Path)
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  func TestSetIndex(t *testing.T) {
   120  	t.Parallel()
   121  	resp := httptest.NewRecorder()
   122  	setIndex(resp, 1000)
   123  	header := resp.Header().Get("X-Nomad-Index")
   124  	if header != "1000" {
   125  		t.Fatalf("Bad: %v", header)
   126  	}
   127  	setIndex(resp, 2000)
   128  	if v := resp.Header()["X-Nomad-Index"]; len(v) != 1 {
   129  		t.Fatalf("bad: %#v", v)
   130  	}
   131  }
   132  
   133  func TestSetKnownLeader(t *testing.T) {
   134  	t.Parallel()
   135  	resp := httptest.NewRecorder()
   136  	setKnownLeader(resp, true)
   137  	header := resp.Header().Get("X-Nomad-KnownLeader")
   138  	if header != "true" {
   139  		t.Fatalf("Bad: %v", header)
   140  	}
   141  	resp = httptest.NewRecorder()
   142  	setKnownLeader(resp, false)
   143  	header = resp.Header().Get("X-Nomad-KnownLeader")
   144  	if header != "false" {
   145  		t.Fatalf("Bad: %v", header)
   146  	}
   147  }
   148  
   149  func TestSetLastContact(t *testing.T) {
   150  	t.Parallel()
   151  	resp := httptest.NewRecorder()
   152  	setLastContact(resp, 123456*time.Microsecond)
   153  	header := resp.Header().Get("X-Nomad-LastContact")
   154  	if header != "123" {
   155  		t.Fatalf("Bad: %v", header)
   156  	}
   157  }
   158  
   159  func TestSetMeta(t *testing.T) {
   160  	t.Parallel()
   161  	meta := structs.QueryMeta{
   162  		Index:       1000,
   163  		KnownLeader: true,
   164  		LastContact: 123456 * time.Microsecond,
   165  	}
   166  	resp := httptest.NewRecorder()
   167  	setMeta(resp, &meta)
   168  	header := resp.Header().Get("X-Nomad-Index")
   169  	if header != "1000" {
   170  		t.Fatalf("Bad: %v", header)
   171  	}
   172  	header = resp.Header().Get("X-Nomad-KnownLeader")
   173  	if header != "true" {
   174  		t.Fatalf("Bad: %v", header)
   175  	}
   176  	header = resp.Header().Get("X-Nomad-LastContact")
   177  	if header != "123" {
   178  		t.Fatalf("Bad: %v", header)
   179  	}
   180  }
   181  
   182  func TestSetHeaders(t *testing.T) {
   183  	t.Parallel()
   184  	s := makeHTTPServer(t, nil)
   185  	s.Agent.config.HTTPAPIResponseHeaders = map[string]string{"foo": "bar"}
   186  	defer s.Shutdown()
   187  
   188  	resp := httptest.NewRecorder()
   189  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   190  		return &structs.Job{Name: "foo"}, nil
   191  	}
   192  
   193  	req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   194  	s.Server.wrap(handler)(resp, req)
   195  	header := resp.Header().Get("foo")
   196  
   197  	if header != "bar" {
   198  		t.Fatalf("expected header: %v, actual: %v", "bar", header)
   199  	}
   200  
   201  }
   202  
   203  func TestContentTypeIsJSON(t *testing.T) {
   204  	t.Parallel()
   205  	s := makeHTTPServer(t, nil)
   206  	defer s.Shutdown()
   207  
   208  	resp := httptest.NewRecorder()
   209  
   210  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   211  		return &structs.Job{Name: "foo"}, nil
   212  	}
   213  
   214  	req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   215  	s.Server.wrap(handler)(resp, req)
   216  
   217  	contentType := resp.Header().Get("Content-Type")
   218  
   219  	if contentType != "application/json" {
   220  		t.Fatalf("Content-Type header was not 'application/json'")
   221  	}
   222  }
   223  
   224  func TestWrapNonJSON(t *testing.T) {
   225  	t.Parallel()
   226  	s := makeHTTPServer(t, nil)
   227  	defer s.Shutdown()
   228  
   229  	resp := httptest.NewRecorder()
   230  
   231  	handler := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
   232  		return []byte("test response"), nil
   233  	}
   234  
   235  	req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   236  	s.Server.wrapNonJSON(handler)(resp, req)
   237  
   238  	respBody, _ := ioutil.ReadAll(resp.Body)
   239  	require.Equal(t, respBody, []byte("test response"))
   240  
   241  }
   242  
   243  func TestWrapNonJSON_Error(t *testing.T) {
   244  	t.Parallel()
   245  	s := makeHTTPServer(t, nil)
   246  	defer s.Shutdown()
   247  
   248  	handlerRPCErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
   249  		return nil, structs.NewErrRPCCoded(404, "not found")
   250  	}
   251  
   252  	handlerCodedErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
   253  		return nil, CodedError(422, "unprocessable")
   254  	}
   255  
   256  	// RPC coded error
   257  	{
   258  		resp := httptest.NewRecorder()
   259  		req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   260  		s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
   261  		respBody, _ := ioutil.ReadAll(resp.Body)
   262  		require.Equal(t, []byte("not found"), respBody)
   263  		require.Equal(t, 404, resp.Code)
   264  	}
   265  
   266  	// CodedError
   267  	{
   268  		resp := httptest.NewRecorder()
   269  		req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   270  		s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
   271  		respBody, _ := ioutil.ReadAll(resp.Body)
   272  		require.Equal(t, []byte("unprocessable"), respBody)
   273  		require.Equal(t, 422, resp.Code)
   274  	}
   275  
   276  }
   277  
   278  func TestPrettyPrint(t *testing.T) {
   279  	t.Parallel()
   280  	testPrettyPrint("pretty=1", true, t)
   281  }
   282  
   283  func TestPrettyPrintOff(t *testing.T) {
   284  	t.Parallel()
   285  	testPrettyPrint("pretty=0", false, t)
   286  }
   287  
   288  func TestPrettyPrintBare(t *testing.T) {
   289  	t.Parallel()
   290  	testPrettyPrint("pretty", true, t)
   291  }
   292  
   293  func testPrettyPrint(pretty string, prettyFmt bool, t *testing.T) {
   294  	s := makeHTTPServer(t, nil)
   295  	defer s.Shutdown()
   296  
   297  	r := &structs.Job{Name: "foo"}
   298  
   299  	resp := httptest.NewRecorder()
   300  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   301  		return r, nil
   302  	}
   303  
   304  	urlStr := "/v1/job/foo?" + pretty
   305  	req, _ := http.NewRequest("GET", urlStr, nil)
   306  	s.Server.wrap(handler)(resp, req)
   307  
   308  	var expected bytes.Buffer
   309  	var err error
   310  	if prettyFmt {
   311  		enc := codec.NewEncoder(&expected, structs.JsonHandlePretty)
   312  		err = enc.Encode(r)
   313  		expected.WriteByte('\n')
   314  	} else {
   315  		enc := codec.NewEncoder(&expected, structs.JsonHandle)
   316  		err = enc.Encode(r)
   317  	}
   318  	if err != nil {
   319  		t.Fatalf("failed to encode: %v", err)
   320  	}
   321  	actual, err := ioutil.ReadAll(resp.Body)
   322  	if err != nil {
   323  		t.Fatalf("err: %s", err)
   324  	}
   325  
   326  	if !bytes.Equal(expected.Bytes(), actual) {
   327  		t.Fatalf("bad:\nexpected:\t%q\nactual:\t\t%q", expected.String(), string(actual))
   328  	}
   329  }
   330  
   331  func TestPermissionDenied(t *testing.T) {
   332  	s := makeHTTPServer(t, func(c *Config) {
   333  		c.ACL.Enabled = true
   334  	})
   335  	defer s.Shutdown()
   336  
   337  	{
   338  		resp := httptest.NewRecorder()
   339  		handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   340  			return nil, structs.ErrPermissionDenied
   341  		}
   342  
   343  		req, _ := http.NewRequest("GET", "/v1/job/foo", nil)
   344  		s.Server.wrap(handler)(resp, req)
   345  		assert.Equal(t, resp.Code, 403)
   346  	}
   347  
   348  	// When remote RPC is used the errors have "rpc error: " prependend
   349  	{
   350  		resp := httptest.NewRecorder()
   351  		handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   352  			return nil, fmt.Errorf("rpc error: %v", structs.ErrPermissionDenied)
   353  		}
   354  
   355  		req, _ := http.NewRequest("GET", "/v1/job/foo", nil)
   356  		s.Server.wrap(handler)(resp, req)
   357  		assert.Equal(t, resp.Code, 403)
   358  	}
   359  }
   360  
   361  func TestTokenNotFound(t *testing.T) {
   362  	s := makeHTTPServer(t, func(c *Config) {
   363  		c.ACL.Enabled = true
   364  	})
   365  	defer s.Shutdown()
   366  
   367  	resp := httptest.NewRecorder()
   368  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   369  		return nil, structs.ErrTokenNotFound
   370  	}
   371  
   372  	urlStr := "/v1/job/foo"
   373  	req, _ := http.NewRequest("GET", urlStr, nil)
   374  	s.Server.wrap(handler)(resp, req)
   375  	assert.Equal(t, resp.Code, 403)
   376  }
   377  
   378  func TestParseWait(t *testing.T) {
   379  	t.Parallel()
   380  	resp := httptest.NewRecorder()
   381  	var b structs.QueryOptions
   382  
   383  	req, err := http.NewRequest("GET",
   384  		"/v1/catalog/nodes?wait=60s&index=1000", nil)
   385  	if err != nil {
   386  		t.Fatalf("err: %v", err)
   387  	}
   388  
   389  	if d := parseWait(resp, req, &b); d {
   390  		t.Fatalf("unexpected done")
   391  	}
   392  
   393  	if b.MinQueryIndex != 1000 {
   394  		t.Fatalf("Bad: %v", b)
   395  	}
   396  	if b.MaxQueryTime != 60*time.Second {
   397  		t.Fatalf("Bad: %v", b)
   398  	}
   399  }
   400  
   401  func TestParseWait_InvalidTime(t *testing.T) {
   402  	t.Parallel()
   403  	resp := httptest.NewRecorder()
   404  	var b structs.QueryOptions
   405  
   406  	req, err := http.NewRequest("GET",
   407  		"/v1/catalog/nodes?wait=60foo&index=1000", nil)
   408  	if err != nil {
   409  		t.Fatalf("err: %v", err)
   410  	}
   411  
   412  	if d := parseWait(resp, req, &b); !d {
   413  		t.Fatalf("expected done")
   414  	}
   415  
   416  	if resp.Code != 400 {
   417  		t.Fatalf("bad code: %v", resp.Code)
   418  	}
   419  }
   420  
   421  func TestParseWait_InvalidIndex(t *testing.T) {
   422  	t.Parallel()
   423  	resp := httptest.NewRecorder()
   424  	var b structs.QueryOptions
   425  
   426  	req, err := http.NewRequest("GET",
   427  		"/v1/catalog/nodes?wait=60s&index=foo", nil)
   428  	if err != nil {
   429  		t.Fatalf("err: %v", err)
   430  	}
   431  
   432  	if d := parseWait(resp, req, &b); !d {
   433  		t.Fatalf("expected done")
   434  	}
   435  
   436  	if resp.Code != 400 {
   437  		t.Fatalf("bad code: %v", resp.Code)
   438  	}
   439  }
   440  
   441  func TestParseConsistency(t *testing.T) {
   442  	t.Parallel()
   443  	var b structs.QueryOptions
   444  
   445  	req, err := http.NewRequest("GET",
   446  		"/v1/catalog/nodes?stale", nil)
   447  	if err != nil {
   448  		t.Fatalf("err: %v", err)
   449  	}
   450  
   451  	parseConsistency(req, &b)
   452  	if !b.AllowStale {
   453  		t.Fatalf("Bad: %v", b)
   454  	}
   455  
   456  	b = structs.QueryOptions{}
   457  	req, err = http.NewRequest("GET",
   458  		"/v1/catalog/nodes?consistent", nil)
   459  	if err != nil {
   460  		t.Fatalf("err: %v", err)
   461  	}
   462  
   463  	parseConsistency(req, &b)
   464  	if b.AllowStale {
   465  		t.Fatalf("Bad: %v", b)
   466  	}
   467  }
   468  
   469  func TestParseRegion(t *testing.T) {
   470  	t.Parallel()
   471  	s := makeHTTPServer(t, nil)
   472  	defer s.Shutdown()
   473  
   474  	req, err := http.NewRequest("GET",
   475  		"/v1/jobs?region=foo", nil)
   476  	if err != nil {
   477  		t.Fatalf("err: %v", err)
   478  	}
   479  
   480  	var region string
   481  	s.Server.parseRegion(req, &region)
   482  	if region != "foo" {
   483  		t.Fatalf("bad %s", region)
   484  	}
   485  
   486  	region = ""
   487  	req, err = http.NewRequest("GET", "/v1/jobs", nil)
   488  	if err != nil {
   489  		t.Fatalf("err: %v", err)
   490  	}
   491  
   492  	s.Server.parseRegion(req, &region)
   493  	if region != "global" {
   494  		t.Fatalf("bad %s", region)
   495  	}
   496  }
   497  
   498  func TestParseToken(t *testing.T) {
   499  	t.Parallel()
   500  	s := makeHTTPServer(t, nil)
   501  	defer s.Shutdown()
   502  
   503  	req, err := http.NewRequest("GET", "/v1/jobs", nil)
   504  	req.Header.Add("X-Nomad-Token", "foobar")
   505  	if err != nil {
   506  		t.Fatalf("err: %v", err)
   507  	}
   508  
   509  	var token string
   510  	s.Server.parseToken(req, &token)
   511  	if token != "foobar" {
   512  		t.Fatalf("bad %s", token)
   513  	}
   514  }
   515  
   516  // TestHTTP_VerifyHTTPSClient asserts that a client certificate signed by the
   517  // appropriate CA is required when VerifyHTTPSClient=true.
   518  func TestHTTP_VerifyHTTPSClient(t *testing.T) {
   519  	t.Parallel()
   520  	const (
   521  		cafile  = "../../helper/tlsutil/testdata/ca.pem"
   522  		foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
   523  		fookey  = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
   524  	)
   525  	s := makeHTTPServer(t, func(c *Config) {
   526  		c.Region = "foo" // match the region on foocert
   527  		c.TLSConfig = &config.TLSConfig{
   528  			EnableHTTP:        true,
   529  			VerifyHTTPSClient: true,
   530  			CAFile:            cafile,
   531  			CertFile:          foocert,
   532  			KeyFile:           fookey,
   533  		}
   534  	})
   535  	defer s.Shutdown()
   536  
   537  	reqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
   538  
   539  	// FAIL: Requests that expect 127.0.0.1 as the name should fail
   540  	resp, err := http.Get(reqURL)
   541  	if err == nil {
   542  		resp.Body.Close()
   543  		t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
   544  	}
   545  	urlErr, ok := err.(*url.Error)
   546  	if !ok {
   547  		t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
   548  	}
   549  	hostErr, ok := urlErr.Err.(x509.HostnameError)
   550  	if !ok {
   551  		t.Fatalf("expected a x509.HostnameError but received: %T -> %v", urlErr.Err, urlErr.Err)
   552  	}
   553  	if expected := "127.0.0.1"; hostErr.Host != expected {
   554  		t.Fatalf("expected hostname on error to be %q but found %q", expected, hostErr.Host)
   555  	}
   556  
   557  	// FAIL: Requests that specify a valid hostname but not the CA should
   558  	// fail
   559  	tlsConf := &tls.Config{
   560  		ServerName: "client.regionFoo.nomad",
   561  	}
   562  	transport := &http.Transport{TLSClientConfig: tlsConf}
   563  	client := &http.Client{Transport: transport}
   564  	req, err := http.NewRequest("GET", reqURL, nil)
   565  	if err != nil {
   566  		t.Fatalf("error creating request: %v", err)
   567  	}
   568  	resp, err = client.Do(req)
   569  	if err == nil {
   570  		resp.Body.Close()
   571  		t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
   572  	}
   573  	urlErr, ok = err.(*url.Error)
   574  	if !ok {
   575  		t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
   576  	}
   577  	_, ok = urlErr.Err.(x509.UnknownAuthorityError)
   578  	if !ok {
   579  		t.Fatalf("expected a x509.UnknownAuthorityError but received: %T -> %v", urlErr.Err, urlErr.Err)
   580  	}
   581  
   582  	// FAIL: Requests that specify a valid hostname and CA cert but lack a
   583  	// client certificate should fail
   584  	cacertBytes, err := ioutil.ReadFile(cafile)
   585  	if err != nil {
   586  		t.Fatalf("error reading cacert: %v", err)
   587  	}
   588  	tlsConf.RootCAs = x509.NewCertPool()
   589  	tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
   590  	req, err = http.NewRequest("GET", reqURL, nil)
   591  	if err != nil {
   592  		t.Fatalf("error creating request: %v", err)
   593  	}
   594  	resp, err = client.Do(req)
   595  	if err == nil {
   596  		resp.Body.Close()
   597  		t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
   598  	}
   599  	urlErr, ok = err.(*url.Error)
   600  	if !ok {
   601  		t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
   602  	}
   603  	opErr, ok := urlErr.Err.(*net.OpError)
   604  	if !ok {
   605  		t.Fatalf("expected a *net.OpErr but received: %T -> %v", urlErr.Err, urlErr.Err)
   606  	}
   607  	const badCertificate = "tls: bad certificate" // from crypto/tls/alert.go:52 and RFC 5246 ยง A.3
   608  	if opErr.Err.Error() != badCertificate {
   609  		t.Fatalf("expected tls.alert bad_certificate but received: %q", opErr.Err.Error())
   610  	}
   611  
   612  	// PASS: Requests that specify a valid hostname, CA cert, and client
   613  	// certificate succeed.
   614  	tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   615  		c, err := tls.LoadX509KeyPair(foocert, fookey)
   616  		if err != nil {
   617  			return nil, err
   618  		}
   619  		return &c, nil
   620  	}
   621  	transport = &http.Transport{TLSClientConfig: tlsConf}
   622  	client = &http.Client{Transport: transport}
   623  	req, err = http.NewRequest("GET", reqURL, nil)
   624  	if err != nil {
   625  		t.Fatalf("error creating request: %v", err)
   626  	}
   627  	resp, err = client.Do(req)
   628  	if err != nil {
   629  		t.Fatalf("unexpected error: %v", err)
   630  	}
   631  	resp.Body.Close()
   632  	if resp.StatusCode != 200 {
   633  		t.Fatalf("expected 200 status code but got: %d", resp.StatusCode)
   634  	}
   635  }
   636  
   637  func TestHTTP_VerifyHTTPSClient_AfterConfigReload(t *testing.T) {
   638  	t.Parallel()
   639  	assert := assert.New(t)
   640  
   641  	const (
   642  		cafile   = "../../helper/tlsutil/testdata/ca.pem"
   643  		foocert  = "../../helper/tlsutil/testdata/nomad-bad.pem"
   644  		fookey   = "../../helper/tlsutil/testdata/nomad-bad-key.pem"
   645  		foocert2 = "../../helper/tlsutil/testdata/nomad-foo.pem"
   646  		fookey2  = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
   647  	)
   648  
   649  	agentConfig := &Config{
   650  		TLSConfig: &config.TLSConfig{
   651  			EnableHTTP:        true,
   652  			VerifyHTTPSClient: true,
   653  			CAFile:            cafile,
   654  			CertFile:          foocert,
   655  			KeyFile:           fookey,
   656  		},
   657  	}
   658  
   659  	newConfig := &Config{
   660  		TLSConfig: &config.TLSConfig{
   661  			EnableHTTP:        true,
   662  			VerifyHTTPSClient: true,
   663  			CAFile:            cafile,
   664  			CertFile:          foocert2,
   665  			KeyFile:           fookey2,
   666  		},
   667  	}
   668  
   669  	s := makeHTTPServer(t, func(c *Config) {
   670  		c.TLSConfig = agentConfig.TLSConfig
   671  	})
   672  	defer s.Shutdown()
   673  
   674  	// Make an initial request that should fail.
   675  	// Requests that specify a valid hostname, CA cert, and client
   676  	// certificate succeed.
   677  	tlsConf := &tls.Config{
   678  		ServerName: "client.regionFoo.nomad",
   679  		RootCAs:    x509.NewCertPool(),
   680  		GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   681  			c, err := tls.LoadX509KeyPair(foocert, fookey)
   682  			if err != nil {
   683  				return nil, err
   684  			}
   685  			return &c, nil
   686  		},
   687  	}
   688  
   689  	// HTTPS request should succeed
   690  	httpsReqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
   691  
   692  	cacertBytes, err := ioutil.ReadFile(cafile)
   693  	assert.Nil(err)
   694  	tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
   695  
   696  	transport := &http.Transport{TLSClientConfig: tlsConf}
   697  	client := &http.Client{Transport: transport}
   698  	req, err := http.NewRequest("GET", httpsReqURL, nil)
   699  	assert.Nil(err)
   700  
   701  	// Check that we get an error that the certificate isn't valid for the
   702  	// region we are contacting.
   703  	_, err = client.Do(req)
   704  	assert.Contains(err.Error(), "certificate is valid for")
   705  
   706  	// Reload the TLS configuration==
   707  	assert.Nil(s.Agent.Reload(newConfig))
   708  
   709  	// Requests that specify a valid hostname, CA cert, and client
   710  	// certificate succeed.
   711  	tlsConf = &tls.Config{
   712  		ServerName: "client.regionFoo.nomad",
   713  		RootCAs:    x509.NewCertPool(),
   714  		GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   715  			c, err := tls.LoadX509KeyPair(foocert2, fookey2)
   716  			if err != nil {
   717  				return nil, err
   718  			}
   719  			return &c, nil
   720  		},
   721  	}
   722  
   723  	cacertBytes, err = ioutil.ReadFile(cafile)
   724  	assert.Nil(err)
   725  	tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
   726  
   727  	transport = &http.Transport{TLSClientConfig: tlsConf}
   728  	client = &http.Client{Transport: transport}
   729  	req, err = http.NewRequest("GET", httpsReqURL, nil)
   730  	assert.Nil(err)
   731  
   732  	resp, err := client.Do(req)
   733  	if assert.Nil(err) {
   734  		resp.Body.Close()
   735  		assert.Equal(resp.StatusCode, 200)
   736  	}
   737  }
   738  
   739  // TestHTTPServer_Limits_Error asserts invalid Limits cause errors. This is the
   740  // HTTP counterpart to TestAgent_ServerConfig_Limits_Error.
   741  func TestHTTPServer_Limits_Error(t *testing.T) {
   742  	t.Parallel()
   743  
   744  	cases := []struct {
   745  		tls         bool
   746  		timeout     string
   747  		limit       *int
   748  		expectedErr string
   749  	}{
   750  		{
   751  			tls:         true,
   752  			timeout:     "",
   753  			limit:       nil,
   754  			expectedErr: "error parsing https_handshake_timeout: ",
   755  		},
   756  		{
   757  			tls:         false,
   758  			timeout:     "",
   759  			limit:       nil,
   760  			expectedErr: "error parsing https_handshake_timeout: ",
   761  		},
   762  		{
   763  			tls:         true,
   764  			timeout:     "-1s",
   765  			limit:       nil,
   766  			expectedErr: "https_handshake_timeout must be >= 0",
   767  		},
   768  		{
   769  			tls:         false,
   770  			timeout:     "-1s",
   771  			limit:       nil,
   772  			expectedErr: "https_handshake_timeout must be >= 0",
   773  		},
   774  		{
   775  			tls:         true,
   776  			timeout:     "5s",
   777  			limit:       helper.IntToPtr(-1),
   778  			expectedErr: "http_max_conns_per_client must be >= 0",
   779  		},
   780  		{
   781  			tls:         false,
   782  			timeout:     "5s",
   783  			limit:       helper.IntToPtr(-1),
   784  			expectedErr: "http_max_conns_per_client must be >= 0",
   785  		},
   786  	}
   787  
   788  	for i := range cases {
   789  		tc := cases[i]
   790  		name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
   791  		t.Run(name, func(t *testing.T) {
   792  			t.Parallel()
   793  
   794  			// Use a fake agent since the HTTP server should never start
   795  			agent := &Agent{
   796  				logger: testlog.HCLogger(t),
   797  			}
   798  
   799  			conf := &Config{
   800  				normalizedAddrs: &Addresses{
   801  					HTTP: "localhost:0", // port is never used
   802  				},
   803  				TLSConfig: &config.TLSConfig{
   804  					EnableHTTP: tc.tls,
   805  				},
   806  				Limits: config.Limits{
   807  					HTTPSHandshakeTimeout: tc.timeout,
   808  					HTTPMaxConnsPerClient: tc.limit,
   809  				},
   810  			}
   811  
   812  			srv, err := NewHTTPServer(agent, conf)
   813  			require.Error(t, err)
   814  			require.Nil(t, srv)
   815  			require.Contains(t, err.Error(), tc.expectedErr)
   816  		})
   817  	}
   818  }
   819  
   820  // TestHTTPServer_Limits_OK asserts that all valid limits combinations
   821  // (tls/timeout/conns) work.
   822  func TestHTTPServer_Limits_OK(t *testing.T) {
   823  	t.Parallel()
   824  	const (
   825  		cafile   = "../../helper/tlsutil/testdata/ca.pem"
   826  		foocert  = "../../helper/tlsutil/testdata/nomad-foo.pem"
   827  		fookey   = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
   828  		maxConns = 10 // limit must be < this for testing
   829  	)
   830  
   831  	cases := []struct {
   832  		tls           bool
   833  		timeout       string
   834  		limit         *int
   835  		assertTimeout bool
   836  		assertLimit   bool
   837  	}{
   838  		{
   839  			tls:           false,
   840  			timeout:       "5s",
   841  			limit:         nil,
   842  			assertTimeout: false,
   843  			assertLimit:   false,
   844  		},
   845  		{
   846  			tls:           true,
   847  			timeout:       "5s",
   848  			limit:         nil,
   849  			assertTimeout: true,
   850  			assertLimit:   false,
   851  		},
   852  		{
   853  			tls:           false,
   854  			timeout:       "0",
   855  			limit:         nil,
   856  			assertTimeout: false,
   857  			assertLimit:   false,
   858  		},
   859  		{
   860  			tls:           true,
   861  			timeout:       "0",
   862  			limit:         nil,
   863  			assertTimeout: false,
   864  			assertLimit:   false,
   865  		},
   866  		{
   867  			tls:           false,
   868  			timeout:       "0",
   869  			limit:         helper.IntToPtr(2),
   870  			assertTimeout: false,
   871  			assertLimit:   true,
   872  		},
   873  		{
   874  			tls:           true,
   875  			timeout:       "0",
   876  			limit:         helper.IntToPtr(2),
   877  			assertTimeout: false,
   878  			assertLimit:   true,
   879  		},
   880  		{
   881  			tls:           false,
   882  			timeout:       "5s",
   883  			limit:         helper.IntToPtr(2),
   884  			assertTimeout: false,
   885  			assertLimit:   true,
   886  		},
   887  		{
   888  			tls:           true,
   889  			timeout:       "5s",
   890  			limit:         helper.IntToPtr(2),
   891  			assertTimeout: true,
   892  			assertLimit:   true,
   893  		},
   894  	}
   895  
   896  	assertTimeout := func(t *testing.T, a *TestAgent, assertTimeout bool, timeout string) {
   897  		timeoutDeadline, err := time.ParseDuration(timeout)
   898  		require.NoError(t, err)
   899  
   900  		// Increase deadline to detect timeouts
   901  		deadline := timeoutDeadline + time.Second
   902  
   903  		conn, err := net.DialTimeout("tcp", a.Server.Addr, deadline)
   904  		require.NoError(t, err)
   905  		defer conn.Close()
   906  
   907  		buf := []byte{0}
   908  		readDeadline := time.Now().Add(deadline)
   909  		conn.SetReadDeadline(readDeadline)
   910  		n, err := conn.Read(buf)
   911  		require.Zero(t, n)
   912  		if assertTimeout {
   913  			// Server timeouts == EOF
   914  			require.Equal(t, io.EOF, err)
   915  
   916  			// Perform blocking query to assert timeout is not
   917  			// enabled post-TLS-handshake.
   918  			q := &api.QueryOptions{
   919  				WaitIndex: 10000, // wait a looong time
   920  				WaitTime:  deadline,
   921  			}
   922  
   923  			// Assertions don't require certificate validation
   924  			conf := api.DefaultConfig()
   925  			conf.Address = a.HTTPAddr()
   926  			conf.TLSConfig.Insecure = true
   927  			client, err := api.NewClient(conf)
   928  			require.NoError(t, err)
   929  
   930  			// Assert a blocking query isn't timed out by the
   931  			// handshake timeout
   932  			jobs, meta, err := client.Jobs().List(q)
   933  			require.NoError(t, err)
   934  			require.Len(t, jobs, 0)
   935  			require.Truef(t, meta.RequestTime >= deadline,
   936  				"expected RequestTime (%s) >= Deadline (%s)",
   937  				meta.RequestTime, deadline)
   938  
   939  			return
   940  		}
   941  
   942  		// HTTP Server should *not* have timed out.
   943  		// Now() should always be after the read deadline, but
   944  		// isn't a sufficient assertion for correctness as slow
   945  		// tests may cause this to be true even if the server
   946  		// timed out.
   947  		require.True(t, time.Now().After(readDeadline))
   948  
   949  		testutil.RequireDeadlineErr(t, err)
   950  	}
   951  
   952  	assertNoLimit := func(t *testing.T, addr string) {
   953  		var err error
   954  
   955  		// Create max connections
   956  		conns := make([]net.Conn, maxConns)
   957  		errCh := make(chan error, maxConns)
   958  		for i := 0; i < maxConns; i++ {
   959  			conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
   960  			require.NoError(t, err)
   961  			defer conns[i].Close()
   962  
   963  			go func(i int) {
   964  				buf := []byte{0}
   965  				readDeadline := time.Now().Add(1 * time.Second)
   966  				conns[i].SetReadDeadline(readDeadline)
   967  				n, err := conns[i].Read(buf)
   968  				if n > 0 {
   969  					errCh <- fmt.Errorf("n > 0: %d", n)
   970  					return
   971  				}
   972  				errCh <- err
   973  			}(i)
   974  		}
   975  
   976  		// Now assert each error is a clientside read deadline error
   977  		for i := 0; i < maxConns; i++ {
   978  			select {
   979  			case <-time.After(2 * time.Second):
   980  				t.Fatalf("timed out waiting for conn error %d", i)
   981  			case err := <-errCh:
   982  				testutil.RequireDeadlineErr(t, err)
   983  			}
   984  		}
   985  	}
   986  
   987  	assertLimit := func(t *testing.T, addr string, limit int) {
   988  		var err error
   989  
   990  		// Create limit connections
   991  		conns := make([]net.Conn, limit)
   992  		errCh := make(chan error, limit)
   993  		for i := range conns {
   994  			conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
   995  			require.NoError(t, err)
   996  			defer conns[i].Close()
   997  
   998  			go func(i int) {
   999  				buf := []byte{0}
  1000  				n, err := conns[i].Read(buf)
  1001  				if n > 0 {
  1002  					errCh <- fmt.Errorf("n > 0: %d", n)
  1003  					return
  1004  				}
  1005  				errCh <- err
  1006  			}(i)
  1007  		}
  1008  
  1009  		select {
  1010  		case err := <-errCh:
  1011  			t.Fatalf("unexpected error from connection prior to limit: %T %v", err, err)
  1012  		case <-time.After(500 * time.Millisecond):
  1013  		}
  1014  
  1015  		// Assert a new connection is dropped
  1016  		conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
  1017  		require.NoError(t, err)
  1018  		defer conn.Close()
  1019  
  1020  		buf := []byte{0}
  1021  		deadline := time.Now().Add(10 * time.Second)
  1022  		conn.SetReadDeadline(deadline)
  1023  		n, err := conn.Read(buf)
  1024  		require.Zero(t, n)
  1025  
  1026  		// Soft-fail as following assertion helps with debugging
  1027  		assert.Equal(t, io.EOF, err)
  1028  
  1029  		// Assert existing connections are ok
  1030  		require.Len(t, errCh, 0)
  1031  
  1032  		// Cleanup
  1033  		for _, conn := range conns {
  1034  			conn.Close()
  1035  		}
  1036  		for range conns {
  1037  			err := <-errCh
  1038  			require.Contains(t, err.Error(), "use of closed network connection")
  1039  		}
  1040  	}
  1041  
  1042  	for i := range cases {
  1043  		tc := cases[i]
  1044  		name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
  1045  		t.Run(name, func(t *testing.T) {
  1046  			t.Parallel()
  1047  
  1048  			if tc.limit != nil && *tc.limit >= maxConns {
  1049  				t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", *tc.limit, maxConns)
  1050  			}
  1051  
  1052  			s := makeHTTPServer(t, func(c *Config) {
  1053  				if tc.tls {
  1054  					c.TLSConfig = &config.TLSConfig{
  1055  						EnableHTTP: true,
  1056  						CAFile:     cafile,
  1057  						CertFile:   foocert,
  1058  						KeyFile:    fookey,
  1059  					}
  1060  				}
  1061  				c.Limits.HTTPSHandshakeTimeout = tc.timeout
  1062  				c.Limits.HTTPMaxConnsPerClient = tc.limit
  1063  			})
  1064  			defer s.Shutdown()
  1065  
  1066  			assertTimeout(t, s, tc.assertTimeout, tc.timeout)
  1067  
  1068  			if tc.assertLimit {
  1069  				// There's a race between assertTimeout(false) closing
  1070  				// its connection and the HTTP server noticing and
  1071  				// untracking it. Since there's no way to coordiante
  1072  				// when this occurs, sleeping is the only way to avoid
  1073  				// asserting limits before the timed out connection is
  1074  				// untracked.
  1075  				time.Sleep(1 * time.Second)
  1076  
  1077  				assertLimit(t, s.Server.Addr, *tc.limit)
  1078  			} else {
  1079  				assertNoLimit(t, s.Server.Addr)
  1080  			}
  1081  		})
  1082  	}
  1083  }
  1084  
  1085  func Test_IsAPIClientError(t *testing.T) {
  1086  	trueCases := []int{400, 403, 404, 499}
  1087  	for _, c := range trueCases {
  1088  		require.Truef(t, isAPIClientError(c), "code: %v", c)
  1089  	}
  1090  
  1091  	falseCases := []int{100, 300, 500, 501, 505}
  1092  	for _, c := range falseCases {
  1093  		require.Falsef(t, isAPIClientError(c), "code: %v", c)
  1094  	}
  1095  }
  1096  
  1097  func httpTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) {
  1098  	s := makeHTTPServer(t, cb)
  1099  	defer s.Shutdown()
  1100  	testutil.WaitForLeader(t, s.Agent.RPC)
  1101  	f(s)
  1102  }
  1103  
  1104  func httpACLTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) {
  1105  	s := makeHTTPServer(t, func(c *Config) {
  1106  		c.ACL.Enabled = true
  1107  		if cb != nil {
  1108  			cb(c)
  1109  		}
  1110  	})
  1111  	defer s.Shutdown()
  1112  	testutil.WaitForLeader(t, s.Agent.RPC)
  1113  	f(s)
  1114  }
  1115  
  1116  func setToken(req *http.Request, token *structs.ACLToken) {
  1117  	req.Header.Set("X-Nomad-Token", token.SecretID)
  1118  }
  1119  
  1120  func encodeReq(obj interface{}) io.ReadCloser {
  1121  	buf := bytes.NewBuffer(nil)
  1122  	enc := json.NewEncoder(buf)
  1123  	enc.Encode(obj)
  1124  	return ioutil.NopCloser(buf)
  1125  }