github.com/volatiletech/authboss@v2.4.1+incompatible/register/register.go (about)

     1  // Package register allows for user registration.
     2  package register
     3  
     4  import (
     5  	"context"
     6  	"net/http"
     7  	"sort"
     8  
     9  	"github.com/pkg/errors"
    10  
    11  	"github.com/volatiletech/authboss"
    12  	"golang.org/x/crypto/bcrypt"
    13  )
    14  
    15  // Pages
    16  const (
    17  	PageRegister = "register"
    18  )
    19  
    20  func init() {
    21  	authboss.RegisterModule("register", &Register{})
    22  }
    23  
    24  // Register module.
    25  type Register struct {
    26  	*authboss.Authboss
    27  }
    28  
    29  // Init the module.
    30  func (r *Register) Init(ab *authboss.Authboss) (err error) {
    31  	r.Authboss = ab
    32  
    33  	if _, ok := ab.Config.Storage.Server.(authboss.CreatingServerStorer); !ok {
    34  		return errors.New("register module activated but storer could not be upgraded to CreatingServerStorer")
    35  	}
    36  
    37  	if err := ab.Config.Core.ViewRenderer.Load(PageRegister); err != nil {
    38  		return err
    39  	}
    40  
    41  	sort.Strings(ab.Config.Modules.RegisterPreserveFields)
    42  
    43  	ab.Config.Core.Router.Get("/register", ab.Config.Core.ErrorHandler.Wrap(r.Get))
    44  	ab.Config.Core.Router.Post("/register", ab.Config.Core.ErrorHandler.Wrap(r.Post))
    45  
    46  	return nil
    47  }
    48  
    49  // Get the register page
    50  func (r *Register) Get(w http.ResponseWriter, req *http.Request) error {
    51  	return r.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRegister, nil)
    52  }
    53  
    54  // Post to the register page
    55  func (r *Register) Post(w http.ResponseWriter, req *http.Request) error {
    56  	logger := r.RequestLogger(req)
    57  	validatable, err := r.Core.BodyReader.Read(PageRegister, req)
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	var arbitrary map[string]string
    63  	var preserve map[string]string
    64  	if arb, ok := validatable.(authboss.ArbitraryValuer); ok {
    65  		arbitrary = arb.GetValues()
    66  		preserve = make(map[string]string)
    67  
    68  		for k, v := range arbitrary {
    69  			if hasString(r.Config.Modules.RegisterPreserveFields, k) {
    70  				preserve[k] = v
    71  			}
    72  		}
    73  	}
    74  
    75  	errs := validatable.Validate()
    76  	if errs != nil {
    77  		logger.Info("registration validation failed")
    78  		data := authboss.HTMLData{
    79  			authboss.DataValidation: authboss.ErrorMap(errs),
    80  		}
    81  		if preserve != nil {
    82  			data[authboss.DataPreserve] = preserve
    83  		}
    84  		return r.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRegister, data)
    85  	}
    86  
    87  	// Get values from request
    88  	userVals := authboss.MustHaveUserValues(validatable)
    89  	pid, password := userVals.GetPID(), userVals.GetPassword()
    90  
    91  	// Put values into newly created user for storage
    92  	storer := authboss.EnsureCanCreate(r.Config.Storage.Server)
    93  	user := authboss.MustBeAuthable(storer.New(req.Context()))
    94  
    95  	pass, err := bcrypt.GenerateFromPassword([]byte(password), r.Config.Modules.BCryptCost)
    96  	if err != nil {
    97  		return err
    98  	}
    99  
   100  	user.PutPID(pid)
   101  	user.PutPassword(string(pass))
   102  
   103  	if arbUser, ok := user.(authboss.ArbitraryUser); ok && arbitrary != nil {
   104  		arbUser.PutArbitrary(arbitrary)
   105  	}
   106  
   107  	err = storer.Create(req.Context(), user)
   108  	switch {
   109  	case err == authboss.ErrUserFound:
   110  		logger.Infof("user %s attempted to re-register", pid)
   111  		errs = []error{errors.New("user already exists")}
   112  		data := authboss.HTMLData{
   113  			authboss.DataValidation: authboss.ErrorMap(errs),
   114  		}
   115  		if preserve != nil {
   116  			data[authboss.DataPreserve] = preserve
   117  		}
   118  		return r.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRegister, data)
   119  	case err != nil:
   120  		return err
   121  	}
   122  
   123  	req = req.WithContext(context.WithValue(req.Context(), authboss.CTXKeyUser, user))
   124  	handled, err := r.Events.FireAfter(authboss.EventRegister, w, req)
   125  	if err != nil {
   126  		return err
   127  	} else if handled {
   128  		return nil
   129  	}
   130  
   131  	// Log the user in, but only if the response wasn't handled previously
   132  	// by a module like confirm.
   133  	authboss.PutSession(w, authboss.SessionKey, pid)
   134  
   135  	logger.Infof("registered and logged in user %s", pid)
   136  	ro := authboss.RedirectOptions{
   137  		Code:         http.StatusTemporaryRedirect,
   138  		Success:      "Account successfully created, you are now logged in",
   139  		RedirectPath: r.Config.Paths.RegisterOK,
   140  	}
   141  	return r.Config.Core.Redirector.Redirect(w, req, ro)
   142  }
   143  
   144  // hasString checks to see if a sorted (ascending) array of
   145  // strings contains a string
   146  func hasString(arr []string, s string) bool {
   147  	index := sort.SearchStrings(arr, s)
   148  	if index < 0 || index >= len(arr) {
   149  		return false
   150  	}
   151  
   152  	return arr[index] == s
   153  }