github.com/Axway/agent-sdk@v1.1.101/pkg/authz/oauth/authclient.go (about) 1 package oauth 2 3 import ( 4 "crypto/rsa" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "net/http" 9 "net/url" 10 "sync" 11 "time" 12 13 "github.com/Axway/agent-sdk/pkg/api" 14 "github.com/Axway/agent-sdk/pkg/util/log" 15 ) 16 17 // AuthClient - Interface representing the auth Client 18 type AuthClient interface { 19 GetToken() (string, error) 20 FetchToken(useCachedToken bool) (string, error) 21 } 22 23 // AuthClientOption - configures auth client. 24 type AuthClientOption func(*authClientOptions) 25 26 type authClientOptions struct { 27 serverName string 28 headers map[string]string 29 queryParams map[string]string 30 authenticator authenticator 31 } 32 33 // authClient - 34 type authClient struct { 35 tokenURL string 36 logger log.FieldLogger 37 apiClient api.Client 38 cachedToken *tokenResponse 39 getTokenMutex *sync.Mutex 40 options *authClientOptions 41 cachedTokenExpiry time.Time 42 } 43 44 type authenticator interface { 45 prepareRequest() (url.Values, map[string]string, error) 46 } 47 48 type tokenResponse struct { 49 AccessToken string `json:"access_token"` 50 ExpiresIn int64 `json:"expires_in"` 51 } 52 53 // NewAuthClient - create a new auth client with client options 54 func NewAuthClient(tokenURL string, apiClient api.Client, opts ...AuthClientOption) (AuthClient, error) { 55 logger := log.NewFieldLogger(). 56 WithComponent("authclient"). 57 WithPackage("sdk.agent.authz.oauth") 58 client := &authClient{ 59 tokenURL: tokenURL, 60 apiClient: apiClient, 61 getTokenMutex: &sync.Mutex{}, 62 options: &authClientOptions{}, 63 logger: logger, 64 } 65 for _, o := range opts { 66 o(client.options) 67 } 68 69 if client.options.serverName == "" { 70 client.options.serverName = defaultServerName 71 } 72 if client.options.authenticator == nil { 73 return nil, errors.New("unable to create client, no authenticator configured") 74 } 75 return client, nil 76 } 77 78 // WithServerName - sets up the server name in auth client 79 func WithServerName(serverName string) AuthClientOption { 80 return func(opt *authClientOptions) { 81 opt.serverName = serverName 82 } 83 } 84 85 // WithRequestHeaders - sets up the additional request headers in auth client 86 func WithRequestHeaders(hdr map[string]string) AuthClientOption { 87 return func(opt *authClientOptions) { 88 opt.headers = hdr 89 } 90 } 91 92 // WithQueryParams - sets up the additional query params in auth client 93 func WithQueryParams(queryParams map[string]string) AuthClientOption { 94 return func(opt *authClientOptions) { 95 opt.queryParams = queryParams 96 } 97 } 98 99 // WithClientSecretBasicAuth - sets up to use client secret basic authenticator 100 func WithClientSecretBasicAuth(clientID, clientSecret, scope string) AuthClientOption { 101 return func(opt *authClientOptions) { 102 opt.authenticator = &clientSecretBasicAuthenticator{ 103 clientID, 104 clientSecret, 105 scope, 106 } 107 } 108 } 109 110 // WithClientSecretPostAuth - sets up to use client secret authenticator 111 func WithClientSecretPostAuth(clientID, clientSecret, scope string) AuthClientOption { 112 return func(opt *authClientOptions) { 113 opt.authenticator = &clientSecretPostAuthenticator{ 114 clientID, 115 clientSecret, 116 scope, 117 } 118 } 119 } 120 121 // WithClientSecretJwtAuth - sets up to use client secret authenticator 122 func WithClientSecretJwtAuth(clientID, clientSecret, scope, issuer, aud, signingMethod string) AuthClientOption { 123 return func(opt *authClientOptions) { 124 opt.authenticator = &clientSecretJwtAuthenticator{ 125 clientID, 126 clientSecret, 127 scope, 128 issuer, 129 aud, 130 signingMethod, 131 } 132 } 133 } 134 135 // WithKeyPairAuth - sets up to use public/private key pair authenticator 136 func WithKeyPairAuth(clientID, issuer, audience string, privKey *rsa.PrivateKey, publicKey []byte, scope, signingMethod string) AuthClientOption { 137 return func(opt *authClientOptions) { 138 opt.authenticator = &keyPairAuthenticator{ 139 clientID, 140 issuer, 141 audience, 142 privKey, 143 publicKey, 144 scope, 145 signingMethod, 146 } 147 } 148 } 149 150 // WithTLSClientAuth - sets up to use tls_client_auth and self_signed_tls_client_auth authenticator 151 func WithTLSClientAuth(clientID, scope string) AuthClientOption { 152 return func(opt *authClientOptions) { 153 opt.authenticator = &tlsClientAuthenticator{ 154 clientID: clientID, 155 scope: scope, 156 } 157 } 158 } 159 160 func (c *authClient) getCachedToken() string { 161 if time.Now().After(c.cachedTokenExpiry) { 162 c.cachedToken = nil 163 } 164 if c.cachedToken != nil { 165 return c.cachedToken.AccessToken 166 } 167 return "" 168 } 169 170 // GetToken returns a token from cache if not expired or fetches a new token 171 func (c *authClient) GetToken() (string, error) { 172 return c.FetchToken(true) 173 } 174 175 // GetToken returns a token from cache if not expired or fetches a new token 176 func (c *authClient) FetchToken(useCachedToken bool) (string, error) { 177 // only one GetToken should execute at a time 178 c.getTokenMutex.Lock() 179 defer c.getTokenMutex.Unlock() 180 token := c.getCachedToken() 181 if useCachedToken && token != "" { 182 return token, nil 183 } 184 185 // try fetching a new token 186 return c.fetchNewToken() 187 } 188 189 // fetchNewToken fetches a new token from the platform and updates the token cache. 190 func (c *authClient) fetchNewToken() (string, error) { 191 tokenResponse, err := c.getOAuthTokens() 192 if err != nil { 193 return "", err 194 } 195 196 almostExpires := (tokenResponse.ExpiresIn * 4) / 5 197 198 c.cachedToken = tokenResponse 199 c.cachedTokenExpiry = time.Now().Add(time.Duration(almostExpires) * time.Second) 200 return c.cachedToken.AccessToken, nil 201 } 202 203 func (c *authClient) getOAuthTokens() (*tokenResponse, error) { 204 req, headers, err := c.options.authenticator.prepareRequest() 205 if err != nil { 206 return nil, err 207 } 208 209 resp, err := c.postAuthForm(req, headers) 210 if err != nil { 211 return nil, err 212 } 213 214 if resp.Code != 200 { 215 err := fmt.Errorf("bad response from %s: %d %s", c.options.serverName, resp.Code, http.StatusText(resp.Code)) 216 c.logger. 217 WithField("server", c.options.serverName). 218 WithField("url", c.tokenURL). 219 WithField("status", resp.Code). 220 WithField("body", string(resp.Body)). 221 WithError(err). 222 Debug(err.Error()) 223 return nil, err 224 } 225 226 tokens := tokenResponse{} 227 if err := json.Unmarshal(resp.Body, &tokens); err != nil { 228 return nil, fmt.Errorf("unable to unmarshal token: %v", err) 229 } 230 231 return &tokens, nil 232 } 233 234 func (c *authClient) postAuthForm(data url.Values, headers map[string]string) (resp *api.Response, err error) { 235 reqHeaders := map[string]string{ 236 hdrContentType: mimeApplicationFormURLEncoded, 237 } 238 for name, value := range c.options.headers { 239 reqHeaders[name] = value 240 } 241 for name, value := range headers { 242 reqHeaders[name] = value 243 } 244 req := api.Request{ 245 Method: api.POST, 246 URL: c.tokenURL, 247 Body: []byte(data.Encode()), 248 Headers: reqHeaders, 249 QueryParams: c.options.queryParams, 250 } 251 return c.apiClient.Send(req) 252 }