github.com/freiheit-com/kuberpult@v1.24.2-0.20240328135542-315d5630abe6/pkg/auth/azure_test.go (about)

     1  /*This file is part of kuberpult.
     2  
     3  Kuberpult is free software: you can redistribute it and/or modify
     4  it under the terms of the Expat(MIT) License as published by
     5  the Free Software Foundation.
     6  
     7  Kuberpult is distributed in the hope that it will be useful,
     8  but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    10  MIT License for more details.
    11  
    12  You should have received a copy of the MIT License
    13  along with kuberpult. If not, see <https://directory.fsf.org/wiki/License:Expat>.
    14  
    15  Copyright 2023 freiheit.com*/
    16  
    17  package auth
    18  
    19  import (
    20  	"fmt"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/MicahParks/keyfunc/v2"
    29  	jwt "github.com/golang-jwt/jwt/v5"
    30  	"github.com/google/go-cmp/cmp"
    31  	"github.com/google/go-cmp/cmp/cmpopts"
    32  )
    33  
    34  // Used to compare two error message strings, needed because errors.Is(fmt.Errorf(text),fmt.Errorf(text)) == false
    35  type errMatcher struct {
    36  	msg string
    37  }
    38  
    39  func (e errMatcher) Error() string {
    40  	return e.msg
    41  }
    42  
    43  func (e errMatcher) Is(err error) bool {
    44  	return e.Error() == err.Error()
    45  }
    46  
    47  func TestValidateTokenStatic(t *testing.T) {
    48  	tcs := []struct {
    49  		Name          string
    50  		Token         string
    51  		ExpectedError error
    52  		noInit        bool
    53  	}{
    54  		{
    55  			Name:          "Not a token",
    56  			Token:         "asdf",
    57  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
    58  		},
    59  		{
    60  			Name:          "Not initialized",
    61  			Token:         "asdf",
    62  			noInit:        true,
    63  			ExpectedError: errMatcher{"JWKS not initialized."},
    64  		},
    65  		{
    66  			Name:          "Not a token 2",
    67  			Token:         "asdf.asdf.asdf",
    68  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: could not JSON decode header: invalid character 'j' looking for beginning of value"},
    69  		},
    70  		{
    71  			Name:          "Kid not present",
    72  			Token:         "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.WDlNbJFe8ZX6C1mS27xwxg-9tk8vtkk6sDgucRj8xW0",
    73  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is unverifiable: error while executing keyfunc: the JWT has an invalid kid: could not find kid in JWT header"},
    74  		},
    75  		{
    76  			Name:          "Kid not part of jwks",
    77  			Token:         "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImFzZGYifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.aNyAK8qpCScGchUmv1q1pBXOddWKN8_7agLUo7pXDog",
    78  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is unverifiable: error while executing keyfunc: the given key ID was not found in the JWKS"},
    79  		},
    80  	}
    81  
    82  	var jwks, err = JWKSInitAzureFromJson()
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  
    87  	for _, tc := range tcs {
    88  		tc := tc
    89  		t.Run(tc.Name, func(t *testing.T) {
    90  			t.Parallel()
    91  			testJWKS := jwks
    92  			if tc.noInit {
    93  				testJWKS = nil
    94  			}
    95  			_, err = ValidateToken(tc.Token, testJWKS, "clientId", "tenantId")
    96  			if diff := cmp.Diff(tc.ExpectedError, err, cmpopts.EquateErrors()); diff != "" {
    97  				t.Errorf("error mismatch (-want, +got):\n%s", diff)
    98  			}
    99  		})
   100  	}
   101  }
   102  
   103  func getToken(clientId string, tenantId string, kid string, expiry int64, name string, email string) (string, error) {
   104  	privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(`-----BEGIN RSA PRIVATE KEY-----
   105  MIICXQIBAAKBgQC/oyqURHIPNzx4vcKrUUZYr6Bxq2OSD44a63zeIDA1oZkR+sac
   106  tmkub+8NI49GqrbssWf944v3ZLp8KXMh6i+U9pkSdDfvKcQUProQ+Tlm/m0SFXa6
   107  h7vq6iVD1uawzN9aQaR7WiKV1TuPGUgE86/l+XTvLZ/MbKh0tz9j8JtY4QIDAQAB
   108  AoGBAICNeROq8oSIfjVUvlDkHXeCoPN/kDS74IzoaYQsPYrMk30/J5qatuYiyk6b
   109  CxLRlBIlU+g5i3vygzKlL4mRqkZuCM4xPbpuW9sdZp61TxWZk7Tm+SYBTStYSGkT
   110  tPmvnKsYWkUh1WDSkeLJqHkRbQXAZJkAKRMYgLu2F29fWOZBAkEA8P31nm/AiDiD
   111  dkGSGp4GVQ5BBry3XdP3c6rfzmW8sMElxqoj2watdia72+grf8eVo8vtsTiOrVUD
   112  ZoS5C5GKKQJBAMuSXXQZrBa4qB7YkGi5ysQRQZoegdYZa44q9L9oBE/iEl/ejR1l
   113  EKZi+v2greoIruqczGAD7VbEiwT50+npH/kCQQDJgpGvOaK0RQ0oBQw2VYzV8mVN
   114  TN/HBUcU4PzjiQ6OffMoe3wf2SWSdjD/YNN+tVTa8dp/Jdun9D4zqydQFRKBAkBV
   115  zlPl5AxNZ3g1yELWYbm9+ygTtlgzznMvcZvIMiffJANqtXv1r+vctkvlLB0iUJap
   116  /X2H2x/nOuD+L+/K4KDBAkAHcO3Gv7VZsSHfnd/JfDzxtL0MFWerGZyGlaNFmX27
   117  1dWRXvcS5A0zPMgiBWfvHFx2DpSiceffqnis+UryeE+L
   118  -----END RSA PRIVATE KEY-----`))
   119  	claims := jwt.MapClaims{}
   120  	if len(clientId) > 0 {
   121  		claims["aud"] = clientId
   122  	}
   123  	if len(tenantId) > 0 {
   124  		claims["tid"] = tenantId
   125  	}
   126  	if len(name) > 0 {
   127  		claims["name"] = name
   128  	}
   129  	if len(email) > 0 {
   130  		claims["email"] = email
   131  	}
   132  
   133  	claims["exp"] = expiry
   134  	jwtToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
   135  	jwtToken.Header["kid"] = kid
   136  	tokenString, err := jwtToken.SignedString(privateKey)
   137  	if err != nil {
   138  		return "", fmt.Errorf("Could not sign token %s", err.Error())
   139  	}
   140  	return tokenString, nil
   141  }
   142  
   143  func getJwks() (*keyfunc.JWKS, error) {
   144  	publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(`-----BEGIN PUBLIC KEY-----
   145  MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC/oyqURHIPNzx4vcKrUUZYr6Bx
   146  q2OSD44a63zeIDA1oZkR+sactmkub+8NI49GqrbssWf944v3ZLp8KXMh6i+U9pkS
   147  dDfvKcQUProQ+Tlm/m0SFXa6h7vq6iVD1uawzN9aQaR7WiKV1TuPGUgE86/l+XTv
   148  LZ/MbKh0tz9j8JtY4QIDAQAB
   149  -----END PUBLIC KEY-----`))
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	givenKey := keyfunc.NewGivenRSA(publicKey, keyfunc.GivenKeyOptions{})
   154  	keys := map[string]keyfunc.GivenKey{
   155  		"testKey": givenKey,
   156  	}
   157  	return keyfunc.NewGiven(keys), nil
   158  }
   159  
   160  func TestValidateTokenGenerated(t *testing.T) {
   161  	tcs := []struct {
   162  		Name          string
   163  		ClientId      string
   164  		TenantId      string
   165  		ExpectedError error
   166  		Expiry        int64
   167  		Kid           string
   168  	}{
   169  		{
   170  			Name:          "invalid client id",
   171  			ClientId:      "invalidClient",
   172  			TenantId:      "tenantId",
   173  			ExpectedError: errMatcher{"Unknown client id provided: invalidClient"},
   174  			Kid:           "testKey",
   175  		},
   176  		{
   177  			Name:          "No client id",
   178  			ClientId:      "",
   179  			TenantId:      "tenantId",
   180  			ExpectedError: errMatcher{"Client id not found in token."},
   181  			Kid:           "testKey",
   182  		},
   183  		{
   184  			Name:          "invalid tenant id",
   185  			ClientId:      "clientId",
   186  			TenantId:      "invalidTenant",
   187  			ExpectedError: errMatcher{"Unknown tenant id provided: invalidTenant"},
   188  			Kid:           "testKey",
   189  		},
   190  		{
   191  			Name:          "No tenant id",
   192  			ClientId:      "clientId",
   193  			TenantId:      "",
   194  			ExpectedError: errMatcher{"Tenant id not found in token."},
   195  			Kid:           "testKey",
   196  		},
   197  		{
   198  			Name:          "invalid  kid",
   199  			ClientId:      "clientId",
   200  			TenantId:      "tenantId",
   201  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is unverifiable: error while executing keyfunc: the given key ID was not found in the JWKS"},
   202  			Kid:           "tests",
   203  		},
   204  		{
   205  			Name:          "Expired key",
   206  			ClientId:      "clientId",
   207  			TenantId:      "tenantId",
   208  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token has invalid claims: token is expired"},
   209  			Expiry:        time.Now().Unix(),
   210  			Kid:           "testKey",
   211  		},
   212  		{
   213  			Name:     "valid key",
   214  			ClientId: "clientId",
   215  			TenantId: "tenantId",
   216  			Kid:      "testKey",
   217  		},
   218  	}
   219  
   220  	for _, tc := range tcs {
   221  		tc := tc
   222  		t.Run(tc.Name, func(t *testing.T) {
   223  			t.Parallel()
   224  			duration, err := time.ParseDuration("10m")
   225  			if err != nil {
   226  				t.Fatal(err)
   227  			}
   228  			expiry := time.Now().Add(duration).Unix()
   229  			if tc.Expiry != 0 {
   230  				expiry = tc.Expiry
   231  			}
   232  			tokenString, err := getToken(tc.ClientId, tc.TenantId, tc.Kid, expiry, "", "")
   233  			if err != nil {
   234  				t.Fatal(err)
   235  			}
   236  			jwks, err := getJwks()
   237  			if err != nil {
   238  				t.Fatal(err)
   239  			}
   240  			_, err = ValidateToken(tokenString, jwks, "clientId", "tenantId")
   241  			if diff := cmp.Diff(tc.ExpectedError, err, cmpopts.EquateErrors()); diff != "" {
   242  				t.Errorf("error mismatch (-want, +got):\n%s", diff)
   243  			}
   244  		})
   245  	}
   246  }
   247  
   248  func TestHttpMiddleware(t *testing.T) {
   249  	tcs := []struct {
   250  		Name          string
   251  		Path          string
   252  		Method        string
   253  		ExpectedError error
   254  		Authenticated bool
   255  	}{
   256  		{
   257  			Name:   "root path",
   258  			Path:   "/",
   259  			Method: http.MethodGet,
   260  		},
   261  		{
   262  			Name:   "js path",
   263  			Path:   "/static/js/content.js",
   264  			Method: http.MethodGet,
   265  		},
   266  		{
   267  			Name:   "css path",
   268  			Path:   "/static/css/content.css",
   269  			Method: http.MethodGet,
   270  		},
   271  		{
   272  			Name:          "api call - wrong url",
   273  			Path:          "/environment/production/locks/999",
   274  			Method:        http.MethodGet,
   275  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   276  			Authenticated: false,
   277  		},
   278  		{
   279  			Name:          "api call - wrong url path",
   280  			Path:          "/environment/production/releasetrainisawsome",
   281  			Method:        http.MethodGet,
   282  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   283  			Authenticated: false,
   284  		},
   285  		{
   286  			Name:          "api call rleasetrain",
   287  			Path:          "/environments/production/releasetrain",
   288  			Method:        http.MethodGet,
   289  			Authenticated: false,
   290  		},
   291  		{
   292  			Name:          "api call ",
   293  			Path:          "/environments/production/locks/999",
   294  			Method:        http.MethodGet,
   295  			Authenticated: false,
   296  		},
   297  		{
   298  			Name:          "api call create environment POST",
   299  			Path:          "/environments/dev",
   300  			Method:        http.MethodPost,
   301  			Authenticated: false,
   302  		},
   303  		{
   304  			Name:          "api call create environment GET",
   305  			Path:          "/environments/dev",
   306  			Method:        http.MethodGet,
   307  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   308  			Authenticated: false,
   309  		},
   310  		{
   311  			Name:          "api call create environment wrong url",
   312  			Path:          "/environments/dev/something",
   313  			Method:        http.MethodPost,
   314  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   315  			Authenticated: false,
   316  		},
   317  		{
   318  			Name:          "api call create environment another wrong url GET",
   319  			Path:          "/environments/something/dev",
   320  			Method:        http.MethodPost,
   321  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   322  			Authenticated: false,
   323  		},
   324  		{
   325  			Name:          "api call create environment another wrong url POST",
   326  			Path:          "/environments/something/dev",
   327  			Method:        http.MethodPost,
   328  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   329  			Authenticated: false,
   330  		},
   331  		{
   332  			Name:          "api call create environment - no env",
   333  			Path:          "/environments/",
   334  			Method:        http.MethodPost,
   335  			ExpectedError: errMatcher{"Failed to parse the JWT.\nError: token is malformed: token contains an invalid number of segments"},
   336  			Authenticated: false,
   337  		},
   338  	}
   339  
   340  	for _, tc := range tcs {
   341  		tc := tc
   342  		t.Run(tc.Name, func(t *testing.T) {
   343  			t.Parallel()
   344  			r := strings.NewReader("Test message incoming")
   345  			sr := io.Reader(r)
   346  			req, err := http.NewRequest(tc.Method, tc.Path, sr)
   347  			if err != nil {
   348  				t.Fatal(err)
   349  			}
   350  			duration, err := time.ParseDuration("10m")
   351  			if err != nil {
   352  				t.Fatal(err)
   353  			}
   354  			expiry := time.Now().Add(duration).Unix()
   355  			tokenString, err := getToken("clientId", "tenantId", "testKey", expiry, "testName", "test.email@com")
   356  			if err != nil {
   357  				t.Fatal(err)
   358  			}
   359  			jwks, err := getJwks()
   360  			if err != nil {
   361  				t.Fatal(err)
   362  			}
   363  
   364  			if tc.Authenticated {
   365  				req.Header.Set("Authorization", tokenString)
   366  			}
   367  			testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   368  				err := HttpAuthMiddleWare(w, r, jwks, "clientId", "tenantId", []string{"/"}, []string{"/static/js", "/static/css"})
   369  				if diff := cmp.Diff(tc.ExpectedError, err, cmpopts.EquateErrors()); diff != "" {
   370  					t.Errorf("error mismatch (-want, +got):\n%s", diff)
   371  				}
   372  				if tc.Authenticated {
   373  					username := req.Header.Get("username")
   374  					email := req.Header.Get("email")
   375  					if username != "testName" {
   376  						t.Fatalf("Expected username testName but got %q", username)
   377  					}
   378  					if email != "test.email@com" {
   379  						t.Fatalf("Expected email test.email@com but got %q", email)
   380  					}
   381  				}
   382  			})
   383  			rw := httptest.NewRecorder()
   384  			handler := testHandler
   385  			handler.ServeHTTP(rw, req)
   386  		})
   387  	}
   388  }
   389  
   390  func TestAllowBypassingAzureAuth(t *testing.T) {
   391  	tcs := []struct {
   392  		Name            string
   393  		allowedPaths    []string
   394  		requestUrlPath  string
   395  		requestMethod   string
   396  		allowedPrefixes []string
   397  		expectedResult  bool
   398  	}{
   399  		{
   400  			Name:            "Bugfix env group locks",
   401  			allowedPaths:    nil,
   402  			requestUrlPath:  "environment-groups/dev/locks/mylock123",
   403  			requestMethod:   "POST",
   404  			allowedPrefixes: nil,
   405  			expectedResult:  true,
   406  		},
   407  		{
   408  			Name:            "env locks",
   409  			allowedPaths:    nil,
   410  			requestUrlPath:  "environments/dev/locks/mylock123",
   411  			requestMethod:   "POST",
   412  			allowedPrefixes: nil,
   413  			expectedResult:  true,
   414  		},
   415  		{
   416  			Name:            "env rollout status",
   417  			allowedPaths:    nil,
   418  			requestUrlPath:  "environments/dev/rollout-status",
   419  			requestMethod:   "POST",
   420  			allowedPrefixes: nil,
   421  			expectedResult:  true,
   422  		},
   423  		{
   424  			Name:            "env group rollout status",
   425  			allowedPaths:    nil,
   426  			requestUrlPath:  "environment-groups/dev/rollout-status",
   427  			requestMethod:   "POST",
   428  			allowedPrefixes: nil,
   429  			expectedResult:  true,
   430  		},
   431  		{
   432  			Name:            "allowed path succeeds",
   433  			allowedPaths:    []string{"foo/bar"},
   434  			requestUrlPath:  "foo/bar",
   435  			requestMethod:   "POST",
   436  			allowedPrefixes: nil,
   437  			expectedResult:  true,
   438  		},
   439  		{
   440  			Name:            "allowed path fails",
   441  			allowedPaths:    []string{"bar/foo"},
   442  			requestUrlPath:  "foo/bar",
   443  			requestMethod:   "POST",
   444  			allowedPrefixes: nil,
   445  			expectedResult:  false,
   446  		},
   447  		{
   448  			Name:            "allowed prefix succeeds",
   449  			allowedPaths:    nil,
   450  			requestUrlPath:  "foo/bar",
   451  			requestMethod:   "POST",
   452  			allowedPrefixes: []string{"foo"},
   453  			expectedResult:  true,
   454  		},
   455  		{
   456  			Name:            "allowed prefix fails",
   457  			allowedPaths:    nil,
   458  			requestUrlPath:  "foo/bar",
   459  			requestMethod:   "POST",
   460  			allowedPrefixes: []string{"bar"},
   461  			expectedResult:  false,
   462  		},
   463  	}
   464  
   465  	for _, tc := range tcs {
   466  		tc := tc
   467  		t.Run(tc.Name, func(t *testing.T) {
   468  			t.Parallel()
   469  			actualResult := AllowBypassingAzureAuth(tc.allowedPaths, tc.requestUrlPath, tc.requestMethod, tc.allowedPrefixes)
   470  			if actualResult != tc.expectedResult {
   471  				t.Errorf("Expected %v but got %v", tc.expectedResult, actualResult)
   472  			}
   473  		})
   474  	}
   475  }