github.com/govau/cf-common@v0.0.7/uaa/token_validator.go (about)

     1  package uaa
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/json"
     9  	"errors"
    10  	"net/http"
    11  	"net/url"
    12  	"sync"
    13  
    14  	jwt "github.com/dgrijalva/jwt-go"
    15  )
    16  
    17  // OAuthGrant used to parse JSON for an access token from UAA server.
    18  type OAuthGrant struct {
    19  	AccessToken  string `json:"access_token"`
    20  	TokenType    string `json:"token_type"`
    21  	ExpiresIn    int    `json:"expires_in"`
    22  	Scope        string `json:"scope"`
    23  	RefreshToken string `json:"refresh_token"`
    24  	JTI          string `json:"jti"`
    25  }
    26  
    27  // FetchAccessToken sends data to endpoint to fetch a token and returns a grant
    28  // object.
    29  func (c *Client) FetchAccessToken(postData url.Values) (*OAuthGrant, error) {
    30  	err := c.init()
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	req, err := http.NewRequest(http.MethodPost, c.GetTokenEndpoint(), bytes.NewReader([]byte(postData.Encode())))
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	// Older versions of CF require this to be set via header, not in POST data
    40  	// WONTFIX: we query escape these per OAuth spec. Apparently UAA does not -
    41  	// might cause an issue if they don't fix their end.
    42  	req.SetBasicAuth(url.QueryEscape(c.ClientID), url.QueryEscape(c.ClientSecret))
    43  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
    44  
    45  	resp, err := c.uaaHTTPClient.Do(req)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	if resp.StatusCode != http.StatusOK {
    51  		return nil, errors.New("bad status code")
    52  	}
    53  
    54  	var og OAuthGrant
    55  	err = json.NewDecoder(resp.Body).Decode(&og)
    56  	resp.Body.Close()
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	return &og, nil
    62  }
    63  
    64  // Client will validate access tokens against a UAA instance, caching keys as
    65  // required.
    66  type Client struct {
    67  	// URL is the URL to UAA, e.g. https://uaa.system.example.com.
    68  	URL string
    69  
    70  	// Used for authorize redirects, and issuer validation
    71  	ExternalURL string
    72  
    73  	ClientID     string
    74  	ClientSecret string
    75  
    76  	// If specified, used in instead of system CAs
    77  	CACerts []string
    78  
    79  	// cachedKeysMu protects cachedKeys.
    80  	cachedKeysMu sync.RWMutex
    81  
    82  	// cachedKeys is the public key map.
    83  	cachedKeys map[string]*rsa.PublicKey
    84  
    85  	initMutex     sync.Mutex
    86  	inited        bool
    87  	uaaHTTPClient *http.Client
    88  }
    89  
    90  func (c *Client) init() error {
    91  	c.initMutex.Lock()
    92  	defer c.initMutex.Unlock()
    93  
    94  	if c.inited {
    95  		return nil
    96  	}
    97  
    98  	if len(c.CACerts) == 0 {
    99  		c.uaaHTTPClient = http.DefaultClient
   100  	} else {
   101  		uaaCaCertPool := x509.NewCertPool()
   102  		for _, ca := range c.CACerts {
   103  			ok := uaaCaCertPool.AppendCertsFromPEM([]byte(ca))
   104  			if !ok {
   105  				return errors.New("AppendCertsFromPEM was not ok")
   106  			}
   107  		}
   108  		uaaTLS := &tls.Config{RootCAs: uaaCaCertPool}
   109  		uaaTLS.BuildNameToCertificate()
   110  		c.uaaHTTPClient = &http.Client{Transport: &http.Transport{TLSClientConfig: uaaTLS}}
   111  	}
   112  
   113  	c.inited = true
   114  	return nil
   115  }
   116  
   117  // NewClientFromAPIURL looks up, via the apiEndpoint, the correct UAA address
   118  // and returns a client.
   119  func NewClientFromAPIURL(apiEndpoint string) (*Client, error) {
   120  	resp, err := http.Get(apiEndpoint)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	defer resp.Body.Close()
   125  
   126  	var m struct {
   127  		Links struct {
   128  			UAA struct {
   129  				URL string `json:"href"`
   130  			} `json:"uaa"`
   131  		} `json:"links"`
   132  	}
   133  	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	u := m.Links.UAA.URL
   138  	if u == "" {
   139  		return nil, errors.New("no uaa URL returned")
   140  	}
   141  
   142  	return &Client{
   143  		URL:         u,
   144  		ExternalURL: u,
   145  	}, nil
   146  }
   147  
   148  func (c *Client) GetAuthorizeEndpoint() string {
   149  	return c.ExternalURL + "/oauth/authorize"
   150  }
   151  
   152  func (c *Client) GetTokenEndpoint() string {
   153  	return c.URL + "/oauth/token"
   154  }
   155  
   156  // ExchangeBearerTokenForClientToken takes a bearer token (such as that returned
   157  // by CF), and exchanges via the API auth flow, for an OAuthGrant for the
   158  // specified clientID. The clientSecret here is really not a secret.
   159  func (c *Client) ExchangeBearerTokenForClientToken(bearerLine string) (*OAuthGrant, error) {
   160  	err := c.init()
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	req, err := http.NewRequest(http.MethodPost, c.GetAuthorizeEndpoint(), bytes.NewReader([]byte(url.Values{
   166  		"client_id":     {c.ClientID},
   167  		"response_type": {"code"},
   168  	}.Encode())))
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   173  	req.Header.Set("Authorization", bearerLine)
   174  
   175  	hc := &http.Client{
   176  		CheckRedirect: func(r *http.Request, via []*http.Request) error {
   177  			return http.ErrUseLastResponse
   178  		},
   179  	}
   180  	if c.uaaHTTPClient != nil {
   181  		hc.Transport = c.uaaHTTPClient.Transport
   182  	}
   183  
   184  	resp, err := hc.Do(req)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	defer resp.Body.Close()
   189  	if resp.StatusCode != http.StatusFound {
   190  		return nil, errors.New("expected 302 back from UAA")
   191  	}
   192  	u, err := url.Parse(resp.Header.Get("Location"))
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  	authCode := u.Query().Get("code")
   197  	if authCode == "" {
   198  		return nil, errors.New("expected auth code back from UAA")
   199  	}
   200  
   201  	return c.FetchAccessToken(url.Values(map[string][]string{
   202  		"response_type": {"token"},
   203  		"grant_type":    {"authorization_code"},
   204  		"code":          {authCode},
   205  	}))
   206  }
   207  
   208  // pubKeyForID returns public key for a given key ID, if we have it, else nil
   209  // is returned.
   210  func (c *Client) pubKeyForID(kid string) *rsa.PublicKey {
   211  	c.cachedKeysMu.RLock()
   212  	defer c.cachedKeysMu.RUnlock()
   213  
   214  	if c.cachedKeys == nil {
   215  		return nil
   216  	}
   217  
   218  	rv, ok := c.cachedKeys[kid]
   219  	if !ok {
   220  		return nil
   221  	}
   222  
   223  	return rv
   224  }
   225  
   226  // fetchAndSaveLatestKey contacts UAA to fetch latest public key, and if it
   227  // matches the key ID requested, then return it, else an error will be returned.
   228  func (c *Client) fetchAndSaveLatestKey(idWanted string) (*rsa.PublicKey, error) {
   229  	err := c.init()
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  
   234  	resp, err := c.uaaHTTPClient.Get(c.URL + "/token_key")
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	var dd struct {
   240  		ID  string `json:"kid"`
   241  		PEM string `json:"value"`
   242  	}
   243  	err = json.NewDecoder(resp.Body).Decode(&dd)
   244  	resp.Body.Close()
   245  
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	pk, err := jwt.ParseRSAPublicKeyFromPEM([]byte(dd.PEM))
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	c.cachedKeysMu.Lock()
   256  	defer c.cachedKeysMu.Unlock()
   257  
   258  	if c.cachedKeys == nil {
   259  		c.cachedKeys = make(map[string]*rsa.PublicKey)
   260  	}
   261  
   262  	// With old versions of CF, the KID will be empty.
   263  	// That seems OK as it'll now be empty here too.
   264  	c.cachedKeys[dd.ID] = pk
   265  
   266  	if dd.ID != idWanted {
   267  		return nil, errors.New("still can't find it")
   268  	}
   269  
   270  	return pk, nil
   271  }
   272  
   273  // Find the public key to verify the JWT, and check the algorithm.
   274  func (c *Client) cfKeyFunc(t *jwt.Token) (interface{}, error) {
   275  	// Ensure that RS256 is used. This might seem overkill to care,
   276  	// but since the JWT spec actually allows a None algorithm which
   277  	// we definitely don't want, so instead we whitelist what we will allow.
   278  	if t.Method.Alg() != "RS256" {
   279  		return nil, errors.New("bad token9")
   280  	}
   281  
   282  	// Get Key ID
   283  	kid, ok := t.Header["kid"]
   284  	if !ok {
   285  		// some versions of Cloud Foundry don't return a key ID - if so, let's
   286  		// just hope for the best.
   287  		kid = ""
   288  	}
   289  
   290  	kidS, ok := kid.(string)
   291  	if !ok {
   292  		return nil, errors.New("bad token 11")
   293  	}
   294  
   295  	rv := c.pubKeyForID(kidS)
   296  	if rv != nil {
   297  		return rv, nil
   298  	}
   299  
   300  	rv, err := c.fetchAndSaveLatestKey(kidS)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	return rv, nil
   306  }
   307  
   308  // ValidateAccessToken will validate the given access token, ensure it matches
   309  // the client ID, and return the claims reported within.
   310  func (c *Client) ValidateAccessToken(at, expectedClientID string) (jwt.MapClaims, error) {
   311  	token, err := jwt.Parse(at, c.cfKeyFunc)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  
   316  	if !token.Valid {
   317  		return nil, errors.New("bad token 1")
   318  	}
   319  
   320  	mapClaims, ok := token.Claims.(jwt.MapClaims)
   321  	if !ok {
   322  		return nil, errors.New("bad token 2")
   323  	}
   324  
   325  	if !mapClaims.VerifyIssuer(c.ExternalURL+"/oauth/token", true) {
   326  		return nil, errors.New("bad token 3")
   327  	}
   328  
   329  	// Never, ever, ever, skip a client ID check (common error).
   330  	cid, _ := mapClaims["client_id"].(string)
   331  	if cid != expectedClientID {
   332  		return nil, errors.New("very bad token 4")
   333  	}
   334  
   335  	return mapClaims, nil
   336  }