github.com/freiheit-com/kuberpult@v1.24.2-0.20240328135542-315d5630abe6/pkg/auth/dex_test.go (about)

     1  /*This file is part of kuberpult.
     2  
     3  Kuberpult is free software: you can redistribute it and/or modify
     4  it under the terms of the Expat(MIT) License as published by
     5  the Free Software Foundation.
     6  
     7  Kuberpult is distributed in the hope that it will be useful,
     8  but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    10  MIT License for more details.
    11  
    12  You should have received a copy of the MIT License
    13  along with kuberpult. If not, see <https://directory.fsf.org/wiki/License:Expat>.
    14  
    15  Copyright 2023 freiheit.com*/
    16  
    17  package auth
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"crypto/rsa"
    24  	"crypto/tls"
    25  	"encoding/json"
    26  	"fmt"
    27  	"io"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"net/url"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/coreos/go-oidc/v3/oidc"
    35  	jwtV5 "github.com/golang-jwt/jwt/v5"
    36  	"github.com/google/go-cmp/cmp"
    37  	"github.com/google/go-cmp/cmp/cmpopts"
    38  	"github.com/lestrrat-go/jwx/v2/jwa"
    39  	"github.com/lestrrat-go/jwx/v2/jwk"
    40  	"github.com/lestrrat-go/jwx/v2/jwt"
    41  )
    42  
    43  func TestNewDexAppClient(t *testing.T) {
    44  	DEX_URL, _ := url.Parse(dexServiceURL)
    45  	testCases := []struct {
    46  		Name          string
    47  		clientID      string
    48  		clientSecret  string
    49  		baseURL       string
    50  		scopes        []string
    51  		wantErr       bool
    52  		wantClientApp *DexAppClient
    53  	}{
    54  		{
    55  			Name:         "Creates the a new Dex App Client as expected",
    56  			clientID:     "test-client",
    57  			clientSecret: "test-secret",
    58  			baseURL:      "www.test-url.com",
    59  			scopes:       []string{"scope1", "scope2"},
    60  			wantErr:      false,
    61  			wantClientApp: &DexAppClient{
    62  				ClientID:     "test-client",
    63  				ClientSecret: "test-secret",
    64  				RedirectURI:  "www.test-url.com/callback",
    65  				IssuerURL:    "www.test-url.com/dex",
    66  				BaseURL:      "www.test-url.com",
    67  				Scopes:       []string{"scope1", "scope2"},
    68  				Client: &http.Client{
    69  					Transport: DexRewriteURLRoundTripper{
    70  						DexURL: DEX_URL,
    71  						T:      http.DefaultTransport,
    72  					},
    73  				},
    74  			},
    75  		},
    76  	}
    77  	for _, tc := range testCases {
    78  		t.Run(tc.Name, func(t *testing.T) {
    79  			a, err := NewDexAppClient(tc.clientID, tc.clientSecret, tc.baseURL, tc.scopes)
    80  			if (err != nil) != tc.wantErr {
    81  				t.Errorf("creating new dex client error = %v, wantErr %v", err, tc.wantErr)
    82  			}
    83  			if diff := cmp.Diff(a, tc.wantClientApp, cmpopts.IgnoreFields(DexRewriteURLRoundTripper{}, "T")); diff != "" {
    84  				t.Errorf("got %v, want %v, diff (-want +got) %s", a, tc.wantClientApp, diff)
    85  			}
    86  		})
    87  	}
    88  }
    89  
    90  func TestNewDexReverseProxy(t *testing.T) {
    91  	testCases := []struct {
    92  		Name           string
    93  		mockDexServer  *httptest.Server
    94  		wantStatusCode int
    95  	}{
    96  		{
    97  			Name:           "Dex reverse proxy is working as expected on success",
    98  			mockDexServer:  makeNewMockServer(http.StatusOK),
    99  			wantStatusCode: http.StatusOK,
   100  		},
   101  		{
   102  			Name:           "Dex reverse proxy is working as expected on error",
   103  			mockDexServer:  makeNewMockServer(http.StatusInternalServerError),
   104  			wantStatusCode: http.StatusInternalServerError,
   105  		},
   106  	}
   107  	for _, tc := range testCases {
   108  		t.Run(tc.Name, func(t *testing.T) {
   109  			// mock Dex server the app is being redirected to.
   110  			mockDexServer := tc.mockDexServer
   111  			defer mockDexServer.Close()
   112  			server := httptest.NewServer(http.HandlerFunc(NewDexReverseProxy(mockDexServer.URL)))
   113  			defer server.Close()
   114  			resp, err := http.Get(server.URL)
   115  			if err != nil {
   116  				t.Errorf("could not create HTTP request: %s", err)
   117  			}
   118  			if diff := cmp.Diff(resp.StatusCode, tc.wantStatusCode); diff != "" {
   119  				t.Errorf("got %v, want %v, diff (-want +got) %s", resp.StatusCode, tc.wantStatusCode, diff)
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func TestDexRoundTripper(t *testing.T) {
   126  	testCases := []struct {
   127  		Name           string
   128  		mockDexServer  *httptest.Server
   129  		wantStatusCode int
   130  	}{
   131  		{
   132  			Name: "Round tripper works as expected",
   133  			mockDexServer: httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   134  				rw.WriteHeader(http.StatusOK)
   135  			})),
   136  			wantStatusCode: http.StatusOK,
   137  		},
   138  	}
   139  	for _, tc := range testCases {
   140  		t.Run(tc.Name, func(t *testing.T) {
   141  			// mock Dex server the app is being redirected to.
   142  			mockDexServer := tc.mockDexServer
   143  			defer mockDexServer.Close()
   144  			serverURL, _ := url.Parse(mockDexServer.URL)
   145  			rt := DexRewriteURLRoundTripper{
   146  				DexURL: serverURL,
   147  				T:      http.DefaultTransport,
   148  			}
   149  			req, _ := http.NewRequest(http.MethodGet, "/", bytes.NewBuffer([]byte("")))
   150  			rt.RoundTrip(req)
   151  			target, _ := url.Parse(mockDexServer.URL)
   152  			if diff := cmp.Diff(req.Host, target.Host); diff != "" {
   153  				t.Errorf("got %v, want %v, diff (-want +got) %s", req.Host, target.Host, diff)
   154  			}
   155  		})
   156  	}
   157  }
   158  
   159  func TestValidateToken(t *testing.T) {
   160  	// Reset Handler registration. Avoids "panic: http: multiple registrations for <PATH>" error.
   161  	http.DefaultServeMux = new(http.ServeMux)
   162  	// Dex app client configuration variables.
   163  	clientID := "test-client"
   164  	clientSecret := "test-client"
   165  	hostURL := "https://www.test.com"
   166  	scopes := []string{"scope1", "scope2"}
   167  	appDex, _ := NewDexAppClient(clientID, clientSecret, hostURL, scopes)
   168  
   169  	testCases := []struct {
   170  		Name            string
   171  		allowedAudience string
   172  		dexApp          *DexAppClient
   173  		wantErr         bool
   174  	}{
   175  		{
   176  			Name:            "Token Verifier works as expected with the correct audience",
   177  			dexApp:          appDex,
   178  			allowedAudience: "test-client",
   179  			wantErr:         false,
   180  		},
   181  		{
   182  			Name:            "Token Verifier works as expected with the wrong audience",
   183  			dexApp:          appDex,
   184  			allowedAudience: "wrong-audience",
   185  			wantErr:         true,
   186  		},
   187  	}
   188  	for _, tc := range testCases {
   189  		// Create a key set, private key and public key.
   190  		keySet, jwkPrivateKey, _ := getJWKeySet()
   191  		t.Run(tc.Name, func(t *testing.T) {
   192  			// Mocks the OIDC server to retrieve the provider.
   193  			oidcServer := MockOIDCTestServer(appDex.IssuerURL, keySet)
   194  			defer oidcServer.Close()
   195  
   196  			// Disable the TLS check to allow the test to run.
   197  			dexURL, _ := url.Parse(oidcServer.URL)
   198  			httpClient := &http.Client{
   199  				Transport: DexRewriteURLRoundTripper{
   200  					DexURL: dexURL,
   201  					T: &http.Transport{
   202  						TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
   203  					},
   204  				},
   205  			}
   206  			claims := jwtV5.MapClaims{jwt.AudienceKey: tc.dexApp.ClientID, jwt.IssuerKey: tc.dexApp.IssuerURL}
   207  			token, _ := GetSignedJwt(jwkPrivateKey, claims)
   208  
   209  			ctx := oidc.ClientContext(context.Background(), httpClient)
   210  			_, err := ValidateOIDCToken(ctx, appDex.IssuerURL, string(token), tc.allowedAudience)
   211  			if (err != nil) != tc.wantErr {
   212  				t.Errorf("creating new dex client error = %v, wantErr %v", err, tc.wantErr)
   213  			}
   214  		})
   215  	}
   216  }
   217  
   218  func TestVerifyToken(t *testing.T) {
   219  	// Reset Handler registration. Avoids "panic: http: multiple registrations for <PATH>" error.
   220  	http.DefaultServeMux = new(http.ServeMux)
   221  	// Dex app client configuration variables.
   222  	clientID := "test-client"
   223  	clientSecret := "test-client"
   224  	hostURL := "https://www.test.com"
   225  	scopes := []string{"scope1", "scope2"}
   226  	appDex, _ := NewDexAppClient(clientID, clientSecret, hostURL, scopes)
   227  
   228  	testCases := []struct {
   229  		Name     string
   230  		claims   jwtV5.MapClaims
   231  		wantErr  string
   232  		wantUser string
   233  	}{
   234  		{
   235  			Name: "Token Verifier works as expected with the correct token value",
   236  			claims: jwtV5.MapClaims{
   237  				jwt.AudienceKey: clientID,
   238  				jwt.IssuerKey:   appDex.IssuerURL,
   239  				"name":          "User",
   240  				"email":         "user@mail.com",
   241  				"groups":        []string{"Developer"}},
   242  			wantUser: "Developer,",
   243  		},
   244  		{
   245  			Name: "Token Verifier works as expected with no name",
   246  			claims: jwtV5.MapClaims{
   247  				jwt.AudienceKey: clientID,
   248  				jwt.IssuerKey:   appDex.IssuerURL,
   249  				"name":          "",
   250  				"email":         "user@mail.com",
   251  				"groups":        []string{}},
   252  			wantErr: "failed to verify token: no group defined",
   253  		},
   254  	}
   255  	for _, tc := range testCases {
   256  		// Create a key set, private key and public key.
   257  		t.Run(tc.Name, func(t *testing.T) {
   258  			keySet, jwkPrivateKey, _ := getJWKeySet()
   259  			idToken, _ := GetSignedJwt(jwkPrivateKey, tc.claims)
   260  
   261  			// Mocks the OIDC server to retrieve the provider.
   262  			oidcServer := MockOIDCTestServer(appDex.IssuerURL, keySet)
   263  			defer oidcServer.Close()
   264  
   265  			// Disable the TLS check to allow the test to run.
   266  			dexURL, _ := url.Parse(oidcServer.URL)
   267  			httpClient := &http.Client{
   268  				Transport: DexRewriteURLRoundTripper{
   269  					DexURL: dexURL,
   270  					T: &http.Transport{
   271  						TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
   272  					},
   273  				},
   274  			}
   275  
   276  			req := httptest.NewRequest(http.MethodGet, "/", nil)
   277  			cookie := &http.Cookie{
   278  				Name:  dexOAUTHTokenName,
   279  				Value: string(idToken),
   280  			}
   281  			req.AddCookie(cookie)
   282  
   283  			ctx := oidc.ClientContext(context.Background(), httpClient)
   284  			u, err := VerifyToken(ctx, req, appDex.ClientID, hostURL)
   285  			if err != nil {
   286  				if diff := cmp.Diff(tc.wantErr, err.Error()); diff != "" {
   287  					t.Errorf("Error mismatch (-want +got):\n%s", diff)
   288  				}
   289  			} else {
   290  				if diff := cmp.Diff(u, tc.wantUser); diff != "" {
   291  					t.Errorf("got %v, want %v, diff (-want +got) %s", u, tc.wantUser, diff)
   292  				}
   293  			}
   294  		})
   295  	}
   296  }
   297  
   298  // Helper function to make a new mock server.
   299  func makeNewMockServer(status int) *httptest.Server {
   300  	return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   301  		rw.WriteHeader(status)
   302  	}))
   303  }
   304  
   305  // Creates a signed JWT token with the respective claims.
   306  func GetSignedJwt(signingKey any, claims jwtV5.MapClaims) ([]byte, error) {
   307  	token := jwt.New()
   308  	_ = token.Set(jwt.ExpirationKey, time.Now().Add(time.Hour*24).Unix())
   309  
   310  	for key, value := range claims {
   311  		_ = token.Set(key, value)
   312  	}
   313  
   314  	signedToken, _ := jwt.Sign(token, jwt.WithKey(jwa.RS256, signingKey))
   315  	return signedToken, nil
   316  }
   317  
   318  // Generates and returns a key set, private key and public key.
   319  func getJWKeySet() (keySet jwk.Set, jwkPrivateKey, jwkPublicKey jwk.Key) {
   320  	rsaPrivate, rsaPublic := getRSAKeyPair()
   321  	jwkPrivateKey, _ = jwk.FromRaw(rsaPrivate)
   322  	jwkPublicKey, _ = jwk.FromRaw(rsaPublic)
   323  
   324  	_ = jwkPrivateKey.Set(jwk.KeyIDKey, "my-unique-kid")
   325  	_ = jwkPublicKey.Set(jwk.KeyIDKey, "my-unique-kid")
   326  
   327  	keySet = jwk.NewSet()
   328  	err := keySet.AddKey(jwkPublicKey)
   329  	if err != nil {
   330  		return nil, nil, nil
   331  	}
   332  
   333  	return keySet, jwkPrivateKey, jwkPublicKey
   334  }
   335  
   336  // Generates and returns a rsa key pair
   337  func getRSAKeyPair() (*rsa.PrivateKey, *rsa.PublicKey) {
   338  	privateKey, _ := rsa.GenerateKey(rand.Reader, 2048)
   339  	publicKey := &privateKey.PublicKey
   340  	return privateKey, publicKey
   341  }
   342  
   343  // Mocks the OIDC server to get all provider.
   344  func MockOIDCTestServer(issuerURL string, keySet jwk.Set) *httptest.Server {
   345  	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   346  	}))
   347  	ts.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   348  		w.Header().Set("Content-Type", "application/json")
   349  		switch r.RequestURI {
   350  		case "/dex/.well-known/openid-configuration":
   351  			io.WriteString(w, fmt.Sprintf(`
   352  {
   353    "issuer": "%[1]s",
   354    "authorization_endpoint": "%[1]s/auth",
   355    "token_endpoint": "%[1]s/token",
   356    "jwks_uri": "%[1]s/keys",
   357    "userinfo_endpoint": "%[1]s/userinfo",
   358    "device_authorization_endpoint": "%[1]s/device/code",
   359    "grant_types_supported": ["authorization_code"],
   360    "response_types_supported": ["code"],
   361    "subject_types_supported": ["public"],
   362    "id_token_signing_alg_values_supported": ["RS256"],
   363    "code_challenge_methods_supported": ["S256", "plain"],
   364    "scopes_supported": ["openid"],
   365    "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
   366    "claims_supported": ["sub", "aud", "exp"]
   367  }`, issuerURL))
   368  		case "/dex/keys":
   369  			out, _ := json.Marshal(keySet)
   370  			_, _ = io.WriteString(w, string(out))
   371  		default:
   372  			w.WriteHeader(http.StatusNotFound)
   373  		}
   374  	})
   375  	return ts
   376  }