github.com/volatiletech/authboss@v2.4.1+incompatible/oauth2/oauth2.go (about)

     1  // Package oauth2 allows users to be created and authenticated
     2  // via oauth2 services like facebook, google etc. Currently
     3  // only the web server flow is supported.
     4  //
     5  // The general flow looks like this:
     6  //   1. User goes to Start handler and has his session packed with goodies
     7  //      then redirects to the OAuth service.
     8  //   2. OAuth service returns to OAuthCallback which extracts state and
     9  //      parameters and generally checks that everything is ok. It uses the
    10  //      token received to get an access token from the oauth2 library
    11  //   3. Calls the OAuth2Provider.FindUserDetails which should return the user's
    12  //      details in a generic form.
    13  //   4. Passes the user details into the OAuth2ServerStorer.NewFromOAuth2 in
    14  //      order to create a user object we can work with.
    15  //   5. Saves the user in the database, logs them in, redirects.
    16  //
    17  // In order to do this there are a number of parts:
    18  //   1. The configuration of a provider
    19  //      (handled by authboss.Config.Modules.OAuth2Providers).
    20  //   2. The flow of redirection of client, parameter passing etc
    21  //      (handled by this package)
    22  //   3. The HTTP call to the service once a token has been retrieved to
    23  //      get user details (handled by OAuth2Provider.FindUserDetails)
    24  //   4. The creation of a user from the user details returned from the
    25  //      FindUserDetails (authboss.OAuth2ServerStorer)
    26  //
    27  // Of these parts, the responsibility of the authboss library consumer
    28  // is on 1, 3, and 4. Configuration of providers that should be used is totally
    29  // up to the consumer. The FindUserDetails function is typically up to the
    30  // user, but we have some basic ones included in this package too.
    31  // The creation of users from the FindUserDetail's map[string]string return
    32  // is handled as part of the implementation of the OAuth2ServerStorer.
    33  package oauth2
    34  
    35  import (
    36  	"context"
    37  	"crypto/rand"
    38  	"encoding/base64"
    39  	"encoding/json"
    40  	"fmt"
    41  	"io"
    42  	"net/http"
    43  	"net/url"
    44  	"path"
    45  	"path/filepath"
    46  	"sort"
    47  	"strings"
    48  
    49  	"github.com/pkg/errors"
    50  	"golang.org/x/oauth2"
    51  
    52  	"github.com/volatiletech/authboss"
    53  )
    54  
    55  // FormValue constants
    56  const (
    57  	FormValueOAuth2State = "state"
    58  	FormValueOAuth2Redir = "redir"
    59  )
    60  
    61  var (
    62  	errOAuthStateValidation = errors.New("could not validate oauth2 state param")
    63  )
    64  
    65  // OAuth2 module
    66  type OAuth2 struct {
    67  	*authboss.Authboss
    68  }
    69  
    70  func init() {
    71  	authboss.RegisterModule("oauth2", &OAuth2{})
    72  }
    73  
    74  // Init module
    75  func (o *OAuth2) Init(ab *authboss.Authboss) error {
    76  	o.Authboss = ab
    77  
    78  	// Do annoying sorting on keys so we can have predictible
    79  	// route registration (both for consistency inside the router but
    80  	// also for tests -_-)
    81  	var keys []string
    82  	for k := range o.Authboss.Config.Modules.OAuth2Providers {
    83  		keys = append(keys, k)
    84  	}
    85  	sort.Strings(keys)
    86  
    87  	for _, provider := range keys {
    88  		cfg := o.Authboss.Config.Modules.OAuth2Providers[provider]
    89  		provider = strings.ToLower(provider)
    90  
    91  		init := fmt.Sprintf("/oauth2/%s", provider)
    92  		callback := fmt.Sprintf("/oauth2/callback/%s", provider)
    93  
    94  		o.Authboss.Config.Core.Router.Get(init, o.Authboss.Core.ErrorHandler.Wrap(o.Start))
    95  		o.Authboss.Config.Core.Router.Get(callback, o.Authboss.Core.ErrorHandler.Wrap(o.End))
    96  
    97  		if mount := o.Authboss.Config.Paths.Mount; len(mount) > 0 {
    98  			callback = path.Join(mount, callback)
    99  		}
   100  
   101  		cfg.OAuth2Config.RedirectURL = o.Authboss.Config.Paths.RootURL + callback
   102  	}
   103  
   104  	return nil
   105  }
   106  
   107  // Start the oauth2 process
   108  func (o *OAuth2) Start(w http.ResponseWriter, r *http.Request) error {
   109  	logger := o.Authboss.RequestLogger(r)
   110  
   111  	provider := strings.ToLower(filepath.Base(r.URL.Path))
   112  	logger.Infof("started oauth2 flow for provider: %s", provider)
   113  	cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
   114  	if !ok {
   115  		return errors.Errorf("oauth2 provider %q not found", provider)
   116  	}
   117  
   118  	// Create nonce
   119  	nonce := make([]byte, 32)
   120  	if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
   121  		return errors.Wrap(err, "failed to create nonce")
   122  	}
   123  
   124  	state := base64.URLEncoding.EncodeToString(nonce)
   125  	authboss.PutSession(w, authboss.SessionOAuth2State, state)
   126  
   127  	// This clearly ignores the fact that query parameters can have multiple
   128  	// values but I guess we're ignoring that
   129  	passAlongs := make(map[string]string)
   130  	for k, vals := range r.URL.Query() {
   131  		for _, val := range vals {
   132  			passAlongs[k] = val
   133  		}
   134  	}
   135  
   136  	if len(passAlongs) > 0 {
   137  		byt, err := json.Marshal(passAlongs)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		authboss.PutSession(w, authboss.SessionOAuth2Params, string(byt))
   142  	} else {
   143  		authboss.DelSession(w, authboss.SessionOAuth2Params)
   144  	}
   145  
   146  	authCodeUrl := cfg.OAuth2Config.AuthCodeURL(state)
   147  
   148  	extraParams := cfg.AdditionalParams.Encode()
   149  	if len(extraParams) > 0 {
   150  		authCodeUrl = fmt.Sprintf("%s&%s", authCodeUrl, extraParams)
   151  	}
   152  
   153  	ro := authboss.RedirectOptions{
   154  		Code:         http.StatusTemporaryRedirect,
   155  		RedirectPath: authCodeUrl,
   156  	}
   157  	return o.Authboss.Core.Redirector.Redirect(w, r, ro)
   158  }
   159  
   160  // for testing, mocked out at the beginning
   161  var exchanger = (*oauth2.Config).Exchange
   162  
   163  // End the oauth2 process, this is the handler for the oauth2 callback
   164  // that the third party will redirect to.
   165  func (o *OAuth2) End(w http.ResponseWriter, r *http.Request) error {
   166  	logger := o.Authboss.RequestLogger(r)
   167  	provider := strings.ToLower(filepath.Base(r.URL.Path))
   168  	logger.Infof("finishing oauth2 flow for provider: %s", provider)
   169  
   170  	// This shouldn't happen because the router should 404 first, but just in case
   171  	cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
   172  	if !ok {
   173  		return errors.Errorf("oauth2 provider %q not found", provider)
   174  	}
   175  
   176  	wantState, ok := authboss.GetSession(r, authboss.SessionOAuth2State)
   177  	if !ok {
   178  		return errors.New("oauth2 endpoint hit without session state")
   179  	}
   180  
   181  	// Verify we got the same state in the session as was passed to us in the
   182  	// query parameter.
   183  	state := r.FormValue(FormValueOAuth2State)
   184  	if state != wantState {
   185  		return errOAuthStateValidation
   186  	}
   187  
   188  	rawParams, ok := authboss.GetSession(r, authboss.SessionOAuth2Params)
   189  	var params map[string]string
   190  	if ok {
   191  		if err := json.Unmarshal([]byte(rawParams), &params); err != nil {
   192  			return errors.Wrap(err, "failed to decode oauth2 params")
   193  		}
   194  	}
   195  
   196  	authboss.DelSession(w, authboss.SessionOAuth2State)
   197  	authboss.DelSession(w, authboss.SessionOAuth2Params)
   198  
   199  	hasErr := r.FormValue("error")
   200  	if len(hasErr) > 0 {
   201  		reason := r.FormValue("error_reason")
   202  		logger.Infof("oauth2 login failed: %s, reason: %s", hasErr, reason)
   203  
   204  		handled, err := o.Authboss.Events.FireAfter(authboss.EventOAuth2Fail, w, r)
   205  		if err != nil {
   206  			return err
   207  		} else if handled {
   208  			return nil
   209  		}
   210  
   211  		ro := authboss.RedirectOptions{
   212  			Code:         http.StatusTemporaryRedirect,
   213  			RedirectPath: o.Authboss.Config.Paths.OAuth2LoginNotOK,
   214  			Failure:      fmt.Sprintf("%s login cancelled or failed", strings.Title(provider)),
   215  		}
   216  		return o.Authboss.Core.Redirector.Redirect(w, r, ro)
   217  	}
   218  
   219  	// Get the code which we can use to make an access token
   220  	code := r.FormValue("code")
   221  	token, err := exchanger(cfg.OAuth2Config, r.Context(), code)
   222  	if err != nil {
   223  		return errors.Wrap(err, "could not validate oauth2 code")
   224  	}
   225  
   226  	details, err := cfg.FindUserDetails(r.Context(), *cfg.OAuth2Config, token)
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	storer := authboss.EnsureCanOAuth2(o.Authboss.Config.Storage.Server)
   232  	user, err := storer.NewFromOAuth2(r.Context(), provider, details)
   233  	if err != nil {
   234  		return errors.Wrap(err, "failed to create oauth2 user from values")
   235  	}
   236  
   237  	user.PutOAuth2Provider(provider)
   238  	user.PutOAuth2AccessToken(token.AccessToken)
   239  	user.PutOAuth2Expiry(token.Expiry)
   240  	if len(token.RefreshToken) != 0 {
   241  		user.PutOAuth2RefreshToken(token.RefreshToken)
   242  	}
   243  
   244  	if err := storer.SaveOAuth2(r.Context(), user); err != nil {
   245  		return err
   246  	}
   247  
   248  	r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
   249  
   250  	handled, err := o.Authboss.Events.FireBefore(authboss.EventOAuth2, w, r)
   251  	if err != nil {
   252  		return err
   253  	} else if handled {
   254  		return nil
   255  	}
   256  
   257  	// Fully log user in
   258  	authboss.PutSession(w, authboss.SessionKey, authboss.MakeOAuth2PID(provider, user.GetOAuth2UID()))
   259  	authboss.DelSession(w, authboss.SessionHalfAuthKey)
   260  
   261  	// Create a query string from all the pieces we've received
   262  	// as passthru from the original request.
   263  	redirect := o.Authboss.Config.Paths.OAuth2LoginOK
   264  	query := make(url.Values)
   265  	for k, v := range params {
   266  		switch k {
   267  		case authboss.CookieRemember:
   268  			if v == "true" {
   269  				r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyValues, RMTrue{}))
   270  			}
   271  		case FormValueOAuth2Redir:
   272  			redirect = v
   273  		default:
   274  			query.Set(k, v)
   275  		}
   276  	}
   277  
   278  	handled, err = o.Authboss.Events.FireAfter(authboss.EventOAuth2, w, r)
   279  	if err != nil {
   280  		return err
   281  	} else if handled {
   282  		return nil
   283  	}
   284  
   285  	if len(query) > 0 {
   286  		redirect = fmt.Sprintf("%s?%s", redirect, query.Encode())
   287  	}
   288  
   289  	ro := authboss.RedirectOptions{
   290  		Code:         http.StatusTemporaryRedirect,
   291  		RedirectPath: redirect,
   292  		Success:      fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)),
   293  	}
   294  	return o.Authboss.Config.Core.Redirector.Redirect(w, r, ro)
   295  }
   296  
   297  // RMTrue is a dummy struct implementing authboss.RememberValuer
   298  // in order to tell the remember me module to remember them.
   299  type RMTrue struct{}
   300  
   301  // GetShouldRemember always returns true
   302  func (RMTrue) GetShouldRemember() bool { return true }