github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/authentication/service_test.go (about)

     1  package authentication_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"testing"
     8  
     9  	"github.com/golang/mock/gomock"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/treeverse/lakefs/pkg/authentication"
    12  	"github.com/treeverse/lakefs/pkg/authentication/apiclient"
    13  	"github.com/treeverse/lakefs/pkg/authentication/mock"
    14  	"github.com/treeverse/lakefs/pkg/logging"
    15  )
    16  
    17  func TestAPIAuthService_STSLogin(t *testing.T) {
    18  	someErr := errors.New("some error")
    19  	tests := []struct {
    20  		name                 string
    21  		responseStatusCode   int
    22  		expectedErr          error
    23  		error                error
    24  		additionalClaim      string
    25  		additionalClaimValue string
    26  		validateClaim        string
    27  		validateClaimValue   string
    28  		returnedSubject      string
    29  	}{
    30  		{
    31  			name:               "ok",
    32  			responseStatusCode: http.StatusOK,
    33  			expectedErr:        nil,
    34  			returnedSubject:    "external_user_id",
    35  		},
    36  		{
    37  			name:                 "With additional claim",
    38  			responseStatusCode:   http.StatusOK,
    39  			expectedErr:          nil,
    40  			additionalClaim:      "additional_claim",
    41  			additionalClaimValue: "additional_claim_value",
    42  			validateClaim:        "additional_claim",
    43  			validateClaimValue:   "additional_claim_value",
    44  			returnedSubject:      "external_user_id",
    45  		},
    46  		{
    47  			name:                 "Non matching additional claim",
    48  			responseStatusCode:   http.StatusOK,
    49  			expectedErr:          authentication.ErrInsufficientPermissions,
    50  			additionalClaim:      "additional_claim",
    51  			additionalClaimValue: "additional_claim_value",
    52  			validateClaim:        "additional_claim",
    53  			validateClaimValue:   "additional_claim_value2",
    54  			returnedSubject:      "external_user_id",
    55  		},
    56  		{
    57  			name:               "Missing subject",
    58  			responseStatusCode: http.StatusOK,
    59  			expectedErr:        authentication.ErrInsufficientPermissions,
    60  		},
    61  		{
    62  			name:               "Not authorized",
    63  			responseStatusCode: http.StatusUnauthorized,
    64  			expectedErr:        authentication.ErrInsufficientPermissions,
    65  		},
    66  		{
    67  			name:               "Internal server error",
    68  			responseStatusCode: http.StatusInternalServerError,
    69  			expectedErr:        authentication.ErrUnexpectedStatusCode,
    70  		},
    71  		{
    72  			name:        "Other error",
    73  			error:       someErr,
    74  			expectedErr: someErr,
    75  		},
    76  	}
    77  	code := "some_code"
    78  	state := "some_state"
    79  	redirectURI := "some_redirect_uri"
    80  	for _, tt := range tests {
    81  		t.Run(tt.name, func(t *testing.T) {
    82  			validateTokenClaims := map[string]string{tt.validateClaim: tt.validateClaimValue}
    83  			mockClient, s := NewTestApiService(t, validateTokenClaims, false)
    84  			ctx := context.Background()
    85  			requestEq := gomock.Eq(apiclient.STSLoginJSONRequestBody{
    86  				RedirectUri: redirectURI,
    87  				Code:        code,
    88  				State:       state,
    89  			})
    90  
    91  			loginResponse := &apiclient.STSLoginResponse{
    92  				Body:         nil,
    93  				HTTPResponse: &http.Response{StatusCode: tt.responseStatusCode},
    94  				JSON200:      nil,
    95  				JSON401:      nil,
    96  				JSONDefault:  nil,
    97  			}
    98  			if tt.responseStatusCode == http.StatusOK {
    99  				loginResponse.JSON200 = &apiclient.OidcTokenData{
   100  					Claims: apiclient.OidcTokenData_Claims{
   101  						AdditionalProperties: map[string]string{tt.additionalClaim: tt.additionalClaimValue},
   102  					},
   103  				}
   104  				if tt.returnedSubject != "" {
   105  					loginResponse.JSON200.Claims.AdditionalProperties["sub"] = tt.returnedSubject
   106  				}
   107  
   108  			}
   109  			mockClient.EXPECT().STSLoginWithResponse(gomock.Any(), requestEq).Return(loginResponse, tt.error)
   110  			externalUserID, err := s.ValidateSTS(ctx, code, redirectURI, state)
   111  			if !errors.Is(err, tt.expectedErr) {
   112  				t.Fatalf("ValidateSTS: expected err: %v got: %v", tt.expectedErr, err)
   113  			}
   114  			if err != nil {
   115  				return
   116  			}
   117  			if externalUserID != tt.returnedSubject {
   118  				t.Fatalf("expected subject to be 'external_user_id', got %s", externalUserID)
   119  			}
   120  		})
   121  	}
   122  }
   123  
   124  func NewTestApiService(t *testing.T, validateIDTokenClaims map[string]string, externalPrincipalsEnabled bool) (*mock.MockClientWithResponsesInterface, *authentication.APIService) {
   125  	t.Helper()
   126  	ctrl := gomock.NewController(t)
   127  	mockClient := mock.NewMockClientWithResponsesInterface(ctrl)
   128  	s, err := authentication.NewAPIServiceWithClients(mockClient, logging.ContextUnavailable(), validateIDTokenClaims, externalPrincipalsEnabled)
   129  	if err != nil {
   130  		t.Fatalf("failed initiating API service with mock")
   131  	}
   132  	return mockClient, s
   133  }
   134  
   135  func TestAPIAuthService_ExternalLogin(t *testing.T) {
   136  	mockClient, s := NewTestApiService(t, map[string]string{}, true)
   137  	ctx := context.Background()
   138  	principalId := "arn"
   139  	externalLoginInfo := map[string]interface{}{"IdentityToken": "Token"}
   140  
   141  	mockClient.EXPECT().ExternalPrincipalLoginWithResponse(gomock.Any(), gomock.Eq(externalLoginInfo)).Return(
   142  		&apiclient.ExternalPrincipalLoginResponse{
   143  			HTTPResponse: &http.Response{
   144  				StatusCode: http.StatusOK,
   145  			},
   146  			JSON200: &apiclient.ExternalPrincipal{Id: principalId},
   147  		}, nil)
   148  
   149  	resp, err := s.ExternalPrincipalLogin(ctx, externalLoginInfo)
   150  	require.NoError(t, err)
   151  	require.Equal(t, principalId, resp.Id)
   152  }