github.com/blend/go-sdk@v1.20240719.1/oauth/manager.go (about) 1 /* 2 3 Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package oauth 9 10 import ( 11 "context" 12 "crypto/hmac" 13 "crypto/sha512" 14 "encoding/base64" 15 "encoding/json" 16 "fmt" 17 "net/http" 18 "strings" 19 20 "golang.org/x/oauth2" 21 "golang.org/x/oauth2/google" 22 23 "github.com/blend/go-sdk/ex" 24 "github.com/blend/go-sdk/r2" 25 "github.com/blend/go-sdk/uuid" 26 "github.com/blend/go-sdk/webutil" 27 ) 28 29 const ( 30 googleIssuerURL = "https://www.googleapis.com/oauth2" 31 ) 32 33 // New returns a new Google Auth manager if options do not 34 // specify an endpoint, PublicKeyCache and Issuer 35 func New(options ...Option) (*Manager, error) { 36 manager := &Manager{ 37 Config: oauth2.Config{ 38 Endpoint: google.Endpoint, 39 Scopes: DefaultScopes, 40 }, 41 PublicKeyCache: NewPublicKeyCache(GoogleKeysURL), 42 Issuer: googleIssuerURL, 43 } 44 45 for _, option := range options { 46 if err := option(manager); err != nil { 47 return nil, err 48 } 49 } 50 return manager, nil 51 } 52 53 // MustNew returns a new manager mutated by a given set of options 54 // and will panic on error. 55 func MustNew(options ...Option) *Manager { 56 m, err := New(options...) 57 if err != nil { 58 panic(err) 59 } 60 return m 61 } 62 63 // Manager is the oauth manager. 64 type Manager struct { 65 oauth2.Config 66 Tracer Tracer 67 68 Secret []byte 69 70 HostedDomain string 71 AllowedDomains []string 72 73 Issuer string 74 75 ValidateJWT ValidateJWTFunc 76 77 FetchProfileDefaults []r2.Option 78 PublicKeyCache *PublicKeyCache 79 } 80 81 // OAuthURL is the auth url for google with a given clientID. 82 // This is typically the link that a user will click on to start the auth process. 83 func (m *Manager) OAuthURL(r *http.Request, stateOptions ...StateOption) (oauthURL string, err error) { 84 var state string 85 state, err = SerializeState(m.CreateState(stateOptions...)) 86 if err != nil { 87 return 88 } 89 var opts []oauth2.AuthCodeOption 90 if len(m.HostedDomain) > 0 { 91 opts = append(opts, oauth2.SetAuthURLParam("hd", m.HostedDomain)) 92 } 93 oauthURL = m.AuthCodeURL(state, opts...) 94 return 95 } 96 97 // Finish processes the returned code, exchanging for an access token, and fetches the user profile. 98 func (m *Manager) Finish(r *http.Request) (result *Result, err error) { 99 if m.Tracer != nil { 100 tf := m.Tracer.Start(r.Context(), &m.Config) 101 if tf != nil { 102 defer func() { tf.Finish(r.Context(), &m.Config, result, err) }() 103 } 104 } 105 106 // grab the code off the request. 107 code := r.URL.Query().Get("code") 108 if len(code) == 0 { 109 err = ErrCodeMissing 110 return 111 } 112 113 // fetch the state 114 state := r.URL.Query().Get("state") 115 result = &Result{} 116 if len(state) > 0 { 117 var deserialized State 118 deserialized, err = DeserializeState(state) 119 if err != nil { 120 return 121 } 122 result.State = deserialized 123 } 124 err = m.ValidateState(result.State) 125 if err != nil { 126 return 127 } 128 129 // Handle the exchange code to initiate a transport. 130 var tok *oauth2.Token 131 tok, err = m.Exchange(r.Context(), code) 132 if err != nil { 133 err = ex.New(ErrFailedCodeExchange, ex.OptInner(err)) 134 return 135 } 136 137 jwtClaims, err := ParseTokenJWT(tok, m.PublicKeyCache.Keyfunc(r.Context())) 138 if err != nil { 139 err = ex.New(ErrInvalidJWT, ex.OptInner(err)) 140 return 141 } 142 143 // define the JWT validate function handler 144 validateJWT := m.ValidateJWT 145 if validateJWT == nil { 146 validateJWT = ValidateJWTGoogle 147 } 148 149 // validate the JWT 150 if err = validateJWT(m, jwtClaims); err != nil { 151 return 152 } 153 154 result.Response.AccessToken = tok.AccessToken 155 result.Response.TokenType = tok.TokenType 156 result.Response.RefreshToken = tok.RefreshToken 157 result.Response.Expiry = tok.Expiry 158 result.Response.HostedDomain = jwtClaims.HD 159 160 var prof Profile 161 prof, err = m.FetchProfile(r.Context(), tok.AccessToken) 162 if err != nil { 163 return 164 } 165 result.Profile = prof 166 return 167 } 168 169 // FetchProfile gets a google profile for an access token. 170 func (m *Manager) FetchProfile(ctx context.Context, accessToken string) (profile Profile, err error) { 171 res, err := r2.New(m.Issuer+"/v1/userinfo", append([]r2.Option{ 172 r2.OptGet(), 173 r2.OptContext(ctx), 174 r2.OptQueryValue("alt", "json"), 175 r2.OptHeaderValue(webutil.HeaderAuthorization, fmt.Sprintf("Bearer %s", accessToken)), 176 }, m.FetchProfileDefaults...)...).Do() 177 if err != nil { 178 return 179 } 180 defer res.Body.Close() 181 if code := res.StatusCode; code < 200 || code > 299 { 182 err = ex.New(ErrGoogleResponseStatus, ex.OptMessagef("status code: %d", res.StatusCode)) 183 return 184 } 185 if err = json.NewDecoder(res.Body).Decode(&profile); err != nil { 186 err = ex.New(ErrProfileJSONUnmarshal, ex.OptInner(err)) 187 return 188 } 189 return 190 } 191 192 // CreateState creates auth state. 193 func (m *Manager) CreateState(options ...StateOption) (state State) { 194 for _, opt := range options { 195 opt(&state) 196 } 197 if len(m.Secret) > 0 && state.Token == "" && state.SecureToken == "" { 198 state.Token = uuid.V4().String() 199 state.SecureToken = m.hash(state.Token) 200 } 201 return 202 } 203 204 // -------------------------------------------------------------------------------- 205 // Validation Helpers 206 // -------------------------------------------------------------------------------- 207 208 // ValidateState validates oauth state. 209 func (m *Manager) ValidateState(state State) error { 210 if len(m.Secret) > 0 { 211 expected := m.hash(state.Token) 212 actual := state.SecureToken 213 if !hmac.Equal([]byte(expected), []byte(actual)) { 214 return ErrInvalidAntiforgeryToken 215 } 216 } 217 return nil 218 } 219 220 // ValidateJWTGoogle returns if the google issued jwt is valid or not. 221 func ValidateJWTGoogle(m *Manager, jwtClaims *GoogleClaims) error { 222 if !jwtClaims.StandardClaims.VerifyAudience(m.Config.ClientID, true) { 223 return ex.New(ErrInvalidJWTAudience, ex.OptMessagef("audience: %s", jwtClaims.StandardClaims.Audience)) 224 } 225 if jwtClaims.StandardClaims.Issuer != GoogleIssuer && jwtClaims.StandardClaims.Issuer != GoogleIssuerAlternate { 226 return ex.New(ErrInvalidJWTIssuer, ex.OptMessagef("issuer: %s", jwtClaims.StandardClaims.Issuer)) 227 } 228 if len(m.AllowedDomains) > 0 { 229 if strings.TrimSpace(jwtClaims.HD) == "" { 230 return ex.New(ErrInvalidJWTHostedDomain, ex.OptMessagef("hosted domain: likely gmail.com, but empty")) 231 } 232 var matchedDomain bool 233 for _, domain := range m.AllowedDomains { 234 if strings.EqualFold(domain, jwtClaims.HD) { 235 matchedDomain = true 236 break 237 } 238 } 239 if !matchedDomain { 240 return ex.New(ErrInvalidJWTHostedDomain, ex.OptMessagef("hosted domain: %s", jwtClaims.HD)) 241 } 242 } 243 return nil 244 } 245 246 // ValidateJWTOkta returns if the okta issued jwt is valid or not. 247 func ValidateJWTOkta(m *Manager, jwtClaims *GoogleClaims) error { 248 if !jwtClaims.StandardClaims.VerifyAudience(m.Config.ClientID, true) { 249 return ex.New(ErrInvalidJWTAudience, ex.OptMessagef("audience: %s", jwtClaims.StandardClaims.Audience)) 250 } 251 return nil 252 } 253 254 // -------------------------------------------------------------------------------- 255 // internal helpers 256 // -------------------------------------------------------------------------------- 257 258 func (m *Manager) hash(plaintext string) string { 259 return base64.URLEncoding.EncodeToString(m.hmac([]byte(plaintext))) 260 } 261 262 // hmac hashes data with the given key. 263 func (m *Manager) hmac(plainText []byte) []byte { 264 mac := hmac.New(sha512.New, m.Secret) 265 _, _ = mac.Write([]byte(plainText)) 266 return mac.Sum(nil) 267 }