github.imxd.top/hashicorp/consul@v1.4.5/agent/http_oss_test.go (about)

     1  package agent
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/hashicorp/consul/testrpc"
    15  
    16  	"github.com/hashicorp/consul/logger"
    17  )
    18  
    19  // extra endpoints that should be tested, and their allowed methods
    20  var extraTestEndpoints = map[string][]string{
    21  	"/v1/query":             []string{"GET", "POST"},
    22  	"/v1/query/":            []string{"GET", "PUT", "DELETE"},
    23  	"/v1/query/xxx/execute": []string{"GET"},
    24  	"/v1/query/xxx/explain": []string{"GET"},
    25  }
    26  
    27  // These endpoints are ignored in unit testing for response codes
    28  var ignoredEndpoints = []string{"/v1/status/peers", "/v1/agent/monitor", "/v1/agent/reload"}
    29  
    30  // These have custom logic
    31  var customEndpoints = []string{"/v1/query", "/v1/query/"}
    32  
    33  // includePathInTest returns whether this path should be ignored for the purpose of testing its response code
    34  func includePathInTest(path string) bool {
    35  	ignored := false
    36  	for _, p := range ignoredEndpoints {
    37  		if p == path {
    38  			ignored = true
    39  			break
    40  		}
    41  	}
    42  	for _, p := range customEndpoints {
    43  		if p == path {
    44  			ignored = true
    45  			break
    46  		}
    47  	}
    48  
    49  	return !ignored
    50  }
    51  
    52  func newHttpClient(timeout time.Duration) *http.Client {
    53  	return &http.Client{
    54  		Timeout: timeout,
    55  		Transport: &http.Transport{
    56  			Dial: (&net.Dialer{
    57  				Timeout: timeout,
    58  			}).Dial,
    59  			TLSHandshakeTimeout: timeout,
    60  		},
    61  	}
    62  }
    63  
    64  func TestHTTPAPI_MethodNotAllowed_OSS(t *testing.T) {
    65  	// To avoid actually triggering RPCs that are allowed, lock everything down
    66  	// with default-deny ACLs. This drops the test runtime from 11s to 0.6s.
    67  	a := NewTestAgent(t, t.Name(), `
    68  	primary_datacenter = "dc1"
    69  	acl {
    70  		enabled        = true
    71  		default_policy = "deny"
    72  		tokens {
    73  			master  = "sekrit"
    74  			agent   = "sekrit"
    75  		}
    76  	}
    77  	`)
    78  	a.Agent.LogWriter = logger.NewLogWriter(512)
    79  	defer a.Shutdown()
    80  	// Use the master token here so the wait actually works.
    81  	testrpc.WaitForTestAgent(t, a.RPC, "dc1", testrpc.WithToken("sekrit"))
    82  
    83  	all := []string{"GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"}
    84  
    85  	client := newHttpClient(15 * time.Second)
    86  
    87  	testMethodNotAllowed := func(t *testing.T, method string, path string, allowedMethods []string) {
    88  		t.Run(method+" "+path, func(t *testing.T) {
    89  			uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
    90  			req, _ := http.NewRequest(method, uri, nil)
    91  			resp, err := client.Do(req)
    92  			if err != nil {
    93  				t.Fatal("client.Do failed: ", err)
    94  			}
    95  			defer resp.Body.Close()
    96  
    97  			allowed := method == "OPTIONS"
    98  			for _, allowedMethod := range allowedMethods {
    99  				if allowedMethod == method {
   100  					allowed = true
   101  					break
   102  				}
   103  			}
   104  
   105  			if allowed && resp.StatusCode == http.StatusMethodNotAllowed {
   106  				t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode)
   107  			}
   108  			if !allowed && resp.StatusCode != http.StatusMethodNotAllowed {
   109  				t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed)
   110  			}
   111  		})
   112  	}
   113  
   114  	for path, methods := range extraTestEndpoints {
   115  		for _, method := range all {
   116  			testMethodNotAllowed(t, method, path, methods)
   117  		}
   118  	}
   119  
   120  	for path, methods := range allowedMethods {
   121  		if includePathInTest(path) {
   122  			for _, method := range all {
   123  				testMethodNotAllowed(t, method, path, methods)
   124  			}
   125  		}
   126  	}
   127  }
   128  
   129  func TestHTTPAPI_OptionMethod_OSS(t *testing.T) {
   130  	a := NewTestAgent(t, t.Name(), `acl_datacenter = "dc1"`)
   131  	a.Agent.LogWriter = logger.NewLogWriter(512)
   132  	defer a.Shutdown()
   133  	testrpc.WaitForTestAgent(t, a.RPC, "dc1")
   134  
   135  	testOptionMethod := func(path string, methods []string) {
   136  		t.Run("OPTIONS "+path, func(t *testing.T) {
   137  			uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
   138  			req, _ := http.NewRequest("OPTIONS", uri, nil)
   139  			resp := httptest.NewRecorder()
   140  			a.srv.Handler.ServeHTTP(resp, req)
   141  			allMethods := append([]string{"OPTIONS"}, methods...)
   142  
   143  			if resp.Code != http.StatusOK {
   144  				t.Fatalf("options request: got status code %d want %d", resp.Code, http.StatusOK)
   145  			}
   146  
   147  			optionsStr := resp.Header().Get("Allow")
   148  			if optionsStr == "" {
   149  				t.Fatalf("options request: got empty 'Allow' header")
   150  			} else if optionsStr != strings.Join(allMethods, ",") {
   151  				t.Fatalf("options request: got 'Allow' header value of %s want %s", optionsStr, allMethods)
   152  			}
   153  		})
   154  	}
   155  
   156  	for path, methods := range extraTestEndpoints {
   157  		testOptionMethod(path, methods)
   158  	}
   159  	for path, methods := range allowedMethods {
   160  		if includePathInTest(path) {
   161  			testOptionMethod(path, methods)
   162  		}
   163  	}
   164  }
   165  
   166  func TestHTTPAPI_AllowedNets_OSS(t *testing.T) {
   167  	a := NewTestAgent(t, t.Name(), `
   168  		acl_datacenter = "dc1"
   169  		http_config {
   170  			allow_write_http_from = ["127.0.0.1/8"]
   171  		}
   172  	`)
   173  	a.Agent.LogWriter = logger.NewLogWriter(512)
   174  	defer a.Shutdown()
   175  	testrpc.WaitForTestAgent(t, a.RPC, "dc1")
   176  
   177  	testOptionMethod := func(path string, method string) {
   178  		t.Run(method+" "+path, func(t *testing.T) {
   179  			uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
   180  			req, _ := http.NewRequest(method, uri, nil)
   181  			req.RemoteAddr = "192.168.1.2:5555"
   182  			resp := httptest.NewRecorder()
   183  			a.srv.Handler.ServeHTTP(resp, req)
   184  
   185  			require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
   186  		})
   187  	}
   188  
   189  	for path, methods := range extraTestEndpoints {
   190  		if !includePathInTest(path) {
   191  			continue
   192  		}
   193  		for _, method := range methods {
   194  			if method == http.MethodGet {
   195  				continue
   196  			}
   197  
   198  			testOptionMethod(path, method)
   199  		}
   200  	}
   201  	for path, methods := range allowedMethods {
   202  		if !includePathInTest(path) {
   203  			continue
   204  		}
   205  		for _, method := range methods {
   206  			if method == http.MethodGet {
   207  				continue
   208  			}
   209  
   210  			testOptionMethod(path, method)
   211  		}
   212  	}
   213  }