github.com/blend/go-sdk@v1.20240719.1/oauth/manager.go (about)

     1  /*
     2  
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package oauth
     9  
    10  import (
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha512"
    14  	"encoding/base64"
    15  	"encoding/json"
    16  	"fmt"
    17  	"net/http"
    18  	"strings"
    19  
    20  	"golang.org/x/oauth2"
    21  	"golang.org/x/oauth2/google"
    22  
    23  	"github.com/blend/go-sdk/ex"
    24  	"github.com/blend/go-sdk/r2"
    25  	"github.com/blend/go-sdk/uuid"
    26  	"github.com/blend/go-sdk/webutil"
    27  )
    28  
    29  const (
    30  	googleIssuerURL = "https://www.googleapis.com/oauth2"
    31  )
    32  
    33  // New returns a new Google Auth manager if options do not
    34  // specify an endpoint, PublicKeyCache and Issuer
    35  func New(options ...Option) (*Manager, error) {
    36  	manager := &Manager{
    37  		Config: oauth2.Config{
    38  			Endpoint: google.Endpoint,
    39  			Scopes:   DefaultScopes,
    40  		},
    41  		PublicKeyCache: NewPublicKeyCache(GoogleKeysURL),
    42  		Issuer:         googleIssuerURL,
    43  	}
    44  
    45  	for _, option := range options {
    46  		if err := option(manager); err != nil {
    47  			return nil, err
    48  		}
    49  	}
    50  	return manager, nil
    51  }
    52  
    53  // MustNew returns a new manager mutated by a given set of options
    54  // and will panic on error.
    55  func MustNew(options ...Option) *Manager {
    56  	m, err := New(options...)
    57  	if err != nil {
    58  		panic(err)
    59  	}
    60  	return m
    61  }
    62  
    63  // Manager is the oauth manager.
    64  type Manager struct {
    65  	oauth2.Config
    66  	Tracer Tracer
    67  
    68  	Secret []byte
    69  
    70  	HostedDomain   string
    71  	AllowedDomains []string
    72  
    73  	Issuer string
    74  
    75  	ValidateJWT ValidateJWTFunc
    76  
    77  	FetchProfileDefaults []r2.Option
    78  	PublicKeyCache       *PublicKeyCache
    79  }
    80  
    81  // OAuthURL is the auth url for google with a given clientID.
    82  // This is typically the link that a user will click on to start the auth process.
    83  func (m *Manager) OAuthURL(r *http.Request, stateOptions ...StateOption) (oauthURL string, err error) {
    84  	var state string
    85  	state, err = SerializeState(m.CreateState(stateOptions...))
    86  	if err != nil {
    87  		return
    88  	}
    89  	var opts []oauth2.AuthCodeOption
    90  	if len(m.HostedDomain) > 0 {
    91  		opts = append(opts, oauth2.SetAuthURLParam("hd", m.HostedDomain))
    92  	}
    93  	oauthURL = m.AuthCodeURL(state, opts...)
    94  	return
    95  }
    96  
    97  // Finish processes the returned code, exchanging for an access token, and fetches the user profile.
    98  func (m *Manager) Finish(r *http.Request) (result *Result, err error) {
    99  	if m.Tracer != nil {
   100  		tf := m.Tracer.Start(r.Context(), &m.Config)
   101  		if tf != nil {
   102  			defer func() { tf.Finish(r.Context(), &m.Config, result, err) }()
   103  		}
   104  	}
   105  
   106  	// grab the code off the request.
   107  	code := r.URL.Query().Get("code")
   108  	if len(code) == 0 {
   109  		err = ErrCodeMissing
   110  		return
   111  	}
   112  
   113  	// fetch the state
   114  	state := r.URL.Query().Get("state")
   115  	result = &Result{}
   116  	if len(state) > 0 {
   117  		var deserialized State
   118  		deserialized, err = DeserializeState(state)
   119  		if err != nil {
   120  			return
   121  		}
   122  		result.State = deserialized
   123  	}
   124  	err = m.ValidateState(result.State)
   125  	if err != nil {
   126  		return
   127  	}
   128  
   129  	// Handle the exchange code to initiate a transport.
   130  	var tok *oauth2.Token
   131  	tok, err = m.Exchange(r.Context(), code)
   132  	if err != nil {
   133  		err = ex.New(ErrFailedCodeExchange, ex.OptInner(err))
   134  		return
   135  	}
   136  
   137  	jwtClaims, err := ParseTokenJWT(tok, m.PublicKeyCache.Keyfunc(r.Context()))
   138  	if err != nil {
   139  		err = ex.New(ErrInvalidJWT, ex.OptInner(err))
   140  		return
   141  	}
   142  
   143  	// define the JWT validate function handler
   144  	validateJWT := m.ValidateJWT
   145  	if validateJWT == nil {
   146  		validateJWT = ValidateJWTGoogle
   147  	}
   148  
   149  	// validate the JWT
   150  	if err = validateJWT(m, jwtClaims); err != nil {
   151  		return
   152  	}
   153  
   154  	result.Response.AccessToken = tok.AccessToken
   155  	result.Response.TokenType = tok.TokenType
   156  	result.Response.RefreshToken = tok.RefreshToken
   157  	result.Response.Expiry = tok.Expiry
   158  	result.Response.HostedDomain = jwtClaims.HD
   159  
   160  	var prof Profile
   161  	prof, err = m.FetchProfile(r.Context(), tok.AccessToken)
   162  	if err != nil {
   163  		return
   164  	}
   165  	result.Profile = prof
   166  	return
   167  }
   168  
   169  // FetchProfile gets a google profile for an access token.
   170  func (m *Manager) FetchProfile(ctx context.Context, accessToken string) (profile Profile, err error) {
   171  	res, err := r2.New(m.Issuer+"/v1/userinfo", append([]r2.Option{
   172  		r2.OptGet(),
   173  		r2.OptContext(ctx),
   174  		r2.OptQueryValue("alt", "json"),
   175  		r2.OptHeaderValue(webutil.HeaderAuthorization, fmt.Sprintf("Bearer %s", accessToken)),
   176  	}, m.FetchProfileDefaults...)...).Do()
   177  	if err != nil {
   178  		return
   179  	}
   180  	defer res.Body.Close()
   181  	if code := res.StatusCode; code < 200 || code > 299 {
   182  		err = ex.New(ErrGoogleResponseStatus, ex.OptMessagef("status code: %d", res.StatusCode))
   183  		return
   184  	}
   185  	if err = json.NewDecoder(res.Body).Decode(&profile); err != nil {
   186  		err = ex.New(ErrProfileJSONUnmarshal, ex.OptInner(err))
   187  		return
   188  	}
   189  	return
   190  }
   191  
   192  // CreateState creates auth state.
   193  func (m *Manager) CreateState(options ...StateOption) (state State) {
   194  	for _, opt := range options {
   195  		opt(&state)
   196  	}
   197  	if len(m.Secret) > 0 && state.Token == "" && state.SecureToken == "" {
   198  		state.Token = uuid.V4().String()
   199  		state.SecureToken = m.hash(state.Token)
   200  	}
   201  	return
   202  }
   203  
   204  // --------------------------------------------------------------------------------
   205  // Validation Helpers
   206  // --------------------------------------------------------------------------------
   207  
   208  // ValidateState validates oauth state.
   209  func (m *Manager) ValidateState(state State) error {
   210  	if len(m.Secret) > 0 {
   211  		expected := m.hash(state.Token)
   212  		actual := state.SecureToken
   213  		if !hmac.Equal([]byte(expected), []byte(actual)) {
   214  			return ErrInvalidAntiforgeryToken
   215  		}
   216  	}
   217  	return nil
   218  }
   219  
   220  // ValidateJWTGoogle returns if the google issued jwt is valid or not.
   221  func ValidateJWTGoogle(m *Manager, jwtClaims *GoogleClaims) error {
   222  	if !jwtClaims.StandardClaims.VerifyAudience(m.Config.ClientID, true) {
   223  		return ex.New(ErrInvalidJWTAudience, ex.OptMessagef("audience: %s", jwtClaims.StandardClaims.Audience))
   224  	}
   225  	if jwtClaims.StandardClaims.Issuer != GoogleIssuer && jwtClaims.StandardClaims.Issuer != GoogleIssuerAlternate {
   226  		return ex.New(ErrInvalidJWTIssuer, ex.OptMessagef("issuer: %s", jwtClaims.StandardClaims.Issuer))
   227  	}
   228  	if len(m.AllowedDomains) > 0 {
   229  		if strings.TrimSpace(jwtClaims.HD) == "" {
   230  			return ex.New(ErrInvalidJWTHostedDomain, ex.OptMessagef("hosted domain: likely gmail.com, but empty"))
   231  		}
   232  		var matchedDomain bool
   233  		for _, domain := range m.AllowedDomains {
   234  			if strings.EqualFold(domain, jwtClaims.HD) {
   235  				matchedDomain = true
   236  				break
   237  			}
   238  		}
   239  		if !matchedDomain {
   240  			return ex.New(ErrInvalidJWTHostedDomain, ex.OptMessagef("hosted domain: %s", jwtClaims.HD))
   241  		}
   242  	}
   243  	return nil
   244  }
   245  
   246  // ValidateJWTOkta returns if the okta issued jwt is valid or not.
   247  func ValidateJWTOkta(m *Manager, jwtClaims *GoogleClaims) error {
   248  	if !jwtClaims.StandardClaims.VerifyAudience(m.Config.ClientID, true) {
   249  		return ex.New(ErrInvalidJWTAudience, ex.OptMessagef("audience: %s", jwtClaims.StandardClaims.Audience))
   250  	}
   251  	return nil
   252  }
   253  
   254  // --------------------------------------------------------------------------------
   255  // internal helpers
   256  // --------------------------------------------------------------------------------
   257  
   258  func (m *Manager) hash(plaintext string) string {
   259  	return base64.URLEncoding.EncodeToString(m.hmac([]byte(plaintext)))
   260  }
   261  
   262  // hmac hashes data with the given key.
   263  func (m *Manager) hmac(plainText []byte) []byte {
   264  	mac := hmac.New(sha512.New, m.Secret)
   265  	_, _ = mac.Write([]byte(plainText))
   266  	return mac.Sum(nil)
   267  }