github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/pkg/usertokens/oidc/oidc.go (about)

     1  package oidc
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/sha1"
     7  	"encoding/base64"
     8  	"fmt"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/bluele/gcache"
    16  	oidc "github.com/coreos/go-oidc"
    17  	"github.com/rs/xid"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/usertokens/common"
    19  	"golang.org/x/oauth2"
    20  )
    21  
    22  var (
    23  	// We maintain two caches. The first maintains the set of states that
    24  	// we issue the redirect requests with. This helps us validate the
    25  	// callbacks and verify the state to avoid any cross-origin violations.
    26  	// Currently providing 60 seconds for the user to authenticate.
    27  	stateCache gcache.Cache
    28  	// The second cache will maintain the validations of the tokens so that
    29  	// we don't go to the authorizer for every request.
    30  	tokenCache gcache.Cache
    31  )
    32  
    33  // clientData is the state maintained for a client to improve response
    34  // times and hold the refresh tokens.
    35  type clientData struct {
    36  	attributes  []string
    37  	tokenSource oauth2.TokenSource
    38  	expiry      time.Time
    39  	sync.Mutex
    40  }
    41  
    42  // TokenVerifier is an OIDC validator.
    43  type TokenVerifier struct {
    44  	ProviderURL    string
    45  	ClientID       string
    46  	ClientSecret   string
    47  	Scopes         []string
    48  	RedirectURL    string
    49  	NonceSize      int
    50  	CookieDuration time.Duration
    51  	clientConfig   *oauth2.Config
    52  	oauthVerifier  *oidc.IDTokenVerifier
    53  	googleHack     bool
    54  }
    55  
    56  // NewClient creates a new validator client
    57  func NewClient(ctx context.Context, v *TokenVerifier) (*TokenVerifier, error) {
    58  	// Initialize caches only once if they are nil.
    59  	if stateCache == nil {
    60  		stateCache = gcache.New(2048).LRU().Expiration(120 * time.Second).Build()
    61  	}
    62  	if tokenCache == nil {
    63  		tokenCache = gcache.New(2048).LRU().Build()
    64  	}
    65  
    66  	// Create a new generic OIDC provider based on the provider URL.
    67  	// The library will auto-discover the configuration of the provider.
    68  	// If it is not a compliant provider we should report and error here.
    69  	provider, err := oidc.NewProvider(ctx, v.ProviderURL)
    70  	if err != nil {
    71  		return nil, fmt.Errorf("Failed to initialize provider: %s", err)
    72  	}
    73  
    74  	oidConfig := &oidc.Config{
    75  		ClientID:          v.ClientID,
    76  		SkipClientIDCheck: true,
    77  	}
    78  	v.oauthVerifier = provider.Verifier(oidConfig)
    79  	scopes := []string{oidc.ScopeOpenID, "profile", "email"}
    80  	for _, scope := range v.Scopes {
    81  		if scope != oidc.ScopeOpenID && scope != "profile" && scope != "email" {
    82  			scopes = append(scopes, scope)
    83  		}
    84  	}
    85  
    86  	v.clientConfig = &oauth2.Config{
    87  		ClientID:     v.ClientID,
    88  		ClientSecret: v.ClientSecret,
    89  		Endpoint:     provider.Endpoint(),
    90  		RedirectURL:  v.RedirectURL,
    91  		Scopes:       scopes,
    92  	}
    93  
    94  	// Google does not honor the OIDC standard to refresh tokens
    95  	// with a proper scope. Instead it requires a prompt parameter
    96  	// to be passed. In order to deal wit this, we will have to
    97  	// detect Google as the OIDC and pass the parameters.
    98  	if strings.Contains(v.ProviderURL, "accounts.google.com") {
    99  		v.googleHack = true
   100  	}
   101  
   102  	return v, nil
   103  }
   104  
   105  // IssueRedirect creates the redirect URL. The URI is created by the provider
   106  // and it includes a state that is random. The state will be remembered
   107  // for the return. There is an assumption here that the LBs in front of
   108  // applications are sticky or the TCP session is re-used. Otherwise, we will
   109  // need a global state that could introduce additional calls to a central
   110  // system.
   111  // TODO: add support for a global state.
   112  func (v *TokenVerifier) IssueRedirect(originURL string) string {
   113  	state, err := randomSha1(v.NonceSize)
   114  	if err != nil {
   115  		state = xid.New().String()
   116  	}
   117  	if err := stateCache.Set(state, originURL); err != nil {
   118  		return ""
   119  	}
   120  
   121  	redirectURL := v.clientConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
   122  	if v.googleHack {
   123  		redirectURL = redirectURL + "&prompt=consent"
   124  	}
   125  
   126  	return redirectURL
   127  }
   128  
   129  // Callback is the function that is called back by the IDP to catch the token
   130  // and perform all other validations. It will return the resulting token,
   131  // the original URL that was called to initiate the protocol, and the
   132  // http status response.
   133  func (v *TokenVerifier) Callback(ctx context.Context, u *url.URL) (string, string, int, error) {
   134  
   135  	// We first validate that the callback state matches the original redirect
   136  	// state. We clean up the cache once it is validated. During this process
   137  	// we recover the original URL that initiated the protocol. This allows
   138  	// us to redirect the client to their original request.
   139  	receivedState := u.Query().Get("state")
   140  	originURL, err := stateCache.Get(receivedState)
   141  	if err != nil {
   142  		return "", "", http.StatusBadRequest, fmt.Errorf("bad state")
   143  	}
   144  	stateCache.Remove(receivedState)
   145  
   146  	// We exchange the authorization code with an OAUTH token. This is the main
   147  	// step where the OAUTH provider will match the code to the token.
   148  	oauth2Token, err := v.clientConfig.Exchange(ctx, u.Query().Get("code"), oauth2.AccessTypeOffline)
   149  	if err != nil {
   150  		return "", "", http.StatusInternalServerError, fmt.Errorf("bad code: %s", err)
   151  	}
   152  
   153  	// We extract the rawID token.
   154  	rawIDToken, ok := oauth2Token.Extra("id_token").(string)
   155  	if !ok {
   156  		return "", "", http.StatusInternalServerError, fmt.Errorf("bad ID")
   157  	}
   158  
   159  	if err := tokenCache.SetWithExpire(
   160  		rawIDToken,
   161  		&clientData{
   162  			tokenSource: v.clientConfig.TokenSource(ctx, oauth2Token),
   163  			expiry:      oauth2Token.Expiry,
   164  		},
   165  		time.Until(oauth2Token.Expiry.Add(3600*time.Second)),
   166  	); err != nil {
   167  		return "", "", http.StatusInternalServerError, fmt.Errorf("failed to insert token in the cache: %s", err)
   168  	}
   169  
   170  	return rawIDToken, originURL.(string), http.StatusTemporaryRedirect, nil
   171  }
   172  
   173  // Validate checks if the token is valid and returns the claims. The validator
   174  // maintains an internal cache with tokens to accelerate performance. If the
   175  // token is not in the cache, it will validate it with the central authorizer.
   176  func (v *TokenVerifier) Validate(ctx context.Context, token string) ([]string, bool, string, error) {
   177  
   178  	if len(token) == 0 {
   179  		return []string{}, true, token, fmt.Errorf("invalid token presented")
   180  	}
   181  
   182  	var tokenData *clientData
   183  
   184  	// If it is not found in the cache initiate a call back process.
   185  	data, err := tokenCache.Get(token)
   186  	if err == nil {
   187  		var ok bool
   188  		tokenData, ok = data.(*clientData)
   189  		if !ok {
   190  			return nil, true, token, fmt.Errorf("internal server error")
   191  		}
   192  
   193  		// If the cached token hasn't expired yet, we can just accept it and not
   194  		// go through a whole verification process. Nothing new.
   195  		if tokenData.expiry.After(time.Now()) && len(tokenData.attributes) > 0 {
   196  			return tokenData.attributes, false, token, nil
   197  		}
   198  	} else { // No token in the cache. Let's try to see if it is valid and we can cache it now.
   199  		//
   200  		tokenData = &clientData{}
   201  	}
   202  
   203  	// The token has expired. Let's try to refresh it.
   204  	tokenData.Lock()
   205  	defer tokenData.Unlock()
   206  
   207  	// If it is the first time we are verifying the token, let's do
   208  	// it now. This is possible if the token was created earlier
   209  	// but we never had a chance to verify it. In this case, the
   210  	// attributes were empty.
   211  	idToken, err := v.oauthVerifier.Verify(ctx, token)
   212  	if err != nil {
   213  		var ok bool
   214  		// Token is expired. Let's try to refresh it if we have something
   215  		// in the cache. If we don't have a refresh token, we reject it
   216  		// and ask the client to validate again.
   217  		if tokenData.tokenSource == nil {
   218  			return []string{}, true, token, fmt.Errorf("no cached data and expired token - request authorization: %s", err)
   219  		}
   220  		refreshedToken, err := tokenData.tokenSource.Token()
   221  		if err != nil {
   222  			return []string{}, true, token, fmt.Errorf("token validation failed and cannot refresh: %s", err)
   223  		}
   224  		token, ok = refreshedToken.Extra("id_token").(string)
   225  		if !ok {
   226  			return []string{}, true, token, fmt.Errorf("failed to find id_token - initiate re-authorization")
   227  		}
   228  		idToken, err = v.oauthVerifier.Verify(ctx, token)
   229  		if err != nil {
   230  			return []string{}, true, token, fmt.Errorf("invalid token derived from refresh - manual authorization is required: %s", err)
   231  		}
   232  	}
   233  
   234  	// Get the claims out of the token. Use the standard data structure for
   235  	// this and ignore the other fields. We are only interested on the ID.
   236  	resp := struct {
   237  		IDTokenClaims map[string]interface{} // ID Token payload is just JSON.
   238  	}{map[string]interface{}{}}
   239  	if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
   240  		return []string{}, true, token, fmt.Errorf("unable to process claims: %s", err)
   241  	}
   242  
   243  	// Flatten the claims in a generic format.
   244  	attributes := []string{}
   245  	for k, v := range resp.IDTokenClaims {
   246  		attributes = append(attributes, common.FlattenClaim(k, v)...)
   247  	}
   248  
   249  	tokenData.attributes = attributes
   250  	tokenData.expiry = idToken.Expiry
   251  
   252  	// Cache the token and attributes to avoid multiple validations and update the
   253  	// expiration time.
   254  	if err := tokenCache.SetWithExpire(token, tokenData, time.Until(idToken.Expiry.Add(3600*time.Second))); err != nil {
   255  		return []string{}, false, token, fmt.Errorf("cannot cache token: %s", err)
   256  	}
   257  
   258  	return attributes, false, token, nil
   259  }
   260  
   261  // VerifierType returns the type of the TokenVerifier.
   262  func (v *TokenVerifier) VerifierType() common.JWTType {
   263  	return common.OIDC
   264  }
   265  
   266  func randomSha1(nonceSourceSize int) (string, error) {
   267  	nonceSource := make([]byte, nonceSourceSize)
   268  	if _, err := rand.Read(nonceSource); err != nil {
   269  		return "", err
   270  	}
   271  	sha := sha1.Sum(nonceSource)
   272  	return base64.StdEncoding.EncodeToString(sha[:]), nil
   273  }