github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/privilege/privileges/privileges.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package privileges
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  
    20  	"github.com/insionng/yougam/libraries/juju/errors"
    21  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/context"
    23  	"github.com/insionng/yougam/libraries/pingcap/tidb/model"
    24  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    25  	"github.com/insionng/yougam/libraries/pingcap/tidb/privilege"
    26  	"github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx/variable"
    27  	"github.com/insionng/yougam/libraries/pingcap/tidb/terror"
    28  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/sqlexec"
    29  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    30  )
    31  
    32  // privilege error codes.
    33  const (
    34  	codeInvalidPrivilegeType  terror.ErrCode = 1
    35  	codeInvalidUserNameFormat                = 2
    36  )
    37  
    38  var (
    39  	errInvalidPrivilegeType  = terror.ClassPrivilege.New(codeInvalidPrivilegeType, "unknown privilege type")
    40  	errInvalidUserNameFormat = terror.ClassPrivilege.New(codeInvalidUserNameFormat, "wrong username format")
    41  )
    42  
    43  var _ privilege.Checker = (*UserPrivileges)(nil)
    44  
    45  type privileges struct {
    46  	Level ast.GrantLevelType
    47  	privs map[mysql.PrivilegeType]bool
    48  }
    49  
    50  func (ps *privileges) contain(p mysql.PrivilegeType) bool {
    51  	if ps.privs == nil {
    52  		return false
    53  	}
    54  	_, ok := ps.privs[p]
    55  	return ok
    56  }
    57  
    58  func (ps *privileges) add(p mysql.PrivilegeType) {
    59  	if ps.privs == nil {
    60  		ps.privs = make(map[mysql.PrivilegeType]bool)
    61  	}
    62  	ps.privs[p] = true
    63  }
    64  
    65  func (ps *privileges) String() string {
    66  	switch ps.Level {
    67  	case ast.GrantLevelGlobal:
    68  		return ps.globalPrivToString()
    69  	case ast.GrantLevelDB:
    70  		return ps.dbPrivToString()
    71  	case ast.GrantLevelTable:
    72  		return ps.tablePrivToString()
    73  	}
    74  	return ""
    75  }
    76  
    77  func (ps *privileges) globalPrivToString() string {
    78  	if len(ps.privs) == len(mysql.AllGlobalPrivs) {
    79  		return mysql.AllPrivilegeLiteral
    80  	}
    81  	pstrs := make([]string, 0, len(ps.privs))
    82  	// Iterate AllGlobalPrivs to get stable order result.
    83  	for _, p := range mysql.AllGlobalPrivs {
    84  		_, ok := ps.privs[p]
    85  		if !ok {
    86  			continue
    87  		}
    88  		s, _ := mysql.Priv2Str[p]
    89  		pstrs = append(pstrs, s)
    90  	}
    91  	return strings.Join(pstrs, ",")
    92  }
    93  
    94  func (ps *privileges) dbPrivToString() string {
    95  	if len(ps.privs) == len(mysql.AllDBPrivs) {
    96  		return mysql.AllPrivilegeLiteral
    97  	}
    98  	pstrs := make([]string, 0, len(ps.privs))
    99  	// Iterate AllDBPrivs to get stable order result.
   100  	for _, p := range mysql.AllDBPrivs {
   101  		_, ok := ps.privs[p]
   102  		if !ok {
   103  			continue
   104  		}
   105  		s, _ := mysql.Priv2SetStr[p]
   106  		pstrs = append(pstrs, s)
   107  	}
   108  	return strings.Join(pstrs, ",")
   109  }
   110  
   111  func (ps *privileges) tablePrivToString() string {
   112  	if len(ps.privs) == len(mysql.AllTablePrivs) {
   113  		return mysql.AllPrivilegeLiteral
   114  	}
   115  	pstrs := make([]string, 0, len(ps.privs))
   116  	// Iterate AllTablePrivs to get stable order result.
   117  	for _, p := range mysql.AllTablePrivs {
   118  		_, ok := ps.privs[p]
   119  		if !ok {
   120  			continue
   121  		}
   122  		s, _ := mysql.Priv2Str[p]
   123  		pstrs = append(pstrs, s)
   124  	}
   125  	return strings.Join(pstrs, ",")
   126  }
   127  
   128  type userPrivileges struct {
   129  	User string
   130  	Host string
   131  	// Global privileges
   132  	GlobalPrivs *privileges
   133  	// DBName-privileges
   134  	DBPrivs map[string]*privileges
   135  	// DBName-TableName-privileges
   136  	TablePrivs map[string]map[string]*privileges
   137  }
   138  
   139  func (ps *userPrivileges) ShowGrants() []string {
   140  	gs := []string{}
   141  	// Show global grants
   142  	g := ps.GlobalPrivs.String()
   143  	if len(g) > 0 {
   144  		s := fmt.Sprintf(`GRANT %s ON *.* TO '%s'@'%s'`, g, ps.User, ps.Host)
   145  		gs = append(gs, s)
   146  	}
   147  	// Show db scope grants
   148  	for d, p := range ps.DBPrivs {
   149  		g := p.String()
   150  		if len(g) > 0 {
   151  			s := fmt.Sprintf(`GRANT %s ON %s.* TO '%s'@'%s'`, g, d, ps.User, ps.Host)
   152  			gs = append(gs, s)
   153  		}
   154  	}
   155  	// Show table scope grants
   156  	for d, dps := range ps.TablePrivs {
   157  		for t, p := range dps {
   158  			g := p.String()
   159  			if len(g) > 0 {
   160  				s := fmt.Sprintf(`GRANT %s ON %s.%s TO '%s'@'%s'`, g, d, t, ps.User, ps.Host)
   161  				gs = append(gs, s)
   162  			}
   163  		}
   164  	}
   165  	return gs
   166  }
   167  
   168  // UserPrivileges implements privilege.Checker interface.
   169  // This is used to check privilege for the current user.
   170  type UserPrivileges struct {
   171  	User  string
   172  	privs *userPrivileges
   173  }
   174  
   175  // Check implements Checker.Check interface.
   176  func (p *UserPrivileges) Check(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error) {
   177  	if p.privs == nil {
   178  		// Lazy load
   179  		if len(p.User) == 0 {
   180  			// User current user
   181  			p.User = variable.GetSessionVars(ctx).User
   182  			if len(p.User) == 0 {
   183  				// In embedded db mode, user does not need to login. So we do not have username.
   184  				// TODO: remove this check latter.
   185  				return true, nil
   186  			}
   187  		}
   188  		err := p.loadPrivileges(ctx)
   189  		if err != nil {
   190  			return false, errors.Trace(err)
   191  		}
   192  	}
   193  	// Check global scope privileges.
   194  	ok := p.privs.GlobalPrivs.contain(privilege)
   195  	if ok {
   196  		return true, nil
   197  	}
   198  	// Check db scope privileges.
   199  	dbp, ok := p.privs.DBPrivs[db.Name.O]
   200  	if ok {
   201  		ok = dbp.contain(privilege)
   202  		if ok {
   203  			return true, nil
   204  		}
   205  	}
   206  	if tbl == nil {
   207  		return false, nil
   208  	}
   209  	// Check table scope privileges.
   210  	dbTbl, ok := p.privs.TablePrivs[db.Name.O]
   211  	if !ok {
   212  		return false, nil
   213  	}
   214  	tblp, ok := dbTbl[tbl.Name.O]
   215  	if !ok {
   216  		return false, nil
   217  	}
   218  	return tblp.contain(privilege), nil
   219  }
   220  
   221  func (p *UserPrivileges) loadPrivileges(ctx context.Context) error {
   222  	strs := strings.Split(p.User, "@")
   223  	if len(strs) != 2 {
   224  		return errInvalidUserNameFormat.Gen("Wrong username format: %s", p.User)
   225  	}
   226  	username, host := strs[0], strs[1]
   227  	p.privs = &userPrivileges{
   228  		User: username,
   229  		Host: host,
   230  	}
   231  	// Load privileges from mysql.User/DB/Table_privs/Column_privs table
   232  	err := p.loadGlobalPrivileges(ctx)
   233  	if err != nil {
   234  		return errors.Trace(err)
   235  	}
   236  	err = p.loadDBScopePrivileges(ctx)
   237  	if err != nil {
   238  		return errors.Trace(err)
   239  	}
   240  	err = p.loadTableScopePrivileges(ctx)
   241  	if err != nil {
   242  		return errors.Trace(err)
   243  	}
   244  	// TODO: consider column scope privilege latter.
   245  	return nil
   246  }
   247  
   248  // mysql.User/mysql.DB table privilege columns start from index 3.
   249  // See: booststrap.go CreateUserTable/CreateDBPrivTable
   250  const userTablePrivColumnStartIndex = 3
   251  const dbTablePrivColumnStartIndex = 3
   252  
   253  func (p *UserPrivileges) loadGlobalPrivileges(ctx context.Context) error {
   254  	sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
   255  		mysql.SystemDB, mysql.UserTable, p.privs.User, p.privs.Host)
   256  	rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
   257  	if err != nil {
   258  		return errors.Trace(err)
   259  	}
   260  	defer rs.Close()
   261  	ps := &privileges{Level: ast.GrantLevelGlobal}
   262  	fs, err := rs.Fields()
   263  	if err != nil {
   264  		return errors.Trace(err)
   265  	}
   266  	for {
   267  		row, err := rs.Next()
   268  		if err != nil {
   269  			return errors.Trace(err)
   270  		}
   271  		if row == nil {
   272  			break
   273  		}
   274  		for i := userTablePrivColumnStartIndex; i < len(fs); i++ {
   275  			d := row.Data[i]
   276  			if d.Kind() != types.KindMysqlEnum {
   277  				return errInvalidPrivilegeType.Gen("Privilege should be mysql.Enum: %v(%T)", d, d)
   278  			}
   279  			ed := d.GetMysqlEnum()
   280  			if ed.String() != "Y" {
   281  				continue
   282  			}
   283  			f := fs[i]
   284  			p, ok := mysql.Col2PrivType[f.ColumnAsName.O]
   285  			if !ok {
   286  				return errInvalidPrivilegeType.Gen("Unknown Privilege Type!")
   287  			}
   288  			ps.add(p)
   289  		}
   290  	}
   291  	p.privs.GlobalPrivs = ps
   292  	return nil
   293  }
   294  
   295  func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error {
   296  	sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
   297  		mysql.SystemDB, mysql.DBTable, p.privs.User, p.privs.Host)
   298  	rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
   299  	if err != nil {
   300  		return errors.Trace(err)
   301  	}
   302  	defer rs.Close()
   303  	ps := make(map[string]*privileges)
   304  	fs, err := rs.Fields()
   305  	if err != nil {
   306  		return errors.Trace(err)
   307  	}
   308  	for {
   309  		row, err := rs.Next()
   310  		if err != nil {
   311  			return errors.Trace(err)
   312  		}
   313  		if row == nil {
   314  			break
   315  		}
   316  		// DB
   317  		dbStr := row.Data[1].GetString()
   318  		ps[dbStr] = &privileges{Level: ast.GrantLevelDB}
   319  		for i := dbTablePrivColumnStartIndex; i < len(fs); i++ {
   320  			d := row.Data[i]
   321  			if d.Kind() != types.KindMysqlEnum {
   322  				return errInvalidPrivilegeType.Gen("Privilege should be mysql.Enum: %v(%T)", d, d)
   323  			}
   324  			ed := d.GetMysqlEnum()
   325  			if ed.String() != "Y" {
   326  				continue
   327  			}
   328  			f := fs[i]
   329  			p, ok := mysql.Col2PrivType[f.ColumnAsName.O]
   330  			if !ok {
   331  				return errInvalidPrivilegeType.Gen("Unknown Privilege Type!")
   332  			}
   333  			ps[dbStr].add(p)
   334  		}
   335  	}
   336  	p.privs.DBPrivs = ps
   337  	return nil
   338  }
   339  
   340  func (p *UserPrivileges) loadTableScopePrivileges(ctx context.Context) error {
   341  	sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
   342  		mysql.SystemDB, mysql.TablePrivTable, p.privs.User, p.privs.Host)
   343  	rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
   344  	if err != nil {
   345  		return errors.Trace(err)
   346  	}
   347  	defer rs.Close()
   348  	ps := make(map[string]map[string]*privileges)
   349  	for {
   350  		row, err := rs.Next()
   351  		if err != nil {
   352  			return errors.Trace(err)
   353  		}
   354  		if row == nil {
   355  			break
   356  		}
   357  		// DB
   358  		dbStr := row.Data[1].GetString()
   359  		// Table_name
   360  		tblStr := row.Data[3].GetString()
   361  		_, ok := ps[dbStr]
   362  		if !ok {
   363  			ps[dbStr] = make(map[string]*privileges)
   364  		}
   365  		ps[dbStr][tblStr] = &privileges{Level: ast.GrantLevelTable}
   366  		// Table_priv
   367  		tblPrivs := row.Data[6].GetMysqlSet()
   368  		pvs := strings.Split(tblPrivs.Name, ",")
   369  		for _, d := range pvs {
   370  			p, ok := mysql.SetStr2Priv[d]
   371  			if !ok {
   372  				return errInvalidPrivilegeType.Gen("Unknown Privilege Type!")
   373  			}
   374  			ps[dbStr][tblStr].add(p)
   375  		}
   376  	}
   377  	p.privs.TablePrivs = ps
   378  	return nil
   379  }
   380  
   381  // ShowGrants implements privilege.Checker ShowGrants interface.
   382  func (p *UserPrivileges) ShowGrants(ctx context.Context, user string) ([]string, error) {
   383  	// If user is current user
   384  	if user == p.User {
   385  		return p.privs.ShowGrants(), nil
   386  	}
   387  	userp := &UserPrivileges{User: user}
   388  	err := userp.loadPrivileges(ctx)
   389  	if err != nil {
   390  		return nil, errors.Trace(err)
   391  	}
   392  	return userp.privs.ShowGrants(), nil
   393  }