github.com/dolthub/go-mysql-server@v0.18.0/sql/mysql_db/privilege_set.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql_db 16 17 import ( 18 "sort" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 ) 23 24 // PrivilegeSet is a set containing privileges. Due to the nested sets potentially returning empty sets, this also acts 25 // as the singular location to modify all nested sets. 26 type PrivilegeSet struct { 27 globalStatic map[sql.PrivilegeType]struct{} 28 globalDynamic map[string]bool 29 databases map[string]PrivilegeSetDatabase 30 } 31 32 var _ sql.PrivilegeSet = PrivilegeSet{} 33 34 // NewPrivilegeSet returns a new PrivilegeSet. 35 func NewPrivilegeSet() PrivilegeSet { 36 return PrivilegeSet{ 37 make(map[sql.PrivilegeType]struct{}), 38 make(map[string]bool), 39 make(map[string]PrivilegeSetDatabase), 40 } 41 } 42 43 // NewPrivilegeSetWithAllPrivileges returns a new PrivilegeSet with every global static privilege added. 44 func NewPrivilegeSetWithAllPrivileges() PrivilegeSet { 45 return PrivilegeSet{ 46 map[sql.PrivilegeType]struct{}{ 47 sql.PrivilegeType_Select: {}, 48 sql.PrivilegeType_Insert: {}, 49 sql.PrivilegeType_Update: {}, 50 sql.PrivilegeType_Delete: {}, 51 sql.PrivilegeType_Create: {}, 52 sql.PrivilegeType_Drop: {}, 53 sql.PrivilegeType_Reload: {}, 54 sql.PrivilegeType_Shutdown: {}, 55 sql.PrivilegeType_Process: {}, 56 sql.PrivilegeType_File: {}, 57 sql.PrivilegeType_GrantOption: {}, 58 sql.PrivilegeType_References: {}, 59 sql.PrivilegeType_Index: {}, 60 sql.PrivilegeType_Alter: {}, 61 sql.PrivilegeType_ShowDB: {}, 62 sql.PrivilegeType_Super: {}, 63 sql.PrivilegeType_CreateTempTable: {}, 64 sql.PrivilegeType_LockTables: {}, 65 sql.PrivilegeType_Execute: {}, 66 sql.PrivilegeType_ReplicationSlave: {}, 67 sql.PrivilegeType_ReplicationClient: {}, 68 sql.PrivilegeType_CreateView: {}, 69 sql.PrivilegeType_ShowView: {}, 70 sql.PrivilegeType_CreateRoutine: {}, 71 sql.PrivilegeType_AlterRoutine: {}, 72 sql.PrivilegeType_CreateUser: {}, 73 sql.PrivilegeType_Event: {}, 74 sql.PrivilegeType_Trigger: {}, 75 sql.PrivilegeType_CreateTablespace: {}, 76 sql.PrivilegeType_CreateRole: {}, 77 sql.PrivilegeType_DropRole: {}, 78 }, 79 make(map[string]bool), 80 make(map[string]PrivilegeSetDatabase), 81 } 82 } 83 84 // AddGlobalStatic adds the given global static privilege(s). 85 func (ps PrivilegeSet) AddGlobalStatic(privileges ...sql.PrivilegeType) { 86 for _, priv := range privileges { 87 ps.globalStatic[priv] = struct{}{} 88 } 89 } 90 91 // AddGlobalDynamic adds the given global dynamic privilege(s). 92 func (ps PrivilegeSet) AddGlobalDynamic(withGrantOption bool, privileges ...string) { 93 for _, priv := range privileges { 94 ps.globalDynamic[strings.ToLower(priv)] = withGrantOption 95 } 96 } 97 98 // AddDatabase adds the given database privilege(s). 99 func (ps PrivilegeSet) AddDatabase(dbName string, privileges ...sql.PrivilegeType) { 100 dbSet := ps.getUseableDb(dbName) 101 for _, priv := range privileges { 102 dbSet.privs[priv] = struct{}{} 103 } 104 } 105 106 // AddTable adds the given table privilege(s). 107 func (ps PrivilegeSet) AddTable(dbName string, tblName string, privileges ...sql.PrivilegeType) { 108 tblSet := ps.getUseableDb(dbName).getUseableTbl(tblName) 109 for _, priv := range privileges { 110 tblSet.privs[priv] = struct{}{} 111 } 112 } 113 114 func (ps PrivilegeSet) AddRoutine(dbName string, procName string, isProc bool, privileges ...sql.PrivilegeType) { 115 procSet := ps.getUseableDb(dbName).getUseableRoutine(procName, isProc) 116 for _, priv := range privileges { 117 procSet.privs[priv] = struct{}{} 118 } 119 } 120 121 // AddColumn adds the given column privilege(s). 122 func (ps PrivilegeSet) AddColumn(dbName string, tblName string, colName string, privileges ...sql.PrivilegeType) { 123 colSet := ps.getUseableDb(dbName).getUseableTbl(tblName).getUseableCol(colName) 124 for _, priv := range privileges { 125 colSet.privs[priv] = struct{}{} 126 } 127 } 128 129 // RemoveGlobalStatic removes the given global static privilege(s). 130 func (ps PrivilegeSet) RemoveGlobalStatic(privileges ...sql.PrivilegeType) { 131 for _, priv := range privileges { 132 delete(ps.globalStatic, priv) 133 } 134 } 135 136 // RemoveGlobalDynamic removes the given global dynamic privilege(s). 137 func (ps PrivilegeSet) RemoveGlobalDynamic(privileges ...string) { 138 for _, priv := range privileges { 139 delete(ps.globalDynamic, priv) 140 } 141 } 142 143 // RemoveDatabase removes the given database privilege(s). 144 func (ps PrivilegeSet) RemoveDatabase(dbName string, privileges ...sql.PrivilegeType) { 145 // We don't use the getUseableDb function since we don't want to create a new map if it doesn't already exist 146 dbSet := ps.Database(dbName).(PrivilegeSetDatabase) 147 if len(dbSet.privs) > 0 { 148 for _, priv := range privileges { 149 delete(dbSet.privs, priv) 150 } 151 } 152 153 if len(dbSet.privs) == 0 { 154 delete(ps.databases, strings.ToLower(dbName)) 155 } 156 } 157 158 // RemoveTable removes the given table privilege(s). 159 func (ps PrivilegeSet) RemoveTable(dbName string, tblName string, privileges ...sql.PrivilegeType) { 160 // We don't use the getUseable functions since we don't want to create new maps if they don't already exist 161 tblSet := ps.Database(dbName).Table(tblName).(PrivilegeSetTable) 162 if len(tblSet.privs) > 0 { 163 for _, priv := range privileges { 164 delete(tblSet.privs, priv) 165 } 166 } 167 } 168 169 // RemoveColumn removes the given column privilege(s). 170 func (ps PrivilegeSet) RemoveColumn(dbName string, tblName string, colName string, privileges ...sql.PrivilegeType) { 171 // We don't use the getUseable functions since we don't want to create new maps if they don't already exist 172 colSet := ps.Database(dbName).Table(tblName).Column(colName).(PrivilegeSetColumn) 173 if len(colSet.privs) > 0 { 174 for _, priv := range privileges { 175 delete(colSet.privs, priv) 176 } 177 } 178 } 179 180 func (ps PrivilegeSet) RemoveRoutine(dbName string, procName string, isProc bool, privileges ...sql.PrivilegeType) { 181 procSet := ps.getUseableDb(dbName).getUseableRoutine(procName, isProc) 182 for _, priv := range privileges { 183 delete(procSet.privs, priv) 184 } 185 186 if len(procSet.privs) == 0 { 187 delete(ps.getUseableDb(dbName).routines, routineKey{name: procName, isProc: isProc}) 188 } 189 } 190 191 // Has returns whether the given global static privilege(s) exists. 192 func (ps PrivilegeSet) Has(privileges ...sql.PrivilegeType) bool { 193 for _, priv := range privileges { 194 if _, ok := ps.globalStatic[priv]; !ok { 195 return false 196 } 197 } 198 return true 199 } 200 201 // HasDynamic returns whether the given global dynamic privilege(s) exists. 202 func (ps PrivilegeSet) HasDynamic(privileges ...string) bool { 203 for _, priv := range privileges { 204 if _, ok := ps.globalDynamic[strings.ToLower(priv)]; !ok { 205 return false 206 } 207 } 208 return true 209 } 210 211 // HasPrivileges returns whether this PrivilegeSet has any privileges at any level. 212 func (ps PrivilegeSet) HasPrivileges() bool { 213 if len(ps.globalStatic) > 0 || len(ps.globalDynamic) > 0 { 214 return true 215 } 216 for _, dbSet := range ps.databases { 217 if dbSet.HasPrivileges() { 218 return true 219 } 220 } 221 return false 222 } 223 224 // GlobalCount returns the combined number of global static and global dynamic privileges. 225 func (ps PrivilegeSet) GlobalCount() int { 226 return len(ps.globalStatic) + len(ps.globalDynamic) 227 } 228 229 // Count returns the number of global static privileges, while not including global dynamic privileges. 230 func (ps PrivilegeSet) Count() int { 231 return len(ps.globalStatic) 232 } 233 234 // Database returns the set of privileges for the given database. Returns an empty set if the database does not exist. 235 func (ps PrivilegeSet) Database(dbName string) sql.PrivilegeSetDatabase { 236 dbSet, ok := ps.databases[strings.ToLower(dbName)] 237 if ok { 238 return dbSet 239 } 240 return PrivilegeSetDatabase{name: dbName} 241 } 242 243 // GetDatabases returns all databases. 244 func (ps PrivilegeSet) GetDatabases() []sql.PrivilegeSetDatabase { 245 dbSets := make([]sql.PrivilegeSetDatabase, 0, len(ps.databases)) 246 for _, dbSet := range ps.databases { 247 // Only return databases that have a database-level privilege, or a privilege on an underlying table or column. 248 // Otherwise, there is no difference between the returned database and the zero-value for any database. 249 if dbSet.HasPrivileges() { 250 dbSets = append(dbSets, dbSet) 251 } 252 } 253 sort.Slice(dbSets, func(i, j int) bool { 254 return dbSets[i].Name() < dbSets[j].Name() 255 }) 256 return dbSets 257 } 258 259 // getDatabases returns all databases of the native type. 260 func (ps PrivilegeSet) getDatabases() []PrivilegeSetDatabase { 261 dbSets := make([]PrivilegeSetDatabase, 0, len(ps.databases)) 262 for _, dbSet := range ps.databases { 263 // Only return databases that have a database-level privilege, or a privilege on an underlying table or column. 264 // Otherwise, there is no difference between the returned database and the zero-value for any database. 265 if dbSet.HasPrivileges() { 266 dbSets = append(dbSets, dbSet) 267 } 268 } 269 sort.Slice(dbSets, func(i, j int) bool { 270 return dbSets[i].name < dbSets[j].name 271 }) 272 return dbSets 273 } 274 275 // UnionWith merges the given set of privileges to the calling set of privileges. 276 func (ps PrivilegeSet) UnionWith(other PrivilegeSet) { 277 for priv := range other.globalStatic { 278 ps.globalStatic[priv] = struct{}{} 279 } 280 for priv, withGrantOption := range other.globalDynamic { 281 localWithGrantOption, _ := ps.globalDynamic[priv] 282 ps.globalDynamic[priv] = localWithGrantOption || withGrantOption 283 } 284 for _, otherDbSet := range other.databases { 285 ps.getUseableDb(otherDbSet.name).unionWith(otherDbSet) 286 } 287 } 288 289 // ClearGlobal removes all global privileges. 290 func (ps *PrivilegeSet) ClearGlobal() { 291 ps.globalStatic = make(map[sql.PrivilegeType]struct{}) 292 ps.globalDynamic = make(map[string]bool) 293 } 294 295 // ClearDatabase removes all privileges for the given database. 296 func (ps PrivilegeSet) ClearDatabase(dbName string) { 297 lowerDbName := strings.ToLower(dbName) 298 dbSet, ok := ps.databases[lowerDbName] 299 if ok { 300 dbSet.clear() 301 delete(ps.databases, lowerDbName) 302 } 303 } 304 305 // ClearTable removes all privileges for the given table. 306 func (ps PrivilegeSet) ClearTable(dbName string, tblName string) { 307 ps.getUseableDb(dbName).getUseableTbl(tblName).clear() 308 } 309 310 // ClearColumn removes all privileges for the given column. 311 func (ps PrivilegeSet) ClearColumn(dbName string, tblName string, colName string) { 312 ps.getUseableDb(dbName).getUseableTbl(tblName).getUseableCol(colName).clear() 313 } 314 315 func (ps PrivilegeSet) ClearRoutine(dbName string, procName string, isProc bool) { 316 ps.getUseableDb(dbName).getUseableRoutine(procName, isProc).clear() 317 } 318 319 // ClearAll removes all privileges. 320 func (ps *PrivilegeSet) ClearAll() { 321 ps.globalStatic = make(map[sql.PrivilegeType]struct{}) 322 ps.globalDynamic = make(map[string]bool) 323 ps.databases = make(map[string]PrivilegeSetDatabase) 324 } 325 326 // Equals returns whether the given set of privileges is equivalent to the calling set. 327 func (ps PrivilegeSet) Equals(otherPrivSet sql.PrivilegeSet) bool { 328 otherPs := otherPrivSet.(PrivilegeSet) 329 if len(ps.globalStatic) != len(otherPs.globalStatic) || 330 len(ps.globalDynamic) != len(otherPs.globalDynamic) || 331 len(ps.databases) != len(otherPs.databases) { 332 return false 333 } 334 for priv := range ps.globalStatic { 335 if _, ok := otherPs.globalStatic[priv]; !ok { 336 return false 337 } 338 } 339 for priv := range ps.globalDynamic { 340 if _, ok := otherPs.globalDynamic[priv]; !ok { 341 return false 342 } 343 } 344 for dbName, dbSet := range ps.databases { 345 if !dbSet.Equals(otherPs.databases[dbName]) { 346 return false 347 } 348 } 349 return true 350 } 351 352 // Copy returns a duplicate of the calling PrivilegeSet. 353 func (ps PrivilegeSet) Copy() PrivilegeSet { 354 newPs := NewPrivilegeSet() 355 newPs.UnionWith(ps) 356 return newPs 357 } 358 359 // ToSlice returns all of the global static privileges contained as a sorted slice. 360 func (ps PrivilegeSet) ToSlice() []sql.PrivilegeType { 361 privs := make([]sql.PrivilegeType, 0, len(ps.globalStatic)) 362 for priv := range ps.globalStatic { 363 privs = append(privs, priv) 364 } 365 sort.Slice(privs, func(i, j int) bool { 366 return privs[i] < privs[j] 367 }) 368 return privs 369 } 370 371 // ToSliceDynamic returns all of the global dynamic privileges that match the given "WITH GRANT OPTION". Privileges will 372 // be uppercase. 373 func (ps PrivilegeSet) ToSliceDynamic(withGrantOption bool) []string { 374 privs := make([]string, 0, len(ps.globalDynamic)) 375 for priv, option := range ps.globalDynamic { 376 if option == withGrantOption { 377 privs = append(privs, strings.ToUpper(priv)) 378 } 379 } 380 sort.Slice(privs, func(i, j int) bool { 381 return privs[i] < privs[j] 382 }) 383 return privs 384 } 385 386 // getUseableDb is used internally to either retrieve an existing database, or create a new one that is returned. 387 func (ps PrivilegeSet) getUseableDb(dbName string) PrivilegeSetDatabase { 388 lowerDbName := strings.ToLower(dbName) 389 dbSet, ok := ps.databases[lowerDbName] 390 if !ok { 391 dbSet = PrivilegeSetDatabase{ 392 name: dbName, 393 privs: make(map[sql.PrivilegeType]struct{}), 394 tables: make(map[string]PrivilegeSetTable), 395 routines: make(map[routineKey]PrivilegeSetRoutine), 396 } 397 ps.databases[lowerDbName] = dbSet 398 } 399 return dbSet 400 } 401 402 // routineKey is used as a key for the routines map in PrivilegeSetDatabase. 403 type routineKey struct { 404 name string 405 isProc bool // true for proc, false for func 406 } 407 408 // PrivilegeSetDatabase is a set containing database-level privileges. 409 type PrivilegeSetDatabase struct { 410 name string 411 privs map[sql.PrivilegeType]struct{} 412 tables map[string]PrivilegeSetTable 413 routines map[routineKey]PrivilegeSetRoutine 414 } 415 416 var _ sql.PrivilegeSetDatabase = PrivilegeSetDatabase{} 417 418 // Name returns the name of the database that this privilege set belongs to. 419 func (ps PrivilegeSetDatabase) Name() string { 420 return ps.name 421 } 422 423 // Has returns whether the given database privilege(s) exists. 424 func (ps PrivilegeSetDatabase) Has(privileges ...sql.PrivilegeType) bool { 425 for _, priv := range privileges { 426 if _, ok := ps.privs[priv]; !ok { 427 return false 428 } 429 } 430 return true 431 } 432 433 // HasPrivileges returns whether this database has either database-level privileges, or privileges on a table or column 434 // contained within this database. 435 func (ps PrivilegeSetDatabase) HasPrivileges() bool { 436 if len(ps.privs) > 0 { 437 return true 438 } 439 for _, tblSet := range ps.tables { 440 if tblSet.HasPrivileges() { 441 return true 442 } 443 } 444 for _, routineSet := range ps.routines { 445 if routineSet.HasPrivileges() { 446 return true 447 } 448 } 449 450 return false 451 } 452 453 // Count returns the number of database privileges. 454 func (ps PrivilegeSetDatabase) Count() int { 455 return len(ps.privs) 456 } 457 458 // Table returns the set of privileges for the given table. Returns an empty set if the table does not exist. 459 func (ps PrivilegeSetDatabase) Table(tblName string) sql.PrivilegeSetTable { 460 tblSet, ok := ps.tables[strings.ToLower(tblName)] 461 if ok { 462 return tblSet 463 } 464 return PrivilegeSetTable{name: tblName} 465 } 466 467 // GetTables returns all tables. 468 func (ps PrivilegeSetDatabase) GetTables() []sql.PrivilegeSetTable { 469 tblSets := make([]sql.PrivilegeSetTable, 0, len(ps.tables)) 470 for _, tblSet := range ps.tables { 471 // Only return tables that have a table-level privilege, or a privilege on an underlying column. 472 // Otherwise, there is no difference between the returned table and the zero-value for any table. 473 if tblSet.HasPrivileges() { 474 tblSets = append(tblSets, tblSet) 475 } 476 } 477 sort.Slice(tblSets, func(i, j int) bool { 478 return tblSets[i].Name() < tblSets[j].Name() 479 }) 480 return tblSets 481 } 482 483 // getTables returns all tables of the native type. 484 func (ps PrivilegeSetDatabase) getTables() []PrivilegeSetTable { 485 tblSets := make([]PrivilegeSetTable, 0, len(ps.tables)) 486 for _, tblSet := range ps.tables { 487 // Only return tables that have a table-level privilege, or a privilege on an underlying column. 488 // Otherwise, there is no difference between the returned table and the zero-value for any table. 489 if tblSet.HasPrivileges() { 490 tblSets = append(tblSets, tblSet) 491 } 492 } 493 sort.Slice(tblSets, func(i, j int) bool { 494 return tblSets[i].name < tblSets[j].name 495 }) 496 return tblSets 497 } 498 499 // Routine returns the set of privileges for the given procedure or function 500 func (ps PrivilegeSetDatabase) Routine(routineName string, isProc bool) sql.PrivilegeSetRoutine { 501 routineName = strings.ToLower(routineName) 502 set, ok := ps.routines[routineKey{routineName, isProc}] 503 if ok { 504 return set 505 } 506 return PrivilegeSetRoutine{name: routineName, isProc: isProc} 507 } 508 509 // GetRoutines returns all routines. 510 func (ps PrivilegeSetDatabase) GetRoutines() []sql.PrivilegeSetRoutine { 511 if ps.routines == nil || len(ps.routines) == 0 { 512 return []sql.PrivilegeSetRoutine{} 513 } 514 515 routineSets := make([]sql.PrivilegeSetRoutine, 0, len(ps.routines)) 516 517 for _, routine := range ps.routines { 518 routineSets = append(routineSets, routine) 519 } 520 521 sort.Slice(routineSets, func(a, b int) bool { 522 if routineSets[a].RoutineName() != routineSets[b].RoutineName() { 523 return routineSets[a].RoutineName() < routineSets[b].RoutineName() 524 } 525 return routineSets[a].RoutineType() < routineSets[b].RoutineType() 526 }) 527 528 return routineSets 529 } 530 531 func (ps PrivilegeSetDatabase) getRoutines() []PrivilegeSetRoutine { 532 if ps.routines == nil || len(ps.routines) == 0 { 533 return []PrivilegeSetRoutine{} 534 } 535 536 routineSets := make([]PrivilegeSetRoutine, 0, len(ps.routines)) 537 for _, routine := range ps.routines { 538 routineSets = append(routineSets, routine) 539 } 540 541 sort.Slice(routineSets, func(i, j int) bool { 542 if routineSets[i].RoutineName() != routineSets[j].RoutineType() { 543 return routineSets[i].RoutineName() < routineSets[j].RoutineName() 544 } 545 return routineSets[i].RoutineType() < routineSets[j].RoutineType() 546 }) 547 548 return routineSets 549 } 550 551 // Equals returns whether the given set of privileges is equivalent to the calling set. 552 func (ps PrivilegeSetDatabase) Equals(otherPsd sql.PrivilegeSetDatabase) bool { 553 otherPs := otherPsd.(PrivilegeSetDatabase) 554 if len(ps.privs) != len(otherPs.privs) || 555 len(ps.tables) != len(otherPs.tables) { 556 return false 557 } 558 for priv := range ps.privs { 559 if _, ok := otherPs.privs[priv]; !ok { 560 return false 561 } 562 } 563 for tblName, tblSet := range ps.tables { 564 if !tblSet.Equals(otherPs.tables[tblName]) { 565 return false 566 } 567 } 568 for routineKey, routineSet := range ps.routines { 569 if !routineSet.Equals(otherPs.routines[routineKey]) { 570 return false 571 } 572 } 573 574 return true 575 } 576 577 // ToSlice returns all of the database privileges contained as a sorted slice. 578 func (ps PrivilegeSetDatabase) ToSlice() []sql.PrivilegeType { 579 privs := make([]sql.PrivilegeType, 0, len(ps.privs)) 580 for priv := range ps.privs { 581 privs = append(privs, priv) 582 } 583 sort.Slice(privs, func(i, j int) bool { 584 return privs[i] < privs[j] 585 }) 586 return privs 587 } 588 589 // getUseableTbl is used internally to either retrieve an existing table, or create a new one that is returned. 590 func (ps PrivilegeSetDatabase) getUseableTbl(tblName string) PrivilegeSetTable { 591 lowerTblName := strings.ToLower(tblName) 592 tblSet, ok := ps.tables[lowerTblName] 593 if !ok { 594 tblSet = PrivilegeSetTable{ 595 name: tblName, 596 privs: make(map[sql.PrivilegeType]struct{}), 597 columns: make(map[string]PrivilegeSetColumn), 598 } 599 ps.tables[lowerTblName] = tblSet 600 } 601 return tblSet 602 } 603 604 func (ps PrivilegeSetDatabase) getUseableRoutine(routineName string, isProc bool) PrivilegeSetRoutine { 605 lowerProcName := strings.ToLower(routineName) 606 rKey := routineKey{name: lowerProcName, isProc: isProc} 607 608 routineSet, ok := ps.routines[rKey] 609 if !ok { 610 routineSet = PrivilegeSetRoutine{ 611 name: routineName, 612 privs: make(map[sql.PrivilegeType]struct{}), 613 isProc: isProc, 614 } 615 ps.routines[rKey] = routineSet 616 } 617 return routineSet 618 } 619 620 // unionWith merges the given set of privileges to the calling set of privileges. 621 func (ps PrivilegeSetDatabase) unionWith(otherPs PrivilegeSetDatabase) { 622 for priv := range otherPs.privs { 623 ps.privs[priv] = struct{}{} 624 } 625 for _, otherTblSet := range otherPs.tables { 626 ps.getUseableTbl(otherTblSet.name).unionWith(otherTblSet) 627 } 628 for _, otherRoutineSet := range otherPs.routines { 629 ps.getUseableRoutine(otherRoutineSet.name, otherRoutineSet.isProc).unionWith(otherRoutineSet) 630 } 631 } 632 633 // clear removes all database privileges. 634 func (ps PrivilegeSetDatabase) clear() { 635 for priv := range ps.privs { 636 delete(ps.privs, priv) 637 } 638 } 639 640 // PrivilegeSetTable is a set containing table-level privileges. 641 type PrivilegeSetTable struct { 642 name string 643 privs map[sql.PrivilegeType]struct{} 644 columns map[string]PrivilegeSetColumn 645 } 646 647 var _ sql.PrivilegeSetTable = PrivilegeSetTable{} 648 649 // Name returns the name of the table that this privilege set belongs to. 650 func (ps PrivilegeSetTable) Name() string { 651 return ps.name 652 } 653 654 // Has returns whether the given table privilege(s) exists. 655 func (ps PrivilegeSetTable) Has(privileges ...sql.PrivilegeType) bool { 656 for _, priv := range privileges { 657 if _, ok := ps.privs[priv]; !ok { 658 return false 659 } 660 } 661 return true 662 } 663 664 // HasPrivileges returns whether this table has either table-level privileges, or privileges on a column contained 665 // within this table. 666 func (ps PrivilegeSetTable) HasPrivileges() bool { 667 if len(ps.privs) > 0 { 668 return true 669 } 670 for _, colSet := range ps.columns { 671 if colSet.Count() > 0 { 672 return true 673 } 674 } 675 return false 676 } 677 678 // Count returns the number of table privileges. 679 func (ps PrivilegeSetTable) Count() int { 680 return len(ps.privs) 681 } 682 683 // Column returns the set of privileges for the given column. Returns an empty set if the column does not exist. 684 func (ps PrivilegeSetTable) Column(colName string) sql.PrivilegeSetColumn { 685 colSet, ok := ps.columns[strings.ToLower(colName)] 686 if ok { 687 return colSet 688 } 689 return PrivilegeSetColumn{name: colName} 690 } 691 692 // GetColumns returns all columns. 693 func (ps PrivilegeSetTable) GetColumns() []sql.PrivilegeSetColumn { 694 colSets := make([]sql.PrivilegeSetColumn, 0, len(ps.columns)) 695 for _, colSet := range ps.columns { 696 // Only return columns that have privileges. Otherwise, there is no difference between the returned column and 697 // the zero-value for any column. 698 if colSet.Count() > 0 { 699 colSets = append(colSets, colSet) 700 } 701 } 702 sort.Slice(colSets, func(i, j int) bool { 703 return colSets[i].Name() < colSets[j].Name() 704 }) 705 return colSets 706 } 707 708 // getColumns returns all columns of the native type. 709 func (ps PrivilegeSetTable) getColumns() []PrivilegeSetColumn { 710 colSets := make([]PrivilegeSetColumn, 0, len(ps.columns)) 711 for _, colSet := range ps.columns { 712 // Only return columns that have privileges. Otherwise, there is no difference between the returned column and 713 // the zero-value for any column. 714 if colSet.Count() > 0 { 715 colSets = append(colSets, colSet) 716 } 717 } 718 sort.Slice(colSets, func(i, j int) bool { 719 return colSets[i].name < colSets[j].name 720 }) 721 return colSets 722 } 723 724 // Equals returns whether the given set of privileges is equivalent to the calling set. 725 func (ps PrivilegeSetTable) Equals(otherPst sql.PrivilegeSetTable) bool { 726 otherPs := otherPst.(PrivilegeSetTable) 727 if len(ps.privs) != len(otherPs.privs) || 728 len(ps.columns) != len(otherPs.columns) { 729 return false 730 } 731 for priv := range ps.privs { 732 if _, ok := otherPs.privs[priv]; !ok { 733 return false 734 } 735 } 736 for colName, colSet := range ps.columns { 737 if !colSet.Equals(otherPs.columns[colName]) { 738 return false 739 } 740 } 741 return true 742 } 743 744 // ToSlice returns all of the table privileges contained as a sorted slice. 745 func (ps PrivilegeSetTable) ToSlice() []sql.PrivilegeType { 746 privs := make([]sql.PrivilegeType, 0, len(ps.privs)) 747 for priv := range ps.privs { 748 privs = append(privs, priv) 749 } 750 sort.Slice(privs, func(i, j int) bool { 751 return privs[i] < privs[j] 752 }) 753 return privs 754 } 755 756 // getUseableCol is used internally to either retrieve an existing column, or create a new one that is returned. 757 func (ps PrivilegeSetTable) getUseableCol(colName string) PrivilegeSetColumn { 758 lowerColName := strings.ToLower(colName) 759 colSet, ok := ps.columns[lowerColName] 760 if !ok { 761 colSet = PrivilegeSetColumn{ 762 name: colName, 763 privs: make(map[sql.PrivilegeType]struct{}), 764 } 765 ps.columns[lowerColName] = colSet 766 } 767 return colSet 768 } 769 770 // unionWith merges the given set of privileges to the calling set of privileges. 771 func (ps PrivilegeSetTable) unionWith(otherPs PrivilegeSetTable) { 772 for priv := range otherPs.privs { 773 ps.privs[priv] = struct{}{} 774 } 775 for _, otherColSet := range otherPs.columns { 776 ps.getUseableCol(otherColSet.name).unionWith(otherColSet) 777 } 778 } 779 780 // clear removes all table privileges. 781 func (ps PrivilegeSetTable) clear() { 782 for priv := range ps.privs { 783 delete(ps.privs, priv) 784 } 785 for col := range ps.columns { 786 delete(ps.columns, col) 787 } 788 } 789 790 // PrivilegeSetColumn is a set containing column privileges. 791 type PrivilegeSetColumn struct { 792 name string 793 privs map[sql.PrivilegeType]struct{} 794 } 795 796 var _ sql.PrivilegeSetColumn = PrivilegeSetColumn{} 797 798 // Name returns the name of the column that this privilege set belongs to. 799 func (ps PrivilegeSetColumn) Name() string { 800 return ps.name 801 } 802 803 // Has returns whether the given column privilege(s) exists. 804 func (ps PrivilegeSetColumn) Has(privileges ...sql.PrivilegeType) bool { 805 for _, priv := range privileges { 806 if _, ok := ps.privs[priv]; !ok { 807 return false 808 } 809 } 810 return true 811 } 812 813 // HasPrivileges returns whether this column has any privileges. 814 func (ps PrivilegeSetColumn) HasPrivileges() bool { 815 return len(ps.privs) > 0 816 } 817 818 // Count returns the number of column privileges. 819 func (ps PrivilegeSetColumn) Count() int { 820 return len(ps.privs) 821 } 822 823 // Equals returns whether the given set of privileges is equivalent to the calling set. 824 func (ps PrivilegeSetColumn) Equals(otherPsc sql.PrivilegeSetColumn) bool { 825 otherPs := otherPsc.(PrivilegeSetColumn) 826 if len(ps.privs) != len(otherPs.privs) { 827 return false 828 } 829 for priv := range ps.privs { 830 if _, ok := otherPs.privs[priv]; !ok { 831 return false 832 } 833 } 834 return true 835 } 836 837 // ToSlice returns all of the column privileges contained as a sorted slice. 838 func (ps PrivilegeSetColumn) ToSlice() []sql.PrivilegeType { 839 privs := make([]sql.PrivilegeType, 0, len(ps.privs)) 840 for priv := range ps.privs { 841 privs = append(privs, priv) 842 } 843 sort.Slice(privs, func(i, j int) bool { 844 return privs[i] < privs[j] 845 }) 846 return privs 847 } 848 849 // unionWith merges the given set of privileges to the calling set of privileges. 850 func (ps PrivilegeSetColumn) unionWith(otherPs PrivilegeSetColumn) { 851 for priv := range otherPs.privs { 852 ps.privs[priv] = struct{}{} 853 } 854 } 855 856 // clear removes all column privileges. 857 func (ps PrivilegeSetColumn) clear() { 858 for priv := range ps.privs { 859 delete(ps.privs, priv) 860 } 861 } 862 863 type PrivilegeSetRoutine struct { 864 name string 865 isProc bool // true = procedure, false = function 866 privs map[sql.PrivilegeType]struct{} 867 } 868 869 // unionWith merges the given set of privileges to the calling set of privileges. 870 func (ps PrivilegeSetRoutine) unionWith(otherPs PrivilegeSetRoutine) { 871 for priv := range otherPs.privs { 872 ps.privs[priv] = struct{}{} 873 } 874 } 875 876 // clear removes all routine privileges. 877 func (ps PrivilegeSetRoutine) clear() { 878 for priv := range ps.privs { 879 delete(ps.privs, priv) 880 } 881 } 882 883 var _ sql.PrivilegeSetRoutine = PrivilegeSetRoutine{} 884 885 // RoutineName returns the name of the routine that this privilege set belongs to. 886 func (ps PrivilegeSetRoutine) RoutineName() string { 887 return ps.name 888 } 889 890 // RoutineType returns the type of routine this is (PROCEDURE or FUNCTION). 891 func (ps PrivilegeSetRoutine) RoutineType() string { 892 if ps.isProc { 893 return "PROCEDURE" 894 } else { 895 return "FUNCTION" 896 } 897 } 898 899 // Count returns the number of routine privileges. 900 func (ps PrivilegeSetRoutine) Count() int { 901 return len(ps.privs) 902 } 903 904 // Has returns whether the given column privilege(s) exists. 905 func (ps PrivilegeSetRoutine) Has(privileges ...sql.PrivilegeType) bool { 906 for _, priv := range privileges { 907 if _, ok := ps.privs[priv]; !ok { 908 return false 909 } 910 } 911 return true 912 } 913 914 // HasPrivileges returns whether this routine has any privileges. 915 func (ps PrivilegeSetRoutine) HasPrivileges() bool { 916 return len(ps.privs) > 0 917 } 918 919 // ToSlice returns all of the privileges contained as a sorted slice. 920 func (ps PrivilegeSetRoutine) ToSlice() []sql.PrivilegeType { 921 privs := make([]sql.PrivilegeType, 0, len(ps.privs)) 922 for priv := range ps.privs { 923 privs = append(privs, priv) 924 } 925 sort.Slice(privs, func(i, j int) bool { 926 return privs[i] < privs[j] 927 }) 928 return privs 929 } 930 931 // Equals returns whether the given set of privileges is equivalent to the calling set. 932 func (ps PrivilegeSetRoutine) Equals(otherPs sql.PrivilegeSetRoutine) bool { 933 if ps.RoutineName() != otherPs.RoutineName() { 934 return false 935 } 936 if ps.RoutineType() != otherPs.RoutineType() { 937 return false 938 } 939 940 thisSlice := ps.ToSlice() 941 thatSlice := otherPs.ToSlice() 942 943 if len(thisSlice) != len(thatSlice) { 944 return false 945 } 946 for i, val := range thisSlice { 947 if val != thatSlice[i] { 948 return false 949 } 950 } 951 return true 952 }