github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/auth.go (about)

     1  /*
     2  *
     3  * Gosora Authentication Interface
     4  * Copyright Azareal 2017 - 2020
     5  *
     6   */
     7  package common
     8  
     9  import (
    10  	"crypto/sha256"
    11  	"crypto/subtle"
    12  	"database/sql"
    13  	"encoding/hex"
    14  	"errors"
    15  	"net/http"
    16  	"strconv"
    17  	"strings"
    18  
    19  	"github.com/Azareal/Gosora/common/gauth"
    20  	qgen "github.com/Azareal/Gosora/query_gen"
    21  
    22  	//"golang.org/x/crypto/argon2"
    23  	"golang.org/x/crypto/bcrypt"
    24  )
    25  
    26  // TODO: Write more authentication tests
    27  var Auth AuthInt
    28  
    29  const SaltLength int = 32
    30  const SessionLength int = 80
    31  
    32  // ErrMismatchedHashAndPassword is thrown whenever a hash doesn't match it's unhashed password
    33  var ErrMismatchedHashAndPassword = bcrypt.ErrMismatchedHashAndPassword
    34  
    35  // nolint
    36  var ErrHashNotExist = errors.New("We don't recognise that hashing algorithm")
    37  var ErrTooFewHashParams = errors.New("You haven't provided enough hash parameters")
    38  
    39  // ErrPasswordTooLong is silly, but we don't want bcrypt to bork on us
    40  var ErrPasswordTooLong = errors.New("The password you selected is too long")
    41  var ErrWrongPassword = errors.New("That's not the correct password.")
    42  var ErrBadMFAToken = errors.New("I'm not sure where you got that from, but that's not a valid 2FA token")
    43  var ErrWrongMFAToken = errors.New("That 2FA token isn't correct")
    44  var ErrNoMFAToken = errors.New("This user doesn't have 2FA setup")
    45  var ErrSecretError = errors.New("There was a glitch in the system. Please contact your local administrator.")
    46  var ErrNoUserByName = errors.New("We couldn't find an account with that username.")
    47  var DefaultHashAlgo = "bcrypt" // Override this in the configuration file, not here
    48  
    49  //func(realPassword string, password string, salt string) (err error)
    50  var CheckPasswordFuncs = map[string]func(string, string, string) error{
    51  	"bcrypt": BcryptCheckPassword,
    52  	//"argon2": Argon2CheckPassword,
    53  }
    54  
    55  //func(password string) (hashedPassword string, salt string, err error)
    56  var GeneratePasswordFuncs = map[string]func(string) (string, string, error){
    57  	"bcrypt": BcryptGeneratePassword,
    58  	//"argon2": Argon2GeneratePassword,
    59  }
    60  
    61  // TODO: Redirect 2b to bcrypt too?
    62  var HashPrefixes = map[string]string{
    63  	"$2a$": "bcrypt",
    64  	//"argon2$": "argon2",
    65  }
    66  
    67  // AuthInt is the main authentication interface.
    68  type AuthInt interface {
    69  	Authenticate(name, password string) (uid int, err error, requiresExtraAuth bool)
    70  	ValidateMFAToken(mfaToken string, uid int) error
    71  	Logout(w http.ResponseWriter, uid int)
    72  	ForceLogout(uid int) error
    73  	SetCookies(w http.ResponseWriter, uid int, session string)
    74  	SetProvisionalCookies(w http.ResponseWriter, uid int, session, signedSession string) // To avoid logging someone in until they've passed the MFA check
    75  	GetCookies(r *http.Request) (uid int, session string, err error)
    76  	SessionCheck(w http.ResponseWriter, r *http.Request) (u *User, halt bool)
    77  	CreateSession(uid int) (session string, err error)
    78  	CreateProvisionalSession(uid int) (provSession, signedSession string, err error) // To avoid logging someone in until they've passed the MFA check
    79  }
    80  
    81  // DefaultAuth is the default authenticator used by Gosora, may be swapped with an alternate authenticator in some situations. E.g. To support LDAP.
    82  type DefaultAuth struct {
    83  	login         *sql.Stmt
    84  	logout        *sql.Stmt
    85  	updateSession *sql.Stmt
    86  }
    87  
    88  // NewDefaultAuth is a factory for spitting out DefaultAuths
    89  func NewDefaultAuth() (*DefaultAuth, error) {
    90  	acc := qgen.NewAcc()
    91  	return &DefaultAuth{
    92  		login:         acc.Select("users").Columns("uid, password, salt").Where("name = ?").Prepare(),
    93  		logout:        acc.Update("users").Set("session = ''").Where("uid = ?").Prepare(),
    94  		updateSession: acc.Update("users").Set("session = ?").Where("uid = ?").Prepare(),
    95  	}, acc.FirstError()
    96  }
    97  
    98  // Authenticate checks if a specific username and password is valid and returns the UID for the corresponding user, if so. Otherwise, a user safe error.
    99  // IF MFA is enabled, then pass it back a flag telling the caller that authentication isn't complete yet
   100  // TODO: Find a better way of handling errors we don't want to reach the user
   101  func (auth *DefaultAuth) Authenticate(name, password string) (uid int, err error, requiresExtraAuth bool) {
   102  	var realPassword, salt string
   103  	err = auth.login.QueryRow(name).Scan(&uid, &realPassword, &salt)
   104  	if err == ErrNoRows {
   105  		return 0, ErrNoUserByName, false
   106  	} else if err != nil {
   107  		LogError(err)
   108  		return 0, ErrSecretError, false
   109  	}
   110  
   111  	err = CheckPassword(realPassword, password, salt)
   112  	if err == ErrMismatchedHashAndPassword {
   113  		return 0, ErrWrongPassword, false
   114  	} else if err != nil {
   115  		LogError(err)
   116  		return 0, ErrSecretError, false
   117  	}
   118  
   119  	_, err = MFAstore.Get(uid)
   120  	if err != sql.ErrNoRows && err != nil {
   121  		LogError(err)
   122  		return 0, ErrSecretError, false
   123  	}
   124  	if err != ErrNoRows {
   125  		return uid, nil, true
   126  	}
   127  
   128  	return uid, nil, false
   129  }
   130  
   131  func (auth *DefaultAuth) ValidateMFAToken(mfaToken string, uid int) error {
   132  	mfaItem, err := MFAstore.Get(uid)
   133  	if err != sql.ErrNoRows && err != nil {
   134  		LogError(err)
   135  		return ErrSecretError
   136  	}
   137  	if err == ErrNoRows {
   138  		return ErrNoMFAToken
   139  	}
   140  
   141  	ok, err := VerifyGAuthToken(mfaItem.Secret, mfaToken)
   142  	if err != nil {
   143  		return ErrBadMFAToken
   144  	}
   145  	if ok {
   146  		return nil
   147  	}
   148  
   149  	for i, scratch := range mfaItem.Scratch {
   150  		if subtle.ConstantTimeCompare([]byte(scratch), []byte(mfaToken)) == 1 {
   151  			err = mfaItem.BurnScratch(i)
   152  			if err != nil {
   153  				LogError(err)
   154  				return ErrSecretError
   155  			}
   156  			return nil
   157  		}
   158  	}
   159  
   160  	return ErrWrongMFAToken
   161  }
   162  
   163  // ForceLogout logs the user out of every computer, not just the one they logged out of
   164  func (auth *DefaultAuth) ForceLogout(uid int) error {
   165  	_, err := auth.logout.Exec(uid)
   166  	if err != nil {
   167  		LogError(err)
   168  		return ErrSecretError
   169  	}
   170  
   171  	// Flush the user out of the cache
   172  	if uc := Users.GetCache(); uc != nil {
   173  		uc.Remove(uid)
   174  	}
   175  	return nil
   176  }
   177  
   178  func setCookie(w http.ResponseWriter, cookie *http.Cookie, sameSite string) {
   179  	if v := cookie.String(); v != "" {
   180  		switch sameSite {
   181  		case "lax":
   182  			v = v + "; SameSite=lax"
   183  		case "strict":
   184  			v = v + "; SameSite"
   185  		}
   186  		w.Header().Add("Set-Cookie", v)
   187  	}
   188  }
   189  
   190  func deleteCookie(w http.ResponseWriter, cookie *http.Cookie) {
   191  	cookie.MaxAge = -1
   192  	http.SetCookie(w, cookie)
   193  }
   194  
   195  // Logout logs you out of the computer you requested the logout for, but not the other computers you're logged in with
   196  func (auth *DefaultAuth) Logout(w http.ResponseWriter, _ int) {
   197  	cookie := http.Cookie{Name: "uid", Value: "", Path: "/"}
   198  	deleteCookie(w, &cookie)
   199  	cookie = http.Cookie{Name: "session", Value: "", Path: "/"}
   200  	deleteCookie(w, &cookie)
   201  }
   202  
   203  // TODO: Set the cookie domain
   204  // SetCookies sets the two cookies required for the current user to be recognised as a specific user in future requests
   205  func (auth *DefaultAuth) SetCookies(w http.ResponseWriter, uid int, session string) {
   206  	cookie := http.Cookie{Name: "uid", Value: strconv.Itoa(uid), Path: "/", MaxAge: int(Year)}
   207  	setCookie(w, &cookie, "lax")
   208  	cookie = http.Cookie{Name: "session", Value: session, Path: "/", MaxAge: int(Year)}
   209  	setCookie(w, &cookie, "lax")
   210  }
   211  
   212  // TODO: Set the cookie domain
   213  // SetProvisionalCookies sets the two cookies required for guests to be recognised as having passed the initial login but not having passed the additional checks (e.g. multi-factor authentication)
   214  func (auth *DefaultAuth) SetProvisionalCookies(w http.ResponseWriter, uid int, provSession, signedSession string) {
   215  	cookie := http.Cookie{Name: "uid", Value: strconv.Itoa(uid), Path: "/", MaxAge: int(Year)}
   216  	setCookie(w, &cookie, "lax")
   217  	cookie = http.Cookie{Name: "provSession", Value: provSession, Path: "/", MaxAge: int(Year)}
   218  	setCookie(w, &cookie, "lax")
   219  	cookie = http.Cookie{Name: "signedSession", Value: signedSession, Path: "/", MaxAge: int(Year)}
   220  	setCookie(w, &cookie, "lax")
   221  }
   222  
   223  // GetCookies fetches the current user's session cookies
   224  func (auth *DefaultAuth) GetCookies(r *http.Request) (uid int, session string, err error) {
   225  	// Are there any session cookies..?
   226  	cookie, err := r.Cookie("uid")
   227  	if err != nil {
   228  		return 0, "", err
   229  	}
   230  	uid, err = strconv.Atoi(cookie.Value)
   231  	if err != nil {
   232  		return 0, "", err
   233  	}
   234  	cookie, err = r.Cookie("session")
   235  	if err != nil {
   236  		return 0, "", err
   237  	}
   238  	return uid, cookie.Value, err
   239  }
   240  
   241  // SessionCheck checks if a user has session cookies and whether they're valid
   242  func (auth *DefaultAuth) SessionCheck(w http.ResponseWriter, r *http.Request) (user *User, halt bool) {
   243  	uid, session, err := auth.GetCookies(r)
   244  	if err != nil {
   245  		return &GuestUser, false
   246  	}
   247  
   248  	// Is this session valid..?
   249  	user, err = Users.Get(uid)
   250  	if err == ErrNoRows {
   251  		return &GuestUser, false
   252  	} else if err != nil {
   253  		InternalError(err, w, r)
   254  		return &GuestUser, true
   255  	}
   256  
   257  	// We need to do a constant time compare, otherwise someone might be able to deduce the session character by character based on how long it takes to do the comparison. Change this at your own peril.
   258  	if user.Session == "" || subtle.ConstantTimeCompare([]byte(session), []byte(user.Session)) != 1 {
   259  		return &GuestUser, false
   260  	}
   261  
   262  	return user, false
   263  }
   264  
   265  // CreateSession generates a new session to allow a remote client to stay logged in as a specific user
   266  func (auth *DefaultAuth) CreateSession(uid int) (session string, err error) {
   267  	session, err = GenerateSafeString(SessionLength)
   268  	if err != nil {
   269  		return "", err
   270  	}
   271  
   272  	_, err = auth.updateSession.Exec(session, uid)
   273  	if err != nil {
   274  		return "", err
   275  	}
   276  
   277  	// Flush the user data from the cache
   278  	ucache := Users.GetCache()
   279  	if ucache != nil {
   280  		ucache.Remove(uid)
   281  	}
   282  	return session, nil
   283  }
   284  
   285  func (auth *DefaultAuth) CreateProvisionalSession(uid int) (provSession, signedSession string, err error) {
   286  	provSession, err = GenerateSafeString(SessionLength)
   287  	if err != nil {
   288  		return "", "", err
   289  	}
   290  
   291  	h := sha256.New()
   292  	h.Write([]byte(SessionSigningKeyBox.Load().(string)))
   293  	h.Write([]byte(provSession))
   294  	h.Write([]byte(strconv.Itoa(uid)))
   295  	return provSession, hex.EncodeToString(h.Sum(nil)), nil
   296  }
   297  
   298  func CheckPassword(realPassword, password, salt string) (err error) {
   299  	blasted := strings.Split(realPassword, "$")
   300  	prefix := blasted[0]
   301  	if len(blasted) > 1 {
   302  		prefix += "$" + blasted[1] + "$"
   303  	}
   304  	algo, ok := HashPrefixes[prefix]
   305  	if !ok {
   306  		return ErrHashNotExist
   307  	}
   308  	checker := CheckPasswordFuncs[algo]
   309  	return checker(realPassword, password, salt)
   310  }
   311  
   312  func GeneratePassword(password string) (hash, salt string, err error) {
   313  	gen, ok := GeneratePasswordFuncs[DefaultHashAlgo]
   314  	if !ok {
   315  		return "", "", ErrHashNotExist
   316  	}
   317  	return gen(password)
   318  }
   319  
   320  func BcryptCheckPassword(realPassword, password, salt string) (err error) {
   321  	return bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password+salt))
   322  }
   323  
   324  // Note: The salt is in the hash, therefore the salt parameter is blank
   325  func BcryptGeneratePassword(password string) (hash, salt string, err error) {
   326  	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
   327  	if err != nil {
   328  		return "", "", err
   329  	}
   330  	return string(hashedPassword), salt, nil
   331  }
   332  
   333  /*const (
   334  	argon2Time    uint32 = 3
   335  	argon2Memory  uint32 = 32 * 1024
   336  	argon2Threads uint8  = 4
   337  	argon2KeyLen  uint32 = 32
   338  )
   339  
   340  func Argon2CheckPassword(realPassword, password, salt string) (err error) {
   341  	split := strings.Split(realPassword, "$")
   342  	// TODO: Better validation
   343  	if len(split) < 5 {
   344  		return ErrTooFewHashParams
   345  	}
   346  	realKey, _ := base64.StdEncoding.DecodeString(split[len(split)-1])
   347  	time, _ := strconv.Atoi(split[1])
   348  	memory, _ := strconv.Atoi(split[2])
   349  	threads, _ := strconv.Atoi(split[3])
   350  	keyLen, _ := strconv.Atoi(split[4])
   351  	key := argon2.Key([]byte(password), []byte(salt), uint32(time), uint32(memory), uint8(threads), uint32(keyLen))
   352  	if subtle.ConstantTimeCompare(realKey, key) != 1 {
   353  		return ErrMismatchedHashAndPassword
   354  	}
   355  	return nil
   356  }
   357  
   358  func Argon2GeneratePassword(password string) (hash, salt string, err error) {
   359  	sbytes := make([]byte, SaltLength)
   360  	_, err = rand.Read(sbytes)
   361  	if err != nil {
   362  		return "", "", err
   363  	}
   364  	key := argon2.Key([]byte(password), sbytes, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
   365  	hash = base64.StdEncoding.EncodeToString(key)
   366  	return fmt.Sprintf("argon2$%d%d%d%d%s%s", argon2Time, argon2Memory, argon2Threads, argon2KeyLen, salt, hash), string(sbytes), nil
   367  }
   368  */
   369  
   370  // TODO: Test this with Google Authenticator proper
   371  func FriendlyGAuthSecret(secret string) (out string) {
   372  	for i, char := range secret {
   373  		out += string(char)
   374  		if (i+1)%4 == 0 {
   375  			out += " "
   376  		}
   377  	}
   378  	return strings.TrimSpace(out)
   379  }
   380  func GenerateGAuthSecret() (string, error) {
   381  	return GenerateStd32SafeString(14)
   382  }
   383  func VerifyGAuthToken(secret, token string) (bool, error) {
   384  	trueToken, err := gauth.GetTOTPToken(secret)
   385  	return subtle.ConstantTimeCompare([]byte(trueToken), []byte(token)) == 1, err
   386  }