github.com/diptanu/nomad@v0.5.7-0.20170516172507-d72e86cbe3d9/command/agent/http_test.go (about)

     1  package agent
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"os"
    16  	"strconv"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/hashicorp/nomad/nomad/mock"
    21  	"github.com/hashicorp/nomad/nomad/structs"
    22  	"github.com/hashicorp/nomad/nomad/structs/config"
    23  	"github.com/hashicorp/nomad/testutil"
    24  )
    25  
    26  type TestServer struct {
    27  	T      testing.TB
    28  	Dir    string
    29  	Agent  *Agent
    30  	Server *HTTPServer
    31  }
    32  
    33  func (s *TestServer) Cleanup() {
    34  	s.Server.Shutdown()
    35  	s.Agent.Shutdown()
    36  	os.RemoveAll(s.Dir)
    37  }
    38  
    39  // makeHTTPServer returns a test server whose logs will be written to
    40  // the passed writer. If the writer is nil, the logs are written to stderr.
    41  func makeHTTPServer(t testing.TB, cb func(c *Config)) *TestServer {
    42  	dir, agent := makeAgent(t, cb)
    43  	srv, err := NewHTTPServer(agent, agent.config)
    44  	if err != nil {
    45  		t.Fatalf("err: %v", err)
    46  	}
    47  	s := &TestServer{
    48  		T:      t,
    49  		Dir:    dir,
    50  		Agent:  agent,
    51  		Server: srv,
    52  	}
    53  	return s
    54  }
    55  
    56  func BenchmarkHTTPRequests(b *testing.B) {
    57  	s := makeHTTPServer(b, func(c *Config) {
    58  		c.Client.Enabled = false
    59  	})
    60  	defer s.Cleanup()
    61  
    62  	job := mock.Job()
    63  	var allocs []*structs.Allocation
    64  	count := 1000
    65  	for i := 0; i < count; i++ {
    66  		alloc := mock.Alloc()
    67  		alloc.Job = job
    68  		alloc.JobID = job.ID
    69  		alloc.Name = fmt.Sprintf("my-job.web[%d]", i)
    70  		allocs = append(allocs, alloc)
    71  	}
    72  
    73  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
    74  		return allocs[:count], nil
    75  	}
    76  	b.ResetTimer()
    77  
    78  	b.RunParallel(func(pb *testing.PB) {
    79  		for pb.Next() {
    80  			resp := httptest.NewRecorder()
    81  			req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
    82  			s.Server.wrap(handler)(resp, req)
    83  		}
    84  	})
    85  }
    86  
    87  func TestSetIndex(t *testing.T) {
    88  	resp := httptest.NewRecorder()
    89  	setIndex(resp, 1000)
    90  	header := resp.Header().Get("X-Nomad-Index")
    91  	if header != "1000" {
    92  		t.Fatalf("Bad: %v", header)
    93  	}
    94  	setIndex(resp, 2000)
    95  	if v := resp.Header()["X-Nomad-Index"]; len(v) != 1 {
    96  		t.Fatalf("bad: %#v", v)
    97  	}
    98  }
    99  
   100  func TestSetKnownLeader(t *testing.T) {
   101  	resp := httptest.NewRecorder()
   102  	setKnownLeader(resp, true)
   103  	header := resp.Header().Get("X-Nomad-KnownLeader")
   104  	if header != "true" {
   105  		t.Fatalf("Bad: %v", header)
   106  	}
   107  	resp = httptest.NewRecorder()
   108  	setKnownLeader(resp, false)
   109  	header = resp.Header().Get("X-Nomad-KnownLeader")
   110  	if header != "false" {
   111  		t.Fatalf("Bad: %v", header)
   112  	}
   113  }
   114  
   115  func TestSetLastContact(t *testing.T) {
   116  	resp := httptest.NewRecorder()
   117  	setLastContact(resp, 123456*time.Microsecond)
   118  	header := resp.Header().Get("X-Nomad-LastContact")
   119  	if header != "123" {
   120  		t.Fatalf("Bad: %v", header)
   121  	}
   122  }
   123  
   124  func TestSetMeta(t *testing.T) {
   125  	meta := structs.QueryMeta{
   126  		Index:       1000,
   127  		KnownLeader: true,
   128  		LastContact: 123456 * time.Microsecond,
   129  	}
   130  	resp := httptest.NewRecorder()
   131  	setMeta(resp, &meta)
   132  	header := resp.Header().Get("X-Nomad-Index")
   133  	if header != "1000" {
   134  		t.Fatalf("Bad: %v", header)
   135  	}
   136  	header = resp.Header().Get("X-Nomad-KnownLeader")
   137  	if header != "true" {
   138  		t.Fatalf("Bad: %v", header)
   139  	}
   140  	header = resp.Header().Get("X-Nomad-LastContact")
   141  	if header != "123" {
   142  		t.Fatalf("Bad: %v", header)
   143  	}
   144  }
   145  
   146  func TestSetHeaders(t *testing.T) {
   147  	s := makeHTTPServer(t, nil)
   148  	s.Agent.config.HTTPAPIResponseHeaders = map[string]string{"foo": "bar"}
   149  	defer s.Cleanup()
   150  
   151  	resp := httptest.NewRecorder()
   152  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   153  		return &structs.Job{Name: "foo"}, nil
   154  	}
   155  
   156  	req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   157  	s.Server.wrap(handler)(resp, req)
   158  	header := resp.Header().Get("foo")
   159  
   160  	if header != "bar" {
   161  		t.Fatalf("expected header: %v, actual: %v", "bar", header)
   162  	}
   163  
   164  }
   165  
   166  func TestContentTypeIsJSON(t *testing.T) {
   167  	s := makeHTTPServer(t, nil)
   168  	defer s.Cleanup()
   169  
   170  	resp := httptest.NewRecorder()
   171  
   172  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   173  		return &structs.Job{Name: "foo"}, nil
   174  	}
   175  
   176  	req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
   177  	s.Server.wrap(handler)(resp, req)
   178  
   179  	contentType := resp.Header().Get("Content-Type")
   180  
   181  	if contentType != "application/json" {
   182  		t.Fatalf("Content-Type header was not 'application/json'")
   183  	}
   184  }
   185  
   186  func TestPrettyPrint(t *testing.T) {
   187  	testPrettyPrint("pretty=1", true, t)
   188  }
   189  
   190  func TestPrettyPrintOff(t *testing.T) {
   191  	testPrettyPrint("pretty=0", false, t)
   192  }
   193  
   194  func TestPrettyPrintBare(t *testing.T) {
   195  	testPrettyPrint("pretty", true, t)
   196  }
   197  
   198  func testPrettyPrint(pretty string, prettyFmt bool, t *testing.T) {
   199  	s := makeHTTPServer(t, nil)
   200  	defer s.Cleanup()
   201  
   202  	r := &structs.Job{Name: "foo"}
   203  
   204  	resp := httptest.NewRecorder()
   205  	handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
   206  		return r, nil
   207  	}
   208  
   209  	urlStr := "/v1/job/foo?" + pretty
   210  	req, _ := http.NewRequest("GET", urlStr, nil)
   211  	s.Server.wrap(handler)(resp, req)
   212  
   213  	var expected []byte
   214  	if prettyFmt {
   215  		expected, _ = json.MarshalIndent(r, "", "    ")
   216  		expected = append(expected, "\n"...)
   217  	} else {
   218  		expected, _ = json.Marshal(r)
   219  	}
   220  	actual, err := ioutil.ReadAll(resp.Body)
   221  	if err != nil {
   222  		t.Fatalf("err: %s", err)
   223  	}
   224  
   225  	if !bytes.Equal(expected, actual) {
   226  		t.Fatalf("bad:\nexpected:\t%q\nactual:\t\t%q", string(expected), string(actual))
   227  	}
   228  }
   229  
   230  func TestParseWait(t *testing.T) {
   231  	resp := httptest.NewRecorder()
   232  	var b structs.QueryOptions
   233  
   234  	req, err := http.NewRequest("GET",
   235  		"/v1/catalog/nodes?wait=60s&index=1000", nil)
   236  	if err != nil {
   237  		t.Fatalf("err: %v", err)
   238  	}
   239  
   240  	if d := parseWait(resp, req, &b); d {
   241  		t.Fatalf("unexpected done")
   242  	}
   243  
   244  	if b.MinQueryIndex != 1000 {
   245  		t.Fatalf("Bad: %v", b)
   246  	}
   247  	if b.MaxQueryTime != 60*time.Second {
   248  		t.Fatalf("Bad: %v", b)
   249  	}
   250  }
   251  
   252  func TestParseWait_InvalidTime(t *testing.T) {
   253  	resp := httptest.NewRecorder()
   254  	var b structs.QueryOptions
   255  
   256  	req, err := http.NewRequest("GET",
   257  		"/v1/catalog/nodes?wait=60foo&index=1000", nil)
   258  	if err != nil {
   259  		t.Fatalf("err: %v", err)
   260  	}
   261  
   262  	if d := parseWait(resp, req, &b); !d {
   263  		t.Fatalf("expected done")
   264  	}
   265  
   266  	if resp.Code != 400 {
   267  		t.Fatalf("bad code: %v", resp.Code)
   268  	}
   269  }
   270  
   271  func TestParseWait_InvalidIndex(t *testing.T) {
   272  	resp := httptest.NewRecorder()
   273  	var b structs.QueryOptions
   274  
   275  	req, err := http.NewRequest("GET",
   276  		"/v1/catalog/nodes?wait=60s&index=foo", nil)
   277  	if err != nil {
   278  		t.Fatalf("err: %v", err)
   279  	}
   280  
   281  	if d := parseWait(resp, req, &b); !d {
   282  		t.Fatalf("expected done")
   283  	}
   284  
   285  	if resp.Code != 400 {
   286  		t.Fatalf("bad code: %v", resp.Code)
   287  	}
   288  }
   289  
   290  func TestParseConsistency(t *testing.T) {
   291  	var b structs.QueryOptions
   292  
   293  	req, err := http.NewRequest("GET",
   294  		"/v1/catalog/nodes?stale", nil)
   295  	if err != nil {
   296  		t.Fatalf("err: %v", err)
   297  	}
   298  
   299  	parseConsistency(req, &b)
   300  	if !b.AllowStale {
   301  		t.Fatalf("Bad: %v", b)
   302  	}
   303  
   304  	b = structs.QueryOptions{}
   305  	req, err = http.NewRequest("GET",
   306  		"/v1/catalog/nodes?consistent", nil)
   307  	if err != nil {
   308  		t.Fatalf("err: %v", err)
   309  	}
   310  
   311  	parseConsistency(req, &b)
   312  	if b.AllowStale {
   313  		t.Fatalf("Bad: %v", b)
   314  	}
   315  }
   316  
   317  func TestParseRegion(t *testing.T) {
   318  	s := makeHTTPServer(t, nil)
   319  	defer s.Cleanup()
   320  
   321  	req, err := http.NewRequest("GET",
   322  		"/v1/jobs?region=foo", nil)
   323  	if err != nil {
   324  		t.Fatalf("err: %v", err)
   325  	}
   326  
   327  	var region string
   328  	s.Server.parseRegion(req, &region)
   329  	if region != "foo" {
   330  		t.Fatalf("bad %s", region)
   331  	}
   332  
   333  	region = ""
   334  	req, err = http.NewRequest("GET", "/v1/jobs", nil)
   335  	if err != nil {
   336  		t.Fatalf("err: %v", err)
   337  	}
   338  
   339  	s.Server.parseRegion(req, &region)
   340  	if region != "global" {
   341  		t.Fatalf("bad %s", region)
   342  	}
   343  }
   344  
   345  // TestHTTP_VerifyHTTPSClient asserts that a client certificate signed by the
   346  // appropriate CA is required when VerifyHTTPSClient=true.
   347  func TestHTTP_VerifyHTTPSClient(t *testing.T) {
   348  	const (
   349  		cafile  = "../../helper/tlsutil/testdata/ca.pem"
   350  		foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
   351  		fookey  = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
   352  	)
   353  	s := makeHTTPServer(t, func(c *Config) {
   354  		c.Region = "foo" // match the region on foocert
   355  		c.TLSConfig = &config.TLSConfig{
   356  			EnableHTTP:        true,
   357  			VerifyHTTPSClient: true,
   358  			CAFile:            cafile,
   359  			CertFile:          foocert,
   360  			KeyFile:           fookey,
   361  		}
   362  	})
   363  	defer s.Cleanup()
   364  
   365  	reqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
   366  
   367  	// FAIL: Requests that expect 127.0.0.1 as the name should fail
   368  	resp, err := http.Get(reqURL)
   369  	if err == nil {
   370  		resp.Body.Close()
   371  		t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
   372  	}
   373  	urlErr, ok := err.(*url.Error)
   374  	if !ok {
   375  		t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
   376  	}
   377  	hostErr, ok := urlErr.Err.(x509.HostnameError)
   378  	if !ok {
   379  		t.Fatalf("expected a x509.HostnameError but received: %T -> %v", urlErr.Err, urlErr.Err)
   380  	}
   381  	if expected := "127.0.0.1"; hostErr.Host != expected {
   382  		t.Fatalf("expected hostname on error to be %q but found %q", expected, hostErr.Host)
   383  	}
   384  
   385  	// FAIL: Requests that specify a valid hostname but not the CA should
   386  	// fail
   387  	tlsConf := &tls.Config{
   388  		ServerName: "client.regionFoo.nomad",
   389  	}
   390  	transport := &http.Transport{TLSClientConfig: tlsConf}
   391  	client := &http.Client{Transport: transport}
   392  	req, err := http.NewRequest("GET", reqURL, nil)
   393  	if err != nil {
   394  		t.Fatalf("error creating request: %v", err)
   395  	}
   396  	resp, err = client.Do(req)
   397  	if err == nil {
   398  		resp.Body.Close()
   399  		t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
   400  	}
   401  	urlErr, ok = err.(*url.Error)
   402  	if !ok {
   403  		t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
   404  	}
   405  	_, ok = urlErr.Err.(x509.UnknownAuthorityError)
   406  	if !ok {
   407  		t.Fatalf("expected a x509.UnknownAuthorityError but received: %T -> %v", urlErr.Err, urlErr.Err)
   408  	}
   409  
   410  	// FAIL: Requests that specify a valid hostname and CA cert but lack a
   411  	// client certificate should fail
   412  	cacertBytes, err := ioutil.ReadFile(cafile)
   413  	if err != nil {
   414  		t.Fatalf("error reading cacert: %v", err)
   415  	}
   416  	tlsConf.RootCAs = x509.NewCertPool()
   417  	tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
   418  	req, err = http.NewRequest("GET", reqURL, nil)
   419  	if err != nil {
   420  		t.Fatalf("error creating request: %v", err)
   421  	}
   422  	resp, err = client.Do(req)
   423  	if err == nil {
   424  		resp.Body.Close()
   425  		t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
   426  	}
   427  	urlErr, ok = err.(*url.Error)
   428  	if !ok {
   429  		t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
   430  	}
   431  	opErr, ok := urlErr.Err.(*net.OpError)
   432  	if !ok {
   433  		t.Fatalf("expected a *net.OpErr but received: %T -> %v", urlErr.Err, urlErr.Err)
   434  	}
   435  	const badCertificate = "tls: bad certificate" // from crypto/tls/alert.go:52 and RFC 5246 ยง A.3
   436  	if opErr.Err.Error() != badCertificate {
   437  		t.Fatalf("expected tls.alert bad_certificate but received: %q", opErr.Err.Error())
   438  	}
   439  
   440  	// PASS: Requests that specify a valid hostname, CA cert, and client
   441  	// certificate succeed.
   442  	tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   443  		c, err := tls.LoadX509KeyPair(foocert, fookey)
   444  		if err != nil {
   445  			return nil, err
   446  		}
   447  		return &c, nil
   448  	}
   449  	transport = &http.Transport{TLSClientConfig: tlsConf}
   450  	client = &http.Client{Transport: transport}
   451  	req, err = http.NewRequest("GET", reqURL, nil)
   452  	if err != nil {
   453  		t.Fatalf("error creating request: %v", err)
   454  	}
   455  	resp, err = client.Do(req)
   456  	if err != nil {
   457  		t.Fatalf("unexpected error: %v", err)
   458  	}
   459  	resp.Body.Close()
   460  	if resp.StatusCode != 200 {
   461  		t.Fatalf("expected 200 status code but got: %d", resp.StatusCode)
   462  	}
   463  }
   464  
   465  // assertIndex tests that X-Nomad-Index is set and non-zero
   466  func assertIndex(t *testing.T, resp *httptest.ResponseRecorder) {
   467  	header := resp.Header().Get("X-Nomad-Index")
   468  	if header == "" || header == "0" {
   469  		t.Fatalf("Bad: %v", header)
   470  	}
   471  }
   472  
   473  // checkIndex is like assertIndex but returns an error
   474  func checkIndex(resp *httptest.ResponseRecorder) error {
   475  	header := resp.Header().Get("X-Nomad-Index")
   476  	if header == "" || header == "0" {
   477  		return fmt.Errorf("Bad: %v", header)
   478  	}
   479  	return nil
   480  }
   481  
   482  // getIndex parses X-Nomad-Index
   483  func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 {
   484  	header := resp.Header().Get("X-Nomad-Index")
   485  	if header == "" {
   486  		t.Fatalf("Bad: %v", header)
   487  	}
   488  	val, err := strconv.Atoi(header)
   489  	if err != nil {
   490  		t.Fatalf("Bad: %v", header)
   491  	}
   492  	return uint64(val)
   493  }
   494  
   495  func httpTest(t testing.TB, cb func(c *Config), f func(srv *TestServer)) {
   496  	s := makeHTTPServer(t, cb)
   497  	defer s.Cleanup()
   498  	testutil.WaitForLeader(t, s.Agent.RPC)
   499  	f(s)
   500  }
   501  
   502  func encodeReq(obj interface{}) io.ReadCloser {
   503  	buf := bytes.NewBuffer(nil)
   504  	enc := json.NewEncoder(buf)
   505  	enc.Encode(obj)
   506  	return ioutil.NopCloser(buf)
   507  }