go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/oauth/manager.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package oauth
     9  
    10  import (
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha512"
    14  	"encoding/base64"
    15  	"encoding/json"
    16  	"fmt"
    17  	"net/http"
    18  
    19  	"github.com/coreos/go-oidc/v3/oidc"
    20  	"golang.org/x/oauth2"
    21  
    22  	"go.charczuk.com/sdk/r2"
    23  	"go.charczuk.com/sdk/uuid"
    24  )
    25  
    26  // New returns a new manager mutated by a given set of options.
    27  func New(ctx context.Context, options ...Option) (*Manager, error) {
    28  	oidcProvider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  	manager := &Manager{
    33  		oauth2: oauth2.Config{
    34  			Endpoint: oidcProvider.Endpoint(),
    35  			Scopes:   DefaultScopes,
    36  		},
    37  	}
    38  	for _, option := range options {
    39  		if err := option(manager); err != nil {
    40  			return nil, err
    41  		}
    42  	}
    43  	if len(manager.Secret) == 0 {
    44  		return nil, ErrSecretRequired
    45  	}
    46  	manager.verifier = oidcProvider.Verifier(&oidc.Config{
    47  		ClientID: manager.oauth2.ClientID,
    48  	})
    49  	return manager, nil
    50  }
    51  
    52  // MustNew returns a new manager mutated by a given set of options
    53  // and will panic on error.
    54  func MustNew(ctx context.Context, options ...Option) *Manager {
    55  	m, err := New(ctx, options...)
    56  	if err != nil {
    57  		panic(err)
    58  	}
    59  	return m
    60  }
    61  
    62  // Manager is the oauth manager.
    63  type Manager struct {
    64  	Secret         []byte
    65  	HostedDomain   string
    66  	AllowedDomains []string
    67  
    68  	oauth2   oauth2.Config
    69  	verifier *oidc.IDTokenVerifier
    70  }
    71  
    72  // OAuthURL is the auth url for google with a given clientID.
    73  // This is typically the link that a user will click on to start the auth process.
    74  func (m *Manager) OAuthURL(r *http.Request, stateOptions ...StateOption) (oauthURL string, err error) {
    75  	var state string
    76  	state, err = SerializeState(m.CreateState(stateOptions...))
    77  	if err != nil {
    78  		return
    79  	}
    80  	var opts []oauth2.AuthCodeOption
    81  	if len(m.HostedDomain) > 0 {
    82  		opts = append(opts, oauth2.SetAuthURLParam("hd", m.HostedDomain))
    83  	}
    84  	oauthURL = m.oauth2.AuthCodeURL(state, opts...)
    85  	return
    86  }
    87  
    88  // Finish processes the returned code, exchanging for an access token, and fetches the user profile.
    89  func (m *Manager) Finish(r *http.Request) (result *Result, err error) {
    90  	code := r.URL.Query().Get("code")
    91  	if len(code) == 0 {
    92  		err = ErrCodeMissing
    93  		return
    94  	}
    95  
    96  	state := r.URL.Query().Get("state")
    97  	result = new(Result)
    98  	if state != "" {
    99  		var deserialized State
   100  		deserialized, err = DeserializeState(state)
   101  		if err != nil {
   102  			return
   103  		}
   104  		result.State = deserialized
   105  	}
   106  	err = m.ValidateState(result.State)
   107  	if err != nil {
   108  		return
   109  	}
   110  
   111  	// Handle the exchange code to initiate a transport.
   112  	tok, err := m.oauth2.Exchange(r.Context(), code)
   113  	if err != nil {
   114  		err = fmt.Errorf("%w: %v", ErrFailedCodeExchange, err)
   115  		return
   116  	}
   117  
   118  	// Extract the ID Token from OAuth2 token.
   119  	rawIDToken, ok := tok.Extra("id_token").(string)
   120  	if !ok {
   121  		err = fmt.Errorf("%w: id_token missing", ErrFailedCodeExchange)
   122  		return
   123  	}
   124  
   125  	// Parse and verify ID Token payload.
   126  	idToken, err := m.verifier.Verify(r.Context(), rawIDToken)
   127  	if err != nil {
   128  		err = fmt.Errorf("%w: %v", ErrFailedCodeExchange, err)
   129  		return
   130  	}
   131  
   132  	var claims GoogleClaims
   133  	if err = idToken.Claims(&claims); err != nil {
   134  		err = fmt.Errorf("%w: %v", ErrFailedCodeExchange, err)
   135  		return
   136  	}
   137  
   138  	result.Response.AccessToken = tok.AccessToken
   139  	result.Response.TokenType = tok.TokenType
   140  	result.Response.RefreshToken = tok.RefreshToken
   141  	result.Response.Expiry = tok.Expiry
   142  
   143  	result.Profile, err = m.FetchProfile(r.Context(), tok.AccessToken)
   144  	if err != nil {
   145  		return
   146  	}
   147  	return
   148  }
   149  
   150  // FetchProfile gets a google profile for an access token.
   151  func (m *Manager) FetchProfile(ctx context.Context, accessToken string) (profile Profile, err error) {
   152  	res, err := r2.New("https://www.googleapis.com/oauth2/v1/userinfo",
   153  		r2.OptGet(),
   154  		r2.OptContext(ctx),
   155  		r2.OptQuery("alt", "json"),
   156  		r2.OptHeader("Authorization", fmt.Sprintf("Bearer %s", accessToken)),
   157  	).Do()
   158  	if err != nil {
   159  		return
   160  	}
   161  	defer res.Body.Close()
   162  	if code := res.StatusCode; code < 200 || code > 299 {
   163  		err = ErrGoogleResponseStatus
   164  		return
   165  	}
   166  	if err = json.NewDecoder(res.Body).Decode(&profile); err != nil {
   167  		err = fmt.Errorf("%v: %w", ErrProfileJSONUnmarshal, err)
   168  		return
   169  	}
   170  	return
   171  }
   172  
   173  // CreateState creates auth state.
   174  func (m *Manager) CreateState(options ...StateOption) (state State) {
   175  	for _, opt := range options {
   176  		opt(&state)
   177  	}
   178  	if len(m.Secret) > 0 && state.Token == "" && state.SecureToken == "" {
   179  		state.Token = uuid.V4().String()
   180  		state.SecureToken = m.hash(state.Token)
   181  	}
   182  	return
   183  }
   184  
   185  // --------------------------------------------------------------------------------
   186  // Validation Helpers
   187  // --------------------------------------------------------------------------------
   188  
   189  // ValidateState validates oauth state.
   190  func (m *Manager) ValidateState(state State) error {
   191  	if len(m.Secret) > 0 {
   192  		expected := m.hash(state.Token)
   193  		actual := state.SecureToken
   194  		if !hmac.Equal([]byte(expected), []byte(actual)) {
   195  			return ErrInvalidAntiforgeryToken
   196  		}
   197  	}
   198  	return nil
   199  }
   200  
   201  // --------------------------------------------------------------------------------
   202  // internal helpers
   203  // --------------------------------------------------------------------------------
   204  
   205  func (m *Manager) hash(plaintext string) string {
   206  	return base64.URLEncoding.EncodeToString(m.hmac([]byte(plaintext)))
   207  }
   208  
   209  // hmac hashes data with the given key.
   210  func (m *Manager) hmac(plainText []byte) []byte {
   211  	mac := hmac.New(sha512.New, m.Secret)
   212  	_, _ = mac.Write([]byte(plainText))
   213  	return mac.Sum(nil)
   214  }