go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/oauth/manager.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package oauth 9 10 import ( 11 "context" 12 "crypto/hmac" 13 "crypto/sha512" 14 "encoding/base64" 15 "encoding/json" 16 "fmt" 17 "net/http" 18 19 "github.com/coreos/go-oidc/v3/oidc" 20 "golang.org/x/oauth2" 21 22 "go.charczuk.com/sdk/r2" 23 "go.charczuk.com/sdk/uuid" 24 ) 25 26 // New returns a new manager mutated by a given set of options. 27 func New(ctx context.Context, options ...Option) (*Manager, error) { 28 oidcProvider, err := oidc.NewProvider(ctx, "https://accounts.google.com") 29 if err != nil { 30 return nil, err 31 } 32 manager := &Manager{ 33 oauth2: oauth2.Config{ 34 Endpoint: oidcProvider.Endpoint(), 35 Scopes: DefaultScopes, 36 }, 37 } 38 for _, option := range options { 39 if err := option(manager); err != nil { 40 return nil, err 41 } 42 } 43 if len(manager.Secret) == 0 { 44 return nil, ErrSecretRequired 45 } 46 manager.verifier = oidcProvider.Verifier(&oidc.Config{ 47 ClientID: manager.oauth2.ClientID, 48 }) 49 return manager, nil 50 } 51 52 // MustNew returns a new manager mutated by a given set of options 53 // and will panic on error. 54 func MustNew(ctx context.Context, options ...Option) *Manager { 55 m, err := New(ctx, options...) 56 if err != nil { 57 panic(err) 58 } 59 return m 60 } 61 62 // Manager is the oauth manager. 63 type Manager struct { 64 Secret []byte 65 HostedDomain string 66 AllowedDomains []string 67 68 oauth2 oauth2.Config 69 verifier *oidc.IDTokenVerifier 70 } 71 72 // OAuthURL is the auth url for google with a given clientID. 73 // This is typically the link that a user will click on to start the auth process. 74 func (m *Manager) OAuthURL(r *http.Request, stateOptions ...StateOption) (oauthURL string, err error) { 75 var state string 76 state, err = SerializeState(m.CreateState(stateOptions...)) 77 if err != nil { 78 return 79 } 80 var opts []oauth2.AuthCodeOption 81 if len(m.HostedDomain) > 0 { 82 opts = append(opts, oauth2.SetAuthURLParam("hd", m.HostedDomain)) 83 } 84 oauthURL = m.oauth2.AuthCodeURL(state, opts...) 85 return 86 } 87 88 // Finish processes the returned code, exchanging for an access token, and fetches the user profile. 89 func (m *Manager) Finish(r *http.Request) (result *Result, err error) { 90 code := r.URL.Query().Get("code") 91 if len(code) == 0 { 92 err = ErrCodeMissing 93 return 94 } 95 96 state := r.URL.Query().Get("state") 97 result = new(Result) 98 if state != "" { 99 var deserialized State 100 deserialized, err = DeserializeState(state) 101 if err != nil { 102 return 103 } 104 result.State = deserialized 105 } 106 err = m.ValidateState(result.State) 107 if err != nil { 108 return 109 } 110 111 // Handle the exchange code to initiate a transport. 112 tok, err := m.oauth2.Exchange(r.Context(), code) 113 if err != nil { 114 err = fmt.Errorf("%w: %v", ErrFailedCodeExchange, err) 115 return 116 } 117 118 // Extract the ID Token from OAuth2 token. 119 rawIDToken, ok := tok.Extra("id_token").(string) 120 if !ok { 121 err = fmt.Errorf("%w: id_token missing", ErrFailedCodeExchange) 122 return 123 } 124 125 // Parse and verify ID Token payload. 126 idToken, err := m.verifier.Verify(r.Context(), rawIDToken) 127 if err != nil { 128 err = fmt.Errorf("%w: %v", ErrFailedCodeExchange, err) 129 return 130 } 131 132 var claims GoogleClaims 133 if err = idToken.Claims(&claims); err != nil { 134 err = fmt.Errorf("%w: %v", ErrFailedCodeExchange, err) 135 return 136 } 137 138 result.Response.AccessToken = tok.AccessToken 139 result.Response.TokenType = tok.TokenType 140 result.Response.RefreshToken = tok.RefreshToken 141 result.Response.Expiry = tok.Expiry 142 143 result.Profile, err = m.FetchProfile(r.Context(), tok.AccessToken) 144 if err != nil { 145 return 146 } 147 return 148 } 149 150 // FetchProfile gets a google profile for an access token. 151 func (m *Manager) FetchProfile(ctx context.Context, accessToken string) (profile Profile, err error) { 152 res, err := r2.New("https://www.googleapis.com/oauth2/v1/userinfo", 153 r2.OptGet(), 154 r2.OptContext(ctx), 155 r2.OptQuery("alt", "json"), 156 r2.OptHeader("Authorization", fmt.Sprintf("Bearer %s", accessToken)), 157 ).Do() 158 if err != nil { 159 return 160 } 161 defer res.Body.Close() 162 if code := res.StatusCode; code < 200 || code > 299 { 163 err = ErrGoogleResponseStatus 164 return 165 } 166 if err = json.NewDecoder(res.Body).Decode(&profile); err != nil { 167 err = fmt.Errorf("%v: %w", ErrProfileJSONUnmarshal, err) 168 return 169 } 170 return 171 } 172 173 // CreateState creates auth state. 174 func (m *Manager) CreateState(options ...StateOption) (state State) { 175 for _, opt := range options { 176 opt(&state) 177 } 178 if len(m.Secret) > 0 && state.Token == "" && state.SecureToken == "" { 179 state.Token = uuid.V4().String() 180 state.SecureToken = m.hash(state.Token) 181 } 182 return 183 } 184 185 // -------------------------------------------------------------------------------- 186 // Validation Helpers 187 // -------------------------------------------------------------------------------- 188 189 // ValidateState validates oauth state. 190 func (m *Manager) ValidateState(state State) error { 191 if len(m.Secret) > 0 { 192 expected := m.hash(state.Token) 193 actual := state.SecureToken 194 if !hmac.Equal([]byte(expected), []byte(actual)) { 195 return ErrInvalidAntiforgeryToken 196 } 197 } 198 return nil 199 } 200 201 // -------------------------------------------------------------------------------- 202 // internal helpers 203 // -------------------------------------------------------------------------------- 204 205 func (m *Manager) hash(plaintext string) string { 206 return base64.URLEncoding.EncodeToString(m.hmac([]byte(plaintext))) 207 } 208 209 // hmac hashes data with the given key. 210 func (m *Manager) hmac(plainText []byte) []byte { 211 mac := hmac.New(sha512.New, m.Secret) 212 _, _ = mac.Write([]byte(plainText)) 213 return mac.Sum(nil) 214 }