github.com/wanddynosios/cli/v8@v8.7.9-0.20240221182337-1a92e3a7017f/api/uaa/wrapper/uaa_authentication.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"bytes"
     5  	"io/ioutil"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"code.cloudfoundry.org/cli/api/uaa"
    10  )
    11  
    12  //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 . UAAClient
    13  
    14  // UAAClient is the interface for getting a valid access token
    15  type UAAClient interface {
    16  	RefreshAccessToken(refreshToken string) (uaa.RefreshedTokens, error)
    17  }
    18  
    19  //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 . TokenCache
    20  
    21  // TokenCache is where the UAA token information is stored.
    22  type TokenCache interface {
    23  	AccessToken() string
    24  	RefreshToken() string
    25  	SetAccessToken(token string)
    26  	SetRefreshToken(token string)
    27  }
    28  
    29  // UAAAuthentication wraps connections and adds authentication headers to all
    30  // requests
    31  type UAAAuthentication struct {
    32  	connection uaa.Connection
    33  	client     UAAClient
    34  	cache      TokenCache
    35  }
    36  
    37  // NewUAAAuthentication returns a pointer to a UAAAuthentication wrapper with
    38  // the client and token cache.
    39  func NewUAAAuthentication(client UAAClient, cache TokenCache) *UAAAuthentication {
    40  	return &UAAAuthentication{
    41  		client: client,
    42  		cache:  cache,
    43  	}
    44  }
    45  
    46  // Make adds authentication headers to the passed in request and then calls the
    47  // wrapped connection's Make
    48  func (t *UAAAuthentication) Make(request *http.Request, passedResponse *uaa.Response) error {
    49  	if t.client == nil {
    50  		return t.connection.Make(request, passedResponse)
    51  	}
    52  
    53  	var err error
    54  	var rawRequestBody []byte
    55  
    56  	if request.Body != nil {
    57  		rawRequestBody, err = ioutil.ReadAll(request.Body)
    58  		defer request.Body.Close()
    59  		if err != nil {
    60  			return err
    61  		}
    62  
    63  		request.Body = ioutil.NopCloser(bytes.NewBuffer(rawRequestBody))
    64  
    65  		if skipAuthenticationHeader(request, rawRequestBody) {
    66  			return t.connection.Make(request, passedResponse)
    67  		}
    68  	}
    69  
    70  	request.Header.Set("Authorization", t.cache.AccessToken())
    71  
    72  	err = t.connection.Make(request, passedResponse)
    73  	if _, ok := err.(uaa.InvalidAuthTokenError); ok {
    74  		tokens, refreshErr := t.client.RefreshAccessToken(t.cache.RefreshToken())
    75  		if refreshErr != nil {
    76  			return refreshErr
    77  		}
    78  
    79  		t.cache.SetAccessToken(tokens.AuthorizationToken())
    80  		t.cache.SetRefreshToken(tokens.RefreshToken)
    81  
    82  		if rawRequestBody != nil {
    83  			request.Body = ioutil.NopCloser(bytes.NewBuffer(rawRequestBody))
    84  		}
    85  		request.Header.Set("Authorization", t.cache.AccessToken())
    86  		return t.connection.Make(request, passedResponse)
    87  	}
    88  
    89  	return err
    90  }
    91  
    92  // SetClient sets the UAA client that the wrapper will use.
    93  func (t *UAAAuthentication) SetClient(client UAAClient) {
    94  	t.client = client
    95  }
    96  
    97  // Wrap sets the connection on the UAAAuthentication and returns itself
    98  func (t *UAAAuthentication) Wrap(innerconnection uaa.Connection) uaa.Connection {
    99  	t.connection = innerconnection
   100  	return t
   101  }
   102  
   103  // The authentication header is not added to token refresh requests or login
   104  // requests.
   105  func skipAuthenticationHeader(request *http.Request, body []byte) bool {
   106  	stringBody := string(body)
   107  
   108  	return strings.Contains(request.URL.String(), "/oauth/token") &&
   109  		request.Method == http.MethodPost &&
   110  		(strings.Contains(stringBody, "grant_type=refresh_token") ||
   111  			strings.Contains(stringBody, "grant_type=password") ||
   112  			strings.Contains(stringBody, "grant_type=client_credentials"))
   113  }