github.com/Axway/agent-sdk@v1.1.101/pkg/authz/oauth/authclient.go (about)

     1  package oauth
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"net/url"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/Axway/agent-sdk/pkg/api"
    14  	"github.com/Axway/agent-sdk/pkg/util/log"
    15  )
    16  
    17  // AuthClient - Interface representing the auth Client
    18  type AuthClient interface {
    19  	GetToken() (string, error)
    20  	FetchToken(useCachedToken bool) (string, error)
    21  }
    22  
    23  // AuthClientOption - configures auth client.
    24  type AuthClientOption func(*authClientOptions)
    25  
    26  type authClientOptions struct {
    27  	serverName    string
    28  	headers       map[string]string
    29  	queryParams   map[string]string
    30  	authenticator authenticator
    31  }
    32  
    33  // authClient -
    34  type authClient struct {
    35  	tokenURL          string
    36  	logger            log.FieldLogger
    37  	apiClient         api.Client
    38  	cachedToken       *tokenResponse
    39  	getTokenMutex     *sync.Mutex
    40  	options           *authClientOptions
    41  	cachedTokenExpiry time.Time
    42  }
    43  
    44  type authenticator interface {
    45  	prepareRequest() (url.Values, map[string]string, error)
    46  }
    47  
    48  type tokenResponse struct {
    49  	AccessToken string `json:"access_token"`
    50  	ExpiresIn   int64  `json:"expires_in"`
    51  }
    52  
    53  // NewAuthClient - create a new auth client with client options
    54  func NewAuthClient(tokenURL string, apiClient api.Client, opts ...AuthClientOption) (AuthClient, error) {
    55  	logger := log.NewFieldLogger().
    56  		WithComponent("authclient").
    57  		WithPackage("sdk.agent.authz.oauth")
    58  	client := &authClient{
    59  		tokenURL:      tokenURL,
    60  		apiClient:     apiClient,
    61  		getTokenMutex: &sync.Mutex{},
    62  		options:       &authClientOptions{},
    63  		logger:        logger,
    64  	}
    65  	for _, o := range opts {
    66  		o(client.options)
    67  	}
    68  
    69  	if client.options.serverName == "" {
    70  		client.options.serverName = defaultServerName
    71  	}
    72  	if client.options.authenticator == nil {
    73  		return nil, errors.New("unable to create client, no authenticator configured")
    74  	}
    75  	return client, nil
    76  }
    77  
    78  // WithServerName - sets up the server name in auth client
    79  func WithServerName(serverName string) AuthClientOption {
    80  	return func(opt *authClientOptions) {
    81  		opt.serverName = serverName
    82  	}
    83  }
    84  
    85  // WithRequestHeaders - sets up the additional request headers in auth client
    86  func WithRequestHeaders(hdr map[string]string) AuthClientOption {
    87  	return func(opt *authClientOptions) {
    88  		opt.headers = hdr
    89  	}
    90  }
    91  
    92  // WithQueryParams - sets up the additional query params in auth client
    93  func WithQueryParams(queryParams map[string]string) AuthClientOption {
    94  	return func(opt *authClientOptions) {
    95  		opt.queryParams = queryParams
    96  	}
    97  }
    98  
    99  // WithClientSecretBasicAuth - sets up to use client secret basic authenticator
   100  func WithClientSecretBasicAuth(clientID, clientSecret, scope string) AuthClientOption {
   101  	return func(opt *authClientOptions) {
   102  		opt.authenticator = &clientSecretBasicAuthenticator{
   103  			clientID,
   104  			clientSecret,
   105  			scope,
   106  		}
   107  	}
   108  }
   109  
   110  // WithClientSecretPostAuth - sets up to use client secret authenticator
   111  func WithClientSecretPostAuth(clientID, clientSecret, scope string) AuthClientOption {
   112  	return func(opt *authClientOptions) {
   113  		opt.authenticator = &clientSecretPostAuthenticator{
   114  			clientID,
   115  			clientSecret,
   116  			scope,
   117  		}
   118  	}
   119  }
   120  
   121  // WithClientSecretJwtAuth - sets up to use client secret authenticator
   122  func WithClientSecretJwtAuth(clientID, clientSecret, scope, issuer, aud, signingMethod string) AuthClientOption {
   123  	return func(opt *authClientOptions) {
   124  		opt.authenticator = &clientSecretJwtAuthenticator{
   125  			clientID,
   126  			clientSecret,
   127  			scope,
   128  			issuer,
   129  			aud,
   130  			signingMethod,
   131  		}
   132  	}
   133  }
   134  
   135  // WithKeyPairAuth - sets up to use public/private key pair authenticator
   136  func WithKeyPairAuth(clientID, issuer, audience string, privKey *rsa.PrivateKey, publicKey []byte, scope, signingMethod string) AuthClientOption {
   137  	return func(opt *authClientOptions) {
   138  		opt.authenticator = &keyPairAuthenticator{
   139  			clientID,
   140  			issuer,
   141  			audience,
   142  			privKey,
   143  			publicKey,
   144  			scope,
   145  			signingMethod,
   146  		}
   147  	}
   148  }
   149  
   150  // WithTLSClientAuth - sets up to use tls_client_auth and self_signed_tls_client_auth authenticator
   151  func WithTLSClientAuth(clientID, scope string) AuthClientOption {
   152  	return func(opt *authClientOptions) {
   153  		opt.authenticator = &tlsClientAuthenticator{
   154  			clientID: clientID,
   155  			scope:    scope,
   156  		}
   157  	}
   158  }
   159  
   160  func (c *authClient) getCachedToken() string {
   161  	if time.Now().After(c.cachedTokenExpiry) {
   162  		c.cachedToken = nil
   163  	}
   164  	if c.cachedToken != nil {
   165  		return c.cachedToken.AccessToken
   166  	}
   167  	return ""
   168  }
   169  
   170  // GetToken returns a token from cache if not expired or fetches a new token
   171  func (c *authClient) GetToken() (string, error) {
   172  	return c.FetchToken(true)
   173  }
   174  
   175  // GetToken returns a token from cache if not expired or fetches a new token
   176  func (c *authClient) FetchToken(useCachedToken bool) (string, error) {
   177  	// only one GetToken should execute at a time
   178  	c.getTokenMutex.Lock()
   179  	defer c.getTokenMutex.Unlock()
   180  	token := c.getCachedToken()
   181  	if useCachedToken && token != "" {
   182  		return token, nil
   183  	}
   184  
   185  	// try fetching a new token
   186  	return c.fetchNewToken()
   187  }
   188  
   189  // fetchNewToken fetches a new token from the platform and updates the token cache.
   190  func (c *authClient) fetchNewToken() (string, error) {
   191  	tokenResponse, err := c.getOAuthTokens()
   192  	if err != nil {
   193  		return "", err
   194  	}
   195  
   196  	almostExpires := (tokenResponse.ExpiresIn * 4) / 5
   197  
   198  	c.cachedToken = tokenResponse
   199  	c.cachedTokenExpiry = time.Now().Add(time.Duration(almostExpires) * time.Second)
   200  	return c.cachedToken.AccessToken, nil
   201  }
   202  
   203  func (c *authClient) getOAuthTokens() (*tokenResponse, error) {
   204  	req, headers, err := c.options.authenticator.prepareRequest()
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	resp, err := c.postAuthForm(req, headers)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  
   214  	if resp.Code != 200 {
   215  		err := fmt.Errorf("bad response from %s: %d %s", c.options.serverName, resp.Code, http.StatusText(resp.Code))
   216  		c.logger.
   217  			WithField("server", c.options.serverName).
   218  			WithField("url", c.tokenURL).
   219  			WithField("status", resp.Code).
   220  			WithField("body", string(resp.Body)).
   221  			WithError(err).
   222  			Debug(err.Error())
   223  		return nil, err
   224  	}
   225  
   226  	tokens := tokenResponse{}
   227  	if err := json.Unmarshal(resp.Body, &tokens); err != nil {
   228  		return nil, fmt.Errorf("unable to unmarshal token: %v", err)
   229  	}
   230  
   231  	return &tokens, nil
   232  }
   233  
   234  func (c *authClient) postAuthForm(data url.Values, headers map[string]string) (resp *api.Response, err error) {
   235  	reqHeaders := map[string]string{
   236  		hdrContentType: mimeApplicationFormURLEncoded,
   237  	}
   238  	for name, value := range c.options.headers {
   239  		reqHeaders[name] = value
   240  	}
   241  	for name, value := range headers {
   242  		reqHeaders[name] = value
   243  	}
   244  	req := api.Request{
   245  		Method:      api.POST,
   246  		URL:         c.tokenURL,
   247  		Body:        []byte(data.Encode()),
   248  		Headers:     reqHeaders,
   249  		QueryParams: c.options.queryParams,
   250  	}
   251  	return c.apiClient.Send(req)
   252  }