github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/authentication/service.go (about)

     1  package authentication
     2  
     3  //go:generate go run github.com/deepmap/oapi-codegen/cmd/oapi-codegen@v1.5.6 -package apiclient -generate "types,client" -o apiclient/client.gen.go  ../../api/authentication.yml
     4  //go:generate go run github.com/golang/mock/mockgen@v1.6.0 -package=mock -destination=mock/mock_authentication_client.go github.com/treeverse/lakefs/pkg/authentication/apiclient ClientWithResponsesInterface
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"net/http"
    10  
    11  	"github.com/getkin/kin-openapi/openapi3filter"
    12  	"github.com/treeverse/lakefs/pkg/authentication/apiclient"
    13  	"github.com/treeverse/lakefs/pkg/logging"
    14  )
    15  
    16  type Service interface {
    17  	IsExternalPrincipalsEnabled() bool
    18  	ExternalPrincipalLogin(ctx context.Context, identityRequest map[string]interface{}) (*apiclient.ExternalPrincipal, error)
    19  	// ValidateSTS validates the STS parameters and returns the external user ID
    20  	ValidateSTS(ctx context.Context, code, redirectURI, state string) (string, error)
    21  }
    22  
    23  type DummyService struct{}
    24  
    25  func NewDummyService() *DummyService {
    26  	return &DummyService{}
    27  }
    28  
    29  func (d DummyService) ValidateSTS(ctx context.Context, code, redirectURI, state string) (string, error) {
    30  	return "", ErrNotImplemented
    31  }
    32  
    33  func (d DummyService) ExternalPrincipalLogin(_ context.Context, _ map[string]interface{}) (*apiclient.ExternalPrincipal, error) {
    34  	return nil, ErrNotImplemented
    35  }
    36  
    37  func (d DummyService) IsExternalPrincipalsEnabled() bool {
    38  	return false
    39  }
    40  
    41  type APIService struct {
    42  	validateIDTokenClaims     map[string]string
    43  	apiClient                 apiclient.ClientWithResponsesInterface
    44  	logger                    logging.Logger
    45  	externalPrincipalsEnabled bool
    46  }
    47  
    48  func NewAPIService(apiEndpoint string, validateIDTokenClaims map[string]string, logger logging.Logger, externalPrincipalsEnabled bool) (*APIService, error) {
    49  	client, err := apiclient.NewClientWithResponses(apiEndpoint)
    50  	if err != nil {
    51  		return nil, fmt.Errorf("failed to create authentication api client: %w", err)
    52  	}
    53  
    54  	res := &APIService{
    55  		validateIDTokenClaims:     validateIDTokenClaims,
    56  		apiClient:                 client,
    57  		logger:                    logger,
    58  		externalPrincipalsEnabled: externalPrincipalsEnabled,
    59  	}
    60  	return res, nil
    61  }
    62  
    63  func NewAPIServiceWithClients(apiClient apiclient.ClientWithResponsesInterface, logger logging.Logger, validateIDTokenClaims map[string]string, externalPrincipalsEnabled bool) (*APIService, error) {
    64  	return &APIService{
    65  		apiClient:                 apiClient,
    66  		logger:                    logger,
    67  		validateIDTokenClaims:     validateIDTokenClaims,
    68  		externalPrincipalsEnabled: externalPrincipalsEnabled,
    69  	}, nil
    70  }
    71  
    72  // validateResponse returns ErrUnexpectedStatusCode if the response status code is not as expected
    73  func (s *APIService) validateResponse(resp openapi3filter.StatusCoder, expectedStatusCode int) error {
    74  	statusCode := resp.StatusCode()
    75  	if statusCode == expectedStatusCode {
    76  		return nil
    77  	}
    78  	switch statusCode {
    79  	case http.StatusBadRequest:
    80  		return ErrInvalidRequest
    81  	case http.StatusConflict:
    82  		return ErrAlreadyExists
    83  	case http.StatusUnauthorized:
    84  		return ErrInsufficientPermissions
    85  	default:
    86  		return fmt.Errorf("%w - got %d expected %d", ErrUnexpectedStatusCode, statusCode, expectedStatusCode)
    87  	}
    88  }
    89  
    90  // ValidateSTS calls the external authentication service to validate the STS parameters
    91  // validates the required claims and returns the external user id and expiration time
    92  func (s *APIService) ValidateSTS(ctx context.Context, code, redirectURI, state string) (string, error) {
    93  	res, err := s.apiClient.STSLoginWithResponse(ctx, apiclient.STSLoginJSONRequestBody{
    94  		Code:        code,
    95  		RedirectUri: redirectURI,
    96  		State:       state,
    97  	})
    98  	if err != nil {
    99  		return "", fmt.Errorf("failed to authenticate user: %w", err)
   100  	}
   101  
   102  	if err := s.validateResponse(res, http.StatusOK); err != nil {
   103  		return "", fmt.Errorf("invalid authentication response: %w", err)
   104  	}
   105  
   106  	// validate claims
   107  	claims := res.JSON200.Claims
   108  	for claim, expectedValue := range s.validateIDTokenClaims {
   109  		if claimValue, found := claims.Get(claim); !found || claimValue != expectedValue {
   110  			return "", fmt.Errorf("claim %s has unexpected value %s: %w", claim, claimValue, ErrInsufficientPermissions)
   111  		}
   112  	}
   113  	subject, found := claims.Get("sub")
   114  	if !found {
   115  		return "", fmt.Errorf("missing subject in claims: %w", ErrInsufficientPermissions)
   116  	}
   117  	return subject, nil
   118  }
   119  
   120  func (s *APIService) ExternalPrincipalLogin(ctx context.Context, identityRequest map[string]interface{}) (*apiclient.ExternalPrincipal, error) {
   121  	if !s.IsExternalPrincipalsEnabled() {
   122  		return nil, fmt.Errorf("external principals disabled: %w", ErrInvalidRequest)
   123  	}
   124  	resp, err := s.apiClient.ExternalPrincipalLoginWithResponse(ctx, identityRequest)
   125  	if err != nil {
   126  		return nil, fmt.Errorf("calling authenticate user: %w", err)
   127  	}
   128  	if resp.StatusCode() != http.StatusOK {
   129  		switch resp.StatusCode() {
   130  		case http.StatusBadRequest:
   131  			return nil, ErrInvalidRequest
   132  		case http.StatusUnauthorized:
   133  			return nil, ErrInvalidTokenFormat
   134  		case http.StatusForbidden:
   135  			return nil, ErrSessionExpired
   136  		default:
   137  			return nil, fmt.Errorf("%w - got %d expected %d", ErrUnexpectedStatusCode, resp.StatusCode(), http.StatusOK)
   138  		}
   139  	}
   140  	return resp.JSON200, nil
   141  }
   142  
   143  func (s *APIService) IsExternalPrincipalsEnabled() bool {
   144  	return s.externalPrincipalsEnabled
   145  }