github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/auth/auth.go (about) 1 package auth 2 3 import ( 4 "crypto/subtle" 5 "encoding/base64" 6 "errors" 7 "net/http" 8 "strconv" 9 10 "github.com/sohaha/zlsgo/zerror" 11 "github.com/sohaha/zlsgo/znet" 12 "github.com/sohaha/zlsgo/zstring" 13 ) 14 15 const UserKey = "auth_user" 16 17 type ( 18 Accounts map[string]string 19 authPairs []authPair 20 authPair struct { 21 value string 22 user string 23 } 24 ) 25 26 func (a authPairs) searchCredential(authValue string) (string, bool) { 27 if authValue == "" { 28 return "", false 29 } 30 31 for _, pair := range a { 32 if subtle.ConstantTimeCompare(zstring.String2Bytes(pair.value), zstring.String2Bytes(authValue)) == 1 { 33 return pair.user, true 34 } 35 } 36 return "", false 37 } 38 39 func New(accounts Accounts) znet.Handler { 40 return BasicRealm(accounts, "") 41 } 42 43 func BasicRealm(accounts Accounts, realm string) znet.Handler { 44 if realm == "" { 45 realm = "Authorization Required" 46 } 47 realm = "Basic realm=" + strconv.Quote(realm) 48 pairs, err := processAccounts(accounts) 49 zerror.Panic(err) 50 51 return func(c *znet.Context) { 52 user, found := pairs.searchCredential(c.GetHeader("Authorization")) 53 if !found { 54 c.SetHeader("WWW-Authenticate", realm) 55 c.Abort(http.StatusUnauthorized) 56 return 57 } 58 59 c.WithValue(UserKey, user) 60 c.Next() 61 } 62 } 63 64 func processAccounts(accounts Accounts) (authPairs, error) { 65 length := len(accounts) 66 if length == 0 { 67 return nil, errors.New("empty list of authorized credentials") 68 } 69 pairs := make(authPairs, 0, length) 70 for user, password := range accounts { 71 if user == "" { 72 return nil, errors.New("user can not be empty") 73 } 74 value := authorizationHeader(user, password) 75 pairs = append(pairs, authPair{ 76 value: value, 77 user: user, 78 }) 79 } 80 return pairs, nil 81 } 82 83 func authorizationHeader(user, password string) string { 84 base := user + ":" + password 85 return "Basic " + base64.StdEncoding.EncodeToString(zstring.String2Bytes(base)) 86 }