github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/acyclic/privilege/privileges/cache.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package privileges
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"encoding/json"
    20  	"fmt"
    21  	"net"
    22  	"sort"
    23  	"strings"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    28  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    29  	"github.com/whtcorpsinc/BerolinaSQL/auth"
    30  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    31  	"github.com/whtcorpsinc/errors"
    32  	"github.com/whtcorpsinc/milevadb/soliton"
    33  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    34  	"github.com/whtcorpsinc/milevadb/soliton/logutil"
    35  	"github.com/whtcorpsinc/milevadb/soliton/replog"
    36  	"github.com/whtcorpsinc/milevadb/soliton/sqlexec"
    37  	"github.com/whtcorpsinc/milevadb/soliton/stringutil"
    38  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    39  	"github.com/whtcorpsinc/milevadb/types"
    40  	"go.uber.org/zap"
    41  )
    42  
    43  var (
    44  	userTablePrivilegeMask = computePrivMask(allegrosql.AllGlobalPrivs)
    45  	dbTablePrivilegeMask   = computePrivMask(allegrosql.AllDBPrivs)
    46  	blockPrivMask          = computePrivMask(allegrosql.AllTablePrivs)
    47  )
    48  
    49  const globalDBVisible = allegrosql.CreatePriv | allegrosql.SelectPriv | allegrosql.InsertPriv | allegrosql.UFIDelatePriv | allegrosql.DeletePriv | allegrosql.ShowDBPriv | allegrosql.DropPriv | allegrosql.AlterPriv | allegrosql.IndexPriv | allegrosql.CreateViewPriv | allegrosql.ShowViewPriv | allegrosql.GrantPriv | allegrosql.TriggerPriv | allegrosql.ReferencesPriv | allegrosql.InterDircutePriv
    50  
    51  func computePrivMask(privs []allegrosql.PrivilegeType) allegrosql.PrivilegeType {
    52  	var mask allegrosql.PrivilegeType
    53  	for _, p := range privs {
    54  		mask |= p
    55  	}
    56  	return mask
    57  }
    58  
    59  // baseRecord is used to represent a base record in privilege cache,
    60  // it only causetstore Host and User field, and it should be nested in other record type.
    61  type baseRecord struct {
    62  	Host string // max length 60, primary key
    63  	User string // max length 32, primary key
    64  
    65  	// patChars is compiled from Host, cached for pattern match performance.
    66  	patChars []byte
    67  	patTypes []byte
    68  
    69  	// IPv4 with netmask, cached for host match performance.
    70  	hostIPNet *net.IPNet
    71  }
    72  
    73  // UserRecord is used to represent a user record in privilege cache.
    74  type UserRecord struct {
    75  	baseRecord
    76  
    77  	AuthenticationString string
    78  	Privileges           allegrosql.PrivilegeType
    79  	AccountLocked        bool // A role record when this field is true
    80  }
    81  
    82  // NewUserRecord return a UserRecord, only use for unit test.
    83  func NewUserRecord(host, user string) UserRecord {
    84  	return UserRecord{
    85  		baseRecord: baseRecord{
    86  			Host: host,
    87  			User: user,
    88  		},
    89  	}
    90  }
    91  
    92  type globalPrivRecord struct {
    93  	baseRecord
    94  
    95  	Priv   GlobalPrivValue
    96  	Broken bool
    97  }
    98  
    99  // SSLType is enum value for GlobalPrivValue.SSLType.
   100  // the value is compatible with MyALLEGROSQL storage json value.
   101  type SSLType int
   102  
   103  const (
   104  	// SslTypeNotSpecified indicates .
   105  	SslTypeNotSpecified SSLType = iota - 1
   106  	// SslTypeNone indicates not require use ssl.
   107  	SslTypeNone
   108  	// SslTypeAny indicates require use ssl but not validate cert.
   109  	SslTypeAny
   110  	// SslTypeX509 indicates require use ssl and validate cert.
   111  	SslTypeX509
   112  	// SslTypeSpecified indicates require use ssl and validate cert's subject or issuer.
   113  	SslTypeSpecified
   114  )
   115  
   116  // GlobalPrivValue is causetstore json format for priv column in allegrosql.global_priv.
   117  type GlobalPrivValue struct {
   118  	SSLType     SSLType                      `json:"ssl_type,omitempty"`
   119  	SSLCipher   string                       `json:"ssl_cipher,omitempty"`
   120  	X509Issuer  string                       `json:"x509_issuer,omitempty"`
   121  	X509Subject string                       `json:"x509_subject,omitempty"`
   122  	SAN         string                       `json:"san,omitempty"`
   123  	SANs        map[soliton.SANType][]string `json:"-"`
   124  }
   125  
   126  // RequireStr returns describe string after `REQUIRE` clause.
   127  func (g *GlobalPrivValue) RequireStr() string {
   128  	require := "NONE"
   129  	switch g.SSLType {
   130  	case SslTypeAny:
   131  		require = "SSL"
   132  	case SslTypeX509:
   133  		require = "X509"
   134  	case SslTypeSpecified:
   135  		var s []string
   136  		if len(g.SSLCipher) > 0 {
   137  			s = append(s, "CIPHER")
   138  			s = append(s, "'"+g.SSLCipher+"'")
   139  		}
   140  		if len(g.X509Issuer) > 0 {
   141  			s = append(s, "ISSUER")
   142  			s = append(s, "'"+g.X509Issuer+"'")
   143  		}
   144  		if len(g.X509Subject) > 0 {
   145  			s = append(s, "SUBJECT")
   146  			s = append(s, "'"+g.X509Subject+"'")
   147  		}
   148  		if len(g.SAN) > 0 {
   149  			s = append(s, "SAN")
   150  			s = append(s, "'"+g.SAN+"'")
   151  		}
   152  		if len(s) > 0 {
   153  			require = strings.Join(s, " ")
   154  		}
   155  	}
   156  	return require
   157  }
   158  
   159  type dbRecord struct {
   160  	baseRecord
   161  
   162  	EDB        string
   163  	Privileges allegrosql.PrivilegeType
   164  
   165  	dbPatChars []byte
   166  	dbPatTypes []byte
   167  }
   168  
   169  type blocksPrivRecord struct {
   170  	baseRecord
   171  
   172  	EDB                string
   173  	TableName          string
   174  	Grantor            string
   175  	Timestamp          time.Time
   176  	TablePriv          allegrosql.PrivilegeType
   177  	DeferredCausetPriv allegrosql.PrivilegeType
   178  }
   179  
   180  type columnsPrivRecord struct {
   181  	baseRecord
   182  
   183  	EDB                string
   184  	TableName          string
   185  	DeferredCausetName string
   186  	Timestamp          time.Time
   187  	DeferredCausetPriv allegrosql.PrivilegeType
   188  }
   189  
   190  // defaultRoleRecord is used to cache allegrosql.default_roles
   191  type defaultRoleRecord struct {
   192  	baseRecord
   193  
   194  	DefaultRoleUser string
   195  	DefaultRoleHost string
   196  }
   197  
   198  // roleGraphEdgesTable is used to cache relationship between and role.
   199  type roleGraphEdgesTable struct {
   200  	roleList map[string]*auth.RoleIdentity
   201  }
   202  
   203  // Find method is used to find role from causet
   204  func (g roleGraphEdgesTable) Find(user, host string) bool {
   205  	if host == "" {
   206  		host = "%"
   207  	}
   208  	key := user + "@" + host
   209  	if g.roleList == nil {
   210  		return false
   211  	}
   212  	_, ok := g.roleList[key]
   213  	return ok
   214  }
   215  
   216  // MyALLEGROSQLPrivilege is the in-memory cache of allegrosql privilege blocks.
   217  type MyALLEGROSQLPrivilege struct {
   218  	// In MyALLEGROSQL, a user identity consists of a user + host.
   219  	// Either portion of user or host can contain wildcards,
   220  	// requiring the privileges system to use a list-like
   221  	// structure instead of a hash.
   222  
   223  	// MilevaDB contains a sensible behavior difference from MyALLEGROSQL,
   224  	// which is that usernames can not contain wildcards.
   225  	// This means that EDB-records are organized in both a
   226  	// slice (p.EDB) and a Map (p.DBMap).
   227  
   228  	// This helps in the case that there are a number of users with
   229  	// non-full privileges (i.e. user.EDB entries).
   230  	User                []UserRecord
   231  	UserMap             map[string][]UserRecord // Accelerate User searching
   232  	Global              map[string][]globalPrivRecord
   233  	EDB                 []dbRecord
   234  	DBMap               map[string][]dbRecord // Accelerate EDB searching
   235  	TablesPriv          []blocksPrivRecord
   236  	TablesPrivMap       map[string][]blocksPrivRecord // Accelerate TablesPriv searching
   237  	DeferredCausetsPriv []columnsPrivRecord
   238  	DefaultRoles        []defaultRoleRecord
   239  	RoleGraph           map[string]roleGraphEdgesTable
   240  }
   241  
   242  // FindAllRole is used to find all roles grant to this user.
   243  func (p *MyALLEGROSQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.RoleIdentity {
   244  	queue, head := make([]*auth.RoleIdentity, 0, len(activeRoles)), 0
   245  	queue = append(queue, activeRoles...)
   246  	// Using breadth first search to find all roles grant to this user.
   247  	visited, ret := make(map[string]bool), make([]*auth.RoleIdentity, 0)
   248  	for head < len(queue) {
   249  		role := queue[head]
   250  		if _, ok := visited[role.String()]; !ok {
   251  			visited[role.String()] = true
   252  			ret = append(ret, role)
   253  			key := role.Username + "@" + role.Hostname
   254  			if edgeTable, ok := p.RoleGraph[key]; ok {
   255  				for _, v := range edgeTable.roleList {
   256  					if _, ok := visited[v.String()]; !ok {
   257  						queue = append(queue, v)
   258  					}
   259  				}
   260  			}
   261  		}
   262  		head += 1
   263  	}
   264  	return ret
   265  }
   266  
   267  // FindRole is used to detect whether there is edges between users and roles.
   268  func (p *MyALLEGROSQLPrivilege) FindRole(user string, host string, role *auth.RoleIdentity) bool {
   269  	rec := p.matchUser(user, host)
   270  	r := p.matchUser(role.Username, role.Hostname)
   271  	if rec != nil && r != nil {
   272  		key := rec.User + "@" + rec.Host
   273  		return p.RoleGraph[key].Find(role.Username, role.Hostname)
   274  	}
   275  	return false
   276  }
   277  
   278  // LoadAll loads the blocks from database to memory.
   279  func (p *MyALLEGROSQLPrivilege) LoadAll(ctx stochastikctx.Context) error {
   280  	err := p.LoadUserTable(ctx)
   281  	if err != nil {
   282  		logutil.BgLogger().Warn("load allegrosql.user fail", zap.Error(err))
   283  		return errLoadPrivilege.FastGen("allegrosql.user")
   284  	}
   285  
   286  	err = p.LoadGlobalPrivTable(ctx)
   287  	if err != nil {
   288  		return errors.Trace(err)
   289  	}
   290  
   291  	err = p.LoadDBTable(ctx)
   292  	if err != nil {
   293  		if !noSuchTable(err) {
   294  			logutil.BgLogger().Warn("load allegrosql.EDB fail", zap.Error(err))
   295  			return errLoadPrivilege.FastGen("allegrosql.EDB")
   296  		}
   297  		logutil.BgLogger().Warn("allegrosql.EDB maybe missing")
   298  	}
   299  
   300  	err = p.LoadTablesPrivTable(ctx)
   301  	if err != nil {
   302  		if !noSuchTable(err) {
   303  			logutil.BgLogger().Warn("load allegrosql.blocks_priv fail", zap.Error(err))
   304  			return errLoadPrivilege.FastGen("allegrosql.blocks_priv")
   305  		}
   306  		logutil.BgLogger().Warn("allegrosql.blocks_priv missing")
   307  	}
   308  
   309  	err = p.LoadDefaultRoles(ctx)
   310  	if err != nil {
   311  		if !noSuchTable(err) {
   312  			logutil.BgLogger().Warn("load allegrosql.roles", zap.Error(err))
   313  			return errLoadPrivilege.FastGen("allegrosql.roles")
   314  		}
   315  		logutil.BgLogger().Warn("allegrosql.default_roles missing")
   316  	}
   317  
   318  	err = p.LoadDeferredCausetsPrivTable(ctx)
   319  	if err != nil {
   320  		if !noSuchTable(err) {
   321  			logutil.BgLogger().Warn("load allegrosql.columns_priv", zap.Error(err))
   322  			return errLoadPrivilege.FastGen("allegrosql.columns_priv")
   323  		}
   324  		logutil.BgLogger().Warn("allegrosql.columns_priv missing")
   325  	}
   326  
   327  	err = p.LoadRoleGraph(ctx)
   328  	if err != nil {
   329  		if !noSuchTable(err) {
   330  			logutil.BgLogger().Warn("load allegrosql.role_edges", zap.Error(err))
   331  			return errLoadPrivilege.FastGen("allegrosql.role_edges")
   332  		}
   333  		logutil.BgLogger().Warn("allegrosql.role_edges missing")
   334  	}
   335  	return nil
   336  }
   337  
   338  func noSuchTable(err error) bool {
   339  	e1 := errors.Cause(err)
   340  	if e2, ok := e1.(*terror.Error); ok {
   341  		if terror.ErrCode(e2.Code()) == terror.ErrCode(allegrosql.ErrNoSuchTable) {
   342  			return true
   343  		}
   344  	}
   345  	return false
   346  }
   347  
   348  // LoadRoleGraph loads the allegrosql.role_edges causet from database.
   349  func (p *MyALLEGROSQLPrivilege) LoadRoleGraph(ctx stochastikctx.Context) error {
   350  	p.RoleGraph = make(map[string]roleGraphEdgesTable)
   351  	err := p.loadTable(ctx, "select FROM_USER, FROM_HOST, TO_USER, TO_HOST from allegrosql.role_edges;", p.decodeRoleEdgesTable)
   352  	if err != nil {
   353  		return errors.Trace(err)
   354  	}
   355  	return nil
   356  }
   357  
   358  // LoadUserTable loads the allegrosql.user causet from database.
   359  func (p *MyALLEGROSQLPrivilege) LoadUserTable(ctx stochastikctx.Context) error {
   360  	userPrivDefCauss := make([]string, 0, len(allegrosql.Priv2UserDefCaus))
   361  	for _, v := range allegrosql.Priv2UserDefCaus {
   362  		userPrivDefCauss = append(userPrivDefCauss, v)
   363  	}
   364  	query := fmt.Sprintf("select HIGH_PRIORITY Host,User,authentication_string,%s,account_locked from allegrosql.user;", strings.Join(userPrivDefCauss, ", "))
   365  	err := p.loadTable(ctx, query, p.decodeUserTableRow)
   366  	if err != nil {
   367  		return errors.Trace(err)
   368  	}
   369  	// See https://dev.allegrosql.com/doc/refman/8.0/en/connection-access.html
   370  	// When multiple matches are possible, the server must determine which of them to use. It resolves this issue as follows:
   371  	// 1. Whenever the server reads the user causet into memory, it sorts the rows.
   372  	// 2. When a client attempts to connect, the server looks through the rows in sorted order.
   373  	// 3. The server uses the first event that matches the client host name and user name.
   374  	// The server uses sorting rules that order rows with the most-specific Host values first.
   375  	p.SortUserTable()
   376  	p.buildUserMap()
   377  	return nil
   378  }
   379  
   380  func (p *MyALLEGROSQLPrivilege) buildUserMap() {
   381  	userMap := make(map[string][]UserRecord, len(p.User))
   382  	for _, record := range p.User {
   383  		userMap[record.User] = append(userMap[record.User], record)
   384  	}
   385  	p.UserMap = userMap
   386  }
   387  
   388  type sortedUserRecord []UserRecord
   389  
   390  func (s sortedUserRecord) Len() int {
   391  	return len(s)
   392  }
   393  
   394  func (s sortedUserRecord) Less(i, j int) bool {
   395  	x := s[i]
   396  	y := s[j]
   397  
   398  	// Compare two item by user's host first.
   399  	c1 := compareHost(x.Host, y.Host)
   400  	if c1 < 0 {
   401  		return true
   402  	}
   403  	if c1 > 0 {
   404  		return false
   405  	}
   406  
   407  	// Then, compare item by user's name value.
   408  	return x.User < y.User
   409  }
   410  
   411  // compareHost compares two host string using some special rules, return value 1, 0, -1 means > = <.
   412  // TODO: Check how MyALLEGROSQL do it exactly, instead of guess its rules.
   413  func compareHost(x, y string) int {
   414  	// The more-specific, the smaller it is.
   415  	// The pattern '%' means “any host” and is least specific.
   416  	if y == `%` {
   417  		if x == `%` {
   418  			return 0
   419  		}
   420  		return -1
   421  	}
   422  
   423  	// The empty string '' also means “any host” but sorts after '%'.
   424  	if y == "" {
   425  		if x == "" {
   426  			return 0
   427  		}
   428  		return -1
   429  	}
   430  
   431  	// One of them end with `%`.
   432  	xEnd := strings.HasSuffix(x, `%`)
   433  	yEnd := strings.HasSuffix(y, `%`)
   434  	if xEnd || yEnd {
   435  		switch {
   436  		case !xEnd && yEnd:
   437  			return -1
   438  		case xEnd && !yEnd:
   439  			return 1
   440  		case xEnd && yEnd:
   441  			// 192.168.199.% smaller than 192.168.%
   442  			// A not very accurate comparison, compare them by length.
   443  			if len(x) > len(y) {
   444  				return -1
   445  			}
   446  		}
   447  		return 0
   448  	}
   449  
   450  	// For other case, the order is nondeterministic.
   451  	switch x < y {
   452  	case true:
   453  		return -1
   454  	case false:
   455  		return 1
   456  	}
   457  	return 0
   458  }
   459  
   460  func (s sortedUserRecord) Swap(i, j int) {
   461  	s[i], s[j] = s[j], s[i]
   462  }
   463  
   464  // SortUserTable sorts p.User in the MyALLEGROSQLPrivilege struct.
   465  func (p MyALLEGROSQLPrivilege) SortUserTable() {
   466  	sort.Sort(sortedUserRecord(p.User))
   467  }
   468  
   469  // LoadGlobalPrivTable loads the allegrosql.global_priv causet from database.
   470  func (p *MyALLEGROSQLPrivilege) LoadGlobalPrivTable(ctx stochastikctx.Context) error {
   471  	return p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Priv from allegrosql.global_priv", p.decodeGlobalPrivTableRow)
   472  }
   473  
   474  // LoadDBTable loads the allegrosql.EDB causet from database.
   475  func (p *MyALLEGROSQLPrivilege) LoadDBTable(ctx stochastikctx.Context) error {
   476  	err := p.loadTable(ctx, "select HIGH_PRIORITY Host,EDB,User,Select_priv,Insert_priv,UFIDelate_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,InterDircute_priv,Create_view_priv,Show_view_priv from allegrosql.EDB order by host, EDB, user;", p.decodeDBTableRow)
   477  	if err != nil {
   478  		return err
   479  	}
   480  	p.buildDBMap()
   481  	return nil
   482  }
   483  
   484  func (p *MyALLEGROSQLPrivilege) buildDBMap() {
   485  	dbMap := make(map[string][]dbRecord, len(p.EDB))
   486  	for _, record := range p.EDB {
   487  		dbMap[record.User] = append(dbMap[record.User], record)
   488  	}
   489  	p.DBMap = dbMap
   490  }
   491  
   492  // LoadTablesPrivTable loads the allegrosql.blocks_priv causet from database.
   493  func (p *MyALLEGROSQLPrivilege) LoadTablesPrivTable(ctx stochastikctx.Context) error {
   494  	err := p.loadTable(ctx, "select HIGH_PRIORITY Host,EDB,User,Table_name,Grantor,Timestamp,Table_priv,DeferredCauset_priv from allegrosql.blocks_priv", p.decodeTablesPrivTableRow)
   495  	if err != nil {
   496  		return err
   497  	}
   498  	p.buildTablesPrivMap()
   499  	return nil
   500  }
   501  
   502  func (p *MyALLEGROSQLPrivilege) buildTablesPrivMap() {
   503  	blocksPrivMap := make(map[string][]blocksPrivRecord, len(p.TablesPriv))
   504  	for _, record := range p.TablesPriv {
   505  		blocksPrivMap[record.User] = append(blocksPrivMap[record.User], record)
   506  	}
   507  	p.TablesPrivMap = blocksPrivMap
   508  }
   509  
   510  // LoadDeferredCausetsPrivTable loads the allegrosql.columns_priv causet from database.
   511  func (p *MyALLEGROSQLPrivilege) LoadDeferredCausetsPrivTable(ctx stochastikctx.Context) error {
   512  	return p.loadTable(ctx, "select HIGH_PRIORITY Host,EDB,User,Table_name,DeferredCauset_name,Timestamp,DeferredCauset_priv from allegrosql.columns_priv", p.decodeDeferredCausetsPrivTableRow)
   513  }
   514  
   515  // LoadDefaultRoles loads the allegrosql.columns_priv causet from database.
   516  func (p *MyALLEGROSQLPrivilege) LoadDefaultRoles(ctx stochastikctx.Context) error {
   517  	return p.loadTable(ctx, "select HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER from allegrosql.default_roles", p.decodeDefaultRoleTableRow)
   518  }
   519  
   520  func (p *MyALLEGROSQLPrivilege) loadTable(sctx stochastikctx.Context, allegrosql string,
   521  	decodeTableRow func(chunk.Row, []*ast.ResultField) error) error {
   522  	ctx := context.Background()
   523  	tmp, err := sctx.(sqlexec.ALLEGROSQLInterlockingDirectorate).InterDircute(ctx, allegrosql)
   524  	if err != nil {
   525  		return errors.Trace(err)
   526  	}
   527  	rs := tmp[0]
   528  	defer terror.Call(rs.Close)
   529  
   530  	fs := rs.Fields()
   531  	req := rs.NewChunk()
   532  	for {
   533  		err = rs.Next(context.TODO(), req)
   534  		if err != nil {
   535  			return errors.Trace(err)
   536  		}
   537  		if req.NumRows() == 0 {
   538  			return nil
   539  		}
   540  		it := chunk.NewIterator4Chunk(req)
   541  		for event := it.Begin(); event != it.End(); event = it.Next() {
   542  			err = decodeTableRow(event, fs)
   543  			if err != nil {
   544  				return errors.Trace(err)
   545  			}
   546  		}
   547  		// NOTE: decodeTableRow decodes data from a chunk Row, that is a shallow copy.
   548  		// The result will reference memory in the chunk, so the chunk must not be reused
   549  		// here, otherwise some werid bug will happen!
   550  		req = chunk.Renew(req, sctx.GetStochastikVars().MaxChunkSize)
   551  	}
   552  }
   553  
   554  // parseHostIPNet parses an IPv4 address and its subnet mask (e.g. `127.0.0.0/255.255.255.0`),
   555  // return the `IPNet` struct which represent the IP range info (e.g. `127.0.0.1 ~ 127.0.0.255`).
   556  // `IPNet` is used to check if a giving IP (e.g. `127.0.0.1`) is in its IP range by call `IPNet.Contains(ip)`.
   557  func parseHostIPNet(s string) *net.IPNet {
   558  	i := strings.IndexByte(s, '/')
   559  	if i < 0 {
   560  		return nil
   561  	}
   562  	hostIP := net.ParseIP(s[:i]).To4()
   563  	if hostIP == nil {
   564  		return nil
   565  	}
   566  	maskIP := net.ParseIP(s[i+1:]).To4()
   567  	if maskIP == nil {
   568  		return nil
   569  	}
   570  	mask := net.IPv4Mask(maskIP[0], maskIP[1], maskIP[2], maskIP[3])
   571  	// We must ensure that: <host_ip> & <netmask> == <host_ip>
   572  	// e.g. `127.0.0.1/255.0.0.0` is an illegal string,
   573  	// because `127.0.0.1` & `255.0.0.0` == `127.0.0.0`, but != `127.0.0.1`
   574  	// see https://dev.allegrosql.com/doc/refman/5.7/en/account-names.html
   575  	if !hostIP.Equal(hostIP.Mask(mask)) {
   576  		return nil
   577  	}
   578  	return &net.IPNet{
   579  		IP:   hostIP,
   580  		Mask: mask,
   581  	}
   582  }
   583  
   584  func (record *baseRecord) assignUserOrHost(event chunk.Row, i int, f *ast.ResultField) {
   585  	switch f.DeferredCausetAsName.L {
   586  	case "user":
   587  		record.User = event.GetString(i)
   588  	case "host":
   589  		record.Host = event.GetString(i)
   590  		record.patChars, record.patTypes = stringutil.CompilePattern(record.Host, '\\')
   591  		record.hostIPNet = parseHostIPNet(record.Host)
   592  	}
   593  }
   594  
   595  func (p *MyALLEGROSQLPrivilege) decodeUserTableRow(event chunk.Row, fs []*ast.ResultField) error {
   596  	var value UserRecord
   597  	for i, f := range fs {
   598  		switch {
   599  		case f.DeferredCausetAsName.L == "authentication_string":
   600  			value.AuthenticationString = event.GetString(i)
   601  		case f.DeferredCausetAsName.L == "account_locked":
   602  			if event.GetEnum(i).String() == "Y" {
   603  				value.AccountLocked = true
   604  			}
   605  		case f.DeferredCauset.Tp == allegrosql.TypeEnum:
   606  			if event.GetEnum(i).String() != "Y" {
   607  				continue
   608  			}
   609  			priv, ok := allegrosql.DefCaus2PrivType[f.DeferredCausetAsName.O]
   610  			if !ok {
   611  				return errInvalidPrivilegeType.GenWithStack(f.DeferredCausetAsName.O)
   612  			}
   613  			value.Privileges |= priv
   614  		default:
   615  			value.assignUserOrHost(event, i, f)
   616  		}
   617  	}
   618  	p.User = append(p.User, value)
   619  	return nil
   620  }
   621  
   622  func (p *MyALLEGROSQLPrivilege) decodeGlobalPrivTableRow(event chunk.Row, fs []*ast.ResultField) error {
   623  	var value globalPrivRecord
   624  	for i, f := range fs {
   625  		switch {
   626  		case f.DeferredCausetAsName.L == "priv":
   627  			privData := event.GetString(i)
   628  			if len(privData) > 0 {
   629  				var privValue GlobalPrivValue
   630  				err := json.Unmarshal(replog.Slice(privData), &privValue)
   631  				if err != nil {
   632  					logutil.BgLogger().Error("one user global priv data is broken, forbidden login until data be fixed",
   633  						zap.String("user", value.User), zap.String("host", value.Host))
   634  					value.Broken = true
   635  				} else {
   636  					value.Priv.SSLType = privValue.SSLType
   637  					value.Priv.SSLCipher = privValue.SSLCipher
   638  					value.Priv.X509Issuer = privValue.X509Issuer
   639  					value.Priv.X509Subject = privValue.X509Subject
   640  					value.Priv.SAN = privValue.SAN
   641  					if len(value.Priv.SAN) > 0 {
   642  						value.Priv.SANs, err = soliton.ParseAndCheckSAN(value.Priv.SAN)
   643  						if err != nil {
   644  							value.Broken = true
   645  						}
   646  					}
   647  				}
   648  			}
   649  		default:
   650  			value.assignUserOrHost(event, i, f)
   651  		}
   652  	}
   653  	if p.Global == nil {
   654  		p.Global = make(map[string][]globalPrivRecord)
   655  	}
   656  	p.Global[value.User] = append(p.Global[value.User], value)
   657  	return nil
   658  }
   659  
   660  func (p *MyALLEGROSQLPrivilege) decodeDBTableRow(event chunk.Row, fs []*ast.ResultField) error {
   661  	var value dbRecord
   662  	for i, f := range fs {
   663  		switch {
   664  		case f.DeferredCausetAsName.L == "EDB":
   665  			value.EDB = event.GetString(i)
   666  			value.dbPatChars, value.dbPatTypes = stringutil.CompilePattern(strings.ToUpper(value.EDB), '\\')
   667  		case f.DeferredCauset.Tp == allegrosql.TypeEnum:
   668  			if event.GetEnum(i).String() != "Y" {
   669  				continue
   670  			}
   671  			priv, ok := allegrosql.DefCaus2PrivType[f.DeferredCausetAsName.O]
   672  			if !ok {
   673  				return errInvalidPrivilegeType.GenWithStack("Unknown Privilege Type!")
   674  			}
   675  			value.Privileges |= priv
   676  		default:
   677  			value.assignUserOrHost(event, i, f)
   678  		}
   679  	}
   680  	p.EDB = append(p.EDB, value)
   681  	return nil
   682  }
   683  
   684  func (p *MyALLEGROSQLPrivilege) decodeTablesPrivTableRow(event chunk.Row, fs []*ast.ResultField) error {
   685  	var value blocksPrivRecord
   686  	for i, f := range fs {
   687  		switch {
   688  		case f.DeferredCausetAsName.L == "EDB":
   689  			value.EDB = event.GetString(i)
   690  		case f.DeferredCausetAsName.L == "block_name":
   691  			value.TableName = event.GetString(i)
   692  		case f.DeferredCausetAsName.L == "block_priv":
   693  			value.TablePriv = decodeSetToPrivilege(event.GetSet(i))
   694  		case f.DeferredCausetAsName.L == "column_priv":
   695  			value.DeferredCausetPriv = decodeSetToPrivilege(event.GetSet(i))
   696  		default:
   697  			value.assignUserOrHost(event, i, f)
   698  		}
   699  	}
   700  	p.TablesPriv = append(p.TablesPriv, value)
   701  	return nil
   702  }
   703  
   704  func (p *MyALLEGROSQLPrivilege) decodeRoleEdgesTable(event chunk.Row, fs []*ast.ResultField) error {
   705  	var fromUser, fromHost, toHost, toUser string
   706  	for i, f := range fs {
   707  		switch {
   708  		case f.DeferredCausetAsName.L == "from_host":
   709  			fromHost = event.GetString(i)
   710  		case f.DeferredCausetAsName.L == "from_user":
   711  			fromUser = event.GetString(i)
   712  		case f.DeferredCausetAsName.L == "to_host":
   713  			toHost = event.GetString(i)
   714  		case f.DeferredCausetAsName.L == "to_user":
   715  			toUser = event.GetString(i)
   716  		}
   717  	}
   718  	fromKey := fromUser + "@" + fromHost
   719  	toKey := toUser + "@" + toHost
   720  	roleGraph, ok := p.RoleGraph[toKey]
   721  	if !ok {
   722  		roleGraph = roleGraphEdgesTable{roleList: make(map[string]*auth.RoleIdentity)}
   723  		p.RoleGraph[toKey] = roleGraph
   724  	}
   725  	roleGraph.roleList[fromKey] = &auth.RoleIdentity{Username: fromUser, Hostname: fromHost}
   726  	return nil
   727  }
   728  
   729  func (p *MyALLEGROSQLPrivilege) decodeDefaultRoleTableRow(event chunk.Row, fs []*ast.ResultField) error {
   730  	var value defaultRoleRecord
   731  	for i, f := range fs {
   732  		switch {
   733  		case f.DeferredCausetAsName.L == "default_role_host":
   734  			value.DefaultRoleHost = event.GetString(i)
   735  		case f.DeferredCausetAsName.L == "default_role_user":
   736  			value.DefaultRoleUser = event.GetString(i)
   737  		default:
   738  			value.assignUserOrHost(event, i, f)
   739  		}
   740  	}
   741  	p.DefaultRoles = append(p.DefaultRoles, value)
   742  	return nil
   743  }
   744  
   745  func (p *MyALLEGROSQLPrivilege) decodeDeferredCausetsPrivTableRow(event chunk.Row, fs []*ast.ResultField) error {
   746  	var value columnsPrivRecord
   747  	for i, f := range fs {
   748  		switch {
   749  		case f.DeferredCausetAsName.L == "EDB":
   750  			value.EDB = event.GetString(i)
   751  		case f.DeferredCausetAsName.L == "block_name":
   752  			value.TableName = event.GetString(i)
   753  		case f.DeferredCausetAsName.L == "column_name":
   754  			value.DeferredCausetName = event.GetString(i)
   755  		case f.DeferredCausetAsName.L == "timestamp":
   756  			var err error
   757  			value.Timestamp, err = event.GetTime(i).GoTime(time.Local)
   758  			if err != nil {
   759  				return errors.Trace(err)
   760  			}
   761  		case f.DeferredCausetAsName.L == "column_priv":
   762  			value.DeferredCausetPriv = decodeSetToPrivilege(event.GetSet(i))
   763  		default:
   764  			value.assignUserOrHost(event, i, f)
   765  		}
   766  	}
   767  	p.DeferredCausetsPriv = append(p.DeferredCausetsPriv, value)
   768  	return nil
   769  }
   770  
   771  func decodeSetToPrivilege(s types.Set) allegrosql.PrivilegeType {
   772  	var ret allegrosql.PrivilegeType
   773  	if s.Name == "" {
   774  		return ret
   775  	}
   776  	for _, str := range strings.Split(s.Name, ",") {
   777  		priv, ok := allegrosql.SetStr2Priv[str]
   778  		if !ok {
   779  			logutil.BgLogger().Warn("unsupported privilege", zap.String("type", str))
   780  			continue
   781  		}
   782  		ret |= priv
   783  	}
   784  	return ret
   785  }
   786  
   787  // hostMatch checks if giving IP is in IP range of hostname.
   788  // In MyALLEGROSQL, the hostname of user can be set to `<IPv4>/<netmask>`
   789  // e.g. `127.0.0.0/255.255.255.0` represent IP range from `127.0.0.1` to `127.0.0.255`,
   790  // only IP addresses that satisfy this condition range can be login with this user.
   791  // See https://dev.allegrosql.com/doc/refman/5.7/en/account-names.html
   792  func (record *baseRecord) hostMatch(s string) bool {
   793  	if record.hostIPNet == nil {
   794  		return false
   795  	}
   796  	ip := net.ParseIP(s).To4()
   797  	if ip == nil {
   798  		return false
   799  	}
   800  	return record.hostIPNet.Contains(ip)
   801  }
   802  
   803  func (record *baseRecord) match(user, host string) bool {
   804  	return record.User == user && (patternMatch(host, record.patChars, record.patTypes) ||
   805  		record.hostMatch(host))
   806  }
   807  
   808  func (record *baseRecord) fullyMatch(user, host string) bool {
   809  	return record.User == user && record.Host == host
   810  }
   811  
   812  func (record *dbRecord) match(user, host, EDB string) bool {
   813  	return record.baseRecord.match(user, host) &&
   814  		patternMatch(strings.ToUpper(EDB), record.dbPatChars, record.dbPatTypes)
   815  }
   816  
   817  func (record *blocksPrivRecord) match(user, host, EDB, causet string) bool {
   818  	return record.baseRecord.match(user, host) &&
   819  		strings.EqualFold(record.EDB, EDB) &&
   820  		strings.EqualFold(record.TableName, causet)
   821  }
   822  
   823  func (record *columnsPrivRecord) match(user, host, EDB, causet, col string) bool {
   824  	return record.baseRecord.match(user, host) &&
   825  		strings.EqualFold(record.EDB, EDB) &&
   826  		strings.EqualFold(record.TableName, causet) &&
   827  		strings.EqualFold(record.DeferredCausetName, col)
   828  }
   829  
   830  // patternMatch matches "%" the same way as ".*" in regular memex, for example,
   831  // "10.0.%" would match "10.0.1" "10.0.1.118" ...
   832  func patternMatch(str string, patChars, patTypes []byte) bool {
   833  	return stringutil.DoMatch(str, patChars, patTypes)
   834  }
   835  
   836  // connectionVerification verifies the connection have access to MilevaDB server.
   837  func (p *MyALLEGROSQLPrivilege) connectionVerification(user, host string) *UserRecord {
   838  	for i := 0; i < len(p.User); i++ {
   839  		record := &p.User[i]
   840  		if record.match(user, host) {
   841  			return record
   842  		}
   843  	}
   844  	return nil
   845  }
   846  
   847  func (p *MyALLEGROSQLPrivilege) matchGlobalPriv(user, host string) *globalPrivRecord {
   848  	uGlobal, exists := p.Global[user]
   849  	if !exists {
   850  		return nil
   851  	}
   852  	for i := 0; i < len(uGlobal); i++ {
   853  		record := &uGlobal[i]
   854  		if record.match(user, host) {
   855  			return record
   856  		}
   857  	}
   858  	return nil
   859  }
   860  
   861  func (p *MyALLEGROSQLPrivilege) matchUser(user, host string) *UserRecord {
   862  	records, exists := p.UserMap[user]
   863  	if exists {
   864  		for i := 0; i < len(records); i++ {
   865  			record := &records[i]
   866  			if record.match(user, host) {
   867  				return record
   868  			}
   869  		}
   870  	}
   871  	return nil
   872  }
   873  
   874  func (p *MyALLEGROSQLPrivilege) matchDB(user, host, EDB string) *dbRecord {
   875  	records, exists := p.DBMap[user]
   876  	if exists {
   877  		for i := 0; i < len(records); i++ {
   878  			record := &records[i]
   879  			if record.match(user, host, EDB) {
   880  				return record
   881  			}
   882  		}
   883  	}
   884  	return nil
   885  }
   886  
   887  func (p *MyALLEGROSQLPrivilege) matchTables(user, host, EDB, causet string) *blocksPrivRecord {
   888  	records, exists := p.TablesPrivMap[user]
   889  	if exists {
   890  		for i := 0; i < len(records); i++ {
   891  			record := &records[i]
   892  			if record.match(user, host, EDB, causet) {
   893  				return record
   894  			}
   895  		}
   896  	}
   897  	return nil
   898  }
   899  
   900  func (p *MyALLEGROSQLPrivilege) matchDeferredCausets(user, host, EDB, causet, column string) *columnsPrivRecord {
   901  	for i := 0; i < len(p.DeferredCausetsPriv); i++ {
   902  		record := &p.DeferredCausetsPriv[i]
   903  		if record.match(user, host, EDB, causet, column) {
   904  			return record
   905  		}
   906  	}
   907  	return nil
   908  }
   909  
   910  // RequestVerification checks whether the user have sufficient privileges to do the operation.
   911  func (p *MyALLEGROSQLPrivilege) RequestVerification(activeRoles []*auth.RoleIdentity, user, host, EDB, causet, column string, priv allegrosql.PrivilegeType) bool {
   912  	roleList := p.FindAllRole(activeRoles)
   913  	roleList = append(roleList, &auth.RoleIdentity{Username: user, Hostname: host})
   914  
   915  	var userPriv, dbPriv, blockPriv, columnPriv allegrosql.PrivilegeType
   916  	for _, r := range roleList {
   917  		userRecord := p.matchUser(r.Username, r.Hostname)
   918  		if userRecord != nil {
   919  			userPriv |= userRecord.Privileges
   920  		}
   921  	}
   922  	if userPriv&priv > 0 {
   923  		return true
   924  	}
   925  
   926  	for _, r := range roleList {
   927  		dbRecord := p.matchDB(r.Username, r.Hostname, EDB)
   928  		if dbRecord != nil {
   929  			dbPriv |= dbRecord.Privileges
   930  		}
   931  	}
   932  	if dbPriv&priv > 0 {
   933  		return true
   934  	}
   935  
   936  	for _, r := range roleList {
   937  		blockRecord := p.matchTables(r.Username, r.Hostname, EDB, causet)
   938  		if blockRecord != nil {
   939  			blockPriv |= blockRecord.TablePriv
   940  			if column != "" {
   941  				columnPriv |= blockRecord.DeferredCausetPriv
   942  			}
   943  		}
   944  	}
   945  	if blockPriv&priv > 0 || columnPriv&priv > 0 {
   946  		return true
   947  	}
   948  
   949  	columnPriv = 0
   950  	for _, r := range roleList {
   951  		columnRecord := p.matchDeferredCausets(r.Username, r.Hostname, EDB, causet, column)
   952  		if columnRecord != nil {
   953  			columnPriv |= columnRecord.DeferredCausetPriv
   954  		}
   955  	}
   956  	if columnPriv&priv > 0 {
   957  		return true
   958  	}
   959  
   960  	return priv == 0
   961  }
   962  
   963  // DBIsVisible checks whether the user can see the EDB.
   964  func (p *MyALLEGROSQLPrivilege) DBIsVisible(user, host, EDB string) bool {
   965  	if record := p.matchUser(user, host); record != nil {
   966  		if record.Privileges&globalDBVisible > 0 {
   967  			return true
   968  		}
   969  	}
   970  
   971  	// INFORMATION_SCHEMA is visible to all users.
   972  	if strings.EqualFold(EDB, "INFORMATION_SCHEMA") {
   973  		return true
   974  	}
   975  
   976  	if record := p.matchDB(user, host, EDB); record != nil {
   977  		if record.Privileges > 0 {
   978  			return true
   979  		}
   980  	}
   981  
   982  	for _, record := range p.TablesPriv {
   983  		if record.baseRecord.match(user, host) &&
   984  			strings.EqualFold(record.EDB, EDB) {
   985  			if record.TablePriv != 0 || record.DeferredCausetPriv != 0 {
   986  				return true
   987  			}
   988  		}
   989  	}
   990  
   991  	for _, record := range p.DeferredCausetsPriv {
   992  		if record.baseRecord.match(user, host) &&
   993  			strings.EqualFold(record.EDB, EDB) {
   994  			if record.DeferredCausetPriv != 0 {
   995  				return true
   996  			}
   997  		}
   998  	}
   999  
  1000  	return false
  1001  }
  1002  
  1003  func (p *MyALLEGROSQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentity) []string {
  1004  	var gs []string
  1005  	var hasGlobalGrant = false
  1006  	// Some privileges may granted from role inheritance.
  1007  	// We should find these inheritance relationship.
  1008  	allRoles := p.FindAllRole(roles)
  1009  	// Show global grants.
  1010  	var currentPriv allegrosql.PrivilegeType
  1011  	var hasGrantOptionPriv, userExists = false, false
  1012  	// Check whether user exists.
  1013  	if userList, ok := p.UserMap[user]; ok {
  1014  		for _, record := range userList {
  1015  			if record.fullyMatch(user, host) {
  1016  				userExists = true
  1017  				break
  1018  			}
  1019  		}
  1020  		if !userExists {
  1021  			return gs
  1022  		}
  1023  	}
  1024  	var g string
  1025  	for _, record := range p.User {
  1026  		if record.fullyMatch(user, host) {
  1027  			hasGlobalGrant = true
  1028  			if (record.Privileges & allegrosql.GrantPriv) > 0 {
  1029  				hasGrantOptionPriv = true
  1030  				currentPriv |= (record.Privileges & ^allegrosql.GrantPriv)
  1031  				continue
  1032  			}
  1033  			currentPriv |= record.Privileges
  1034  		} else {
  1035  			for _, r := range allRoles {
  1036  				if record.baseRecord.match(r.Username, r.Hostname) {
  1037  					hasGlobalGrant = true
  1038  					if (record.Privileges & allegrosql.GrantPriv) > 0 {
  1039  						hasGrantOptionPriv = true
  1040  						currentPriv |= (record.Privileges & ^allegrosql.GrantPriv)
  1041  						continue
  1042  					}
  1043  					currentPriv |= record.Privileges
  1044  				}
  1045  			}
  1046  		}
  1047  	}
  1048  	g = userPrivToString(currentPriv)
  1049  	if len(g) > 0 {
  1050  		var s string
  1051  		if hasGrantOptionPriv {
  1052  			s = fmt.Sprintf(`GRANT %s ON *.* TO '%s'@'%s' WITH GRANT OPTION`, g, user, host)
  1053  
  1054  		} else {
  1055  			s = fmt.Sprintf(`GRANT %s ON *.* TO '%s'@'%s'`, g, user, host)
  1056  
  1057  		}
  1058  		gs = append(gs, s)
  1059  	}
  1060  
  1061  	// This is a allegrosql convention.
  1062  	if len(gs) == 0 && hasGlobalGrant {
  1063  		var s string
  1064  		if hasGrantOptionPriv {
  1065  			s = fmt.Sprintf("GRANT USAGE ON *.* TO '%s'@'%s' WITH GRANT OPTION", user, host)
  1066  		} else {
  1067  			s = fmt.Sprintf("GRANT USAGE ON *.* TO '%s'@'%s'", user, host)
  1068  		}
  1069  		gs = append(gs, s)
  1070  	}
  1071  
  1072  	// Show EDB scope grants.
  1073  	dbPrivTable := make(map[string]allegrosql.PrivilegeType)
  1074  	for _, record := range p.EDB {
  1075  		if record.fullyMatch(user, host) {
  1076  			if _, ok := dbPrivTable[record.EDB]; ok {
  1077  				if (record.Privileges & allegrosql.GrantPriv) > 0 {
  1078  					hasGrantOptionPriv = true
  1079  					dbPrivTable[record.EDB] |= (record.Privileges & ^allegrosql.GrantPriv)
  1080  					continue
  1081  				}
  1082  				dbPrivTable[record.EDB] |= record.Privileges
  1083  			} else {
  1084  				if (record.Privileges & allegrosql.GrantPriv) > 0 {
  1085  					hasGrantOptionPriv = true
  1086  					dbPrivTable[record.EDB] = (record.Privileges & ^allegrosql.GrantPriv)
  1087  					continue
  1088  				}
  1089  				dbPrivTable[record.EDB] = record.Privileges
  1090  			}
  1091  		} else {
  1092  			for _, r := range allRoles {
  1093  				if record.baseRecord.match(r.Username, r.Hostname) {
  1094  					if _, ok := dbPrivTable[record.EDB]; ok {
  1095  						if (record.Privileges & allegrosql.GrantPriv) > 0 {
  1096  							hasGrantOptionPriv = true
  1097  							dbPrivTable[record.EDB] |= (record.Privileges & ^allegrosql.GrantPriv)
  1098  							continue
  1099  						}
  1100  						dbPrivTable[record.EDB] |= record.Privileges
  1101  					} else {
  1102  						if (record.Privileges & allegrosql.GrantPriv) > 0 {
  1103  							hasGrantOptionPriv = true
  1104  							dbPrivTable[record.EDB] = (record.Privileges & ^allegrosql.GrantPriv)
  1105  							continue
  1106  						}
  1107  						dbPrivTable[record.EDB] = record.Privileges
  1108  					}
  1109  				}
  1110  			}
  1111  		}
  1112  	}
  1113  	for dbName, priv := range dbPrivTable {
  1114  		g := dbPrivToString(priv)
  1115  		if len(g) > 0 {
  1116  			var s string
  1117  			if hasGrantOptionPriv {
  1118  				s = fmt.Sprintf(`GRANT %s ON %s.* TO '%s'@'%s' WITH GRANT OPTION`, g, dbName, user, host)
  1119  
  1120  			} else {
  1121  				s = fmt.Sprintf(`GRANT %s ON %s.* TO '%s'@'%s'`, g, dbName, user, host)
  1122  
  1123  			}
  1124  			gs = append(gs, s)
  1125  		}
  1126  	}
  1127  
  1128  	// Show causet scope grants.
  1129  	blockPrivTable := make(map[string]allegrosql.PrivilegeType)
  1130  	for _, record := range p.TablesPriv {
  1131  		recordKey := record.EDB + "." + record.TableName
  1132  		if user == record.User && host == record.Host {
  1133  			if _, ok := dbPrivTable[record.EDB]; ok {
  1134  				if (record.TablePriv & allegrosql.GrantPriv) > 0 {
  1135  					hasGrantOptionPriv = true
  1136  					blockPrivTable[recordKey] |= (record.TablePriv & ^allegrosql.GrantPriv)
  1137  					continue
  1138  				}
  1139  				blockPrivTable[recordKey] |= record.TablePriv
  1140  			} else {
  1141  				if (record.TablePriv & allegrosql.GrantPriv) > 0 {
  1142  					hasGrantOptionPriv = true
  1143  					blockPrivTable[recordKey] = (record.TablePriv & ^allegrosql.GrantPriv)
  1144  					continue
  1145  				}
  1146  				blockPrivTable[recordKey] = record.TablePriv
  1147  			}
  1148  		} else {
  1149  			for _, r := range allRoles {
  1150  				if record.baseRecord.match(r.Username, r.Hostname) {
  1151  					if _, ok := dbPrivTable[record.EDB]; ok {
  1152  						if (record.TablePriv & allegrosql.GrantPriv) > 0 {
  1153  							hasGrantOptionPriv = true
  1154  							blockPrivTable[recordKey] |= (record.TablePriv & ^allegrosql.GrantPriv)
  1155  							continue
  1156  						}
  1157  						blockPrivTable[recordKey] |= record.TablePriv
  1158  					} else {
  1159  						if (record.TablePriv & allegrosql.GrantPriv) > 0 {
  1160  							hasGrantOptionPriv = true
  1161  							blockPrivTable[recordKey] = (record.TablePriv & ^allegrosql.GrantPriv)
  1162  							continue
  1163  						}
  1164  						blockPrivTable[recordKey] = record.TablePriv
  1165  					}
  1166  				}
  1167  			}
  1168  		}
  1169  	}
  1170  	for k, priv := range blockPrivTable {
  1171  		g := blockPrivToString(priv)
  1172  		if len(g) > 0 {
  1173  			var s string
  1174  			if hasGrantOptionPriv {
  1175  				s = fmt.Sprintf(`GRANT %s ON %s TO '%s'@'%s' WITH GRANT OPTION`, g, k, user, host)
  1176  			} else {
  1177  				s = fmt.Sprintf(`GRANT %s ON %s TO '%s'@'%s'`, g, k, user, host)
  1178  			}
  1179  			gs = append(gs, s)
  1180  		}
  1181  	}
  1182  
  1183  	// Show column scope grants, column and causet are combined.
  1184  	// A map of "EDB.Block" => Priv(col1, col2 ...)
  1185  	columnPrivTable := make(map[string]privOnDeferredCausets)
  1186  	for _, record := range p.DeferredCausetsPriv {
  1187  		if !collectDeferredCausetGrant(&record, user, host, columnPrivTable) {
  1188  			for _, r := range allRoles {
  1189  				collectDeferredCausetGrant(&record, r.Username, r.Hostname, columnPrivTable)
  1190  			}
  1191  		}
  1192  	}
  1193  	for k, v := range columnPrivTable {
  1194  		privDefCauss := privOnDeferredCausetsToString(v)
  1195  		s := fmt.Sprintf(`GRANT %s ON %s TO '%s'@'%s'`, privDefCauss, k, user, host)
  1196  		gs = append(gs, s)
  1197  	}
  1198  
  1199  	// Show role grants.
  1200  	graphKey := user + "@" + host
  1201  	edgeTable, ok := p.RoleGraph[graphKey]
  1202  	g = ""
  1203  	if ok {
  1204  		sortedRes := make([]string, 0, 10)
  1205  		for k := range edgeTable.roleList {
  1206  			role := strings.Split(k, "@")
  1207  			roleName, roleHost := role[0], role[1]
  1208  			tmp := fmt.Sprintf("'%s'@'%s'", roleName, roleHost)
  1209  			sortedRes = append(sortedRes, tmp)
  1210  		}
  1211  		sort.Strings(sortedRes)
  1212  		for i, r := range sortedRes {
  1213  			g += r
  1214  			if i != len(sortedRes)-1 {
  1215  				g += ", "
  1216  			}
  1217  		}
  1218  		s := fmt.Sprintf(`GRANT %s TO '%s'@'%s'`, g, user, host)
  1219  		gs = append(gs, s)
  1220  	}
  1221  	return gs
  1222  }
  1223  
  1224  type columnStr = string
  1225  type columnStrs = []columnStr
  1226  type privOnDeferredCausets = map[allegrosql.PrivilegeType]columnStrs
  1227  
  1228  func privOnDeferredCausetsToString(p privOnDeferredCausets) string {
  1229  	var buf bytes.Buffer
  1230  	idx := 0
  1231  	for _, priv := range allegrosql.AllDeferredCausetPrivs {
  1232  		v, ok := p[priv]
  1233  		if !ok || len(v) == 0 {
  1234  			continue
  1235  		}
  1236  
  1237  		if idx > 0 {
  1238  			buf.WriteString(", ")
  1239  		}
  1240  		fmt.Fprintf(&buf, "%s(", allegrosql.Priv2Str[priv])
  1241  		for i, col := range v {
  1242  			if i > 0 {
  1243  				fmt.Fprintf(&buf, ", ")
  1244  			}
  1245  			buf.WriteString(col)
  1246  		}
  1247  		buf.WriteString(")")
  1248  		idx++
  1249  	}
  1250  	return buf.String()
  1251  }
  1252  
  1253  func collectDeferredCausetGrant(record *columnsPrivRecord, user, host string, columnPrivTable map[string]privOnDeferredCausets) bool {
  1254  	if record.baseRecord.match(user, host) {
  1255  		recordKey := record.EDB + "." + record.TableName
  1256  		privDeferredCausets, ok := columnPrivTable[recordKey]
  1257  		if !ok {
  1258  			privDeferredCausets = make(map[allegrosql.PrivilegeType]columnStrs)
  1259  		}
  1260  
  1261  		for _, priv := range allegrosql.AllDeferredCausetPrivs {
  1262  			if priv&record.DeferredCausetPriv > 0 {
  1263  				old := privDeferredCausets[priv]
  1264  				privDeferredCausets[priv] = append(old, record.DeferredCausetName)
  1265  				columnPrivTable[recordKey] = privDeferredCausets
  1266  			}
  1267  		}
  1268  		return true
  1269  	}
  1270  	return false
  1271  }
  1272  
  1273  func userPrivToString(privs allegrosql.PrivilegeType) string {
  1274  	if privs == userTablePrivilegeMask {
  1275  		return allegrosql.AllPrivilegeLiteral
  1276  	}
  1277  	return privToString(privs, allegrosql.AllGlobalPrivs, allegrosql.Priv2Str)
  1278  }
  1279  
  1280  func dbPrivToString(privs allegrosql.PrivilegeType) string {
  1281  	if privs == dbTablePrivilegeMask {
  1282  		return allegrosql.AllPrivilegeLiteral
  1283  	}
  1284  	return privToString(privs, allegrosql.AllDBPrivs, allegrosql.Priv2SetStr)
  1285  }
  1286  
  1287  func blockPrivToString(privs allegrosql.PrivilegeType) string {
  1288  	if privs == blockPrivMask {
  1289  		return allegrosql.AllPrivilegeLiteral
  1290  	}
  1291  	return privToString(privs, allegrosql.AllTablePrivs, allegrosql.Priv2Str)
  1292  }
  1293  
  1294  func privToString(priv allegrosql.PrivilegeType, allPrivs []allegrosql.PrivilegeType, allPrivNames map[allegrosql.PrivilegeType]string) string {
  1295  	pstrs := make([]string, 0, 20)
  1296  	for _, p := range allPrivs {
  1297  		if priv&p == 0 {
  1298  			continue
  1299  		}
  1300  		s := allPrivNames[p]
  1301  		pstrs = append(pstrs, s)
  1302  	}
  1303  	return strings.Join(pstrs, ",")
  1304  }
  1305  
  1306  // UserPrivilegesTable provide data for INFORMATION_SCHEMA.USERS_PRIVILEGE causet.
  1307  func (p *MyALLEGROSQLPrivilege) UserPrivilegesTable() [][]types.Causet {
  1308  	var rows [][]types.Causet
  1309  	for _, user := range p.User {
  1310  		rows = appendUserPrivilegesTableRow(rows, user)
  1311  	}
  1312  	return rows
  1313  }
  1314  
  1315  func appendUserPrivilegesTableRow(rows [][]types.Causet, user UserRecord) [][]types.Causet {
  1316  	var isGranblock string
  1317  	if user.Privileges&allegrosql.GrantPriv > 0 {
  1318  		isGranblock = "YES"
  1319  	} else {
  1320  		isGranblock = "NO"
  1321  	}
  1322  	guarantee := fmt.Sprintf("'%s'@'%s'", user.User, user.Host)
  1323  
  1324  	for _, priv := range allegrosql.AllGlobalPrivs {
  1325  		if user.Privileges&priv > 0 {
  1326  			privilegeType := allegrosql.Priv2Str[priv]
  1327  			// +---------------------------+---------------+-------------------------+--------------+
  1328  			// | GRANTEE                   | TABLE_CATALOG | PRIVILEGE_TYPE          | IS_GRANTABLE |
  1329  			// +---------------------------+---------------+-------------------------+--------------+
  1330  			// | 'root'@'localhost'        | def           | SELECT                  | YES          |
  1331  			record := types.MakeCausets(guarantee, "def", privilegeType, isGranblock)
  1332  			rows = append(rows, record)
  1333  		}
  1334  	}
  1335  	return rows
  1336  }
  1337  
  1338  func (p *MyALLEGROSQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity {
  1339  	ret := make([]*auth.RoleIdentity, 0)
  1340  	for _, r := range p.DefaultRoles {
  1341  		if r.match(user, host) {
  1342  			ret = append(ret, &auth.RoleIdentity{Username: r.DefaultRoleUser, Hostname: r.DefaultRoleHost})
  1343  		}
  1344  	}
  1345  	return ret
  1346  }
  1347  
  1348  func (p *MyALLEGROSQLPrivilege) getAllRoles(user, host string) []*auth.RoleIdentity {
  1349  	key := user + "@" + host
  1350  	edgeTable, ok := p.RoleGraph[key]
  1351  	ret := make([]*auth.RoleIdentity, 0, len(edgeTable.roleList))
  1352  	if ok {
  1353  		for _, r := range edgeTable.roleList {
  1354  			ret = append(ret, r)
  1355  		}
  1356  	}
  1357  	return ret
  1358  }
  1359  
  1360  // Handle wraps MyALLEGROSQLPrivilege providing thread safe access.
  1361  type Handle struct {
  1362  	priv atomic.Value
  1363  }
  1364  
  1365  // NewHandle returns a Handle.
  1366  func NewHandle() *Handle {
  1367  	return &Handle{}
  1368  }
  1369  
  1370  // Get the MyALLEGROSQLPrivilege for read.
  1371  func (h *Handle) Get() *MyALLEGROSQLPrivilege {
  1372  	return h.priv.Load().(*MyALLEGROSQLPrivilege)
  1373  }
  1374  
  1375  // UFIDelate loads all the privilege info from ekv storage.
  1376  func (h *Handle) UFIDelate(ctx stochastikctx.Context) error {
  1377  	var priv MyALLEGROSQLPrivilege
  1378  	err := priv.LoadAll(ctx)
  1379  	if err != nil {
  1380  		return err
  1381  	}
  1382  
  1383  	h.priv.CausetStore(&priv)
  1384  	return nil
  1385  }