github.com/freiheit-com/kuberpult@v1.24.2-0.20240328135542-315d5630abe6/pkg/auth/azure_test.go (about) 1 /*This file is part of kuberpult. 2 3 Kuberpult is free software: you can redistribute it and/or modify 4 it under the terms of the Expat(MIT) License as published by 5 the Free Software Foundation. 6 7 Kuberpult is distributed in the hope that it will be useful, 8 but WITHOUT ANY WARRANTY; without even the implied warranty of 9 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 MIT License for more details. 11 12 You should have received a copy of the MIT License 13 along with kuberpult. If not, see <https://directory.fsf.org/wiki/License:Expat>. 14 15 Copyright 2023 freiheit.com*/ 16 17 package auth 18 19 import ( 20 "fmt" 21 "io" 22 "net/http" 23 "net/http/httptest" 24 "strings" 25 "testing" 26 "time" 27 28 "github.com/MicahParks/keyfunc/v2" 29 jwt "github.com/golang-jwt/jwt/v5" 30 "github.com/google/go-cmp/cmp" 31 "github.com/google/go-cmp/cmp/cmpopts" 32 ) 33 34 // Used to compare two error message strings, needed because errors.Is(fmt.Errorf(text),fmt.Errorf(text)) == false 35 type errMatcher struct { 36 msg string 37 } 38 39 func (e errMatcher) Error() string { 40 return e.msg 41 } 42 43 func (e errMatcher) Is(err error) bool { 44 return e.Error() == err.Error() 45 } 46 47 func TestValidateTokenStatic(t *testing.T) { 48 tcs := []struct { 49 Name string 50 Token string 51 ExpectedError error 52 noInit bool 53 }{ 54 { 55 Name: "Not a token", 56 Token: "asdf", 57 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 58 }, 59 { 60 Name: "Not initialized", 61 Token: "asdf", 62 noInit: true, 63 ExpectedError: errMatcher{"JWKS not initialized."}, 64 }, 65 { 66 Name: "Not a token 2", 67 Token: "asdf.asdf.asdf", 68 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: could not JSON decode header: invalid character 'j' looking for beginning of value"}, 69 }, 70 { 71 Name: "Kid not present", 72 Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.WDlNbJFe8ZX6C1mS27xwxg-9tk8vtkk6sDgucRj8xW0", 73 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is unverifiable: error while executing keyfunc: the JWT has an invalid kid: could not find kid in JWT header"}, 74 }, 75 { 76 Name: "Kid not part of jwks", 77 Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImFzZGYifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.aNyAK8qpCScGchUmv1q1pBXOddWKN8_7agLUo7pXDog", 78 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is unverifiable: error while executing keyfunc: the given key ID was not found in the JWKS"}, 79 }, 80 } 81 82 var jwks, err = JWKSInitAzureFromJson() 83 if err != nil { 84 t.Fatal(err) 85 } 86 87 for _, tc := range tcs { 88 tc := tc 89 t.Run(tc.Name, func(t *testing.T) { 90 t.Parallel() 91 testJWKS := jwks 92 if tc.noInit { 93 testJWKS = nil 94 } 95 _, err = ValidateToken(tc.Token, testJWKS, "clientId", "tenantId") 96 if diff := cmp.Diff(tc.ExpectedError, err, cmpopts.EquateErrors()); diff != "" { 97 t.Errorf("error mismatch (-want, +got):\n%s", diff) 98 } 99 }) 100 } 101 } 102 103 func getToken(clientId string, tenantId string, kid string, expiry int64, name string, email string) (string, error) { 104 privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(`-----BEGIN RSA PRIVATE KEY----- 105 MIICXQIBAAKBgQC/oyqURHIPNzx4vcKrUUZYr6Bxq2OSD44a63zeIDA1oZkR+sac 106 tmkub+8NI49GqrbssWf944v3ZLp8KXMh6i+U9pkSdDfvKcQUProQ+Tlm/m0SFXa6 107 h7vq6iVD1uawzN9aQaR7WiKV1TuPGUgE86/l+XTvLZ/MbKh0tz9j8JtY4QIDAQAB 108 AoGBAICNeROq8oSIfjVUvlDkHXeCoPN/kDS74IzoaYQsPYrMk30/J5qatuYiyk6b 109 CxLRlBIlU+g5i3vygzKlL4mRqkZuCM4xPbpuW9sdZp61TxWZk7Tm+SYBTStYSGkT 110 tPmvnKsYWkUh1WDSkeLJqHkRbQXAZJkAKRMYgLu2F29fWOZBAkEA8P31nm/AiDiD 111 dkGSGp4GVQ5BBry3XdP3c6rfzmW8sMElxqoj2watdia72+grf8eVo8vtsTiOrVUD 112 ZoS5C5GKKQJBAMuSXXQZrBa4qB7YkGi5ysQRQZoegdYZa44q9L9oBE/iEl/ejR1l 113 EKZi+v2greoIruqczGAD7VbEiwT50+npH/kCQQDJgpGvOaK0RQ0oBQw2VYzV8mVN 114 TN/HBUcU4PzjiQ6OffMoe3wf2SWSdjD/YNN+tVTa8dp/Jdun9D4zqydQFRKBAkBV 115 zlPl5AxNZ3g1yELWYbm9+ygTtlgzznMvcZvIMiffJANqtXv1r+vctkvlLB0iUJap 116 /X2H2x/nOuD+L+/K4KDBAkAHcO3Gv7VZsSHfnd/JfDzxtL0MFWerGZyGlaNFmX27 117 1dWRXvcS5A0zPMgiBWfvHFx2DpSiceffqnis+UryeE+L 118 -----END RSA PRIVATE KEY-----`)) 119 claims := jwt.MapClaims{} 120 if len(clientId) > 0 { 121 claims["aud"] = clientId 122 } 123 if len(tenantId) > 0 { 124 claims["tid"] = tenantId 125 } 126 if len(name) > 0 { 127 claims["name"] = name 128 } 129 if len(email) > 0 { 130 claims["email"] = email 131 } 132 133 claims["exp"] = expiry 134 jwtToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) 135 jwtToken.Header["kid"] = kid 136 tokenString, err := jwtToken.SignedString(privateKey) 137 if err != nil { 138 return "", fmt.Errorf("Could not sign token %s", err.Error()) 139 } 140 return tokenString, nil 141 } 142 143 func getJwks() (*keyfunc.JWKS, error) { 144 publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(`-----BEGIN PUBLIC KEY----- 145 MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC/oyqURHIPNzx4vcKrUUZYr6Bx 146 q2OSD44a63zeIDA1oZkR+sactmkub+8NI49GqrbssWf944v3ZLp8KXMh6i+U9pkS 147 dDfvKcQUProQ+Tlm/m0SFXa6h7vq6iVD1uawzN9aQaR7WiKV1TuPGUgE86/l+XTv 148 LZ/MbKh0tz9j8JtY4QIDAQAB 149 -----END PUBLIC KEY-----`)) 150 if err != nil { 151 return nil, err 152 } 153 givenKey := keyfunc.NewGivenRSA(publicKey, keyfunc.GivenKeyOptions{}) 154 keys := map[string]keyfunc.GivenKey{ 155 "testKey": givenKey, 156 } 157 return keyfunc.NewGiven(keys), nil 158 } 159 160 func TestValidateTokenGenerated(t *testing.T) { 161 tcs := []struct { 162 Name string 163 ClientId string 164 TenantId string 165 ExpectedError error 166 Expiry int64 167 Kid string 168 }{ 169 { 170 Name: "invalid client id", 171 ClientId: "invalidClient", 172 TenantId: "tenantId", 173 ExpectedError: errMatcher{"Unknown client id provided: invalidClient"}, 174 Kid: "testKey", 175 }, 176 { 177 Name: "No client id", 178 ClientId: "", 179 TenantId: "tenantId", 180 ExpectedError: errMatcher{"Client id not found in token."}, 181 Kid: "testKey", 182 }, 183 { 184 Name: "invalid tenant id", 185 ClientId: "clientId", 186 TenantId: "invalidTenant", 187 ExpectedError: errMatcher{"Unknown tenant id provided: invalidTenant"}, 188 Kid: "testKey", 189 }, 190 { 191 Name: "No tenant id", 192 ClientId: "clientId", 193 TenantId: "", 194 ExpectedError: errMatcher{"Tenant id not found in token."}, 195 Kid: "testKey", 196 }, 197 { 198 Name: "invalid kid", 199 ClientId: "clientId", 200 TenantId: "tenantId", 201 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is unverifiable: error while executing keyfunc: the given key ID was not found in the JWKS"}, 202 Kid: "tests", 203 }, 204 { 205 Name: "Expired key", 206 ClientId: "clientId", 207 TenantId: "tenantId", 208 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token has invalid claims: token is expired"}, 209 Expiry: time.Now().Unix(), 210 Kid: "testKey", 211 }, 212 { 213 Name: "valid key", 214 ClientId: "clientId", 215 TenantId: "tenantId", 216 Kid: "testKey", 217 }, 218 } 219 220 for _, tc := range tcs { 221 tc := tc 222 t.Run(tc.Name, func(t *testing.T) { 223 t.Parallel() 224 duration, err := time.ParseDuration("10m") 225 if err != nil { 226 t.Fatal(err) 227 } 228 expiry := time.Now().Add(duration).Unix() 229 if tc.Expiry != 0 { 230 expiry = tc.Expiry 231 } 232 tokenString, err := getToken(tc.ClientId, tc.TenantId, tc.Kid, expiry, "", "") 233 if err != nil { 234 t.Fatal(err) 235 } 236 jwks, err := getJwks() 237 if err != nil { 238 t.Fatal(err) 239 } 240 _, err = ValidateToken(tokenString, jwks, "clientId", "tenantId") 241 if diff := cmp.Diff(tc.ExpectedError, err, cmpopts.EquateErrors()); diff != "" { 242 t.Errorf("error mismatch (-want, +got):\n%s", diff) 243 } 244 }) 245 } 246 } 247 248 func TestHttpMiddleware(t *testing.T) { 249 tcs := []struct { 250 Name string 251 Path string 252 Method string 253 ExpectedError error 254 Authenticated bool 255 }{ 256 { 257 Name: "root path", 258 Path: "/", 259 Method: http.MethodGet, 260 }, 261 { 262 Name: "js path", 263 Path: "/static/js/content.js", 264 Method: http.MethodGet, 265 }, 266 { 267 Name: "css path", 268 Path: "/static/css/content.css", 269 Method: http.MethodGet, 270 }, 271 { 272 Name: "api call - wrong url", 273 Path: "/environment/production/locks/999", 274 Method: http.MethodGet, 275 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 276 Authenticated: false, 277 }, 278 { 279 Name: "api call - wrong url path", 280 Path: "/environment/production/releasetrainisawsome", 281 Method: http.MethodGet, 282 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 283 Authenticated: false, 284 }, 285 { 286 Name: "api call rleasetrain", 287 Path: "/environments/production/releasetrain", 288 Method: http.MethodGet, 289 Authenticated: false, 290 }, 291 { 292 Name: "api call ", 293 Path: "/environments/production/locks/999", 294 Method: http.MethodGet, 295 Authenticated: false, 296 }, 297 { 298 Name: "api call create environment POST", 299 Path: "/environments/dev", 300 Method: http.MethodPost, 301 Authenticated: false, 302 }, 303 { 304 Name: "api call create environment GET", 305 Path: "/environments/dev", 306 Method: http.MethodGet, 307 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 308 Authenticated: false, 309 }, 310 { 311 Name: "api call create environment wrong url", 312 Path: "/environments/dev/something", 313 Method: http.MethodPost, 314 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 315 Authenticated: false, 316 }, 317 { 318 Name: "api call create environment another wrong url GET", 319 Path: "/environments/something/dev", 320 Method: http.MethodPost, 321 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 322 Authenticated: false, 323 }, 324 { 325 Name: "api call create environment another wrong url POST", 326 Path: "/environments/something/dev", 327 Method: http.MethodPost, 328 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 329 Authenticated: false, 330 }, 331 { 332 Name: "api call create environment - no env", 333 Path: "/environments/", 334 Method: http.MethodPost, 335 ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"}, 336 Authenticated: false, 337 }, 338 } 339 340 for _, tc := range tcs { 341 tc := tc 342 t.Run(tc.Name, func(t *testing.T) { 343 t.Parallel() 344 r := strings.NewReader("Test message incoming") 345 sr := io.Reader(r) 346 req, err := http.NewRequest(tc.Method, tc.Path, sr) 347 if err != nil { 348 t.Fatal(err) 349 } 350 duration, err := time.ParseDuration("10m") 351 if err != nil { 352 t.Fatal(err) 353 } 354 expiry := time.Now().Add(duration).Unix() 355 tokenString, err := getToken("clientId", "tenantId", "testKey", expiry, "testName", "test.email@com") 356 if err != nil { 357 t.Fatal(err) 358 } 359 jwks, err := getJwks() 360 if err != nil { 361 t.Fatal(err) 362 } 363 364 if tc.Authenticated { 365 req.Header.Set("Authorization", tokenString) 366 } 367 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 368 err := HttpAuthMiddleWare(w, r, jwks, "clientId", "tenantId", []string{"/"}, []string{"/static/js", "/static/css"}) 369 if diff := cmp.Diff(tc.ExpectedError, err, cmpopts.EquateErrors()); diff != "" { 370 t.Errorf("error mismatch (-want, +got):\n%s", diff) 371 } 372 if tc.Authenticated { 373 username := req.Header.Get("username") 374 email := req.Header.Get("email") 375 if username != "testName" { 376 t.Fatalf("Expected username testName but got %q", username) 377 } 378 if email != "test.email@com" { 379 t.Fatalf("Expected email test.email@com but got %q", email) 380 } 381 } 382 }) 383 rw := httptest.NewRecorder() 384 handler := testHandler 385 handler.ServeHTTP(rw, req) 386 }) 387 } 388 } 389 390 func TestAllowBypassingAzureAuth(t *testing.T) { 391 tcs := []struct { 392 Name string 393 allowedPaths []string 394 requestUrlPath string 395 requestMethod string 396 allowedPrefixes []string 397 expectedResult bool 398 }{ 399 { 400 Name: "Bugfix env group locks", 401 allowedPaths: nil, 402 requestUrlPath: "environment-groups/dev/locks/mylock123", 403 requestMethod: "POST", 404 allowedPrefixes: nil, 405 expectedResult: true, 406 }, 407 { 408 Name: "env locks", 409 allowedPaths: nil, 410 requestUrlPath: "environments/dev/locks/mylock123", 411 requestMethod: "POST", 412 allowedPrefixes: nil, 413 expectedResult: true, 414 }, 415 { 416 Name: "env rollout status", 417 allowedPaths: nil, 418 requestUrlPath: "environments/dev/rollout-status", 419 requestMethod: "POST", 420 allowedPrefixes: nil, 421 expectedResult: true, 422 }, 423 { 424 Name: "env group rollout status", 425 allowedPaths: nil, 426 requestUrlPath: "environment-groups/dev/rollout-status", 427 requestMethod: "POST", 428 allowedPrefixes: nil, 429 expectedResult: true, 430 }, 431 { 432 Name: "allowed path succeeds", 433 allowedPaths: []string{"foo/bar"}, 434 requestUrlPath: "foo/bar", 435 requestMethod: "POST", 436 allowedPrefixes: nil, 437 expectedResult: true, 438 }, 439 { 440 Name: "allowed path fails", 441 allowedPaths: []string{"bar/foo"}, 442 requestUrlPath: "foo/bar", 443 requestMethod: "POST", 444 allowedPrefixes: nil, 445 expectedResult: false, 446 }, 447 { 448 Name: "allowed prefix succeeds", 449 allowedPaths: nil, 450 requestUrlPath: "foo/bar", 451 requestMethod: "POST", 452 allowedPrefixes: []string{"foo"}, 453 expectedResult: true, 454 }, 455 { 456 Name: "allowed prefix fails", 457 allowedPaths: nil, 458 requestUrlPath: "foo/bar", 459 requestMethod: "POST", 460 allowedPrefixes: []string{"bar"}, 461 expectedResult: false, 462 }, 463 } 464 465 for _, tc := range tcs { 466 tc := tc 467 t.Run(tc.Name, func(t *testing.T) { 468 t.Parallel() 469 actualResult := AllowBypassingAzureAuth(tc.allowedPaths, tc.requestUrlPath, tc.requestMethod, tc.allowedPrefixes) 470 if actualResult != tc.expectedResult { 471 t.Errorf("Expected %v but got %v", tc.expectedResult, actualResult) 472 } 473 }) 474 } 475 }