github.com/openfga/openfga@v1.5.4-rc1/internal/authn/oidc/oidc.go (about)

     1  package oidc
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"slices"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/MicahParks/keyfunc"
    15  	"github.com/golang-jwt/jwt/v4"
    16  	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
    17  	"github.com/hashicorp/go-retryablehttp"
    18  
    19  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
    20  	"google.golang.org/grpc/codes"
    21  	"google.golang.org/grpc/status"
    22  
    23  	"github.com/openfga/openfga/internal/authn"
    24  )
    25  
    26  type RemoteOidcAuthenticator struct {
    27  	MainIssuer    string
    28  	IssuerAliases []string
    29  	Audience      string
    30  
    31  	JwksURI string
    32  	JWKs    *keyfunc.JWKS
    33  
    34  	httpClient *http.Client
    35  }
    36  
    37  var (
    38  	jwkRefreshInterval = 48 * time.Hour
    39  
    40  	errInvalidAudience = status.Error(codes.Code(openfgav1.AuthErrorCode_auth_failed_invalid_audience), "invalid audience")
    41  	errInvalidClaims   = status.Error(codes.Code(openfgav1.AuthErrorCode_invalid_claims), "invalid claims")
    42  	errInvalidIssuer   = status.Error(codes.Code(openfgav1.AuthErrorCode_auth_failed_invalid_issuer), "invalid issuer")
    43  	errInvalidSubject  = status.Error(codes.Code(openfgav1.AuthErrorCode_auth_failed_invalid_subject), "invalid subject")
    44  	errInvalidToken    = status.Error(codes.Code(openfgav1.AuthErrorCode_auth_failed_invalid_bearer_token), "invalid bearer token")
    45  
    46  	fetchJWKs = fetchJWK
    47  )
    48  
    49  var _ authn.Authenticator = (*RemoteOidcAuthenticator)(nil)
    50  var _ authn.OIDCAuthenticator = (*RemoteOidcAuthenticator)(nil)
    51  
    52  func NewRemoteOidcAuthenticator(mainIssuer string, issuerAliases []string, audience string) (*RemoteOidcAuthenticator, error) {
    53  	client := retryablehttp.NewClient()
    54  	client.Logger = nil
    55  	oidc := &RemoteOidcAuthenticator{
    56  		MainIssuer:    mainIssuer,
    57  		IssuerAliases: issuerAliases,
    58  		Audience:      audience,
    59  		httpClient:    client.StandardClient(),
    60  	}
    61  	err := fetchJWKs(oidc)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return oidc, nil
    66  }
    67  
    68  func (oidc *RemoteOidcAuthenticator) Authenticate(requestContext context.Context) (*authn.AuthClaims, error) {
    69  	authHeader, err := grpcauth.AuthFromMD(requestContext, "Bearer")
    70  	if err != nil {
    71  		return nil, authn.ErrMissingBearerToken
    72  	}
    73  
    74  	jwtParser := jwt.NewParser(jwt.WithValidMethods([]string{"RS256"}))
    75  
    76  	token, err := jwtParser.Parse(authHeader, func(token *jwt.Token) (any, error) {
    77  		return oidc.JWKs.Keyfunc(token)
    78  	})
    79  	if err != nil {
    80  		return nil, errInvalidToken
    81  	}
    82  
    83  	if !token.Valid {
    84  		return nil, errInvalidToken
    85  	}
    86  
    87  	claims, ok := token.Claims.(jwt.MapClaims)
    88  	if !ok {
    89  		return nil, errInvalidClaims
    90  	}
    91  
    92  	validIssuers := []string{
    93  		oidc.MainIssuer,
    94  	}
    95  	validIssuers = append(validIssuers, oidc.IssuerAliases...)
    96  
    97  	ok = slices.ContainsFunc(validIssuers, func(issuer string) bool {
    98  		return claims.VerifyIssuer(issuer, true)
    99  	})
   100  
   101  	if !ok {
   102  		return nil, errInvalidIssuer
   103  	}
   104  
   105  	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
   106  		return nil, errInvalidAudience
   107  	}
   108  
   109  	// optional subject
   110  	var subject = ""
   111  	if subjectClaim, ok := claims["sub"]; ok {
   112  		if subject, ok = subjectClaim.(string); !ok {
   113  			return nil, errInvalidSubject
   114  		}
   115  	}
   116  
   117  	principal := &authn.AuthClaims{
   118  		Subject: subject,
   119  		Scopes:  make(map[string]bool),
   120  	}
   121  
   122  	// optional scopes
   123  	if scopeKey, ok := claims["scope"]; ok {
   124  		if scope, ok := scopeKey.(string); ok {
   125  			scopes := strings.Split(scope, " ")
   126  			for _, s := range scopes {
   127  				principal.Scopes[s] = true
   128  			}
   129  		}
   130  	}
   131  
   132  	return principal, nil
   133  }
   134  
   135  func fetchJWK(oidc *RemoteOidcAuthenticator) error {
   136  	oidcConfig, err := oidc.GetConfiguration()
   137  	if err != nil {
   138  		return fmt.Errorf("error fetching OIDC configuration: %w", err)
   139  	}
   140  
   141  	oidc.JwksURI = oidcConfig.JWKsURI
   142  	jwks, err := oidc.GetKeys()
   143  	if err != nil {
   144  		return fmt.Errorf("error fetching OIDC keys: %w", err)
   145  	}
   146  
   147  	oidc.JWKs = jwks
   148  
   149  	return nil
   150  }
   151  
   152  func (oidc *RemoteOidcAuthenticator) GetKeys() (*keyfunc.JWKS, error) {
   153  	jwks, err := keyfunc.Get(oidc.JwksURI, keyfunc.Options{
   154  		Client:          oidc.httpClient,
   155  		RefreshInterval: jwkRefreshInterval,
   156  	})
   157  	if err != nil {
   158  		return nil, fmt.Errorf("error fetching keys from %v: %w", oidc.JwksURI, err)
   159  	}
   160  	return jwks, nil
   161  }
   162  
   163  func (oidc *RemoteOidcAuthenticator) GetConfiguration() (*authn.OidcConfig, error) {
   164  	wellKnown := strings.TrimSuffix(oidc.MainIssuer, "/") + "/.well-known/openid-configuration"
   165  	req, err := http.NewRequest("GET", wellKnown, nil)
   166  	if err != nil {
   167  		return nil, fmt.Errorf("error forming request to get OIDC: %w", err)
   168  	}
   169  
   170  	res, err := oidc.httpClient.Do(req)
   171  	if err != nil {
   172  		return nil, fmt.Errorf("error getting OIDC: %w", err)
   173  	}
   174  	defer res.Body.Close()
   175  
   176  	if res.StatusCode != http.StatusOK {
   177  		return nil, fmt.Errorf("unexpected status code getting OIDC: %v", res.StatusCode)
   178  	}
   179  
   180  	body, err := io.ReadAll(res.Body)
   181  	if err != nil {
   182  		return nil, fmt.Errorf("error reading response body: %w", err)
   183  	}
   184  
   185  	oidcConfig := &authn.OidcConfig{}
   186  	if err := json.Unmarshal(body, oidcConfig); err != nil {
   187  		return nil, fmt.Errorf("failed parsing document: %w", err)
   188  	}
   189  
   190  	if oidcConfig.Issuer == "" {
   191  		return nil, errors.New("missing issuer value")
   192  	}
   193  
   194  	if oidcConfig.JWKsURI == "" {
   195  		return nil, errors.New("missing jwks_uri value")
   196  	}
   197  	return oidcConfig, nil
   198  }
   199  
   200  func (oidc *RemoteOidcAuthenticator) Close() {
   201  	oidc.JWKs.EndBackground()
   202  }