github.com/greenpau/go-authcrunch@v1.1.4/pkg/authz/authenticate_test.go (about) 1 // Copyright 2022 Paul Greenberg greenpau@outlook.com 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package authz 16 17 import ( 18 "bytes" 19 "crypto/tls" 20 "crypto/x509" 21 "encoding/json" 22 "fmt" 23 "github.com/greenpau/go-authcrunch/internal/tests" 24 "github.com/greenpau/go-authcrunch/internal/testutils" 25 "github.com/greenpau/go-authcrunch/pkg/acl" 26 "github.com/greenpau/go-authcrunch/pkg/requests" 27 logutil "github.com/greenpau/go-authcrunch/pkg/util/log" 28 "io/ioutil" 29 "net" 30 "net/http" 31 "net/http/cookiejar" 32 "net/url" 33 "strings" 34 "time" 35 36 "net/http/httptest" 37 "testing" 38 ) 39 40 type testRequest struct { 41 id string 42 roles []string 43 method string 44 path string 45 headers map[string]string 46 query map[string]string 47 contentType string 48 token string 49 } 50 51 func TestAuthenticate(t *testing.T) { 52 logger := logutil.NewLogger() 53 54 cfg := &PolicyConfig{ 55 Name: "mygatekeeper", 56 AuthURLPath: "/auth", 57 AccessListRules: []*acl.RuleConfiguration{ 58 { 59 Conditions: []string{ 60 "match roles authp/admin authp/user", 61 }, 62 Action: "allow stop", 63 }, 64 }, 65 cryptoRawConfigs: []string{"key verify " + testutils.GetSharedKey()}, 66 } 67 68 gatekeeper, err := NewGatekeeper(cfg, logger) 69 if err != nil { 70 t.Fatal(err) 71 } 72 73 var testcases = []struct { 74 name string 75 want map[string]interface{} 76 shouldErr bool 77 err error 78 disabled bool 79 req *testRequest 80 }{ 81 { 82 name: "admin accesses version with get", 83 req: &testRequest{ 84 roles: []string{"authp/admin"}, 85 method: "GET", 86 path: "/version", 87 }, 88 want: map[string]interface{}{ 89 "response": map[string]interface{}{ 90 "authorized": true, 91 }, 92 "status_code": 200, 93 "content_type": "text/plain; charset=utf-8", 94 }, 95 }, 96 } 97 98 // Initialize HTTP server. 99 ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 100 rr := requests.NewAuthorizationRequest() 101 err := gatekeeper.Authenticate(w, r, rr) 102 resp := make(map[string]interface{}) 103 if err != nil { 104 resp["error"] = err 105 } 106 resp["response"] = rr.Response 107 b, err := json.Marshal(resp) 108 if err != nil { 109 t.Fatalf("failed to marshal %T: %v", resp, err) 110 } 111 fmt.Fprintln(w, string(b)) 112 })) 113 defer ts.Close() 114 115 for _, tc := range testcases { 116 t.Run(tc.name, func(t *testing.T) { 117 got := make(map[string]interface{}) 118 if tc.req.method == "" { 119 tc.req.method = "GET" 120 } 121 if tc.disabled { 122 return 123 } 124 msgs := []string{fmt.Sprintf("test name: %s", tc.name)} 125 msgs = append(msgs, fmt.Sprintf("HTTP %s %s", tc.req.method, ts.URL+tc.req.path)) 126 127 client := buildClient(t, ts, tc.req) 128 if len(tc.req.roles) > 0 { 129 msgs = append(msgs, fmt.Sprintf("roles: %s", tc.req.roles)) 130 } 131 if tc.req.token != "" { 132 msgs = append(msgs, fmt.Sprintf("token: %s", tc.req.token)) 133 } 134 135 req := buildRequest(t, ts, tc.req) 136 137 resp, err := client.Do(req) 138 if tests.EvalErrWithLog(t, err, "response error", tc.shouldErr, tc.err, msgs) { 139 return 140 } 141 142 body, err := ioutil.ReadAll(resp.Body) 143 resp.Body.Close() 144 if err != nil { 145 t.Fatal(err) 146 } 147 148 got["status_code"] = resp.StatusCode 149 got["content_type"] = resp.Header.Get("Content-Type") 150 switch resp.Header.Get("Content-Type") { 151 case "image/png": 152 default: 153 msgs = append(msgs, fmt.Sprintf("response body: %s", body)) 154 } 155 156 switch { 157 case bytes.HasPrefix(body, []byte(`{`)): 158 var decodedResponse map[string]interface{} 159 json.Unmarshal(body, &decodedResponse) 160 for k, v := range decodedResponse { 161 got[k] = v 162 } 163 default: 164 t.Fatalf("detected non-JSON body: %s", strings.Join(msgs, "\n")) 165 } 166 tests.EvalObjectsWithLog(t, "response body", tc.want, got, msgs) 167 }) 168 } 169 } 170 171 func buildClient(t *testing.T, ts *httptest.Server, req *testRequest) http.Client { 172 cert, err := x509.ParseCertificate(ts.TLS.Certificates[0].Certificate[0]) 173 if err != nil { 174 t.Fatalf("failed extracting server certs: %v", err) 175 } 176 cp := x509.NewCertPool() 177 cp.AddCert(cert) 178 179 cj, err := cookiejar.New(nil) 180 if err != nil { 181 t.Fatalf("failed adding cookie jar: %v", err) 182 } 183 184 if len(req.roles) > 0 { 185 usr := testutils.NewTestUser() 186 usr.SetRolesClaim(req.roles) 187 188 ks := testutils.NewTestCryptoKeyStore() 189 if err := ks.SignToken("access_token", "HS512", usr); err != nil { 190 t.Fatalf("Failed to get JWT token for %v: %v", usr.AsMap(), err) 191 } 192 cookies := []*http.Cookie{ 193 &http.Cookie{Name: "access_token", Value: usr.Token}, 194 } 195 req.token = usr.Token 196 197 tsURL, _ := url.Parse(ts.URL) 198 cj.SetCookies(tsURL, cookies) 199 } 200 201 return http.Client{ 202 Jar: cj, 203 Timeout: time.Second * 10, 204 Transport: &http.Transport{ 205 Dial: (&net.Dialer{ 206 Timeout: 5 * time.Second, 207 }).Dial, 208 TLSHandshakeTimeout: 5 * time.Second, 209 TLSClientConfig: &tls.Config{ 210 RootCAs: cp, 211 }, 212 }, 213 CheckRedirect: func(r *http.Request, via []*http.Request) error { 214 // Do not follow redirects. 215 return http.ErrUseLastResponse 216 }, 217 } 218 } 219 220 func buildRequest(t *testing.T, ts *httptest.Server, req *testRequest) *http.Request { 221 r, err := http.NewRequest(req.method, ts.URL+req.path, nil) 222 if err != nil { 223 t.Fatal(err) 224 } 225 226 if len(req.headers) > 0 { 227 for k, v := range req.headers { 228 r.Header.Add(k, v) 229 } 230 } 231 232 if len(req.query) > 0 { 233 q := r.URL.Query() 234 for k, v := range req.query { 235 q.Set(k, v) 236 } 237 r.URL.RawQuery = q.Encode() 238 } 239 return r 240 }