github.com/openfga/openfga@v1.5.4-rc1/internal/authn/oidc/oidc_test.go (about)

     1  package oidc
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"log"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/MicahParks/keyfunc"
    13  	"github.com/golang-jwt/jwt/v4"
    14  	"github.com/stretchr/testify/require"
    15  	"google.golang.org/grpc/metadata"
    16  
    17  	"github.com/openfga/openfga/internal/authn"
    18  )
    19  
    20  func TestRemoteOidcAuthenticator_Authenticate(t *testing.T) {
    21  	t.Run("when_the_authorization_header_is_missing_from_the_gRPC_metadata_of_the_request,_returns_'missing_bearer_token'_error", func(t *testing.T) {
    22  		authenticator := &RemoteOidcAuthenticator{}
    23  		_, err := authenticator.Authenticate(context.Background())
    24  		require.Equal(t, authn.ErrMissingBearerToken, err)
    25  	})
    26  	errorTestCases := []struct {
    27  		testDescription string
    28  		testSetup       func() (*RemoteOidcAuthenticator, context.Context, error)
    29  		expectedError   string
    30  	}{
    31  		{
    32  			testDescription: "when_the_token_has_expired,_return_'invalid_bearer_token'",
    33  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
    34  				return quickConfigSetup(
    35  					"kid_1",
    36  					"kid_1",
    37  					"right_issuer",
    38  					"right_audience",
    39  					nil,
    40  					jwt.MapClaims{
    41  						"iss": "right_issuer",
    42  						"aud": "right_audience",
    43  						"sub": "openfga client",
    44  						"exp": time.Now().Add(-10 * time.Minute).Unix(),
    45  					},
    46  					nil,
    47  				)
    48  			},
    49  			expectedError: "invalid bearer token",
    50  		},
    51  		{
    52  			testDescription: "when_the_JWT_contains_a_future_'iat',_return_'invalid_bearer_token'",
    53  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
    54  				return quickConfigSetup(
    55  					"kid_1",
    56  					"kid_1",
    57  					"right_issuer",
    58  					"right_audience",
    59  					nil,
    60  					jwt.MapClaims{
    61  						"iss": "right_issuer",
    62  						"aud": "right_audience",
    63  						"sub": "openfga client",
    64  						"iat": time.Now().Add(10 * time.Minute).Unix(),
    65  					},
    66  					nil,
    67  				)
    68  			},
    69  			expectedError: "invalid bearer token",
    70  		},
    71  		{
    72  			testDescription: "when_JWT_and_JWK_kid_don't_match,_returns_'invalid_bearer_token'",
    73  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
    74  				return quickConfigSetup("kid_1", "kid_2", "", "", nil, jwt.MapClaims{}, nil)
    75  			},
    76  			expectedError: "invalid bearer token",
    77  		},
    78  		{
    79  			testDescription: "when_token_is_signed_using_different_public/private_key_pairs,_returns__'invalid_bearer_token'",
    80  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
    81  				privateKey, _ := generateJWTSignatureKeys()
    82  				return quickConfigSetup("kid_1", "kid_1", "", "", nil, jwt.MapClaims{}, privateKey)
    83  			},
    84  			expectedError: "invalid bearer token",
    85  		},
    86  		{
    87  			testDescription: "when_token's_issuer_does_not_match_the_one_provided_in_the_server_configuration,_MUST_return_'invalid_issuer'_error",
    88  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
    89  				return quickConfigSetup(
    90  					"kid_1",
    91  					"kid_1",
    92  					"right_issuer",
    93  					"",
    94  					nil,
    95  					jwt.MapClaims{
    96  						"iss": "wrong_issuer",
    97  					},
    98  					nil,
    99  				)
   100  			},
   101  			expectedError: "invalid issuer",
   102  		},
   103  		{
   104  			testDescription: "when_token's_audience_does_not_match_the_one_provided_in_the_server_configuration,_MUST_return_'invalid_audience'_error",
   105  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
   106  				return quickConfigSetup(
   107  					"kid_1",
   108  					"kid_1",
   109  					"right_issuer",
   110  					"right_audience",
   111  					nil,
   112  					jwt.MapClaims{
   113  						"iss": "right_issuer",
   114  						"aud": "wrong_audience",
   115  					},
   116  					nil,
   117  				)
   118  			},
   119  			expectedError: "invalid audience",
   120  		},
   121  		{
   122  			testDescription: "when_the_subject_of_the_token_is_not_a_string,_MUST_return_'invalid_subject'_error",
   123  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
   124  				return quickConfigSetup(
   125  					"kid_1",
   126  					"kid_1",
   127  					"right_issuer",
   128  					"right_audience",
   129  					nil,
   130  					jwt.MapClaims{
   131  						"iss": "right_issuer",
   132  						"aud": "right_audience",
   133  						"sub": 12,
   134  					},
   135  					nil,
   136  				)
   137  			},
   138  			expectedError: "invalid subject",
   139  		},
   140  	}
   141  
   142  	for _, testC := range errorTestCases {
   143  		t.Run(testC.testDescription, func(t *testing.T) {
   144  			if testC.expectedError == "" {
   145  				t.Fatal("this suite is to test error cases and this test didn't have an error expectation")
   146  			}
   147  			oidc, requestContext, _ := testC.testSetup()
   148  			_, err := oidc.Authenticate(requestContext)
   149  			require.Contains(t, err.Error(), testC.expectedError)
   150  		})
   151  	}
   152  
   153  	// Success testcases
   154  
   155  	scopes := "offline_access read write delete"
   156  	successTestCases := []struct {
   157  		testDescription string
   158  		testSetup       func() (*RemoteOidcAuthenticator, context.Context, error)
   159  	}{
   160  		{
   161  			testDescription: "when_the_token_is_valid,_it_MUST_return_the_token_subject_and_its_associated_scopes",
   162  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
   163  				return quickConfigSetup(
   164  					"kid_2",
   165  					"kid_2",
   166  					"right_issuer",
   167  					"right_audience",
   168  					nil,
   169  					jwt.MapClaims{
   170  						"iss":   "right_issuer",
   171  						"aud":   "right_audience",
   172  						"sub":   "openfga client",
   173  						"scope": scopes,
   174  					},
   175  					nil,
   176  				)
   177  			},
   178  		},
   179  		{
   180  			testDescription: "when_the_token_is_valid_with_issuer_alias,_it_MUST_return_the_token_subject_and_its_associated_scopes",
   181  			testSetup: func() (*RemoteOidcAuthenticator, context.Context, error) {
   182  				return quickConfigSetup(
   183  					"kid_2",
   184  					"kid_2",
   185  					"right_issuer",
   186  					"right_audience",
   187  					[]string{"issuer_alias"},
   188  					jwt.MapClaims{
   189  						"iss":   "issuer_alias",
   190  						"aud":   "right_audience",
   191  						"sub":   "openfga client",
   192  						"scope": scopes,
   193  					},
   194  					nil,
   195  				)
   196  			},
   197  		},
   198  	}
   199  
   200  	for _, testC := range successTestCases {
   201  		t.Run(testC.testDescription, func(t *testing.T) {
   202  			oidc, requestContext, err := testC.testSetup()
   203  			if err != nil {
   204  				t.Fatal(err)
   205  			}
   206  			authClaims, err := oidc.Authenticate(requestContext)
   207  			require.NoError(t, err)
   208  			require.Equal(t, "openfga client", authClaims.Subject)
   209  			scopesList := strings.Split(scopes, " ")
   210  			require.Equal(t, len(scopesList), len(authClaims.Scopes))
   211  			for _, scope := range scopesList {
   212  				_, ok := authClaims.Scopes[scope]
   213  				require.True(t, ok)
   214  			}
   215  		})
   216  	}
   217  }
   218  
   219  // quickConfigSetup sets up a basic configuration for testing purposes.
   220  func quickConfigSetup(jwkKid, jwtKid, issuerURL, audience string, issuerAliases []string, jwtClaims jwt.MapClaims, privateKeyOverride *rsa.PrivateKey) (*RemoteOidcAuthenticator, context.Context, error) {
   221  	// Generate JWT signature keys
   222  	privateKey, publicKey := generateJWTSignatureKeys()
   223  	if privateKeyOverride != nil {
   224  		privateKey = privateKeyOverride
   225  	}
   226  	// assign mocked JWKS fetching function to global function
   227  	fetchJWKs = fetchKeysMock(publicKey, jwkKid)
   228  
   229  	// Initialize RemoteOidcAuthenticator
   230  	oidc, err := NewRemoteOidcAuthenticator(issuerURL, issuerAliases, audience)
   231  	if err != nil {
   232  		return nil, nil, err
   233  	}
   234  
   235  	// Generate JWT token
   236  	token := generateJWT(privateKey, jwtKid, jwtClaims)
   237  
   238  	// Generate context with JWT token
   239  	requestContext := generateContext(token)
   240  
   241  	return oidc, requestContext, nil
   242  }
   243  
   244  func generateContext(token string) context.Context {
   245  	md := metadata.Pairs("authorization", "Bearer "+token)
   246  	return metadata.NewIncomingContext(context.Background(), md)
   247  }
   248  
   249  // generateJWTSignatureKeys generates a private key for signing JWT tokens
   250  // and a corresponding public key for verifying JWT token signatures.
   251  func generateJWTSignatureKeys() (*rsa.PrivateKey, *rsa.PublicKey) {
   252  	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
   253  	if err != nil {
   254  		log.Fatal("Private key cannot be created.", err.Error())
   255  	}
   256  	return privateKey, &privateKey.PublicKey
   257  }
   258  
   259  // fetchKeysMock returns a function that sets up a mock JWKS.
   260  func fetchKeysMock(publicKey *rsa.PublicKey, kid string) func(oidc *RemoteOidcAuthenticator) error {
   261  	// Create a keyfunc with the given RSA public key and RS256 algorithm
   262  	givenKeys := keyfunc.NewGivenRSACustomWithOptions(publicKey, keyfunc.GivenKeyOptions{
   263  		Algorithm: "RS256",
   264  	})
   265  	// Return a function that sets up the mock JWKS with the provided kid
   266  	return func(oidc *RemoteOidcAuthenticator) error {
   267  		jwks := keyfunc.NewGiven(map[string]keyfunc.GivenKey{
   268  			kid: givenKeys,
   269  		})
   270  		oidc.JWKs = jwks
   271  		return nil
   272  	}
   273  }
   274  
   275  // generateJWT generates Json Web Tokens signed with the provided privateKey.
   276  func generateJWT(privateKey *rsa.PrivateKey, kid string, claims jwt.MapClaims) string {
   277  	token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
   278  	token.Header["kid"] = kid
   279  	signedToken, err := token.SignedString(privateKey)
   280  	if err != nil {
   281  		log.Fatal("Failed to sign JWT token:", err)
   282  	}
   283  	return signedToken
   284  }