github.com/loggregator/cli@v6.33.1-0.20180224010324-82334f081791+incompatible/api/uaa/wrapper/uaa_authentication.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"bytes"
     5  	"io/ioutil"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"code.cloudfoundry.org/cli/api/uaa"
    10  )
    11  
    12  //go:generate counterfeiter . UAAClient
    13  
    14  // UAAClient is the interface for getting a valid access token
    15  type UAAClient interface {
    16  	RefreshAccessToken(refreshToken string) (uaa.RefreshedTokens, error)
    17  }
    18  
    19  //go:generate counterfeiter . TokenCache
    20  
    21  // TokenCache is where the UAA token information is stored.
    22  type TokenCache interface {
    23  	AccessToken() string
    24  	RefreshToken() string
    25  	SetAccessToken(token string)
    26  	SetRefreshToken(token string)
    27  }
    28  
    29  // UAAAuthentication wraps connections and adds authentication headers to all
    30  // requests
    31  type UAAAuthentication struct {
    32  	connection uaa.Connection
    33  	client     UAAClient
    34  	cache      TokenCache
    35  }
    36  
    37  // NewUAAAuthentication returns a pointer to a UAAAuthentication wrapper with
    38  // the client and token cache.
    39  func NewUAAAuthentication(client UAAClient, cache TokenCache) *UAAAuthentication {
    40  	return &UAAAuthentication{
    41  		client: client,
    42  		cache:  cache,
    43  	}
    44  }
    45  
    46  // Wrap sets the connection on the UAAAuthentication and returns itself
    47  func (t *UAAAuthentication) Wrap(innerconnection uaa.Connection) uaa.Connection {
    48  	t.connection = innerconnection
    49  	return t
    50  }
    51  
    52  // SetClient sets the UAA client that the wrapper will use.
    53  func (t *UAAAuthentication) SetClient(client UAAClient) {
    54  	t.client = client
    55  }
    56  
    57  // Make adds authentication headers to the passed in request and then calls the
    58  // wrapped connection's Make
    59  func (t *UAAAuthentication) Make(request *http.Request, passedResponse *uaa.Response) error {
    60  	if t.client == nil {
    61  		return t.connection.Make(request, passedResponse)
    62  	}
    63  
    64  	var err error
    65  	var rawRequestBody []byte
    66  
    67  	if request.Body != nil {
    68  		rawRequestBody, err = ioutil.ReadAll(request.Body)
    69  		defer request.Body.Close()
    70  		if err != nil {
    71  			return err
    72  		}
    73  
    74  		request.Body = ioutil.NopCloser(bytes.NewBuffer(rawRequestBody))
    75  
    76  		if skipAuthenticationHeader(request, rawRequestBody) {
    77  			return t.connection.Make(request, passedResponse)
    78  		}
    79  	}
    80  
    81  	request.Header.Set("Authorization", t.cache.AccessToken())
    82  
    83  	err = t.connection.Make(request, passedResponse)
    84  	if _, ok := err.(uaa.InvalidAuthTokenError); ok {
    85  		tokens, refreshErr := t.client.RefreshAccessToken(t.cache.RefreshToken())
    86  		if refreshErr != nil {
    87  			return refreshErr
    88  		}
    89  
    90  		t.cache.SetAccessToken(tokens.AuthorizationToken())
    91  		t.cache.SetRefreshToken(tokens.RefreshToken)
    92  
    93  		if rawRequestBody != nil {
    94  			request.Body = ioutil.NopCloser(bytes.NewBuffer(rawRequestBody))
    95  		}
    96  		request.Header.Set("Authorization", t.cache.AccessToken())
    97  		return t.connection.Make(request, passedResponse)
    98  	}
    99  
   100  	return err
   101  }
   102  
   103  // The authentication header is not added to token refresh requests or login
   104  // requests.
   105  func skipAuthenticationHeader(request *http.Request, body []byte) bool {
   106  	stringBody := string(body)
   107  
   108  	return strings.Contains(request.URL.String(), "/oauth/token") &&
   109  		request.Method == http.MethodPost &&
   110  		(strings.Contains(stringBody, "grant_type=refresh_token") ||
   111  			strings.Contains(stringBody, "grant_type=password") ||
   112  			strings.Contains(stringBody, "grant_type=client_credentials"))
   113  }