github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/checker/privilege.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package checker
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/pingcap/errors"
    23  	"github.com/pingcap/tidb/pkg/parser"
    24  	"github.com/pingcap/tidb/pkg/parser/ast"
    25  	"github.com/pingcap/tidb/pkg/parser/mysql"
    26  	_ "github.com/pingcap/tidb/pkg/types/parser_driver" // for parser driver
    27  	"github.com/pingcap/tidb/pkg/util/dbutil"
    28  	"github.com/pingcap/tidb/pkg/util/filter"
    29  	"github.com/pingcap/tidb/pkg/util/stringutil"
    30  	"github.com/pingcap/tiflow/dm/pkg/log"
    31  	"github.com/pingcap/tiflow/pkg/container/sortmap"
    32  	"go.uber.org/zap"
    33  )
    34  
    35  type tablePriv struct {
    36  	wholeTable bool
    37  	columns    map[string]struct{}
    38  }
    39  
    40  type dbPriv struct {
    41  	wholeDB bool
    42  	tables  map[string]tablePriv
    43  }
    44  
    45  type priv struct {
    46  	needGlobal bool
    47  	dbs        map[string]dbPriv
    48  }
    49  
    50  // SourceDumpPrivilegeChecker checks dump privileges of source DB.
    51  type SourceDumpPrivilegeChecker struct {
    52  	db                *sql.DB
    53  	dbinfo            *dbutil.DBConfig
    54  	checkTables       []filter.Table
    55  	consistency       string
    56  	dumpWholeInstance bool
    57  }
    58  
    59  // NewSourceDumpPrivilegeChecker returns a RealChecker.
    60  func NewSourceDumpPrivilegeChecker(
    61  	db *sql.DB,
    62  	dbinfo *dbutil.DBConfig,
    63  	checkTables []filter.Table,
    64  	consistency string,
    65  	dumpWholeInstance bool,
    66  ) RealChecker {
    67  	return &SourceDumpPrivilegeChecker{
    68  		db:                db,
    69  		dbinfo:            dbinfo,
    70  		checkTables:       checkTables,
    71  		consistency:       consistency,
    72  		dumpWholeInstance: dumpWholeInstance,
    73  	}
    74  }
    75  
    76  // Check implements the RealChecker interface.
    77  // We check RELOAD, SELECT, LOCK TABLES privileges according to consistency.
    78  func (pc *SourceDumpPrivilegeChecker) Check(ctx context.Context) *Result {
    79  	result := &Result{
    80  		Name:  pc.Name(),
    81  		Desc:  "check dump privileges of source DB",
    82  		State: StateFailure,
    83  		Extra: fmt.Sprintf("address of db instance - %s:%d", pc.dbinfo.Host, pc.dbinfo.Port),
    84  	}
    85  
    86  	grants, err := dbutil.ShowGrants(ctx, pc.db, "", "")
    87  	if err != nil {
    88  		markCheckError(result, err)
    89  		return result
    90  	}
    91  
    92  	dumpRequiredPrivs := make(map[mysql.PrivilegeType]priv)
    93  	// add required SELECT privilege
    94  	if pc.dumpWholeInstance {
    95  		dumpRequiredPrivs[mysql.SelectPriv] = priv{needGlobal: true}
    96  	} else {
    97  		dumpRequiredPrivs[mysql.SelectPriv] = priv{
    98  			needGlobal: false,
    99  			dbs:        genTableLevelPrivs(pc.checkTables),
   100  		}
   101  	}
   102  
   103  	switch pc.consistency {
   104  	case "auto", "flush":
   105  		dumpRequiredPrivs[mysql.ReloadPriv] = priv{needGlobal: true}
   106  	case "lock":
   107  		dumpRequiredPrivs[mysql.LockTablesPriv] = priv{needGlobal: true}
   108  	}
   109  
   110  	err2 := verifyPrivilegesWithResult(result, grants, dumpRequiredPrivs)
   111  	if err2 != nil {
   112  		result.Errors = append(result.Errors, err2)
   113  		result.Instruction = "Please grant the required privileges to the account."
   114  	} else {
   115  		result.State = StateSuccess
   116  	}
   117  	return result
   118  }
   119  
   120  // Name implements the RealChecker interface.
   121  func (pc *SourceDumpPrivilegeChecker) Name() string {
   122  	return "source db dump privilege checker"
   123  }
   124  
   125  /*****************************************************/
   126  
   127  // SourceReplicatePrivilegeChecker checks replication privileges of source DB.
   128  type SourceReplicatePrivilegeChecker struct {
   129  	db     *sql.DB
   130  	dbinfo *dbutil.DBConfig
   131  }
   132  
   133  // NewSourceReplicationPrivilegeChecker returns a RealChecker.
   134  func NewSourceReplicationPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig) RealChecker {
   135  	return &SourceReplicatePrivilegeChecker{db: db, dbinfo: dbinfo}
   136  }
   137  
   138  // Check implements the RealChecker interface.
   139  // We only check REPLICATION SLAVE, REPLICATION CLIENT privileges.
   140  func (pc *SourceReplicatePrivilegeChecker) Check(ctx context.Context) *Result {
   141  	result := &Result{
   142  		Name:  pc.Name(),
   143  		Desc:  "check replication privileges of source DB",
   144  		State: StateSuccess,
   145  		Extra: fmt.Sprintf("address of db instance - %s:%d", pc.dbinfo.Host, pc.dbinfo.Port),
   146  	}
   147  
   148  	grants, err := dbutil.ShowGrants(ctx, pc.db, "", "")
   149  	if err != nil {
   150  		markCheckError(result, err)
   151  		return result
   152  	}
   153  	replRequiredPrivs := map[mysql.PrivilegeType]priv{
   154  		mysql.ReplicationSlavePriv:  {needGlobal: true},
   155  		mysql.ReplicationClientPriv: {needGlobal: true},
   156  	}
   157  	err2 := verifyPrivilegesWithResult(result, grants, replRequiredPrivs)
   158  	if err2 != nil {
   159  		result.Errors = append(result.Errors, err2)
   160  		result.State = StateFailure
   161  		result.Instruction = "Grant the required privileges to the account."
   162  	}
   163  	return result
   164  }
   165  
   166  // Name implements the RealChecker interface.
   167  func (pc *SourceReplicatePrivilegeChecker) Name() string {
   168  	return "source db replication privilege checker"
   169  }
   170  
   171  type TargetPrivilegeChecker struct {
   172  	db     *sql.DB
   173  	dbinfo *dbutil.DBConfig
   174  }
   175  
   176  func NewTargetPrivilegeChecker(db *sql.DB, dbinfo *dbutil.DBConfig) RealChecker {
   177  	return &TargetPrivilegeChecker{db: db, dbinfo: dbinfo}
   178  }
   179  
   180  func (t *TargetPrivilegeChecker) Name() string {
   181  	return "target db privilege checker"
   182  }
   183  
   184  func (t *TargetPrivilegeChecker) Check(ctx context.Context) *Result {
   185  	result := &Result{
   186  		Name:  t.Name(),
   187  		Desc:  "check privileges of target DB",
   188  		State: StateSuccess,
   189  		Extra: fmt.Sprintf("address of db instance - %s:%d", t.dbinfo.Host, t.dbinfo.Port),
   190  	}
   191  	grants, err := dbutil.ShowGrants(ctx, t.db, "", "")
   192  	if err != nil {
   193  		markCheckError(result, err)
   194  		return result
   195  	}
   196  	replRequiredPrivs := map[mysql.PrivilegeType]priv{
   197  		mysql.CreatePriv: {needGlobal: true},
   198  		mysql.SelectPriv: {needGlobal: true},
   199  		mysql.InsertPriv: {needGlobal: true},
   200  		mysql.UpdatePriv: {needGlobal: true},
   201  		mysql.DeletePriv: {needGlobal: true},
   202  		mysql.AlterPriv:  {needGlobal: true},
   203  		mysql.DropPriv:   {needGlobal: true},
   204  		mysql.IndexPriv:  {needGlobal: true},
   205  	}
   206  	err2 := verifyPrivilegesWithResult(result, grants, replRequiredPrivs)
   207  	if err2 != nil {
   208  		result.Errors = append(result.Errors, err2)
   209  		// because we cannot be very precisely sure about which table
   210  		// the binlog will write, so we only throw a warning here.
   211  		result.State = StateWarning
   212  	}
   213  	return result
   214  }
   215  
   216  func verifyPrivilegesWithResult(
   217  	result *Result,
   218  	grants []string,
   219  	requiredPriv map[mysql.PrivilegeType]priv,
   220  ) *Error {
   221  	lackedPriv, err := VerifyPrivileges(grants, requiredPriv)
   222  	if err != nil {
   223  		return NewError(err.Error())
   224  	}
   225  	if len(lackedPriv) == 0 {
   226  		return nil
   227  	}
   228  
   229  	lackedPrivStr := LackedPrivilegesAsStr(lackedPriv)
   230  	result.Instruction = "You need grant related privileges."
   231  	log.L().Info("lack privilege", zap.String("err msg", lackedPrivStr))
   232  	return NewError(lackedPrivStr)
   233  }
   234  
   235  // LackedPrivilegesAsStr format lacked privileges as string.
   236  // lack of privilege1: {tableID1, tableID2, ...}; lack of privilege2...
   237  func LackedPrivilegesAsStr(lackPriv map[mysql.PrivilegeType]priv) string {
   238  	var b strings.Builder
   239  
   240  	for _, pair := range sortmap.Sort(lackPriv) {
   241  		b.WriteString("lack of ")
   242  		b.WriteString(pair.Key.String())
   243  		if pair.Value.needGlobal {
   244  			b.WriteString(" global (*.*)")
   245  		}
   246  		b.WriteString(" privilege")
   247  		if len(pair.Value.dbs) == 0 {
   248  			b.WriteString("; ")
   249  			continue
   250  		}
   251  
   252  		b.WriteString(": {")
   253  		i := 0
   254  		for _, pair2 := range sortmap.Sort(pair.Value.dbs) {
   255  			if pair2.Value.wholeDB {
   256  				b.WriteString(dbutil.ColumnName(pair2.Key))
   257  				b.WriteString(".*; ")
   258  				continue
   259  			}
   260  
   261  			j := 0
   262  			for table := range pair2.Value.tables {
   263  				b.WriteString(dbutil.TableName(pair2.Key, table))
   264  				j++
   265  				if j != len(pair2.Value.tables) {
   266  					b.WriteString(", ")
   267  				}
   268  			}
   269  			i++
   270  			if i != len(pair.Value.dbs) {
   271  				b.WriteString("; ")
   272  			}
   273  		}
   274  		b.WriteString("}; ")
   275  	}
   276  
   277  	return b.String()
   278  }
   279  
   280  // VerifyPrivileges verify user privileges, returns lacked privileges. this function modifies lackPriv in place.
   281  // we expose it so other component can reuse it.
   282  func VerifyPrivileges(
   283  	grants []string,
   284  	lackPrivs map[mysql.PrivilegeType]priv,
   285  ) (map[mysql.PrivilegeType]priv, error) {
   286  	if len(grants) == 0 {
   287  		return nil, errors.New("there is no such grant defined for current user on host '%%'")
   288  	}
   289  
   290  	p := parser.New()
   291  	for _, grant := range grants {
   292  		if len(lackPrivs) == 0 {
   293  			break
   294  		}
   295  		node, err := p.ParseOneStmt(grant, "", "")
   296  		if err != nil {
   297  			return nil, errors.New(err.Error())
   298  		}
   299  		grantStmt, ok := node.(*ast.GrantStmt)
   300  		if !ok {
   301  			switch node.(type) {
   302  			case *ast.GrantProxyStmt, *ast.GrantRoleStmt:
   303  				continue
   304  			default:
   305  				return nil, errors.Errorf("%s is not grant statement", grant)
   306  			}
   307  		}
   308  
   309  		if len(grantStmt.Users) == 0 {
   310  			return nil, errors.Errorf("grant has no user %s", grant)
   311  		}
   312  
   313  		dbPatChar, dbPatType := stringutil.CompilePattern(grantStmt.Level.DBName, '\\')
   314  		tableName := grantStmt.Level.TableName
   315  		switch grantStmt.Level.Level {
   316  		case ast.GrantLevelGlobal:
   317  			for _, privElem := range grantStmt.Privs {
   318  				// all privileges available at a given privilege level (except GRANT OPTION)
   319  				// from https://dev.mysql.com/doc/refman/5.7/en/privileges-provided.html#priv_all
   320  				if privElem.Priv == mysql.AllPriv {
   321  					if _, ok := lackPrivs[mysql.GrantPriv]; ok {
   322  						lackPrivs = map[mysql.PrivilegeType]priv{
   323  							mysql.GrantPriv: {needGlobal: true},
   324  						}
   325  						continue
   326  					}
   327  					return nil, nil
   328  				}
   329  				// mysql> show master status;
   330  				// ERROR 1227 (42000): Access denied; you need (at least one of) the SUPER, REPLICATION CLIENT privilege(s) for this operation
   331  				if privElem.Priv == mysql.SuperPriv {
   332  					delete(lackPrivs, mysql.ReplicationClientPriv)
   333  				}
   334  				delete(lackPrivs, privElem.Priv)
   335  			}
   336  		case ast.GrantLevelDB:
   337  			for _, privElem := range grantStmt.Privs {
   338  				// all privileges available at a given privilege level (except GRANT OPTION)
   339  				// from https://dev.mysql.com/doc/refman/5.7/en/privileges-provided.html#priv_all
   340  				if privElem.Priv == mysql.AllPriv {
   341  					for _, privs := range lackPrivs {
   342  						if privs.needGlobal {
   343  							continue
   344  						}
   345  						for dbName := range privs.dbs {
   346  							if stringutil.DoMatch(dbName, dbPatChar, dbPatType) {
   347  								delete(privs.dbs, dbName)
   348  							}
   349  						}
   350  					}
   351  					continue
   352  				}
   353  				privs, ok := lackPrivs[privElem.Priv]
   354  				if !ok || privs.needGlobal {
   355  					continue
   356  				}
   357  				// dumpling could report error if an allow-list table is lack of privilege.
   358  				// we only check that SELECT is granted on all columns, otherwise we can't SHOW CREATE TABLE
   359  				if privElem.Priv == mysql.SelectPriv && len(privElem.Cols) != 0 {
   360  					continue
   361  				}
   362  				for dbName := range privs.dbs {
   363  					if stringutil.DoMatch(dbName, dbPatChar, dbPatType) {
   364  						delete(privs.dbs, dbName)
   365  					}
   366  				}
   367  			}
   368  		case ast.GrantLevelTable:
   369  			dbName := grantStmt.Level.DBName
   370  			for _, privElem := range grantStmt.Privs {
   371  				// all privileges available at a given privilege level (except GRANT OPTION)
   372  				// from https://dev.mysql.com/doc/refman/5.7/en/privileges-provided.html#priv_all
   373  				if privElem.Priv == mysql.AllPriv {
   374  					for _, privs := range lackPrivs {
   375  						if privs.needGlobal {
   376  							continue
   377  						}
   378  						dbPrivs, ok := privs.dbs[dbName]
   379  						if !ok || dbPrivs.wholeDB {
   380  							continue
   381  						}
   382  						delete(dbPrivs.tables, tableName)
   383  					}
   384  					continue
   385  				}
   386  				privs, ok := lackPrivs[privElem.Priv]
   387  				if !ok || privs.needGlobal {
   388  					continue
   389  				}
   390  				dbPrivs, ok := privs.dbs[dbName]
   391  				if !ok || dbPrivs.wholeDB {
   392  					continue
   393  				}
   394  				// dumpling could report error if an allow-list table is lack of privilege.
   395  				// we only check that SELECT is granted on all columns, otherwise we can't SHOW CREATE TABLE
   396  				if privElem.Priv == mysql.SelectPriv && len(privElem.Cols) != 0 {
   397  					continue
   398  				}
   399  				delete(dbPrivs.tables, tableName)
   400  			}
   401  		}
   402  	}
   403  
   404  	// purge empty leaves
   405  	for privName, privs := range lackPrivs {
   406  		for dbName, dbPrivs := range privs.dbs {
   407  			for tableName, tablePrivs := range dbPrivs.tables {
   408  				if !tablePrivs.wholeTable && len(tablePrivs.columns) == 0 {
   409  					delete(dbPrivs.tables, tableName)
   410  				}
   411  			}
   412  			if !dbPrivs.wholeDB && len(dbPrivs.tables) == 0 {
   413  				delete(privs.dbs, dbName)
   414  			}
   415  		}
   416  		if !privs.needGlobal && len(privs.dbs) == 0 {
   417  			delete(lackPrivs, privName)
   418  		}
   419  	}
   420  
   421  	return lackPrivs, nil
   422  }
   423  
   424  func genTableLevelPrivs(tables []filter.Table) map[string]dbPriv {
   425  	ret := make(map[string]dbPriv)
   426  	for _, table := range tables {
   427  		if _, ok := ret[table.Schema]; !ok {
   428  			ret[table.Schema] = dbPriv{wholeDB: false, tables: make(map[string]tablePriv)}
   429  		}
   430  		ret[table.Schema].tables[table.Name] = tablePriv{wholeTable: true}
   431  	}
   432  	return ret
   433  }