golang.org/x/oauth2@v0.18.0/internal/token.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package internal
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"math"
    15  	"mime"
    16  	"net/http"
    17  	"net/url"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"sync/atomic"
    22  	"time"
    23  )
    24  
    25  // Token represents the credentials used to authorize
    26  // the requests to access protected resources on the OAuth 2.0
    27  // provider's backend.
    28  //
    29  // This type is a mirror of oauth2.Token and exists to break
    30  // an otherwise-circular dependency. Other internal packages
    31  // should convert this Token into an oauth2.Token before use.
    32  type Token struct {
    33  	// AccessToken is the token that authorizes and authenticates
    34  	// the requests.
    35  	AccessToken string
    36  
    37  	// TokenType is the type of token.
    38  	// The Type method returns either this or "Bearer", the default.
    39  	TokenType string
    40  
    41  	// RefreshToken is a token that's used by the application
    42  	// (as opposed to the user) to refresh the access token
    43  	// if it expires.
    44  	RefreshToken string
    45  
    46  	// Expiry is the optional expiration time of the access token.
    47  	//
    48  	// If zero, TokenSource implementations will reuse the same
    49  	// token forever and RefreshToken or equivalent
    50  	// mechanisms for that TokenSource will not be used.
    51  	Expiry time.Time
    52  
    53  	// Raw optionally contains extra metadata from the server
    54  	// when updating a token.
    55  	Raw interface{}
    56  }
    57  
    58  // tokenJSON is the struct representing the HTTP response from OAuth2
    59  // providers returning a token or error in JSON form.
    60  // https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
    61  type tokenJSON struct {
    62  	AccessToken  string         `json:"access_token"`
    63  	TokenType    string         `json:"token_type"`
    64  	RefreshToken string         `json:"refresh_token"`
    65  	ExpiresIn    expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
    66  	// error fields
    67  	// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
    68  	ErrorCode        string `json:"error"`
    69  	ErrorDescription string `json:"error_description"`
    70  	ErrorURI         string `json:"error_uri"`
    71  }
    72  
    73  func (e *tokenJSON) expiry() (t time.Time) {
    74  	if v := e.ExpiresIn; v != 0 {
    75  		return time.Now().Add(time.Duration(v) * time.Second)
    76  	}
    77  	return
    78  }
    79  
    80  type expirationTime int32
    81  
    82  func (e *expirationTime) UnmarshalJSON(b []byte) error {
    83  	if len(b) == 0 || string(b) == "null" {
    84  		return nil
    85  	}
    86  	var n json.Number
    87  	err := json.Unmarshal(b, &n)
    88  	if err != nil {
    89  		return err
    90  	}
    91  	i, err := n.Int64()
    92  	if err != nil {
    93  		return err
    94  	}
    95  	if i > math.MaxInt32 {
    96  		i = math.MaxInt32
    97  	}
    98  	*e = expirationTime(i)
    99  	return nil
   100  }
   101  
   102  // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
   103  //
   104  // Deprecated: this function no longer does anything. Caller code that
   105  // wants to avoid potential extra HTTP requests made during
   106  // auto-probing of the provider's auth style should set
   107  // Endpoint.AuthStyle.
   108  func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
   109  
   110  // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
   111  type AuthStyle int
   112  
   113  const (
   114  	AuthStyleUnknown  AuthStyle = 0
   115  	AuthStyleInParams AuthStyle = 1
   116  	AuthStyleInHeader AuthStyle = 2
   117  )
   118  
   119  // LazyAuthStyleCache is a backwards compatibility compromise to let Configs
   120  // have a lazily-initialized AuthStyleCache.
   121  //
   122  // The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
   123  // both would ideally just embed an unexported AuthStyleCache but because both
   124  // were historically allowed to be copied by value we can't retroactively add an
   125  // uncopyable Mutex to them.
   126  //
   127  // We could use an atomic.Pointer, but that was added recently enough (in Go
   128  // 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
   129  // still pass. By using an atomic.Value, it supports both Go 1.17 and
   130  // copying by value, even if that's not ideal.
   131  type LazyAuthStyleCache struct {
   132  	v atomic.Value // of *AuthStyleCache
   133  }
   134  
   135  func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
   136  	if c, ok := lc.v.Load().(*AuthStyleCache); ok {
   137  		return c
   138  	}
   139  	c := new(AuthStyleCache)
   140  	if !lc.v.CompareAndSwap(nil, c) {
   141  		c = lc.v.Load().(*AuthStyleCache)
   142  	}
   143  	return c
   144  }
   145  
   146  // AuthStyleCache is the set of tokenURLs we've successfully used via
   147  // RetrieveToken and which style auth we ended up using.
   148  // It's called a cache, but it doesn't (yet?) shrink. It's expected that
   149  // the set of OAuth2 servers a program contacts over time is fixed and
   150  // small.
   151  type AuthStyleCache struct {
   152  	mu sync.Mutex
   153  	m  map[string]AuthStyle // keyed by tokenURL
   154  }
   155  
   156  // lookupAuthStyle reports which auth style we last used with tokenURL
   157  // when calling RetrieveToken and whether we have ever done so.
   158  func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
   159  	c.mu.Lock()
   160  	defer c.mu.Unlock()
   161  	style, ok = c.m[tokenURL]
   162  	return
   163  }
   164  
   165  // setAuthStyle adds an entry to authStyleCache, documented above.
   166  func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
   167  	c.mu.Lock()
   168  	defer c.mu.Unlock()
   169  	if c.m == nil {
   170  		c.m = make(map[string]AuthStyle)
   171  	}
   172  	c.m[tokenURL] = v
   173  }
   174  
   175  // newTokenRequest returns a new *http.Request to retrieve a new token
   176  // from tokenURL using the provided clientID, clientSecret, and POST
   177  // body parameters.
   178  //
   179  // inParams is whether the clientID & clientSecret should be encoded
   180  // as the POST body. An 'inParams' value of true means to send it in
   181  // the POST body (along with any values in v); false means to send it
   182  // in the Authorization header.
   183  func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
   184  	if authStyle == AuthStyleInParams {
   185  		v = cloneURLValues(v)
   186  		if clientID != "" {
   187  			v.Set("client_id", clientID)
   188  		}
   189  		if clientSecret != "" {
   190  			v.Set("client_secret", clientSecret)
   191  		}
   192  	}
   193  	req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   198  	if authStyle == AuthStyleInHeader {
   199  		req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
   200  	}
   201  	return req, nil
   202  }
   203  
   204  func cloneURLValues(v url.Values) url.Values {
   205  	v2 := make(url.Values, len(v))
   206  	for k, vv := range v {
   207  		v2[k] = append([]string(nil), vv...)
   208  	}
   209  	return v2
   210  }
   211  
   212  func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
   213  	needsAuthStyleProbe := authStyle == 0
   214  	if needsAuthStyleProbe {
   215  		if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
   216  			authStyle = style
   217  			needsAuthStyleProbe = false
   218  		} else {
   219  			authStyle = AuthStyleInHeader // the first way we'll try
   220  		}
   221  	}
   222  	req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	token, err := doTokenRoundTrip(ctx, req)
   227  	if err != nil && needsAuthStyleProbe {
   228  		// If we get an error, assume the server wants the
   229  		// clientID & clientSecret in a different form.
   230  		// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
   231  		// In summary:
   232  		// - Reddit only accepts client secret in the Authorization header
   233  		// - Dropbox accepts either it in URL param or Auth header, but not both.
   234  		// - Google only accepts URL param (not spec compliant?), not Auth header
   235  		// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
   236  		//
   237  		// We used to maintain a big table in this code of all the sites and which way
   238  		// they went, but maintaining it didn't scale & got annoying.
   239  		// So just try both ways.
   240  		authStyle = AuthStyleInParams // the second way we'll try
   241  		req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
   242  		token, err = doTokenRoundTrip(ctx, req)
   243  	}
   244  	if needsAuthStyleProbe && err == nil {
   245  		styleCache.setAuthStyle(tokenURL, authStyle)
   246  	}
   247  	// Don't overwrite `RefreshToken` with an empty value
   248  	// if this was a token refreshing request.
   249  	if token != nil && token.RefreshToken == "" {
   250  		token.RefreshToken = v.Get("refresh_token")
   251  	}
   252  	return token, err
   253  }
   254  
   255  func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
   256  	r, err := ContextClient(ctx).Do(req.WithContext(ctx))
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  	body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
   261  	r.Body.Close()
   262  	if err != nil {
   263  		return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
   264  	}
   265  
   266  	failureStatus := r.StatusCode < 200 || r.StatusCode > 299
   267  	retrieveError := &RetrieveError{
   268  		Response: r,
   269  		Body:     body,
   270  		// attempt to populate error detail below
   271  	}
   272  
   273  	var token *Token
   274  	content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
   275  	switch content {
   276  	case "application/x-www-form-urlencoded", "text/plain":
   277  		// some endpoints return a query string
   278  		vals, err := url.ParseQuery(string(body))
   279  		if err != nil {
   280  			if failureStatus {
   281  				return nil, retrieveError
   282  			}
   283  			return nil, fmt.Errorf("oauth2: cannot parse response: %v", err)
   284  		}
   285  		retrieveError.ErrorCode = vals.Get("error")
   286  		retrieveError.ErrorDescription = vals.Get("error_description")
   287  		retrieveError.ErrorURI = vals.Get("error_uri")
   288  		token = &Token{
   289  			AccessToken:  vals.Get("access_token"),
   290  			TokenType:    vals.Get("token_type"),
   291  			RefreshToken: vals.Get("refresh_token"),
   292  			Raw:          vals,
   293  		}
   294  		e := vals.Get("expires_in")
   295  		expires, _ := strconv.Atoi(e)
   296  		if expires != 0 {
   297  			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
   298  		}
   299  	default:
   300  		var tj tokenJSON
   301  		if err = json.Unmarshal(body, &tj); err != nil {
   302  			if failureStatus {
   303  				return nil, retrieveError
   304  			}
   305  			return nil, fmt.Errorf("oauth2: cannot parse json: %v", err)
   306  		}
   307  		retrieveError.ErrorCode = tj.ErrorCode
   308  		retrieveError.ErrorDescription = tj.ErrorDescription
   309  		retrieveError.ErrorURI = tj.ErrorURI
   310  		token = &Token{
   311  			AccessToken:  tj.AccessToken,
   312  			TokenType:    tj.TokenType,
   313  			RefreshToken: tj.RefreshToken,
   314  			Expiry:       tj.expiry(),
   315  			Raw:          make(map[string]interface{}),
   316  		}
   317  		json.Unmarshal(body, &token.Raw) // no error checks for optional fields
   318  	}
   319  	// according to spec, servers should respond status 400 in error case
   320  	// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
   321  	// but some unorthodox servers respond 200 in error case
   322  	if failureStatus || retrieveError.ErrorCode != "" {
   323  		return nil, retrieveError
   324  	}
   325  	if token.AccessToken == "" {
   326  		return nil, errors.New("oauth2: server response missing access_token")
   327  	}
   328  	return token, nil
   329  }
   330  
   331  // mirrors oauth2.RetrieveError
   332  type RetrieveError struct {
   333  	Response         *http.Response
   334  	Body             []byte
   335  	ErrorCode        string
   336  	ErrorDescription string
   337  	ErrorURI         string
   338  }
   339  
   340  func (r *RetrieveError) Error() string {
   341  	if r.ErrorCode != "" {
   342  		s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
   343  		if r.ErrorDescription != "" {
   344  			s += fmt.Sprintf(" %q", r.ErrorDescription)
   345  		}
   346  		if r.ErrorURI != "" {
   347  			s += fmt.Sprintf(" %q", r.ErrorURI)
   348  		}
   349  		return s
   350  	}
   351  	return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
   352  }