github.com/dolthub/go-mysql-server@v0.18.0/sql/mysql_db/privilege_set.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql_db
    16  
    17  import (
    18  	"sort"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  // PrivilegeSet is a set containing privileges. Due to the nested sets potentially returning empty sets, this also acts
    25  // as the singular location to modify all nested sets.
    26  type PrivilegeSet struct {
    27  	globalStatic  map[sql.PrivilegeType]struct{}
    28  	globalDynamic map[string]bool
    29  	databases     map[string]PrivilegeSetDatabase
    30  }
    31  
    32  var _ sql.PrivilegeSet = PrivilegeSet{}
    33  
    34  // NewPrivilegeSet returns a new PrivilegeSet.
    35  func NewPrivilegeSet() PrivilegeSet {
    36  	return PrivilegeSet{
    37  		make(map[sql.PrivilegeType]struct{}),
    38  		make(map[string]bool),
    39  		make(map[string]PrivilegeSetDatabase),
    40  	}
    41  }
    42  
    43  // NewPrivilegeSetWithAllPrivileges returns a new PrivilegeSet with every global static privilege added.
    44  func NewPrivilegeSetWithAllPrivileges() PrivilegeSet {
    45  	return PrivilegeSet{
    46  		map[sql.PrivilegeType]struct{}{
    47  			sql.PrivilegeType_Select:            {},
    48  			sql.PrivilegeType_Insert:            {},
    49  			sql.PrivilegeType_Update:            {},
    50  			sql.PrivilegeType_Delete:            {},
    51  			sql.PrivilegeType_Create:            {},
    52  			sql.PrivilegeType_Drop:              {},
    53  			sql.PrivilegeType_Reload:            {},
    54  			sql.PrivilegeType_Shutdown:          {},
    55  			sql.PrivilegeType_Process:           {},
    56  			sql.PrivilegeType_File:              {},
    57  			sql.PrivilegeType_GrantOption:       {},
    58  			sql.PrivilegeType_References:        {},
    59  			sql.PrivilegeType_Index:             {},
    60  			sql.PrivilegeType_Alter:             {},
    61  			sql.PrivilegeType_ShowDB:            {},
    62  			sql.PrivilegeType_Super:             {},
    63  			sql.PrivilegeType_CreateTempTable:   {},
    64  			sql.PrivilegeType_LockTables:        {},
    65  			sql.PrivilegeType_Execute:           {},
    66  			sql.PrivilegeType_ReplicationSlave:  {},
    67  			sql.PrivilegeType_ReplicationClient: {},
    68  			sql.PrivilegeType_CreateView:        {},
    69  			sql.PrivilegeType_ShowView:          {},
    70  			sql.PrivilegeType_CreateRoutine:     {},
    71  			sql.PrivilegeType_AlterRoutine:      {},
    72  			sql.PrivilegeType_CreateUser:        {},
    73  			sql.PrivilegeType_Event:             {},
    74  			sql.PrivilegeType_Trigger:           {},
    75  			sql.PrivilegeType_CreateTablespace:  {},
    76  			sql.PrivilegeType_CreateRole:        {},
    77  			sql.PrivilegeType_DropRole:          {},
    78  		},
    79  		make(map[string]bool),
    80  		make(map[string]PrivilegeSetDatabase),
    81  	}
    82  }
    83  
    84  // AddGlobalStatic adds the given global static privilege(s).
    85  func (ps PrivilegeSet) AddGlobalStatic(privileges ...sql.PrivilegeType) {
    86  	for _, priv := range privileges {
    87  		ps.globalStatic[priv] = struct{}{}
    88  	}
    89  }
    90  
    91  // AddGlobalDynamic adds the given global dynamic privilege(s).
    92  func (ps PrivilegeSet) AddGlobalDynamic(withGrantOption bool, privileges ...string) {
    93  	for _, priv := range privileges {
    94  		ps.globalDynamic[strings.ToLower(priv)] = withGrantOption
    95  	}
    96  }
    97  
    98  // AddDatabase adds the given database privilege(s).
    99  func (ps PrivilegeSet) AddDatabase(dbName string, privileges ...sql.PrivilegeType) {
   100  	dbSet := ps.getUseableDb(dbName)
   101  	for _, priv := range privileges {
   102  		dbSet.privs[priv] = struct{}{}
   103  	}
   104  }
   105  
   106  // AddTable adds the given table privilege(s).
   107  func (ps PrivilegeSet) AddTable(dbName string, tblName string, privileges ...sql.PrivilegeType) {
   108  	tblSet := ps.getUseableDb(dbName).getUseableTbl(tblName)
   109  	for _, priv := range privileges {
   110  		tblSet.privs[priv] = struct{}{}
   111  	}
   112  }
   113  
   114  func (ps PrivilegeSet) AddRoutine(dbName string, procName string, isProc bool, privileges ...sql.PrivilegeType) {
   115  	procSet := ps.getUseableDb(dbName).getUseableRoutine(procName, isProc)
   116  	for _, priv := range privileges {
   117  		procSet.privs[priv] = struct{}{}
   118  	}
   119  }
   120  
   121  // AddColumn adds the given column privilege(s).
   122  func (ps PrivilegeSet) AddColumn(dbName string, tblName string, colName string, privileges ...sql.PrivilegeType) {
   123  	colSet := ps.getUseableDb(dbName).getUseableTbl(tblName).getUseableCol(colName)
   124  	for _, priv := range privileges {
   125  		colSet.privs[priv] = struct{}{}
   126  	}
   127  }
   128  
   129  // RemoveGlobalStatic removes the given global static privilege(s).
   130  func (ps PrivilegeSet) RemoveGlobalStatic(privileges ...sql.PrivilegeType) {
   131  	for _, priv := range privileges {
   132  		delete(ps.globalStatic, priv)
   133  	}
   134  }
   135  
   136  // RemoveGlobalDynamic removes the given global dynamic privilege(s).
   137  func (ps PrivilegeSet) RemoveGlobalDynamic(privileges ...string) {
   138  	for _, priv := range privileges {
   139  		delete(ps.globalDynamic, priv)
   140  	}
   141  }
   142  
   143  // RemoveDatabase removes the given database privilege(s).
   144  func (ps PrivilegeSet) RemoveDatabase(dbName string, privileges ...sql.PrivilegeType) {
   145  	// We don't use the getUseableDb function since we don't want to create a new map if it doesn't already exist
   146  	dbSet := ps.Database(dbName).(PrivilegeSetDatabase)
   147  	if len(dbSet.privs) > 0 {
   148  		for _, priv := range privileges {
   149  			delete(dbSet.privs, priv)
   150  		}
   151  	}
   152  
   153  	if len(dbSet.privs) == 0 {
   154  		delete(ps.databases, strings.ToLower(dbName))
   155  	}
   156  }
   157  
   158  // RemoveTable removes the given table privilege(s).
   159  func (ps PrivilegeSet) RemoveTable(dbName string, tblName string, privileges ...sql.PrivilegeType) {
   160  	// We don't use the getUseable functions since we don't want to create new maps if they don't already exist
   161  	tblSet := ps.Database(dbName).Table(tblName).(PrivilegeSetTable)
   162  	if len(tblSet.privs) > 0 {
   163  		for _, priv := range privileges {
   164  			delete(tblSet.privs, priv)
   165  		}
   166  	}
   167  }
   168  
   169  // RemoveColumn removes the given column privilege(s).
   170  func (ps PrivilegeSet) RemoveColumn(dbName string, tblName string, colName string, privileges ...sql.PrivilegeType) {
   171  	// We don't use the getUseable functions since we don't want to create new maps if they don't already exist
   172  	colSet := ps.Database(dbName).Table(tblName).Column(colName).(PrivilegeSetColumn)
   173  	if len(colSet.privs) > 0 {
   174  		for _, priv := range privileges {
   175  			delete(colSet.privs, priv)
   176  		}
   177  	}
   178  }
   179  
   180  func (ps PrivilegeSet) RemoveRoutine(dbName string, procName string, isProc bool, privileges ...sql.PrivilegeType) {
   181  	procSet := ps.getUseableDb(dbName).getUseableRoutine(procName, isProc)
   182  	for _, priv := range privileges {
   183  		delete(procSet.privs, priv)
   184  	}
   185  
   186  	if len(procSet.privs) == 0 {
   187  		delete(ps.getUseableDb(dbName).routines, routineKey{name: procName, isProc: isProc})
   188  	}
   189  }
   190  
   191  // Has returns whether the given global static privilege(s) exists.
   192  func (ps PrivilegeSet) Has(privileges ...sql.PrivilegeType) bool {
   193  	for _, priv := range privileges {
   194  		if _, ok := ps.globalStatic[priv]; !ok {
   195  			return false
   196  		}
   197  	}
   198  	return true
   199  }
   200  
   201  // HasDynamic returns whether the given global dynamic privilege(s) exists.
   202  func (ps PrivilegeSet) HasDynamic(privileges ...string) bool {
   203  	for _, priv := range privileges {
   204  		if _, ok := ps.globalDynamic[strings.ToLower(priv)]; !ok {
   205  			return false
   206  		}
   207  	}
   208  	return true
   209  }
   210  
   211  // HasPrivileges returns whether this PrivilegeSet has any privileges at any level.
   212  func (ps PrivilegeSet) HasPrivileges() bool {
   213  	if len(ps.globalStatic) > 0 || len(ps.globalDynamic) > 0 {
   214  		return true
   215  	}
   216  	for _, dbSet := range ps.databases {
   217  		if dbSet.HasPrivileges() {
   218  			return true
   219  		}
   220  	}
   221  	return false
   222  }
   223  
   224  // GlobalCount returns the combined number of global static and global dynamic privileges.
   225  func (ps PrivilegeSet) GlobalCount() int {
   226  	return len(ps.globalStatic) + len(ps.globalDynamic)
   227  }
   228  
   229  // Count returns the number of global static privileges, while not including global dynamic privileges.
   230  func (ps PrivilegeSet) Count() int {
   231  	return len(ps.globalStatic)
   232  }
   233  
   234  // Database returns the set of privileges for the given database. Returns an empty set if the database does not exist.
   235  func (ps PrivilegeSet) Database(dbName string) sql.PrivilegeSetDatabase {
   236  	dbSet, ok := ps.databases[strings.ToLower(dbName)]
   237  	if ok {
   238  		return dbSet
   239  	}
   240  	return PrivilegeSetDatabase{name: dbName}
   241  }
   242  
   243  // GetDatabases returns all databases.
   244  func (ps PrivilegeSet) GetDatabases() []sql.PrivilegeSetDatabase {
   245  	dbSets := make([]sql.PrivilegeSetDatabase, 0, len(ps.databases))
   246  	for _, dbSet := range ps.databases {
   247  		// Only return databases that have a database-level privilege, or a privilege on an underlying table or column.
   248  		// Otherwise, there is no difference between the returned database and the zero-value for any database.
   249  		if dbSet.HasPrivileges() {
   250  			dbSets = append(dbSets, dbSet)
   251  		}
   252  	}
   253  	sort.Slice(dbSets, func(i, j int) bool {
   254  		return dbSets[i].Name() < dbSets[j].Name()
   255  	})
   256  	return dbSets
   257  }
   258  
   259  // getDatabases returns all databases of the native type.
   260  func (ps PrivilegeSet) getDatabases() []PrivilegeSetDatabase {
   261  	dbSets := make([]PrivilegeSetDatabase, 0, len(ps.databases))
   262  	for _, dbSet := range ps.databases {
   263  		// Only return databases that have a database-level privilege, or a privilege on an underlying table or column.
   264  		// Otherwise, there is no difference between the returned database and the zero-value for any database.
   265  		if dbSet.HasPrivileges() {
   266  			dbSets = append(dbSets, dbSet)
   267  		}
   268  	}
   269  	sort.Slice(dbSets, func(i, j int) bool {
   270  		return dbSets[i].name < dbSets[j].name
   271  	})
   272  	return dbSets
   273  }
   274  
   275  // UnionWith merges the given set of privileges to the calling set of privileges.
   276  func (ps PrivilegeSet) UnionWith(other PrivilegeSet) {
   277  	for priv := range other.globalStatic {
   278  		ps.globalStatic[priv] = struct{}{}
   279  	}
   280  	for priv, withGrantOption := range other.globalDynamic {
   281  		localWithGrantOption, _ := ps.globalDynamic[priv]
   282  		ps.globalDynamic[priv] = localWithGrantOption || withGrantOption
   283  	}
   284  	for _, otherDbSet := range other.databases {
   285  		ps.getUseableDb(otherDbSet.name).unionWith(otherDbSet)
   286  	}
   287  }
   288  
   289  // ClearGlobal removes all global privileges.
   290  func (ps *PrivilegeSet) ClearGlobal() {
   291  	ps.globalStatic = make(map[sql.PrivilegeType]struct{})
   292  	ps.globalDynamic = make(map[string]bool)
   293  }
   294  
   295  // ClearDatabase removes all privileges for the given database.
   296  func (ps PrivilegeSet) ClearDatabase(dbName string) {
   297  	lowerDbName := strings.ToLower(dbName)
   298  	dbSet, ok := ps.databases[lowerDbName]
   299  	if ok {
   300  		dbSet.clear()
   301  		delete(ps.databases, lowerDbName)
   302  	}
   303  }
   304  
   305  // ClearTable removes all privileges for the given table.
   306  func (ps PrivilegeSet) ClearTable(dbName string, tblName string) {
   307  	ps.getUseableDb(dbName).getUseableTbl(tblName).clear()
   308  }
   309  
   310  // ClearColumn removes all privileges for the given column.
   311  func (ps PrivilegeSet) ClearColumn(dbName string, tblName string, colName string) {
   312  	ps.getUseableDb(dbName).getUseableTbl(tblName).getUseableCol(colName).clear()
   313  }
   314  
   315  func (ps PrivilegeSet) ClearRoutine(dbName string, procName string, isProc bool) {
   316  	ps.getUseableDb(dbName).getUseableRoutine(procName, isProc).clear()
   317  }
   318  
   319  // ClearAll removes all privileges.
   320  func (ps *PrivilegeSet) ClearAll() {
   321  	ps.globalStatic = make(map[sql.PrivilegeType]struct{})
   322  	ps.globalDynamic = make(map[string]bool)
   323  	ps.databases = make(map[string]PrivilegeSetDatabase)
   324  }
   325  
   326  // Equals returns whether the given set of privileges is equivalent to the calling set.
   327  func (ps PrivilegeSet) Equals(otherPrivSet sql.PrivilegeSet) bool {
   328  	otherPs := otherPrivSet.(PrivilegeSet)
   329  	if len(ps.globalStatic) != len(otherPs.globalStatic) ||
   330  		len(ps.globalDynamic) != len(otherPs.globalDynamic) ||
   331  		len(ps.databases) != len(otherPs.databases) {
   332  		return false
   333  	}
   334  	for priv := range ps.globalStatic {
   335  		if _, ok := otherPs.globalStatic[priv]; !ok {
   336  			return false
   337  		}
   338  	}
   339  	for priv := range ps.globalDynamic {
   340  		if _, ok := otherPs.globalDynamic[priv]; !ok {
   341  			return false
   342  		}
   343  	}
   344  	for dbName, dbSet := range ps.databases {
   345  		if !dbSet.Equals(otherPs.databases[dbName]) {
   346  			return false
   347  		}
   348  	}
   349  	return true
   350  }
   351  
   352  // Copy returns a duplicate of the calling PrivilegeSet.
   353  func (ps PrivilegeSet) Copy() PrivilegeSet {
   354  	newPs := NewPrivilegeSet()
   355  	newPs.UnionWith(ps)
   356  	return newPs
   357  }
   358  
   359  // ToSlice returns all of the global static privileges contained as a sorted slice.
   360  func (ps PrivilegeSet) ToSlice() []sql.PrivilegeType {
   361  	privs := make([]sql.PrivilegeType, 0, len(ps.globalStatic))
   362  	for priv := range ps.globalStatic {
   363  		privs = append(privs, priv)
   364  	}
   365  	sort.Slice(privs, func(i, j int) bool {
   366  		return privs[i] < privs[j]
   367  	})
   368  	return privs
   369  }
   370  
   371  // ToSliceDynamic returns all of the global dynamic privileges that match the given "WITH GRANT OPTION". Privileges will
   372  // be uppercase.
   373  func (ps PrivilegeSet) ToSliceDynamic(withGrantOption bool) []string {
   374  	privs := make([]string, 0, len(ps.globalDynamic))
   375  	for priv, option := range ps.globalDynamic {
   376  		if option == withGrantOption {
   377  			privs = append(privs, strings.ToUpper(priv))
   378  		}
   379  	}
   380  	sort.Slice(privs, func(i, j int) bool {
   381  		return privs[i] < privs[j]
   382  	})
   383  	return privs
   384  }
   385  
   386  // getUseableDb is used internally to either retrieve an existing database, or create a new one that is returned.
   387  func (ps PrivilegeSet) getUseableDb(dbName string) PrivilegeSetDatabase {
   388  	lowerDbName := strings.ToLower(dbName)
   389  	dbSet, ok := ps.databases[lowerDbName]
   390  	if !ok {
   391  		dbSet = PrivilegeSetDatabase{
   392  			name:     dbName,
   393  			privs:    make(map[sql.PrivilegeType]struct{}),
   394  			tables:   make(map[string]PrivilegeSetTable),
   395  			routines: make(map[routineKey]PrivilegeSetRoutine),
   396  		}
   397  		ps.databases[lowerDbName] = dbSet
   398  	}
   399  	return dbSet
   400  }
   401  
   402  // routineKey is used as a key for the routines map in PrivilegeSetDatabase.
   403  type routineKey struct {
   404  	name   string
   405  	isProc bool // true for proc, false for func
   406  }
   407  
   408  // PrivilegeSetDatabase is a set containing database-level privileges.
   409  type PrivilegeSetDatabase struct {
   410  	name     string
   411  	privs    map[sql.PrivilegeType]struct{}
   412  	tables   map[string]PrivilegeSetTable
   413  	routines map[routineKey]PrivilegeSetRoutine
   414  }
   415  
   416  var _ sql.PrivilegeSetDatabase = PrivilegeSetDatabase{}
   417  
   418  // Name returns the name of the database that this privilege set belongs to.
   419  func (ps PrivilegeSetDatabase) Name() string {
   420  	return ps.name
   421  }
   422  
   423  // Has returns whether the given database privilege(s) exists.
   424  func (ps PrivilegeSetDatabase) Has(privileges ...sql.PrivilegeType) bool {
   425  	for _, priv := range privileges {
   426  		if _, ok := ps.privs[priv]; !ok {
   427  			return false
   428  		}
   429  	}
   430  	return true
   431  }
   432  
   433  // HasPrivileges returns whether this database has either database-level privileges, or privileges on a table or column
   434  // contained within this database.
   435  func (ps PrivilegeSetDatabase) HasPrivileges() bool {
   436  	if len(ps.privs) > 0 {
   437  		return true
   438  	}
   439  	for _, tblSet := range ps.tables {
   440  		if tblSet.HasPrivileges() {
   441  			return true
   442  		}
   443  	}
   444  	for _, routineSet := range ps.routines {
   445  		if routineSet.HasPrivileges() {
   446  			return true
   447  		}
   448  	}
   449  
   450  	return false
   451  }
   452  
   453  // Count returns the number of database privileges.
   454  func (ps PrivilegeSetDatabase) Count() int {
   455  	return len(ps.privs)
   456  }
   457  
   458  // Table returns the set of privileges for the given table. Returns an empty set if the table does not exist.
   459  func (ps PrivilegeSetDatabase) Table(tblName string) sql.PrivilegeSetTable {
   460  	tblSet, ok := ps.tables[strings.ToLower(tblName)]
   461  	if ok {
   462  		return tblSet
   463  	}
   464  	return PrivilegeSetTable{name: tblName}
   465  }
   466  
   467  // GetTables returns all tables.
   468  func (ps PrivilegeSetDatabase) GetTables() []sql.PrivilegeSetTable {
   469  	tblSets := make([]sql.PrivilegeSetTable, 0, len(ps.tables))
   470  	for _, tblSet := range ps.tables {
   471  		// Only return tables that have a table-level privilege, or a privilege on an underlying column.
   472  		// Otherwise, there is no difference between the returned table and the zero-value for any table.
   473  		if tblSet.HasPrivileges() {
   474  			tblSets = append(tblSets, tblSet)
   475  		}
   476  	}
   477  	sort.Slice(tblSets, func(i, j int) bool {
   478  		return tblSets[i].Name() < tblSets[j].Name()
   479  	})
   480  	return tblSets
   481  }
   482  
   483  // getTables returns all tables of the native type.
   484  func (ps PrivilegeSetDatabase) getTables() []PrivilegeSetTable {
   485  	tblSets := make([]PrivilegeSetTable, 0, len(ps.tables))
   486  	for _, tblSet := range ps.tables {
   487  		// Only return tables that have a table-level privilege, or a privilege on an underlying column.
   488  		// Otherwise, there is no difference between the returned table and the zero-value for any table.
   489  		if tblSet.HasPrivileges() {
   490  			tblSets = append(tblSets, tblSet)
   491  		}
   492  	}
   493  	sort.Slice(tblSets, func(i, j int) bool {
   494  		return tblSets[i].name < tblSets[j].name
   495  	})
   496  	return tblSets
   497  }
   498  
   499  // Routine returns the set of privileges for the given procedure or function
   500  func (ps PrivilegeSetDatabase) Routine(routineName string, isProc bool) sql.PrivilegeSetRoutine {
   501  	routineName = strings.ToLower(routineName)
   502  	set, ok := ps.routines[routineKey{routineName, isProc}]
   503  	if ok {
   504  		return set
   505  	}
   506  	return PrivilegeSetRoutine{name: routineName, isProc: isProc}
   507  }
   508  
   509  // GetRoutines returns all routines.
   510  func (ps PrivilegeSetDatabase) GetRoutines() []sql.PrivilegeSetRoutine {
   511  	if ps.routines == nil || len(ps.routines) == 0 {
   512  		return []sql.PrivilegeSetRoutine{}
   513  	}
   514  
   515  	routineSets := make([]sql.PrivilegeSetRoutine, 0, len(ps.routines))
   516  
   517  	for _, routine := range ps.routines {
   518  		routineSets = append(routineSets, routine)
   519  	}
   520  
   521  	sort.Slice(routineSets, func(a, b int) bool {
   522  		if routineSets[a].RoutineName() != routineSets[b].RoutineName() {
   523  			return routineSets[a].RoutineName() < routineSets[b].RoutineName()
   524  		}
   525  		return routineSets[a].RoutineType() < routineSets[b].RoutineType()
   526  	})
   527  
   528  	return routineSets
   529  }
   530  
   531  func (ps PrivilegeSetDatabase) getRoutines() []PrivilegeSetRoutine {
   532  	if ps.routines == nil || len(ps.routines) == 0 {
   533  		return []PrivilegeSetRoutine{}
   534  	}
   535  
   536  	routineSets := make([]PrivilegeSetRoutine, 0, len(ps.routines))
   537  	for _, routine := range ps.routines {
   538  		routineSets = append(routineSets, routine)
   539  	}
   540  
   541  	sort.Slice(routineSets, func(i, j int) bool {
   542  		if routineSets[i].RoutineName() != routineSets[j].RoutineType() {
   543  			return routineSets[i].RoutineName() < routineSets[j].RoutineName()
   544  		}
   545  		return routineSets[i].RoutineType() < routineSets[j].RoutineType()
   546  	})
   547  
   548  	return routineSets
   549  }
   550  
   551  // Equals returns whether the given set of privileges is equivalent to the calling set.
   552  func (ps PrivilegeSetDatabase) Equals(otherPsd sql.PrivilegeSetDatabase) bool {
   553  	otherPs := otherPsd.(PrivilegeSetDatabase)
   554  	if len(ps.privs) != len(otherPs.privs) ||
   555  		len(ps.tables) != len(otherPs.tables) {
   556  		return false
   557  	}
   558  	for priv := range ps.privs {
   559  		if _, ok := otherPs.privs[priv]; !ok {
   560  			return false
   561  		}
   562  	}
   563  	for tblName, tblSet := range ps.tables {
   564  		if !tblSet.Equals(otherPs.tables[tblName]) {
   565  			return false
   566  		}
   567  	}
   568  	for routineKey, routineSet := range ps.routines {
   569  		if !routineSet.Equals(otherPs.routines[routineKey]) {
   570  			return false
   571  		}
   572  	}
   573  
   574  	return true
   575  }
   576  
   577  // ToSlice returns all of the database privileges contained as a sorted slice.
   578  func (ps PrivilegeSetDatabase) ToSlice() []sql.PrivilegeType {
   579  	privs := make([]sql.PrivilegeType, 0, len(ps.privs))
   580  	for priv := range ps.privs {
   581  		privs = append(privs, priv)
   582  	}
   583  	sort.Slice(privs, func(i, j int) bool {
   584  		return privs[i] < privs[j]
   585  	})
   586  	return privs
   587  }
   588  
   589  // getUseableTbl is used internally to either retrieve an existing table, or create a new one that is returned.
   590  func (ps PrivilegeSetDatabase) getUseableTbl(tblName string) PrivilegeSetTable {
   591  	lowerTblName := strings.ToLower(tblName)
   592  	tblSet, ok := ps.tables[lowerTblName]
   593  	if !ok {
   594  		tblSet = PrivilegeSetTable{
   595  			name:    tblName,
   596  			privs:   make(map[sql.PrivilegeType]struct{}),
   597  			columns: make(map[string]PrivilegeSetColumn),
   598  		}
   599  		ps.tables[lowerTblName] = tblSet
   600  	}
   601  	return tblSet
   602  }
   603  
   604  func (ps PrivilegeSetDatabase) getUseableRoutine(routineName string, isProc bool) PrivilegeSetRoutine {
   605  	lowerProcName := strings.ToLower(routineName)
   606  	rKey := routineKey{name: lowerProcName, isProc: isProc}
   607  
   608  	routineSet, ok := ps.routines[rKey]
   609  	if !ok {
   610  		routineSet = PrivilegeSetRoutine{
   611  			name:   routineName,
   612  			privs:  make(map[sql.PrivilegeType]struct{}),
   613  			isProc: isProc,
   614  		}
   615  		ps.routines[rKey] = routineSet
   616  	}
   617  	return routineSet
   618  }
   619  
   620  // unionWith merges the given set of privileges to the calling set of privileges.
   621  func (ps PrivilegeSetDatabase) unionWith(otherPs PrivilegeSetDatabase) {
   622  	for priv := range otherPs.privs {
   623  		ps.privs[priv] = struct{}{}
   624  	}
   625  	for _, otherTblSet := range otherPs.tables {
   626  		ps.getUseableTbl(otherTblSet.name).unionWith(otherTblSet)
   627  	}
   628  	for _, otherRoutineSet := range otherPs.routines {
   629  		ps.getUseableRoutine(otherRoutineSet.name, otherRoutineSet.isProc).unionWith(otherRoutineSet)
   630  	}
   631  }
   632  
   633  // clear removes all database privileges.
   634  func (ps PrivilegeSetDatabase) clear() {
   635  	for priv := range ps.privs {
   636  		delete(ps.privs, priv)
   637  	}
   638  }
   639  
   640  // PrivilegeSetTable is a set containing table-level privileges.
   641  type PrivilegeSetTable struct {
   642  	name    string
   643  	privs   map[sql.PrivilegeType]struct{}
   644  	columns map[string]PrivilegeSetColumn
   645  }
   646  
   647  var _ sql.PrivilegeSetTable = PrivilegeSetTable{}
   648  
   649  // Name returns the name of the table that this privilege set belongs to.
   650  func (ps PrivilegeSetTable) Name() string {
   651  	return ps.name
   652  }
   653  
   654  // Has returns whether the given table privilege(s) exists.
   655  func (ps PrivilegeSetTable) Has(privileges ...sql.PrivilegeType) bool {
   656  	for _, priv := range privileges {
   657  		if _, ok := ps.privs[priv]; !ok {
   658  			return false
   659  		}
   660  	}
   661  	return true
   662  }
   663  
   664  // HasPrivileges returns whether this table has either table-level privileges, or privileges on a column contained
   665  // within this table.
   666  func (ps PrivilegeSetTable) HasPrivileges() bool {
   667  	if len(ps.privs) > 0 {
   668  		return true
   669  	}
   670  	for _, colSet := range ps.columns {
   671  		if colSet.Count() > 0 {
   672  			return true
   673  		}
   674  	}
   675  	return false
   676  }
   677  
   678  // Count returns the number of table privileges.
   679  func (ps PrivilegeSetTable) Count() int {
   680  	return len(ps.privs)
   681  }
   682  
   683  // Column returns the set of privileges for the given column. Returns an empty set if the column does not exist.
   684  func (ps PrivilegeSetTable) Column(colName string) sql.PrivilegeSetColumn {
   685  	colSet, ok := ps.columns[strings.ToLower(colName)]
   686  	if ok {
   687  		return colSet
   688  	}
   689  	return PrivilegeSetColumn{name: colName}
   690  }
   691  
   692  // GetColumns returns all columns.
   693  func (ps PrivilegeSetTable) GetColumns() []sql.PrivilegeSetColumn {
   694  	colSets := make([]sql.PrivilegeSetColumn, 0, len(ps.columns))
   695  	for _, colSet := range ps.columns {
   696  		// Only return columns that have privileges. Otherwise, there is no difference between the returned column and
   697  		// the zero-value for any column.
   698  		if colSet.Count() > 0 {
   699  			colSets = append(colSets, colSet)
   700  		}
   701  	}
   702  	sort.Slice(colSets, func(i, j int) bool {
   703  		return colSets[i].Name() < colSets[j].Name()
   704  	})
   705  	return colSets
   706  }
   707  
   708  // getColumns returns all columns of the native type.
   709  func (ps PrivilegeSetTable) getColumns() []PrivilegeSetColumn {
   710  	colSets := make([]PrivilegeSetColumn, 0, len(ps.columns))
   711  	for _, colSet := range ps.columns {
   712  		// Only return columns that have privileges. Otherwise, there is no difference between the returned column and
   713  		// the zero-value for any column.
   714  		if colSet.Count() > 0 {
   715  			colSets = append(colSets, colSet)
   716  		}
   717  	}
   718  	sort.Slice(colSets, func(i, j int) bool {
   719  		return colSets[i].name < colSets[j].name
   720  	})
   721  	return colSets
   722  }
   723  
   724  // Equals returns whether the given set of privileges is equivalent to the calling set.
   725  func (ps PrivilegeSetTable) Equals(otherPst sql.PrivilegeSetTable) bool {
   726  	otherPs := otherPst.(PrivilegeSetTable)
   727  	if len(ps.privs) != len(otherPs.privs) ||
   728  		len(ps.columns) != len(otherPs.columns) {
   729  		return false
   730  	}
   731  	for priv := range ps.privs {
   732  		if _, ok := otherPs.privs[priv]; !ok {
   733  			return false
   734  		}
   735  	}
   736  	for colName, colSet := range ps.columns {
   737  		if !colSet.Equals(otherPs.columns[colName]) {
   738  			return false
   739  		}
   740  	}
   741  	return true
   742  }
   743  
   744  // ToSlice returns all of the table privileges contained as a sorted slice.
   745  func (ps PrivilegeSetTable) ToSlice() []sql.PrivilegeType {
   746  	privs := make([]sql.PrivilegeType, 0, len(ps.privs))
   747  	for priv := range ps.privs {
   748  		privs = append(privs, priv)
   749  	}
   750  	sort.Slice(privs, func(i, j int) bool {
   751  		return privs[i] < privs[j]
   752  	})
   753  	return privs
   754  }
   755  
   756  // getUseableCol is used internally to either retrieve an existing column, or create a new one that is returned.
   757  func (ps PrivilegeSetTable) getUseableCol(colName string) PrivilegeSetColumn {
   758  	lowerColName := strings.ToLower(colName)
   759  	colSet, ok := ps.columns[lowerColName]
   760  	if !ok {
   761  		colSet = PrivilegeSetColumn{
   762  			name:  colName,
   763  			privs: make(map[sql.PrivilegeType]struct{}),
   764  		}
   765  		ps.columns[lowerColName] = colSet
   766  	}
   767  	return colSet
   768  }
   769  
   770  // unionWith merges the given set of privileges to the calling set of privileges.
   771  func (ps PrivilegeSetTable) unionWith(otherPs PrivilegeSetTable) {
   772  	for priv := range otherPs.privs {
   773  		ps.privs[priv] = struct{}{}
   774  	}
   775  	for _, otherColSet := range otherPs.columns {
   776  		ps.getUseableCol(otherColSet.name).unionWith(otherColSet)
   777  	}
   778  }
   779  
   780  // clear removes all table privileges.
   781  func (ps PrivilegeSetTable) clear() {
   782  	for priv := range ps.privs {
   783  		delete(ps.privs, priv)
   784  	}
   785  	for col := range ps.columns {
   786  		delete(ps.columns, col)
   787  	}
   788  }
   789  
   790  // PrivilegeSetColumn is a set containing column privileges.
   791  type PrivilegeSetColumn struct {
   792  	name  string
   793  	privs map[sql.PrivilegeType]struct{}
   794  }
   795  
   796  var _ sql.PrivilegeSetColumn = PrivilegeSetColumn{}
   797  
   798  // Name returns the name of the column that this privilege set belongs to.
   799  func (ps PrivilegeSetColumn) Name() string {
   800  	return ps.name
   801  }
   802  
   803  // Has returns whether the given column privilege(s) exists.
   804  func (ps PrivilegeSetColumn) Has(privileges ...sql.PrivilegeType) bool {
   805  	for _, priv := range privileges {
   806  		if _, ok := ps.privs[priv]; !ok {
   807  			return false
   808  		}
   809  	}
   810  	return true
   811  }
   812  
   813  // HasPrivileges returns whether this column has any privileges.
   814  func (ps PrivilegeSetColumn) HasPrivileges() bool {
   815  	return len(ps.privs) > 0
   816  }
   817  
   818  // Count returns the number of column privileges.
   819  func (ps PrivilegeSetColumn) Count() int {
   820  	return len(ps.privs)
   821  }
   822  
   823  // Equals returns whether the given set of privileges is equivalent to the calling set.
   824  func (ps PrivilegeSetColumn) Equals(otherPsc sql.PrivilegeSetColumn) bool {
   825  	otherPs := otherPsc.(PrivilegeSetColumn)
   826  	if len(ps.privs) != len(otherPs.privs) {
   827  		return false
   828  	}
   829  	for priv := range ps.privs {
   830  		if _, ok := otherPs.privs[priv]; !ok {
   831  			return false
   832  		}
   833  	}
   834  	return true
   835  }
   836  
   837  // ToSlice returns all of the column privileges contained as a sorted slice.
   838  func (ps PrivilegeSetColumn) ToSlice() []sql.PrivilegeType {
   839  	privs := make([]sql.PrivilegeType, 0, len(ps.privs))
   840  	for priv := range ps.privs {
   841  		privs = append(privs, priv)
   842  	}
   843  	sort.Slice(privs, func(i, j int) bool {
   844  		return privs[i] < privs[j]
   845  	})
   846  	return privs
   847  }
   848  
   849  // unionWith merges the given set of privileges to the calling set of privileges.
   850  func (ps PrivilegeSetColumn) unionWith(otherPs PrivilegeSetColumn) {
   851  	for priv := range otherPs.privs {
   852  		ps.privs[priv] = struct{}{}
   853  	}
   854  }
   855  
   856  // clear removes all column privileges.
   857  func (ps PrivilegeSetColumn) clear() {
   858  	for priv := range ps.privs {
   859  		delete(ps.privs, priv)
   860  	}
   861  }
   862  
   863  type PrivilegeSetRoutine struct {
   864  	name   string
   865  	isProc bool // true = procedure, false = function
   866  	privs  map[sql.PrivilegeType]struct{}
   867  }
   868  
   869  // unionWith merges the given set of privileges to the calling set of privileges.
   870  func (ps PrivilegeSetRoutine) unionWith(otherPs PrivilegeSetRoutine) {
   871  	for priv := range otherPs.privs {
   872  		ps.privs[priv] = struct{}{}
   873  	}
   874  }
   875  
   876  // clear removes all routine privileges.
   877  func (ps PrivilegeSetRoutine) clear() {
   878  	for priv := range ps.privs {
   879  		delete(ps.privs, priv)
   880  	}
   881  }
   882  
   883  var _ sql.PrivilegeSetRoutine = PrivilegeSetRoutine{}
   884  
   885  // RoutineName returns the name of the routine that this privilege set belongs to.
   886  func (ps PrivilegeSetRoutine) RoutineName() string {
   887  	return ps.name
   888  }
   889  
   890  // RoutineType returns the type of routine this is (PROCEDURE or FUNCTION).
   891  func (ps PrivilegeSetRoutine) RoutineType() string {
   892  	if ps.isProc {
   893  		return "PROCEDURE"
   894  	} else {
   895  		return "FUNCTION"
   896  	}
   897  }
   898  
   899  // Count returns the number of routine privileges.
   900  func (ps PrivilegeSetRoutine) Count() int {
   901  	return len(ps.privs)
   902  }
   903  
   904  // Has returns whether the given column privilege(s) exists.
   905  func (ps PrivilegeSetRoutine) Has(privileges ...sql.PrivilegeType) bool {
   906  	for _, priv := range privileges {
   907  		if _, ok := ps.privs[priv]; !ok {
   908  			return false
   909  		}
   910  	}
   911  	return true
   912  }
   913  
   914  // HasPrivileges returns whether this routine has any privileges.
   915  func (ps PrivilegeSetRoutine) HasPrivileges() bool {
   916  	return len(ps.privs) > 0
   917  }
   918  
   919  // ToSlice returns all of the privileges contained as a sorted slice.
   920  func (ps PrivilegeSetRoutine) ToSlice() []sql.PrivilegeType {
   921  	privs := make([]sql.PrivilegeType, 0, len(ps.privs))
   922  	for priv := range ps.privs {
   923  		privs = append(privs, priv)
   924  	}
   925  	sort.Slice(privs, func(i, j int) bool {
   926  		return privs[i] < privs[j]
   927  	})
   928  	return privs
   929  }
   930  
   931  // Equals returns whether the given set of privileges is equivalent to the calling set.
   932  func (ps PrivilegeSetRoutine) Equals(otherPs sql.PrivilegeSetRoutine) bool {
   933  	if ps.RoutineName() != otherPs.RoutineName() {
   934  		return false
   935  	}
   936  	if ps.RoutineType() != otherPs.RoutineType() {
   937  		return false
   938  	}
   939  
   940  	thisSlice := ps.ToSlice()
   941  	thatSlice := otherPs.ToSlice()
   942  
   943  	if len(thisSlice) != len(thatSlice) {
   944  		return false
   945  	}
   946  	for i, val := range thisSlice {
   947  		if val != thatSlice[i] {
   948  			return false
   949  		}
   950  	}
   951  	return true
   952  }