github.com/sl1pm4t/consul@v1.4.5-0.20190325224627-74c31c540f9c/agent/http_oss_test.go (about) 1 package agent 2 3 import ( 4 "fmt" 5 "net" 6 "net/http" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 "time" 11 12 "github.com/stretchr/testify/require" 13 14 "github.com/hashicorp/consul/testrpc" 15 16 "github.com/hashicorp/consul/logger" 17 ) 18 19 // extra endpoints that should be tested, and their allowed methods 20 var extraTestEndpoints = map[string][]string{ 21 "/v1/query": []string{"GET", "POST"}, 22 "/v1/query/": []string{"GET", "PUT", "DELETE"}, 23 "/v1/query/xxx/execute": []string{"GET"}, 24 "/v1/query/xxx/explain": []string{"GET"}, 25 } 26 27 // These endpoints are ignored in unit testing for response codes 28 var ignoredEndpoints = []string{"/v1/status/peers", "/v1/agent/monitor", "/v1/agent/reload"} 29 30 // These have custom logic 31 var customEndpoints = []string{"/v1/query", "/v1/query/"} 32 33 // includePathInTest returns whether this path should be ignored for the purpose of testing its response code 34 func includePathInTest(path string) bool { 35 ignored := false 36 for _, p := range ignoredEndpoints { 37 if p == path { 38 ignored = true 39 break 40 } 41 } 42 for _, p := range customEndpoints { 43 if p == path { 44 ignored = true 45 break 46 } 47 } 48 49 return !ignored 50 } 51 52 func newHttpClient(timeout time.Duration) *http.Client { 53 return &http.Client{ 54 Timeout: timeout, 55 Transport: &http.Transport{ 56 Dial: (&net.Dialer{ 57 Timeout: timeout, 58 }).Dial, 59 TLSHandshakeTimeout: timeout, 60 }, 61 } 62 } 63 64 func TestHTTPAPI_MethodNotAllowed_OSS(t *testing.T) { 65 // To avoid actually triggering RPCs that are allowed, lock everything down 66 // with default-deny ACLs. This drops the test runtime from 11s to 0.6s. 67 a := NewTestAgent(t, t.Name(), ` 68 primary_datacenter = "dc1" 69 acl { 70 enabled = true 71 default_policy = "deny" 72 tokens { 73 master = "sekrit" 74 agent = "sekrit" 75 } 76 } 77 `) 78 a.Agent.LogWriter = logger.NewLogWriter(512) 79 defer a.Shutdown() 80 // Use the master token here so the wait actually works. 81 testrpc.WaitForTestAgent(t, a.RPC, "dc1", testrpc.WithToken("sekrit")) 82 83 all := []string{"GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"} 84 85 client := newHttpClient(15 * time.Second) 86 87 testMethodNotAllowed := func(t *testing.T, method string, path string, allowedMethods []string) { 88 t.Run(method+" "+path, func(t *testing.T) { 89 uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) 90 req, _ := http.NewRequest(method, uri, nil) 91 resp, err := client.Do(req) 92 if err != nil { 93 t.Fatal("client.Do failed: ", err) 94 } 95 defer resp.Body.Close() 96 97 allowed := method == "OPTIONS" 98 for _, allowedMethod := range allowedMethods { 99 if allowedMethod == method { 100 allowed = true 101 break 102 } 103 } 104 105 if allowed && resp.StatusCode == http.StatusMethodNotAllowed { 106 t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode) 107 } 108 if !allowed && resp.StatusCode != http.StatusMethodNotAllowed { 109 t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) 110 } 111 }) 112 } 113 114 for path, methods := range extraTestEndpoints { 115 for _, method := range all { 116 testMethodNotAllowed(t, method, path, methods) 117 } 118 } 119 120 for path, methods := range allowedMethods { 121 if includePathInTest(path) { 122 for _, method := range all { 123 testMethodNotAllowed(t, method, path, methods) 124 } 125 } 126 } 127 } 128 129 func TestHTTPAPI_OptionMethod_OSS(t *testing.T) { 130 a := NewTestAgent(t, t.Name(), `acl_datacenter = "dc1"`) 131 a.Agent.LogWriter = logger.NewLogWriter(512) 132 defer a.Shutdown() 133 testrpc.WaitForTestAgent(t, a.RPC, "dc1") 134 135 testOptionMethod := func(path string, methods []string) { 136 t.Run("OPTIONS "+path, func(t *testing.T) { 137 uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) 138 req, _ := http.NewRequest("OPTIONS", uri, nil) 139 resp := httptest.NewRecorder() 140 a.srv.Handler.ServeHTTP(resp, req) 141 allMethods := append([]string{"OPTIONS"}, methods...) 142 143 if resp.Code != http.StatusOK { 144 t.Fatalf("options request: got status code %d want %d", resp.Code, http.StatusOK) 145 } 146 147 optionsStr := resp.Header().Get("Allow") 148 if optionsStr == "" { 149 t.Fatalf("options request: got empty 'Allow' header") 150 } else if optionsStr != strings.Join(allMethods, ",") { 151 t.Fatalf("options request: got 'Allow' header value of %s want %s", optionsStr, allMethods) 152 } 153 }) 154 } 155 156 for path, methods := range extraTestEndpoints { 157 testOptionMethod(path, methods) 158 } 159 for path, methods := range allowedMethods { 160 if includePathInTest(path) { 161 testOptionMethod(path, methods) 162 } 163 } 164 } 165 166 func TestHTTPAPI_AllowedNets_OSS(t *testing.T) { 167 a := NewTestAgent(t, t.Name(), ` 168 acl_datacenter = "dc1" 169 http_config { 170 allow_write_http_from = ["127.0.0.1/8"] 171 } 172 `) 173 a.Agent.LogWriter = logger.NewLogWriter(512) 174 defer a.Shutdown() 175 testrpc.WaitForTestAgent(t, a.RPC, "dc1") 176 177 testOptionMethod := func(path string, method string) { 178 t.Run(method+" "+path, func(t *testing.T) { 179 uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) 180 req, _ := http.NewRequest(method, uri, nil) 181 req.RemoteAddr = "192.168.1.2:5555" 182 resp := httptest.NewRecorder() 183 a.srv.Handler.ServeHTTP(resp, req) 184 185 require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path) 186 }) 187 } 188 189 for path, methods := range extraTestEndpoints { 190 if !includePathInTest(path) { 191 continue 192 } 193 for _, method := range methods { 194 if method == http.MethodGet { 195 continue 196 } 197 198 testOptionMethod(path, method) 199 } 200 } 201 for path, methods := range allowedMethods { 202 if !includePathInTest(path) { 203 continue 204 } 205 for _, method := range methods { 206 if method == http.MethodGet { 207 continue 208 } 209 210 testOptionMethod(path, method) 211 } 212 } 213 }