
     1  /*
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     6  */
     8  package oauth
    10  import (
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha512"
    14  	"encoding/base64"
    15  	"encoding/json"
    16  	"fmt"
    17  	"net/http"
    18  	"strings"
    20  	""
    21  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  )
    29  const (
    30  	googleIssuerURL = ""
    31  )
    33  // New returns a new Google Auth manager if options do not
    34  // specify an endpoint, PublicKeyCache and Issuer
    35  func New(options ...Option) (*Manager, error) {
    36  	manager := &Manager{
    37  		Config: oauth2.Config{
    38  			Endpoint: google.Endpoint,
    39  			Scopes:   DefaultScopes,
    40  		},
    41  		PublicKeyCache: NewPublicKeyCache(GoogleKeysURL),
    42  		Issuer:         googleIssuerURL,
    43  	}
    45  	for _, option := range options {
    46  		if err := option(manager); err != nil {
    47  			return nil, err
    48  		}
    49  	}
    50  	return manager, nil
    51  }
    53  // MustNew returns a new manager mutated by a given set of options
    54  // and will panic on error.
    55  func MustNew(options ...Option) *Manager {
    56  	m, err := New(options...)
    57  	if err != nil {
    58  		panic(err)
    59  	}
    60  	return m
    61  }
    63  // Manager is the oauth manager.
    64  type Manager struct {
    65  	oauth2.Config
    66  	Tracer Tracer
    68  	Secret []byte
    70  	HostedDomain   string
    71  	AllowedDomains []string
    73  	Issuer string
    75  	ValidateJWT ValidateJWTFunc
    77  	FetchProfileDefaults []r2.Option
    78  	PublicKeyCache       *PublicKeyCache
    79  }
    81  // OAuthURL is the auth url for google with a given clientID.
    82  // This is typically the link that a user will click on to start the auth process.
    83  func (m *Manager) OAuthURL(r *http.Request, stateOptions ...StateOption) (oauthURL string, err error) {
    84  	var state string
    85  	state, err = SerializeState(m.CreateState(stateOptions...))
    86  	if err != nil {
    87  		return
    88  	}
    89  	var opts []oauth2.AuthCodeOption
    90  	if len(m.HostedDomain) > 0 {
    91  		opts = append(opts, oauth2.SetAuthURLParam("hd", m.HostedDomain))
    92  	}
    93  	oauthURL = m.AuthCodeURL(state, opts...)
    94  	return
    95  }
    97  // Finish processes the returned code, exchanging for an access token, and fetches the user profile.
    98  func (m *Manager) Finish(r *http.Request) (result *Result, err error) {
    99  	if m.Tracer != nil {
   100  		tf := m.Tracer.Start(r.Context(), &m.Config)
   101  		if tf != nil {
   102  			defer func() { tf.Finish(r.Context(), &m.Config, result, err) }()
   103  		}
   104  	}
   106  	// grab the code off the request.
   107  	code := r.URL.Query().Get("code")
   108  	if len(code) == 0 {
   109  		err = ErrCodeMissing
   110  		return
   111  	}
   113  	// fetch the state
   114  	state := r.URL.Query().Get("state")
   115  	result = &Result{}
   116  	if len(state) > 0 {
   117  		var deserialized State
   118  		deserialized, err = DeserializeState(state)
   119  		if err != nil {
   120  			return
   121  		}
   122  		result.State = deserialized
   123  	}
   124  	err = m.ValidateState(result.State)
   125  	if err != nil {
   126  		return
   127  	}
   129  	// Handle the exchange code to initiate a transport.
   130  	var tok *oauth2.Token
   131  	tok, err = m.Exchange(r.Context(), code)
   132  	if err != nil {
   133  		err = ex.New(ErrFailedCodeExchange, ex.OptInner(err))
   134  		return
   135  	}
   137  	jwtClaims, err := ParseTokenJWT(tok, m.PublicKeyCache.Keyfunc(r.Context()))
   138  	if err != nil {
   139  		err = ex.New(ErrInvalidJWT, ex.OptInner(err))
   140  		return
   141  	}
   143  	// define the JWT validate function handler
   144  	validateJWT := m.ValidateJWT
   145  	if validateJWT == nil {
   146  		validateJWT = ValidateJWTGoogle
   147  	}
   149  	// validate the JWT
   150  	if err = validateJWT(m, jwtClaims); err != nil {
   151  		return
   152  	}
   154  	result.Response.AccessToken = tok.AccessToken
   155  	result.Response.TokenType = tok.TokenType
   156  	result.Response.RefreshToken = tok.RefreshToken
   157  	result.Response.Expiry = tok.Expiry
   158  	result.Response.HostedDomain = jwtClaims.HD
   160  	var prof Profile
   161  	prof, err = m.FetchProfile(r.Context(), tok.AccessToken)
   162  	if err != nil {
   163  		return
   164  	}
   165  	result.Profile = prof
   166  	return
   167  }
   169  // FetchProfile gets a google profile for an access token.
   170  func (m *Manager) FetchProfile(ctx context.Context, accessToken string) (profile Profile, err error) {
   171  	res, err := r2.New(m.Issuer+"/v1/userinfo", append([]r2.Option{
   172  		r2.OptGet(),
   173  		r2.OptContext(ctx),
   174  		r2.OptQueryValue("alt", "json"),
   175  		r2.OptHeaderValue(webutil.HeaderAuthorization, fmt.Sprintf("Bearer %s", accessToken)),
   176  	}, m.FetchProfileDefaults...)...).Do()
   177  	if err != nil {
   178  		return
   179  	}
   180  	defer res.Body.Close()
   181  	if code := res.StatusCode; code < 200 || code > 299 {
   182  		err = ex.New(ErrGoogleResponseStatus, ex.OptMessagef("status code: %d", res.StatusCode))
   183  		return
   184  	}
   185  	if err = json.NewDecoder(res.Body).Decode(&profile); err != nil {
   186  		err = ex.New(ErrProfileJSONUnmarshal, ex.OptInner(err))
   187  		return
   188  	}
   189  	return
   190  }
   192  // CreateState creates auth state.
   193  func (m *Manager) CreateState(options ...StateOption) (state State) {
   194  	for _, opt := range options {
   195  		opt(&state)
   196  	}
   197  	if len(m.Secret) > 0 && state.Token == "" && state.SecureToken == "" {
   198  		state.Token = uuid.V4().String()
   199  		state.SecureToken = m.hash(state.Token)
   200  	}
   201  	return
   202  }
   205  // Validation Helpers
   208  // ValidateState validates oauth state.
   209  func (m *Manager) ValidateState(state State) error {
   210  	if len(m.Secret) > 0 {
   211  		expected := m.hash(state.Token)
   212  		actual := state.SecureToken
   213  		if !hmac.Equal([]byte(expected), []byte(actual)) {
   214  			return ErrInvalidAntiforgeryToken
   215  		}
   216  	}
   217  	return nil
   218  }
   220  // ValidateJWTGoogle returns if the google issued jwt is valid or not.
   221  func ValidateJWTGoogle(m *Manager, jwtClaims *GoogleClaims) error {
   222  	if !jwtClaims.StandardClaims.VerifyAudience(m.Config.ClientID, true) {
   223  		return ex.New(ErrInvalidJWTAudience, ex.OptMessagef("audience: %s", jwtClaims.StandardClaims.Audience))
   224  	}
   225  	if jwtClaims.StandardClaims.Issuer != GoogleIssuer && jwtClaims.StandardClaims.Issuer != GoogleIssuerAlternate {
   226  		return ex.New(ErrInvalidJWTIssuer, ex.OptMessagef("issuer: %s", jwtClaims.StandardClaims.Issuer))
   227  	}
   228  	if len(m.AllowedDomains) > 0 {
   229  		if strings.TrimSpace(jwtClaims.HD) == "" {
   230  			return ex.New(ErrInvalidJWTHostedDomain, ex.OptMessagef("hosted domain: likely, but empty"))
   231  		}
   232  		var matchedDomain bool
   233  		for _, domain := range m.AllowedDomains {
   234  			if strings.EqualFold(domain, jwtClaims.HD) {
   235  				matchedDomain = true
   236  				break
   237  			}
   238  		}
   239  		if !matchedDomain {
   240  			return ex.New(ErrInvalidJWTHostedDomain, ex.OptMessagef("hosted domain: %s", jwtClaims.HD))
   241  		}
   242  	}
   243  	return nil
   244  }
   246  // ValidateJWTOkta returns if the okta issued jwt is valid or not.
   247  func ValidateJWTOkta(m *Manager, jwtClaims *GoogleClaims) error {
   248  	if !jwtClaims.StandardClaims.VerifyAudience(m.Config.ClientID, true) {
   249  		return ex.New(ErrInvalidJWTAudience, ex.OptMessagef("audience: %s", jwtClaims.StandardClaims.Audience))
   250  	}
   251  	return nil
   252  }
   255  // internal helpers
   258  func (m *Manager) hash(plaintext string) string {
   259  	return base64.URLEncoding.EncodeToString(m.hmac([]byte(plaintext)))
   260  }
   262  // hmac hashes data with the given key.
   263  func (m *Manager) hmac(plainText []byte) []byte {
   264  	mac := hmac.New(sha512.New, m.Secret)
   265  	_, _ = mac.Write([]byte(plainText))
   266  	return mac.Sum(nil)
   267  }