github.com/volatiletech/authboss@v2.4.1+incompatible/otp/twofactor/twofactor_recover.go (about)

     1  // Package twofactor allows authentication via one time passwords
     2  package twofactor
     3  
     4  import (
     5  	"crypto/rand"
     6  	"io"
     7  	"net/http"
     8  	"strings"
     9  
    10  	"github.com/volatiletech/authboss"
    11  	"golang.org/x/crypto/bcrypt"
    12  )
    13  
    14  // Recovery for two-factor authentication is handled by this type
    15  type Recovery struct {
    16  	*authboss.Authboss
    17  }
    18  
    19  // Setup the module to provide recovery regeneration routes
    20  func (rc *Recovery) Setup() error {
    21  	var unauthedResponse authboss.MWRespondOnFailure
    22  	if rc.Config.Modules.ResponseOnUnauthed != 0 {
    23  		unauthedResponse = rc.Config.Modules.ResponseOnUnauthed
    24  	} else if rc.Config.Modules.RoutesRedirectOnUnauthed {
    25  		unauthedResponse = authboss.RespondRedirect
    26  	}
    27  	middleware := authboss.MountedMiddleware2(rc.Authboss, true, authboss.RequireFullAuth, unauthedResponse)
    28  	rc.Authboss.Core.Router.Get("/2fa/recovery/regen", middleware(rc.Authboss.Core.ErrorHandler.Wrap(rc.GetRegen)))
    29  	rc.Authboss.Core.Router.Post("/2fa/recovery/regen", middleware(rc.Authboss.Core.ErrorHandler.Wrap(rc.PostRegen)))
    30  
    31  	return rc.Authboss.Core.ViewRenderer.Load(PageRecovery2FA)
    32  }
    33  
    34  // GetRegen shows a button that enables a user to regen their codes
    35  // as well as how many codes are currently remaining.
    36  func (rc *Recovery) GetRegen(w http.ResponseWriter, r *http.Request) error {
    37  	abUser, err := rc.CurrentUser(r)
    38  	if err != nil {
    39  		return err
    40  	}
    41  	user := abUser.(User)
    42  
    43  	var nCodes int
    44  	codes := user.GetRecoveryCodes()
    45  	if len(codes) != 0 {
    46  		nCodes++
    47  	}
    48  	for _, c := range codes {
    49  		if c == ',' {
    50  			nCodes++
    51  		}
    52  	}
    53  
    54  	data := authboss.HTMLData{DataNumRecoveryCodes: nCodes}
    55  	return rc.Authboss.Core.Responder.Respond(w, r, http.StatusOK, PageRecovery2FA, data)
    56  }
    57  
    58  // PostRegen regenerates the codes
    59  func (rc *Recovery) PostRegen(w http.ResponseWriter, r *http.Request) error {
    60  	abUser, err := rc.CurrentUser(r)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	user := abUser.(User)
    65  
    66  	codes, err := GenerateRecoveryCodes()
    67  	if err != nil {
    68  		return err
    69  	}
    70  
    71  	hashedCodes, err := BCryptRecoveryCodes(codes)
    72  	if err != nil {
    73  		return err
    74  	}
    75  
    76  	user.PutRecoveryCodes(EncodeRecoveryCodes(hashedCodes))
    77  	if err = rc.Authboss.Config.Storage.Server.Save(r.Context(), user); err != nil {
    78  		return err
    79  	}
    80  
    81  	data := authboss.HTMLData{DataRecoveryCodes: codes}
    82  	return rc.Authboss.Core.Responder.Respond(w, r, http.StatusOK, PageRecovery2FA, data)
    83  }
    84  
    85  // GenerateRecoveryCodes creates 10 recovery codes of the form:
    86  // abd34-1b24do (using alphabet, of length recoveryCodeLength).
    87  func GenerateRecoveryCodes() ([]string, error) {
    88  	byt := make([]byte, 10*recoveryCodeLength)
    89  	if _, err := io.ReadFull(rand.Reader, byt); err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	codes := make([]string, 10)
    94  	for i := range codes {
    95  		builder := new(strings.Builder)
    96  		for j := 0; j < recoveryCodeLength; j++ {
    97  			if recoveryCodeLength/2 == j {
    98  				builder.WriteByte('-')
    99  			}
   100  
   101  			randNumber := byt[i*recoveryCodeLength+j] % byte(len(alphabet))
   102  			builder.WriteByte(alphabet[randNumber])
   103  		}
   104  		codes[i] = builder.String()
   105  	}
   106  
   107  	return codes, nil
   108  }
   109  
   110  // BCryptRecoveryCodes hashes each recovery code given and return them in a new
   111  // slice.
   112  func BCryptRecoveryCodes(codes []string) ([]string, error) {
   113  	cryptedCodes := make([]string, len(codes))
   114  	for i, c := range codes {
   115  		hash, err := bcrypt.GenerateFromPassword([]byte(c), bcrypt.DefaultCost)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  
   120  		cryptedCodes[i] = string(hash)
   121  	}
   122  
   123  	return cryptedCodes, nil
   124  }
   125  
   126  // UseRecoveryCode deletes the code that was used from the string slice and
   127  // returns it, the bool is true if a code was used
   128  func UseRecoveryCode(codes []string, inputCode string) ([]string, bool) {
   129  	input := []byte(inputCode)
   130  	use := -1
   131  
   132  	for i, c := range codes {
   133  		err := bcrypt.CompareHashAndPassword([]byte(c), input)
   134  		if err == nil {
   135  			use = i
   136  			break
   137  		}
   138  	}
   139  
   140  	if use < 0 {
   141  		return nil, false
   142  	}
   143  
   144  	ret := make([]string, len(codes)-1)
   145  	for j := range codes {
   146  		if j == use {
   147  			continue
   148  		}
   149  		set := j
   150  		if j > use {
   151  			set--
   152  		}
   153  		ret[set] = codes[j]
   154  	}
   155  
   156  	return ret, true
   157  }
   158  
   159  // EncodeRecoveryCodes is an alias for strings.Join(",")
   160  func EncodeRecoveryCodes(codes []string) string { return strings.Join(codes, ",") }
   161  
   162  // DecodeRecoveryCodes is an alias for strings.Split(",")
   163  func DecodeRecoveryCodes(codes string) []string { return strings.Split(codes, ",") }