github.com/volatiletech/authboss@v2.4.1+incompatible/oauth2/oauth2.go (about) 1 // Package oauth2 allows users to be created and authenticated 2 // via oauth2 services like facebook, google etc. Currently 3 // only the web server flow is supported. 4 // 5 // The general flow looks like this: 6 // 1. User goes to Start handler and has his session packed with goodies 7 // then redirects to the OAuth service. 8 // 2. OAuth service returns to OAuthCallback which extracts state and 9 // parameters and generally checks that everything is ok. It uses the 10 // token received to get an access token from the oauth2 library 11 // 3. Calls the OAuth2Provider.FindUserDetails which should return the user's 12 // details in a generic form. 13 // 4. Passes the user details into the OAuth2ServerStorer.NewFromOAuth2 in 14 // order to create a user object we can work with. 15 // 5. Saves the user in the database, logs them in, redirects. 16 // 17 // In order to do this there are a number of parts: 18 // 1. The configuration of a provider 19 // (handled by authboss.Config.Modules.OAuth2Providers). 20 // 2. The flow of redirection of client, parameter passing etc 21 // (handled by this package) 22 // 3. The HTTP call to the service once a token has been retrieved to 23 // get user details (handled by OAuth2Provider.FindUserDetails) 24 // 4. The creation of a user from the user details returned from the 25 // FindUserDetails (authboss.OAuth2ServerStorer) 26 // 27 // Of these parts, the responsibility of the authboss library consumer 28 // is on 1, 3, and 4. Configuration of providers that should be used is totally 29 // up to the consumer. The FindUserDetails function is typically up to the 30 // user, but we have some basic ones included in this package too. 31 // The creation of users from the FindUserDetail's map[string]string return 32 // is handled as part of the implementation of the OAuth2ServerStorer. 33 package oauth2 34 35 import ( 36 "context" 37 "crypto/rand" 38 "encoding/base64" 39 "encoding/json" 40 "fmt" 41 "io" 42 "net/http" 43 "net/url" 44 "path" 45 "path/filepath" 46 "sort" 47 "strings" 48 49 "github.com/pkg/errors" 50 "golang.org/x/oauth2" 51 52 "github.com/volatiletech/authboss" 53 ) 54 55 // FormValue constants 56 const ( 57 FormValueOAuth2State = "state" 58 FormValueOAuth2Redir = "redir" 59 ) 60 61 var ( 62 errOAuthStateValidation = errors.New("could not validate oauth2 state param") 63 ) 64 65 // OAuth2 module 66 type OAuth2 struct { 67 *authboss.Authboss 68 } 69 70 func init() { 71 authboss.RegisterModule("oauth2", &OAuth2{}) 72 } 73 74 // Init module 75 func (o *OAuth2) Init(ab *authboss.Authboss) error { 76 o.Authboss = ab 77 78 // Do annoying sorting on keys so we can have predictible 79 // route registration (both for consistency inside the router but 80 // also for tests -_-) 81 var keys []string 82 for k := range o.Authboss.Config.Modules.OAuth2Providers { 83 keys = append(keys, k) 84 } 85 sort.Strings(keys) 86 87 for _, provider := range keys { 88 cfg := o.Authboss.Config.Modules.OAuth2Providers[provider] 89 provider = strings.ToLower(provider) 90 91 init := fmt.Sprintf("/oauth2/%s", provider) 92 callback := fmt.Sprintf("/oauth2/callback/%s", provider) 93 94 o.Authboss.Config.Core.Router.Get(init, o.Authboss.Core.ErrorHandler.Wrap(o.Start)) 95 o.Authboss.Config.Core.Router.Get(callback, o.Authboss.Core.ErrorHandler.Wrap(o.End)) 96 97 if mount := o.Authboss.Config.Paths.Mount; len(mount) > 0 { 98 callback = path.Join(mount, callback) 99 } 100 101 cfg.OAuth2Config.RedirectURL = o.Authboss.Config.Paths.RootURL + callback 102 } 103 104 return nil 105 } 106 107 // Start the oauth2 process 108 func (o *OAuth2) Start(w http.ResponseWriter, r *http.Request) error { 109 logger := o.Authboss.RequestLogger(r) 110 111 provider := strings.ToLower(filepath.Base(r.URL.Path)) 112 logger.Infof("started oauth2 flow for provider: %s", provider) 113 cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider] 114 if !ok { 115 return errors.Errorf("oauth2 provider %q not found", provider) 116 } 117 118 // Create nonce 119 nonce := make([]byte, 32) 120 if _, err := io.ReadFull(rand.Reader, nonce); err != nil { 121 return errors.Wrap(err, "failed to create nonce") 122 } 123 124 state := base64.URLEncoding.EncodeToString(nonce) 125 authboss.PutSession(w, authboss.SessionOAuth2State, state) 126 127 // This clearly ignores the fact that query parameters can have multiple 128 // values but I guess we're ignoring that 129 passAlongs := make(map[string]string) 130 for k, vals := range r.URL.Query() { 131 for _, val := range vals { 132 passAlongs[k] = val 133 } 134 } 135 136 if len(passAlongs) > 0 { 137 byt, err := json.Marshal(passAlongs) 138 if err != nil { 139 return err 140 } 141 authboss.PutSession(w, authboss.SessionOAuth2Params, string(byt)) 142 } else { 143 authboss.DelSession(w, authboss.SessionOAuth2Params) 144 } 145 146 authCodeUrl := cfg.OAuth2Config.AuthCodeURL(state) 147 148 extraParams := cfg.AdditionalParams.Encode() 149 if len(extraParams) > 0 { 150 authCodeUrl = fmt.Sprintf("%s&%s", authCodeUrl, extraParams) 151 } 152 153 ro := authboss.RedirectOptions{ 154 Code: http.StatusTemporaryRedirect, 155 RedirectPath: authCodeUrl, 156 } 157 return o.Authboss.Core.Redirector.Redirect(w, r, ro) 158 } 159 160 // for testing, mocked out at the beginning 161 var exchanger = (*oauth2.Config).Exchange 162 163 // End the oauth2 process, this is the handler for the oauth2 callback 164 // that the third party will redirect to. 165 func (o *OAuth2) End(w http.ResponseWriter, r *http.Request) error { 166 logger := o.Authboss.RequestLogger(r) 167 provider := strings.ToLower(filepath.Base(r.URL.Path)) 168 logger.Infof("finishing oauth2 flow for provider: %s", provider) 169 170 // This shouldn't happen because the router should 404 first, but just in case 171 cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider] 172 if !ok { 173 return errors.Errorf("oauth2 provider %q not found", provider) 174 } 175 176 wantState, ok := authboss.GetSession(r, authboss.SessionOAuth2State) 177 if !ok { 178 return errors.New("oauth2 endpoint hit without session state") 179 } 180 181 // Verify we got the same state in the session as was passed to us in the 182 // query parameter. 183 state := r.FormValue(FormValueOAuth2State) 184 if state != wantState { 185 return errOAuthStateValidation 186 } 187 188 rawParams, ok := authboss.GetSession(r, authboss.SessionOAuth2Params) 189 var params map[string]string 190 if ok { 191 if err := json.Unmarshal([]byte(rawParams), ¶ms); err != nil { 192 return errors.Wrap(err, "failed to decode oauth2 params") 193 } 194 } 195 196 authboss.DelSession(w, authboss.SessionOAuth2State) 197 authboss.DelSession(w, authboss.SessionOAuth2Params) 198 199 hasErr := r.FormValue("error") 200 if len(hasErr) > 0 { 201 reason := r.FormValue("error_reason") 202 logger.Infof("oauth2 login failed: %s, reason: %s", hasErr, reason) 203 204 handled, err := o.Authboss.Events.FireAfter(authboss.EventOAuth2Fail, w, r) 205 if err != nil { 206 return err 207 } else if handled { 208 return nil 209 } 210 211 ro := authboss.RedirectOptions{ 212 Code: http.StatusTemporaryRedirect, 213 RedirectPath: o.Authboss.Config.Paths.OAuth2LoginNotOK, 214 Failure: fmt.Sprintf("%s login cancelled or failed", strings.Title(provider)), 215 } 216 return o.Authboss.Core.Redirector.Redirect(w, r, ro) 217 } 218 219 // Get the code which we can use to make an access token 220 code := r.FormValue("code") 221 token, err := exchanger(cfg.OAuth2Config, r.Context(), code) 222 if err != nil { 223 return errors.Wrap(err, "could not validate oauth2 code") 224 } 225 226 details, err := cfg.FindUserDetails(r.Context(), *cfg.OAuth2Config, token) 227 if err != nil { 228 return err 229 } 230 231 storer := authboss.EnsureCanOAuth2(o.Authboss.Config.Storage.Server) 232 user, err := storer.NewFromOAuth2(r.Context(), provider, details) 233 if err != nil { 234 return errors.Wrap(err, "failed to create oauth2 user from values") 235 } 236 237 user.PutOAuth2Provider(provider) 238 user.PutOAuth2AccessToken(token.AccessToken) 239 user.PutOAuth2Expiry(token.Expiry) 240 if len(token.RefreshToken) != 0 { 241 user.PutOAuth2RefreshToken(token.RefreshToken) 242 } 243 244 if err := storer.SaveOAuth2(r.Context(), user); err != nil { 245 return err 246 } 247 248 r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user)) 249 250 handled, err := o.Authboss.Events.FireBefore(authboss.EventOAuth2, w, r) 251 if err != nil { 252 return err 253 } else if handled { 254 return nil 255 } 256 257 // Fully log user in 258 authboss.PutSession(w, authboss.SessionKey, authboss.MakeOAuth2PID(provider, user.GetOAuth2UID())) 259 authboss.DelSession(w, authboss.SessionHalfAuthKey) 260 261 // Create a query string from all the pieces we've received 262 // as passthru from the original request. 263 redirect := o.Authboss.Config.Paths.OAuth2LoginOK 264 query := make(url.Values) 265 for k, v := range params { 266 switch k { 267 case authboss.CookieRemember: 268 if v == "true" { 269 r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyValues, RMTrue{})) 270 } 271 case FormValueOAuth2Redir: 272 redirect = v 273 default: 274 query.Set(k, v) 275 } 276 } 277 278 handled, err = o.Authboss.Events.FireAfter(authboss.EventOAuth2, w, r) 279 if err != nil { 280 return err 281 } else if handled { 282 return nil 283 } 284 285 if len(query) > 0 { 286 redirect = fmt.Sprintf("%s?%s", redirect, query.Encode()) 287 } 288 289 ro := authboss.RedirectOptions{ 290 Code: http.StatusTemporaryRedirect, 291 RedirectPath: redirect, 292 Success: fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)), 293 } 294 return o.Authboss.Config.Core.Redirector.Redirect(w, r, ro) 295 } 296 297 // RMTrue is a dummy struct implementing authboss.RememberValuer 298 // in order to tell the remember me module to remember them. 299 type RMTrue struct{} 300 301 // GetShouldRemember always returns true 302 func (RMTrue) GetShouldRemember() bool { return true }