github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/auth/auth.go (about)

     1  package auth
     2  
     3  import (
     4  	"crypto/subtle"
     5  	"encoding/base64"
     6  	"errors"
     7  	"net/http"
     8  	"strconv"
     9  
    10  	"github.com/sohaha/zlsgo/zerror"
    11  	"github.com/sohaha/zlsgo/znet"
    12  	"github.com/sohaha/zlsgo/zstring"
    13  )
    14  
    15  const UserKey = "auth_user"
    16  
    17  type (
    18  	Accounts  map[string]string
    19  	authPairs []authPair
    20  	authPair  struct {
    21  		value string
    22  		user  string
    23  	}
    24  )
    25  
    26  func (a authPairs) searchCredential(authValue string) (string, bool) {
    27  	if authValue == "" {
    28  		return "", false
    29  	}
    30  
    31  	for _, pair := range a {
    32  		if subtle.ConstantTimeCompare(zstring.String2Bytes(pair.value), zstring.String2Bytes(authValue)) == 1 {
    33  			return pair.user, true
    34  		}
    35  	}
    36  	return "", false
    37  }
    38  
    39  func New(accounts Accounts) znet.Handler {
    40  	return BasicRealm(accounts, "")
    41  }
    42  
    43  func BasicRealm(accounts Accounts, realm string) znet.Handler {
    44  	if realm == "" {
    45  		realm = "Authorization Required"
    46  	}
    47  	realm = "Basic realm=" + strconv.Quote(realm)
    48  	pairs, err := processAccounts(accounts)
    49  	zerror.Panic(err)
    50  
    51  	return func(c *znet.Context) {
    52  		user, found := pairs.searchCredential(c.GetHeader("Authorization"))
    53  		if !found {
    54  			c.SetHeader("WWW-Authenticate", realm)
    55  			c.Abort(http.StatusUnauthorized)
    56  			return
    57  		}
    58  
    59  		c.WithValue(UserKey, user)
    60  		c.Next()
    61  	}
    62  }
    63  
    64  func processAccounts(accounts Accounts) (authPairs, error) {
    65  	length := len(accounts)
    66  	if length == 0 {
    67  		return nil, errors.New("empty list of authorized credentials")
    68  	}
    69  	pairs := make(authPairs, 0, length)
    70  	for user, password := range accounts {
    71  		if user == "" {
    72  			return nil, errors.New("user can not be empty")
    73  		}
    74  		value := authorizationHeader(user, password)
    75  		pairs = append(pairs, authPair{
    76  			value: value,
    77  			user:  user,
    78  		})
    79  	}
    80  	return pairs, nil
    81  }
    82  
    83  func authorizationHeader(user, password string) string {
    84  	base := user + ":" + password
    85  	return "Basic " + base64.StdEncoding.EncodeToString(zstring.String2Bytes(base))
    86  }