github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/acyclic/privilege/privileges/privileges.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package privileges
    15  
    16  import (
    17  	"crypto/tls"
    18  	"crypto/x509"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/whtcorpsinc/BerolinaSQL/auth"
    23  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    24  	"github.com/whtcorpsinc/milevadb/schemareplicant"
    25  	"github.com/whtcorpsinc/milevadb/schemareplicant/perfschema"
    26  	"github.com/whtcorpsinc/milevadb/privilege"
    27  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    28  	"github.com/whtcorpsinc/milevadb/types"
    29  	"github.com/whtcorpsinc/milevadb/soliton"
    30  	"github.com/whtcorpsinc/milevadb/soliton/logutil"
    31  	"go.uber.org/zap"
    32  )
    33  
    34  // SkipWithGrant causes the server to start without using the privilege system at all.
    35  var SkipWithGrant = false
    36  
    37  var _ privilege.Manager = (*UserPrivileges)(nil)
    38  
    39  // UserPrivileges implements privilege.Manager interface.
    40  // This is used to check privilege for the current user.
    41  type UserPrivileges struct {
    42  	user string
    43  	host string
    44  	*Handle
    45  }
    46  
    47  // RequestVerification implements the Manager interface.
    48  func (p *UserPrivileges) RequestVerification(activeRoles []*auth.RoleIdentity, EDB, causet, column string, priv allegrosql.PrivilegeType) bool {
    49  	if SkipWithGrant {
    50  		return true
    51  	}
    52  
    53  	if p.user == "" && p.host == "" {
    54  		return true
    55  	}
    56  
    57  	// Skip check for system databases.
    58  	// See https://dev.allegrosql.com/doc/refman/5.7/en/information-schemaReplicant.html
    59  	dbLowerName := strings.ToLower(EDB)
    60  	switch dbLowerName {
    61  	case soliton.InformationSchemaName.L:
    62  		switch priv {
    63  		case allegrosql.CreatePriv, allegrosql.AlterPriv, allegrosql.DropPriv, allegrosql.IndexPriv, allegrosql.CreateViewPriv,
    64  			allegrosql.InsertPriv, allegrosql.UFIDelatePriv, allegrosql.DeletePriv:
    65  			return false
    66  		}
    67  		return true
    68  	// We should be very careful of limiting privileges, so ignore `allegrosql` for now.
    69  	case soliton.PerformanceSchemaName.L, soliton.MetricSchemaName.L:
    70  		if (dbLowerName == soliton.PerformanceSchemaName.L && perfschema.IsPredefinedTable(causet)) ||
    71  			(dbLowerName == soliton.MetricSchemaName.L && schemareplicant.IsMetricTable(causet)) {
    72  			switch priv {
    73  			case allegrosql.CreatePriv, allegrosql.AlterPriv, allegrosql.DropPriv, allegrosql.IndexPriv, allegrosql.InsertPriv, allegrosql.UFIDelatePriv, allegrosql.DeletePriv:
    74  				return false
    75  			case allegrosql.SelectPriv:
    76  				return true
    77  			}
    78  		}
    79  	}
    80  
    81  	mysqlPriv := p.Handle.Get()
    82  	return mysqlPriv.RequestVerification(activeRoles, p.user, p.host, EDB, causet, column, priv)
    83  }
    84  
    85  // RequestVerificationWithUser implements the Manager interface.
    86  func (p *UserPrivileges) RequestVerificationWithUser(EDB, causet, column string, priv allegrosql.PrivilegeType, user *auth.UserIdentity) bool {
    87  	if SkipWithGrant {
    88  		return true
    89  	}
    90  
    91  	if user == nil {
    92  		return false
    93  	}
    94  
    95  	// Skip check for INFORMATION_SCHEMA database.
    96  	// See https://dev.allegrosql.com/doc/refman/5.7/en/information-schemaReplicant.html
    97  	if strings.EqualFold(EDB, "INFORMATION_SCHEMA") {
    98  		return true
    99  	}
   100  
   101  	mysqlPriv := p.Handle.Get()
   102  	return mysqlPriv.RequestVerification(nil, user.Username, user.Hostname, EDB, causet, column, priv)
   103  }
   104  
   105  // GetEncodedPassword implements the Manager interface.
   106  func (p *UserPrivileges) GetEncodedPassword(user, host string) string {
   107  	mysqlPriv := p.Handle.Get()
   108  	record := mysqlPriv.connectionVerification(user, host)
   109  	if record == nil {
   110  		logutil.BgLogger().Error("get user privilege record fail",
   111  			zap.String("user", user), zap.String("host", host))
   112  		return ""
   113  	}
   114  	pwd := record.AuthenticationString
   115  	if len(pwd) != 0 && len(pwd) != allegrosql.PWDHashLen+1 {
   116  		logutil.BgLogger().Error("user password from system EDB not like sha1sum", zap.String("user", user))
   117  		return ""
   118  	}
   119  	return pwd
   120  }
   121  
   122  // GetAuthWithoutVerification implements the Manager interface.
   123  func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
   124  	if SkipWithGrant {
   125  		p.user = user
   126  		p.host = host
   127  		success = true
   128  		return
   129  	}
   130  
   131  	mysqlPriv := p.Handle.Get()
   132  	record := mysqlPriv.connectionVerification(user, host)
   133  	if record == nil {
   134  		logutil.BgLogger().Error("get user privilege record fail",
   135  			zap.String("user", user), zap.String("host", host))
   136  		return
   137  	}
   138  
   139  	u = record.User
   140  	h = record.Host
   141  	p.user = user
   142  	p.host = h
   143  	success = true
   144  	return
   145  }
   146  
   147  // ConnectionVerification implements the Manager interface.
   148  func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
   149  	if SkipWithGrant {
   150  		p.user = user
   151  		p.host = host
   152  		success = true
   153  		return
   154  	}
   155  
   156  	mysqlPriv := p.Handle.Get()
   157  	record := mysqlPriv.connectionVerification(user, host)
   158  	if record == nil {
   159  		logutil.BgLogger().Error("get user privilege record fail",
   160  			zap.String("user", user), zap.String("host", host))
   161  		return
   162  	}
   163  
   164  	u = record.User
   165  	h = record.Host
   166  
   167  	globalPriv := mysqlPriv.matchGlobalPriv(user, host)
   168  	if globalPriv != nil {
   169  		if !p.checkSSL(globalPriv, tlsState) {
   170  			logutil.BgLogger().Error("global priv check ssl fail",
   171  				zap.String("user", user), zap.String("host", host))
   172  			success = false
   173  			return
   174  		}
   175  	}
   176  
   177  	// Login a locked account is not allowed.
   178  	locked := record.AccountLocked
   179  	if locked {
   180  		logutil.BgLogger().Error("try to login a locked account",
   181  			zap.String("user", user), zap.String("host", host))
   182  		success = false
   183  		return
   184  	}
   185  
   186  	pwd := record.AuthenticationString
   187  	if len(pwd) != 0 && len(pwd) != allegrosql.PWDHashLen+1 {
   188  		logutil.BgLogger().Error("user password from system EDB not like sha1sum", zap.String("user", user))
   189  		return
   190  	}
   191  
   192  	// empty password
   193  	if len(pwd) == 0 && len(authentication) == 0 {
   194  		p.user = user
   195  		p.host = h
   196  		success = true
   197  		return
   198  	}
   199  
   200  	if len(pwd) == 0 || len(authentication) == 0 {
   201  		return
   202  	}
   203  
   204  	hpwd, err := auth.DecodePassword(pwd)
   205  	if err != nil {
   206  		logutil.BgLogger().Error("decode password string failed", zap.Error(err))
   207  		return
   208  	}
   209  
   210  	if !auth.CheckScrambledPassword(salt, hpwd, authentication) {
   211  		return
   212  	}
   213  
   214  	p.user = user
   215  	p.host = h
   216  	success = true
   217  	return
   218  }
   219  
   220  type checkResult int
   221  
   222  const (
   223  	notCheck checkResult = iota
   224  	pass
   225  	fail
   226  )
   227  
   228  func (p *UserPrivileges) checkSSL(priv *globalPrivRecord, tlsState *tls.ConnectionState) bool {
   229  	if priv.Broken {
   230  		logutil.BgLogger().Info("ssl check failure, due to broken global_priv record",
   231  			zap.String("user", priv.User), zap.String("host", priv.Host))
   232  		return false
   233  	}
   234  	switch priv.Priv.SSLType {
   235  	case SslTypeNotSpecified, SslTypeNone:
   236  		return true
   237  	case SslTypeAny:
   238  		r := tlsState != nil
   239  		if !r {
   240  			logutil.BgLogger().Info("ssl check failure, require ssl but not use ssl",
   241  				zap.String("user", priv.User), zap.String("host", priv.Host))
   242  		}
   243  		return r
   244  	case SslTypeX509:
   245  		if tlsState == nil {
   246  			logutil.BgLogger().Info("ssl check failure, require x509 but not use ssl",
   247  				zap.String("user", priv.User), zap.String("host", priv.Host))
   248  			return false
   249  		}
   250  		hasCert := false
   251  		for _, chain := range tlsState.VerifiedChains {
   252  			if len(chain) > 0 {
   253  				hasCert = true
   254  				break
   255  			}
   256  		}
   257  		if !hasCert {
   258  			logutil.BgLogger().Info("ssl check failure, require x509 but no verified cert",
   259  				zap.String("user", priv.User), zap.String("host", priv.Host))
   260  		}
   261  		return hasCert
   262  	case SslTypeSpecified:
   263  		if tlsState == nil {
   264  			logutil.BgLogger().Info("ssl check failure, require subject/issuer/cipher but not use ssl",
   265  				zap.String("user", priv.User), zap.String("host", priv.Host))
   266  			return false
   267  		}
   268  		if len(priv.Priv.SSLCipher) > 0 && priv.Priv.SSLCipher != soliton.TLSCipher2String(tlsState.CipherSuite) {
   269  			logutil.BgLogger().Info("ssl check failure for cipher", zap.String("user", priv.User), zap.String("host", priv.Host),
   270  				zap.String("require", priv.Priv.SSLCipher), zap.String("given", soliton.TLSCipher2String(tlsState.CipherSuite)))
   271  			return false
   272  		}
   273  		var (
   274  			hasCert      = false
   275  			matchIssuer  checkResult
   276  			matchSubject checkResult
   277  			matchSAN     checkResult
   278  		)
   279  		for _, chain := range tlsState.VerifiedChains {
   280  			if len(chain) == 0 {
   281  				continue
   282  			}
   283  			cert := chain[0]
   284  			if len(priv.Priv.X509Issuer) > 0 {
   285  				given := soliton.X509NameOnline(cert.Issuer)
   286  				if priv.Priv.X509Issuer == given {
   287  					matchIssuer = pass
   288  				} else if matchIssuer == notCheck {
   289  					matchIssuer = fail
   290  					logutil.BgLogger().Info("ssl check failure for issuer", zap.String("user", priv.User), zap.String("host", priv.Host),
   291  						zap.String("require", priv.Priv.X509Issuer), zap.String("given", given))
   292  				}
   293  			}
   294  			if len(priv.Priv.X509Subject) > 0 {
   295  				given := soliton.X509NameOnline(cert.Subject)
   296  				if priv.Priv.X509Subject == given {
   297  					matchSubject = pass
   298  				} else if matchSubject == notCheck {
   299  					matchSubject = fail
   300  					logutil.BgLogger().Info("ssl check failure for subject", zap.String("user", priv.User), zap.String("host", priv.Host),
   301  						zap.String("require", priv.Priv.X509Subject), zap.String("given", given))
   302  				}
   303  			}
   304  			if len(priv.Priv.SANs) > 0 {
   305  				matchOne := checkCertSAN(priv, cert, priv.Priv.SANs)
   306  				if matchOne {
   307  					matchSAN = pass
   308  				} else if matchSAN == notCheck {
   309  					matchSAN = fail
   310  				}
   311  			}
   312  			hasCert = true
   313  		}
   314  		checkResult := hasCert && matchIssuer != fail && matchSubject != fail && matchSAN != fail
   315  		if !checkResult && !hasCert {
   316  			logutil.BgLogger().Info("ssl check failure, require issuer/subject/SAN but no verified cert",
   317  				zap.String("user", priv.User), zap.String("host", priv.Host))
   318  		}
   319  		return checkResult
   320  	default:
   321  		panic(fmt.Sprintf("support ssl_type: %d", priv.Priv.SSLType))
   322  	}
   323  }
   324  
   325  func checkCertSAN(priv *globalPrivRecord, cert *x509.Certificate, sans map[soliton.SANType][]string) (r bool) {
   326  	r = true
   327  	for typ, requireOr := range sans {
   328  		var (
   329  			unsupported bool
   330  			given       []string
   331  		)
   332  		switch typ {
   333  		case soliton.URI:
   334  			for _, uri := range cert.URIs {
   335  				given = append(given, uri.String())
   336  			}
   337  		case soliton.DNS:
   338  			given = cert.DNSNames
   339  		case soliton.IP:
   340  			for _, ip := range cert.IPAddresses {
   341  				given = append(given, ip.String())
   342  			}
   343  		default:
   344  			unsupported = true
   345  		}
   346  		if unsupported {
   347  			logutil.BgLogger().Warn("skip unsupported SAN type", zap.String("type", string(typ)),
   348  				zap.String("user", priv.User), zap.String("host", priv.Host))
   349  			continue
   350  		}
   351  		var givenMatchOne bool
   352  		for _, req := range requireOr {
   353  			for _, give := range given {
   354  				if req == give {
   355  					givenMatchOne = true
   356  					break
   357  				}
   358  			}
   359  		}
   360  		if !givenMatchOne {
   361  			logutil.BgLogger().Info("ssl check failure for subject", zap.String("user", priv.User), zap.String("host", priv.Host),
   362  				zap.String("require", priv.Priv.SAN), zap.Strings("given", given), zap.String("type", string(typ)))
   363  			r = false
   364  			return
   365  		}
   366  	}
   367  	return
   368  }
   369  
   370  // DBIsVisible implements the Manager interface.
   371  func (p *UserPrivileges) DBIsVisible(activeRoles []*auth.RoleIdentity, EDB string) bool {
   372  	if SkipWithGrant {
   373  		return true
   374  	}
   375  	mysqlPriv := p.Handle.Get()
   376  	if mysqlPriv.DBIsVisible(p.user, p.host, EDB) {
   377  		return true
   378  	}
   379  	allRoles := mysqlPriv.FindAllRole(activeRoles)
   380  	for _, role := range allRoles {
   381  		if mysqlPriv.DBIsVisible(role.Username, role.Hostname, EDB) {
   382  			return true
   383  		}
   384  	}
   385  	return false
   386  }
   387  
   388  // UserPrivilegesTable implements the Manager interface.
   389  func (p *UserPrivileges) UserPrivilegesTable() [][]types.Causet {
   390  	mysqlPriv := p.Handle.Get()
   391  	return mysqlPriv.UserPrivilegesTable()
   392  }
   393  
   394  // ShowGrants implements privilege.Manager ShowGrants interface.
   395  func (p *UserPrivileges) ShowGrants(ctx stochastikctx.Context, user *auth.UserIdentity, roles []*auth.RoleIdentity) (grants []string, err error) {
   396  	if SkipWithGrant {
   397  		return nil, ErrNonexistingGrant.GenWithStackByArgs("root", "%")
   398  	}
   399  	mysqlPrivilege := p.Handle.Get()
   400  	u := user.Username
   401  	h := user.Hostname
   402  	if len(user.AuthUsername) > 0 && len(user.AuthHostname) > 0 {
   403  		u = user.AuthUsername
   404  		h = user.AuthHostname
   405  	}
   406  	grants = mysqlPrivilege.showGrants(u, h, roles)
   407  	if len(grants) == 0 {
   408  		err = ErrNonexistingGrant.GenWithStackByArgs(u, h)
   409  	}
   410  
   411  	return
   412  }
   413  
   414  // ActiveRoles implements privilege.Manager ActiveRoles interface.
   415  func (p *UserPrivileges) ActiveRoles(ctx stochastikctx.Context, roleList []*auth.RoleIdentity) (bool, string) {
   416  	if SkipWithGrant {
   417  		return true, ""
   418  	}
   419  	mysqlPrivilege := p.Handle.Get()
   420  	u := p.user
   421  	h := p.host
   422  	for _, r := range roleList {
   423  		ok := mysqlPrivilege.FindRole(u, h, r)
   424  		if !ok {
   425  			logutil.BgLogger().Error("find role failed", zap.Stringer("role", r))
   426  			return false, r.String()
   427  		}
   428  	}
   429  	ctx.GetStochastikVars().ActiveRoles = roleList
   430  	return true, ""
   431  }
   432  
   433  // FindEdge implements privilege.Manager FindRelationship interface.
   434  func (p *UserPrivileges) FindEdge(ctx stochastikctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool {
   435  	if SkipWithGrant {
   436  		return false
   437  	}
   438  	mysqlPrivilege := p.Handle.Get()
   439  	ok := mysqlPrivilege.FindRole(user.Username, user.Hostname, role)
   440  	if !ok {
   441  		logutil.BgLogger().Error("find role failed", zap.Stringer("role", role))
   442  		return false
   443  	}
   444  	return true
   445  }
   446  
   447  // GetDefaultRoles returns all default roles for certain user.
   448  func (p *UserPrivileges) GetDefaultRoles(user, host string) []*auth.RoleIdentity {
   449  	if SkipWithGrant {
   450  		return make([]*auth.RoleIdentity, 0, 10)
   451  	}
   452  	mysqlPrivilege := p.Handle.Get()
   453  	ret := mysqlPrivilege.getDefaultRoles(user, host)
   454  	return ret
   455  }
   456  
   457  // GetAllRoles return all roles of user.
   458  func (p *UserPrivileges) GetAllRoles(user, host string) []*auth.RoleIdentity {
   459  	if SkipWithGrant {
   460  		return make([]*auth.RoleIdentity, 0, 10)
   461  	}
   462  
   463  	mysqlPrivilege := p.Handle.Get()
   464  	return mysqlPrivilege.getAllRoles(user, host)
   465  }