github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/privilege/privileges/privileges.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package privileges 15 16 import ( 17 "fmt" 18 "strings" 19 20 "github.com/insionng/yougam/libraries/juju/errors" 21 "github.com/insionng/yougam/libraries/pingcap/tidb/ast" 22 "github.com/insionng/yougam/libraries/pingcap/tidb/context" 23 "github.com/insionng/yougam/libraries/pingcap/tidb/model" 24 "github.com/insionng/yougam/libraries/pingcap/tidb/mysql" 25 "github.com/insionng/yougam/libraries/pingcap/tidb/privilege" 26 "github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx/variable" 27 "github.com/insionng/yougam/libraries/pingcap/tidb/terror" 28 "github.com/insionng/yougam/libraries/pingcap/tidb/util/sqlexec" 29 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 30 ) 31 32 // privilege error codes. 33 const ( 34 codeInvalidPrivilegeType terror.ErrCode = 1 35 codeInvalidUserNameFormat = 2 36 ) 37 38 var ( 39 errInvalidPrivilegeType = terror.ClassPrivilege.New(codeInvalidPrivilegeType, "unknown privilege type") 40 errInvalidUserNameFormat = terror.ClassPrivilege.New(codeInvalidUserNameFormat, "wrong username format") 41 ) 42 43 var _ privilege.Checker = (*UserPrivileges)(nil) 44 45 type privileges struct { 46 Level ast.GrantLevelType 47 privs map[mysql.PrivilegeType]bool 48 } 49 50 func (ps *privileges) contain(p mysql.PrivilegeType) bool { 51 if ps.privs == nil { 52 return false 53 } 54 _, ok := ps.privs[p] 55 return ok 56 } 57 58 func (ps *privileges) add(p mysql.PrivilegeType) { 59 if ps.privs == nil { 60 ps.privs = make(map[mysql.PrivilegeType]bool) 61 } 62 ps.privs[p] = true 63 } 64 65 func (ps *privileges) String() string { 66 switch ps.Level { 67 case ast.GrantLevelGlobal: 68 return ps.globalPrivToString() 69 case ast.GrantLevelDB: 70 return ps.dbPrivToString() 71 case ast.GrantLevelTable: 72 return ps.tablePrivToString() 73 } 74 return "" 75 } 76 77 func (ps *privileges) globalPrivToString() string { 78 if len(ps.privs) == len(mysql.AllGlobalPrivs) { 79 return mysql.AllPrivilegeLiteral 80 } 81 pstrs := make([]string, 0, len(ps.privs)) 82 // Iterate AllGlobalPrivs to get stable order result. 83 for _, p := range mysql.AllGlobalPrivs { 84 _, ok := ps.privs[p] 85 if !ok { 86 continue 87 } 88 s, _ := mysql.Priv2Str[p] 89 pstrs = append(pstrs, s) 90 } 91 return strings.Join(pstrs, ",") 92 } 93 94 func (ps *privileges) dbPrivToString() string { 95 if len(ps.privs) == len(mysql.AllDBPrivs) { 96 return mysql.AllPrivilegeLiteral 97 } 98 pstrs := make([]string, 0, len(ps.privs)) 99 // Iterate AllDBPrivs to get stable order result. 100 for _, p := range mysql.AllDBPrivs { 101 _, ok := ps.privs[p] 102 if !ok { 103 continue 104 } 105 s, _ := mysql.Priv2SetStr[p] 106 pstrs = append(pstrs, s) 107 } 108 return strings.Join(pstrs, ",") 109 } 110 111 func (ps *privileges) tablePrivToString() string { 112 if len(ps.privs) == len(mysql.AllTablePrivs) { 113 return mysql.AllPrivilegeLiteral 114 } 115 pstrs := make([]string, 0, len(ps.privs)) 116 // Iterate AllTablePrivs to get stable order result. 117 for _, p := range mysql.AllTablePrivs { 118 _, ok := ps.privs[p] 119 if !ok { 120 continue 121 } 122 s, _ := mysql.Priv2Str[p] 123 pstrs = append(pstrs, s) 124 } 125 return strings.Join(pstrs, ",") 126 } 127 128 type userPrivileges struct { 129 User string 130 Host string 131 // Global privileges 132 GlobalPrivs *privileges 133 // DBName-privileges 134 DBPrivs map[string]*privileges 135 // DBName-TableName-privileges 136 TablePrivs map[string]map[string]*privileges 137 } 138 139 func (ps *userPrivileges) ShowGrants() []string { 140 gs := []string{} 141 // Show global grants 142 g := ps.GlobalPrivs.String() 143 if len(g) > 0 { 144 s := fmt.Sprintf(`GRANT %s ON *.* TO '%s'@'%s'`, g, ps.User, ps.Host) 145 gs = append(gs, s) 146 } 147 // Show db scope grants 148 for d, p := range ps.DBPrivs { 149 g := p.String() 150 if len(g) > 0 { 151 s := fmt.Sprintf(`GRANT %s ON %s.* TO '%s'@'%s'`, g, d, ps.User, ps.Host) 152 gs = append(gs, s) 153 } 154 } 155 // Show table scope grants 156 for d, dps := range ps.TablePrivs { 157 for t, p := range dps { 158 g := p.String() 159 if len(g) > 0 { 160 s := fmt.Sprintf(`GRANT %s ON %s.%s TO '%s'@'%s'`, g, d, t, ps.User, ps.Host) 161 gs = append(gs, s) 162 } 163 } 164 } 165 return gs 166 } 167 168 // UserPrivileges implements privilege.Checker interface. 169 // This is used to check privilege for the current user. 170 type UserPrivileges struct { 171 User string 172 privs *userPrivileges 173 } 174 175 // Check implements Checker.Check interface. 176 func (p *UserPrivileges) Check(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error) { 177 if p.privs == nil { 178 // Lazy load 179 if len(p.User) == 0 { 180 // User current user 181 p.User = variable.GetSessionVars(ctx).User 182 if len(p.User) == 0 { 183 // In embedded db mode, user does not need to login. So we do not have username. 184 // TODO: remove this check latter. 185 return true, nil 186 } 187 } 188 err := p.loadPrivileges(ctx) 189 if err != nil { 190 return false, errors.Trace(err) 191 } 192 } 193 // Check global scope privileges. 194 ok := p.privs.GlobalPrivs.contain(privilege) 195 if ok { 196 return true, nil 197 } 198 // Check db scope privileges. 199 dbp, ok := p.privs.DBPrivs[db.Name.O] 200 if ok { 201 ok = dbp.contain(privilege) 202 if ok { 203 return true, nil 204 } 205 } 206 if tbl == nil { 207 return false, nil 208 } 209 // Check table scope privileges. 210 dbTbl, ok := p.privs.TablePrivs[db.Name.O] 211 if !ok { 212 return false, nil 213 } 214 tblp, ok := dbTbl[tbl.Name.O] 215 if !ok { 216 return false, nil 217 } 218 return tblp.contain(privilege), nil 219 } 220 221 func (p *UserPrivileges) loadPrivileges(ctx context.Context) error { 222 strs := strings.Split(p.User, "@") 223 if len(strs) != 2 { 224 return errInvalidUserNameFormat.Gen("Wrong username format: %s", p.User) 225 } 226 username, host := strs[0], strs[1] 227 p.privs = &userPrivileges{ 228 User: username, 229 Host: host, 230 } 231 // Load privileges from mysql.User/DB/Table_privs/Column_privs table 232 err := p.loadGlobalPrivileges(ctx) 233 if err != nil { 234 return errors.Trace(err) 235 } 236 err = p.loadDBScopePrivileges(ctx) 237 if err != nil { 238 return errors.Trace(err) 239 } 240 err = p.loadTableScopePrivileges(ctx) 241 if err != nil { 242 return errors.Trace(err) 243 } 244 // TODO: consider column scope privilege latter. 245 return nil 246 } 247 248 // mysql.User/mysql.DB table privilege columns start from index 3. 249 // See: booststrap.go CreateUserTable/CreateDBPrivTable 250 const userTablePrivColumnStartIndex = 3 251 const dbTablePrivColumnStartIndex = 3 252 253 func (p *UserPrivileges) loadGlobalPrivileges(ctx context.Context) error { 254 sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, 255 mysql.SystemDB, mysql.UserTable, p.privs.User, p.privs.Host) 256 rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) 257 if err != nil { 258 return errors.Trace(err) 259 } 260 defer rs.Close() 261 ps := &privileges{Level: ast.GrantLevelGlobal} 262 fs, err := rs.Fields() 263 if err != nil { 264 return errors.Trace(err) 265 } 266 for { 267 row, err := rs.Next() 268 if err != nil { 269 return errors.Trace(err) 270 } 271 if row == nil { 272 break 273 } 274 for i := userTablePrivColumnStartIndex; i < len(fs); i++ { 275 d := row.Data[i] 276 if d.Kind() != types.KindMysqlEnum { 277 return errInvalidPrivilegeType.Gen("Privilege should be mysql.Enum: %v(%T)", d, d) 278 } 279 ed := d.GetMysqlEnum() 280 if ed.String() != "Y" { 281 continue 282 } 283 f := fs[i] 284 p, ok := mysql.Col2PrivType[f.ColumnAsName.O] 285 if !ok { 286 return errInvalidPrivilegeType.Gen("Unknown Privilege Type!") 287 } 288 ps.add(p) 289 } 290 } 291 p.privs.GlobalPrivs = ps 292 return nil 293 } 294 295 func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error { 296 sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, 297 mysql.SystemDB, mysql.DBTable, p.privs.User, p.privs.Host) 298 rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) 299 if err != nil { 300 return errors.Trace(err) 301 } 302 defer rs.Close() 303 ps := make(map[string]*privileges) 304 fs, err := rs.Fields() 305 if err != nil { 306 return errors.Trace(err) 307 } 308 for { 309 row, err := rs.Next() 310 if err != nil { 311 return errors.Trace(err) 312 } 313 if row == nil { 314 break 315 } 316 // DB 317 dbStr := row.Data[1].GetString() 318 ps[dbStr] = &privileges{Level: ast.GrantLevelDB} 319 for i := dbTablePrivColumnStartIndex; i < len(fs); i++ { 320 d := row.Data[i] 321 if d.Kind() != types.KindMysqlEnum { 322 return errInvalidPrivilegeType.Gen("Privilege should be mysql.Enum: %v(%T)", d, d) 323 } 324 ed := d.GetMysqlEnum() 325 if ed.String() != "Y" { 326 continue 327 } 328 f := fs[i] 329 p, ok := mysql.Col2PrivType[f.ColumnAsName.O] 330 if !ok { 331 return errInvalidPrivilegeType.Gen("Unknown Privilege Type!") 332 } 333 ps[dbStr].add(p) 334 } 335 } 336 p.privs.DBPrivs = ps 337 return nil 338 } 339 340 func (p *UserPrivileges) loadTableScopePrivileges(ctx context.Context) error { 341 sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, 342 mysql.SystemDB, mysql.TablePrivTable, p.privs.User, p.privs.Host) 343 rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) 344 if err != nil { 345 return errors.Trace(err) 346 } 347 defer rs.Close() 348 ps := make(map[string]map[string]*privileges) 349 for { 350 row, err := rs.Next() 351 if err != nil { 352 return errors.Trace(err) 353 } 354 if row == nil { 355 break 356 } 357 // DB 358 dbStr := row.Data[1].GetString() 359 // Table_name 360 tblStr := row.Data[3].GetString() 361 _, ok := ps[dbStr] 362 if !ok { 363 ps[dbStr] = make(map[string]*privileges) 364 } 365 ps[dbStr][tblStr] = &privileges{Level: ast.GrantLevelTable} 366 // Table_priv 367 tblPrivs := row.Data[6].GetMysqlSet() 368 pvs := strings.Split(tblPrivs.Name, ",") 369 for _, d := range pvs { 370 p, ok := mysql.SetStr2Priv[d] 371 if !ok { 372 return errInvalidPrivilegeType.Gen("Unknown Privilege Type!") 373 } 374 ps[dbStr][tblStr].add(p) 375 } 376 } 377 p.privs.TablePrivs = ps 378 return nil 379 } 380 381 // ShowGrants implements privilege.Checker ShowGrants interface. 382 func (p *UserPrivileges) ShowGrants(ctx context.Context, user string) ([]string, error) { 383 // If user is current user 384 if user == p.User { 385 return p.privs.ShowGrants(), nil 386 } 387 userp := &UserPrivileges{User: user} 388 err := userp.loadPrivileges(ctx) 389 if err != nil { 390 return nil, errors.Trace(err) 391 } 392 return userp.privs.ShowGrants(), nil 393 }