github.com/alwitt/goutils@v0.6.4/oauth.go (about)

     1  package goutils
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"reflect"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/apex/log"
    12  	"github.com/go-playground/validator/v10"
    13  	"github.com/go-resty/resty/v2"
    14  )
    15  
    16  // openIDIssuerConfig holds the OpenID issuer's API info.
    17  //
    18  // This is typically read from http://{{ OpenID issuer }}/.well-known/openid-configuration.
    19  type openIDIssuerConfig struct {
    20  	Issuer               string   `json:"issuer"`
    21  	AuthorizationEP      string   `json:"authorization_endpoint"`
    22  	TokenEP              string   `json:"token_endpoint"`
    23  	IntrospectionEP      string   `json:"introspection_endpoint"`
    24  	TokenIntrospectionEP string   `json:"token_introspection_endpoint"`
    25  	UserinfoEP           string   `json:"userinfo_endpoint"`
    26  	EndSessionEP         string   `json:"end_session_endpoint"`
    27  	JwksURI              string   `json:"jwks_uri"`
    28  	ClientRegistrationEP string   `json:"registration_endpoint"`
    29  	RevocationEP         string   `json:"revocation_endpoint"`
    30  	TokenEPAuthMethods   []string `json:"token_endpoint_auth_methods_supported"`
    31  	ClaimsSupported      []string `json:"claims_supported"`
    32  }
    33  
    34  // OAuthTokenManager Oauth token manager handles fetching and refreshing of OAuth tokens
    35  type OAuthTokenManager interface {
    36  	/*
    37  		GetToken fetch the current valid OAuth token
    38  
    39  		 @param ctxt context.Context - the execution context
    40  		 @param timestamp time.Time - the current timestamp
    41  		 @returns the token
    42  	*/
    43  	GetToken(ctxt context.Context, timestamp time.Time) (string, error)
    44  
    45  	/*
    46  		Stop stop any support background tasks which were started
    47  
    48  		 @param ctxt context.Context - execution context
    49  	*/
    50  	Stop(ctxt context.Context) error
    51  }
    52  
    53  // clientCredOAuthTokenManager client credential flow oauth token manager
    54  type clientCredOAuthTokenManager struct {
    55  	Component
    56  	httpClient       *resty.Client
    57  	tasks            TaskProcessor
    58  	clientID         string
    59  	clientSecret     string
    60  	tokenAudience    string
    61  	idpConfig        openIDIssuerConfig
    62  	workerCtxt       context.Context
    63  	workerCtxtCancel context.CancelFunc
    64  	token            *string
    65  	tokenExpire      time.Time
    66  	wg               sync.WaitGroup
    67  	validate         *validator.Validate
    68  }
    69  
    70  // ClientCredOAuthTokenManagerParam configuration for client credential flow oauth token manager
    71  type ClientCredOAuthTokenManagerParam struct {
    72  	// IDPIssuerURL OpenID provider issuing URL
    73  	IDPIssuerURL string `validate:"required,url"`
    74  	// ClientID OAuth client ID
    75  	ClientID string `validate:"required"`
    76  	// ClientSecret OAuth client secret
    77  	ClientSecret string `validate:"required"`
    78  	// TargetAudience the token's target audience
    79  	TargetAudience string `validate:"required,url"`
    80  	// LogTags metadata fields to include in the logs
    81  	LogTags log.Fields
    82  	// CustomLogModifiers additional log metadata modifiers to use
    83  	CustomLogModifiers []LogMetadataModifier
    84  }
    85  
    86  /*
    87  GetNewClientCredOAuthTokenManager get client credential flow oauth token manager
    88  
    89  	@param parentCtxt context.Context - parent context
    90  	@param httpClient *resty.Client - use this HTTP client to interact with the IDP
    91  	@param params ClientCredOAuthTokenManagerParam - configuration for the token manager
    92  	@returns new OAuthTokenManager instance
    93  */
    94  func GetNewClientCredOAuthTokenManager(
    95  	parentCtxt context.Context,
    96  	httpClient *resty.Client,
    97  	params ClientCredOAuthTokenManagerParam,
    98  ) (OAuthTokenManager, error) {
    99  	validate := validator.New()
   100  	if err := validate.Struct(&params); err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	params.LogTags["idp-issuer"] = params.IDPIssuerURL
   105  	params.LogTags["oauth-client"] = params.ClientID
   106  
   107  	workerCtxt, workerCtxtCancel := context.WithCancel(parentCtxt)
   108  
   109  	// -----------------------------------------------------------------------------------------
   110  	// Query the OpenID provider config first
   111  	var idpConfig openIDIssuerConfig
   112  	idpCfgEP := fmt.Sprintf("%s/.well-known/openid-configuration", params.IDPIssuerURL)
   113  	log.WithFields(params.LogTags).Infof("Fetching IDP config at %s", idpCfgEP)
   114  	resp, err := httpClient.R().SetResult(&idpConfig).Get(idpCfgEP)
   115  	if err != nil {
   116  		log.WithError(err).WithFields(params.LogTags).Error("Failed to read IDP config")
   117  		workerCtxtCancel()
   118  		return nil, err
   119  	}
   120  	if !resp.IsSuccess() {
   121  		err := fmt.Errorf("got status code %d when reading IDP config", resp.StatusCode())
   122  		log.WithError(err).WithFields(params.LogTags).Error("Failed to read IDP config")
   123  		workerCtxtCancel()
   124  		return nil, err
   125  	}
   126  	{
   127  		t, _ := json.Marshal(&idpConfig)
   128  		log.WithFields(params.LogTags).Debugf("OpenID config: %s", t)
   129  	}
   130  
   131  	// -----------------------------------------------------------------------------------------
   132  	// Prepare instance
   133  
   134  	instance := &clientCredOAuthTokenManager{
   135  		Component: Component{
   136  			LogTags:         params.LogTags,
   137  			LogTagModifiers: []LogMetadataModifier{modifyLogMetadataByRRRequestParam},
   138  		},
   139  		httpClient:       httpClient,
   140  		idpConfig:        idpConfig,
   141  		clientID:         params.ClientID,
   142  		clientSecret:     params.ClientSecret,
   143  		tokenAudience:    params.TargetAudience,
   144  		workerCtxt:       workerCtxt,
   145  		workerCtxtCancel: workerCtxtCancel,
   146  		token:            nil,
   147  		tokenExpire:      time.Time{},
   148  		wg:               sync.WaitGroup{},
   149  		validate:         validate,
   150  	}
   151  
   152  	// Add additional log tag modifiers
   153  	instance.LogTagModifiers = append(instance.LogTagModifiers, params.CustomLogModifiers...)
   154  
   155  	// -----------------------------------------------------------------------------------------
   156  	// Define worker task
   157  
   158  	workerLogTags := log.Fields{}
   159  	for lKey, lVal := range params.LogTags {
   160  		workerLogTags[lKey] = lVal
   161  	}
   162  	workerLogTags["sub-module"] = "core-worker"
   163  	worker, err := GetNewTaskProcessorInstance(workerCtxt, "core-worker", 8, workerLogTags)
   164  	if err != nil {
   165  		log.WithError(err).WithFields(params.LogTags).Error("Unable to define worker")
   166  		return nil, err
   167  	}
   168  	instance.tasks = worker
   169  
   170  	// -----------------------------------------------------------------------------------------
   171  	// Define support tasks
   172  
   173  	if err := worker.AddToTaskExecutionMap(
   174  		reflect.TypeOf(getTokenRequest{}), instance.processGetToken,
   175  	); err != nil {
   176  		log.WithError(err).WithFields(params.LogTags).Error("Unable to install task definition")
   177  		return nil, err
   178  	}
   179  
   180  	// -----------------------------------------------------------------------------------------
   181  	// Start worker
   182  	if err := worker.StartEventLoop(&instance.wg); err != nil {
   183  		log.WithError(err).WithFields(params.LogTags).Error("Unable to start support worker")
   184  		return nil, err
   185  	}
   186  
   187  	return instance, nil
   188  }
   189  
   190  type getTokenRequest struct {
   191  	timestamp time.Time
   192  	resultCB  func(string)
   193  	errorCB   func(error)
   194  }
   195  
   196  func (c *clientCredOAuthTokenManager) GetToken(
   197  	ctxt context.Context, timestamp time.Time,
   198  ) (string, error) {
   199  	logTags := c.GetLogTagsForContext(ctxt)
   200  
   201  	resultChan := make(chan string, 1)
   202  	errorChan := make(chan error, 1)
   203  
   204  	resultCB := func(token string) {
   205  		resultChan <- token
   206  	}
   207  	errorCB := func(err error) {
   208  		errorChan <- err
   209  	}
   210  
   211  	// Make the request
   212  	request := getTokenRequest{timestamp: timestamp, resultCB: resultCB, errorCB: errorCB}
   213  	log.WithFields(logTags).Debug("Submitting 'GetToken' job")
   214  	if err := c.tasks.Submit(ctxt, request); err != nil {
   215  		log.WithError(err).WithFields(logTags).Error("Failed to submit 'GetToken' job")
   216  		return "", err
   217  	}
   218  	log.WithFields(logTags).Debug("Submitted 'GetToken' job. AWaiting response")
   219  
   220  	select {
   221  	case <-ctxt.Done():
   222  		err := fmt.Errorf("request timed out waiting for response")
   223  		log.WithError(err).WithFields(logTags).Error("Unable to get current active token")
   224  		return "", err
   225  	case err, ok := <-errorChan:
   226  		if !ok {
   227  			err = fmt.Errorf("error channel failure")
   228  		}
   229  		log.WithError(err).WithFields(logTags).Error("Unable to get current active token")
   230  		return "", err
   231  	case token, ok := <-resultChan:
   232  		if !ok {
   233  			err := fmt.Errorf("result channel failure")
   234  			log.WithError(err).WithFields(logTags).Error("Unable to get current active token")
   235  			return "", err
   236  		}
   237  		return token, nil
   238  	}
   239  }
   240  
   241  func (c *clientCredOAuthTokenManager) processGetToken(params interface{}) error {
   242  	// Convert params into expected data type
   243  	if requestParams, ok := params.(getTokenRequest); ok {
   244  		return c.handleGetToken(requestParams)
   245  	}
   246  	err := fmt.Errorf("received unexpected call parameters: %s", reflect.TypeOf(params))
   247  	logTags := c.GetLogTagsForContext(c.workerCtxt)
   248  	log.WithError(err).WithFields(logTags).Error("'GetToken' processing failure")
   249  	return err
   250  }
   251  
   252  func (c *clientCredOAuthTokenManager) handleGetToken(params getTokenRequest) error {
   253  	logTags := c.GetLogTagsForContext(c.workerCtxt)
   254  
   255  	if c.token == nil || c.tokenExpire.Before(params.timestamp) {
   256  		log.WithFields(logTags).Debug("Fetching new token")
   257  
   258  		// Get new token
   259  		buildRequest := map[string]interface{}{
   260  			"client_id":     c.clientID,
   261  			"client_secret": c.clientSecret,
   262  			"audience":      c.tokenAudience,
   263  			"grant_type":    "client_credentials",
   264  		}
   265  
   266  		// Make the request
   267  		type tokenResp struct {
   268  			Token string `json:"access_token" validate:"required"`
   269  			TTL   uint32 `json:"expires_in" validate:"gte=0"`
   270  		}
   271  
   272  		var newToken tokenResp
   273  		resp, err := c.httpClient.
   274  			R().
   275  			SetBody(&buildRequest).
   276  			SetResult(&newToken).
   277  			Post(c.idpConfig.TokenEP)
   278  		if err != nil {
   279  			log.WithError(err).WithFields(logTags).Error("Token fetch failure")
   280  			params.errorCB(err)
   281  			return err
   282  		}
   283  		if !resp.IsSuccess() {
   284  			err := fmt.Errorf("token fetch returned status code %d", resp.StatusCode())
   285  			log.WithError(err).WithFields(logTags).Error("Token fetch failure")
   286  			params.errorCB(err)
   287  			return err
   288  		}
   289  
   290  		log.WithFields(logTags).Debugf("Token response is %s", resp.Body())
   291  
   292  		if err := c.validate.Struct(&newToken); err != nil {
   293  			log.WithError(err).WithFields(logTags).Error("Invalid token response")
   294  			params.errorCB(err)
   295  			return err
   296  		}
   297  
   298  		// Compute when the token will expire
   299  		expireAt := params.timestamp.Add(time.Second * time.Duration(newToken.TTL))
   300  
   301  		// Store token
   302  		c.token = &newToken.Token
   303  		c.tokenExpire = expireAt
   304  
   305  		log.WithFields(logTags).Debugf("New token expires at '%s'", expireAt)
   306  	} else {
   307  		log.WithFields(logTags).Debug("Reusing existing token")
   308  	}
   309  
   310  	// Return current active token
   311  	params.resultCB(*c.token)
   312  
   313  	return nil
   314  }
   315  
   316  func (c *clientCredOAuthTokenManager) Stop(ctxt context.Context) error {
   317  	c.workerCtxtCancel()
   318  
   319  	return c.tasks.StopEventLoop()
   320  }