github.com/dolthub/go-mysql-server@v0.18.0/sql/mysql_db/privileged_database_provider.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql_db 16 17 import ( 18 "strings" 19 "time" 20 21 "github.com/dolthub/go-mysql-server/sql/fulltext" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 ) 25 26 // PrivilegedDatabaseProvider is a wrapper around a normal sql.DatabaseProvider that takes a context's client's 27 // privileges into consideration when returning a sql.Database. In addition, any returned databases are wrapped with 28 // PrivilegedDatabase. 29 type PrivilegedDatabaseProvider struct { 30 grantTables *MySQLDb 31 provider sql.DatabaseProvider 32 } 33 34 var _ sql.DatabaseProvider = PrivilegedDatabaseProvider{} 35 36 // NewPrivilegedDatabaseProvider returns a new PrivilegedDatabaseProvider. As a sql.DatabaseProvider may be added to an 37 // analyzer when Grant Tables are disabled (and Grant Tables may be enabled or disabled at any time), a new 38 // PrivilegedDatabaseProvider is returned whenever the sql.DatabaseProvider is needed (as long as Grant Tables are 39 // enabled) rather than wrapping a sql.DatabaseProvider when it is provided to the analyzer. 40 func NewPrivilegedDatabaseProvider(grantTables *MySQLDb, p sql.DatabaseProvider) sql.DatabaseProvider { 41 return PrivilegedDatabaseProvider{ 42 grantTables: grantTables, 43 provider: p, 44 } 45 } 46 47 // Database implements the interface sql.DatabaseProvider. 48 func (pdp PrivilegedDatabaseProvider) Database(ctx *sql.Context, name string) (sql.Database, error) { 49 if strings.ToLower(name) == "mysql" { 50 return pdp.grantTables, nil 51 } 52 53 db, providerErr := pdp.provider.Database(ctx, name) 54 if sql.ErrDatabaseNotFound.Is(providerErr) { 55 // continue to priv check below, which will deny access or return not found as appropriate, before returning this 56 // original not found error 57 } else if providerErr != nil { 58 return nil, providerErr 59 } 60 61 checkName := name 62 if adb, ok := db.(sql.AliasedDatabase); ok { 63 checkName = adb.AliasedName() 64 } 65 66 privSet := pdp.grantTables.UserActivePrivilegeSet(ctx) 67 // If the user has no global static privileges or database-relevant privileges then the database is not accessible. 68 if privSet.Count() == 0 && !privSet.Database(checkName).HasPrivileges() { 69 return nil, sql.ErrDatabaseAccessDeniedForUser.New(pdp.usernameFromCtx(ctx), checkName) 70 } 71 72 if providerErr != nil { 73 return nil, providerErr 74 } 75 76 return NewPrivilegedDatabase(pdp.grantTables, db), nil 77 } 78 79 // HasDatabase implements the interface sql.DatabaseProvider. 80 func (pdp PrivilegedDatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool { 81 db, err := pdp.provider.Database(ctx, name) 82 if sql.ErrDatabaseNotFound.Is(err) { 83 // continue to check below, which will deny access or return not found as appropriate 84 } else if err != nil { 85 return false 86 } 87 88 if adb, ok := db.(sql.AliasedDatabase); ok { 89 name = adb.AliasedName() 90 } 91 92 privSet := pdp.grantTables.UserActivePrivilegeSet(ctx) 93 // If the user has no global static privileges or database-relevant privileges then the database is not accessible. 94 if privSet.Count() == 0 && !privSet.Database(name).HasPrivileges() { 95 return false 96 } 97 98 return pdp.provider.HasDatabase(ctx, name) 99 } 100 101 // AllDatabases implements the interface sql.DatabaseProvider. 102 func (pdp PrivilegedDatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database { 103 privilegeSet := pdp.grantTables.UserActivePrivilegeSet(ctx) 104 privilegeSetCount := privilegeSet.Count() 105 106 var databasesWithAccess []sql.Database 107 allDatabases := pdp.provider.AllDatabases(ctx) 108 for _, db := range allDatabases { 109 // If the user has any global static privileges or database-relevant privileges then the database is accessible 110 checkName := db.Name() 111 112 if adb, ok := db.(sql.AliasedDatabase); ok { 113 checkName = adb.AliasedName() 114 } 115 116 if privilegeSetCount > 0 || privilegeSet.Database(checkName).HasPrivileges() { 117 databasesWithAccess = append(databasesWithAccess, NewPrivilegedDatabase(pdp.grantTables, db)) 118 } 119 } 120 return databasesWithAccess 121 } 122 123 // usernameFromCtx returns the username from the context, properly formatted for returned errors. 124 func (pdp PrivilegedDatabaseProvider) usernameFromCtx(ctx *sql.Context) string { 125 client := ctx.Session.Client() 126 return User{User: client.User, Host: client.Address}.UserHostToString("'") 127 } 128 129 // PrivilegedDatabase is a wrapper around a normal sql.Database that takes a context's client's privileges into 130 // consideration when returning a sql.Table. 131 type PrivilegedDatabase struct { 132 grantTables *MySQLDb 133 db sql.Database 134 //TODO: this should also handle views as the relevant privilege exists 135 } 136 137 var _ sql.Database = PrivilegedDatabase{} 138 var _ sql.VersionedDatabase = PrivilegedDatabase{} 139 var _ sql.TableCreator = PrivilegedDatabase{} 140 var _ sql.TableDropper = PrivilegedDatabase{} 141 var _ sql.TableRenamer = PrivilegedDatabase{} 142 var _ sql.TriggerDatabase = PrivilegedDatabase{} 143 var _ sql.StoredProcedureDatabase = PrivilegedDatabase{} 144 var _ sql.EventDatabase = PrivilegedDatabase{} 145 var _ sql.TableCopierDatabase = PrivilegedDatabase{} 146 var _ sql.ReadOnlyDatabase = PrivilegedDatabase{} 147 var _ sql.TemporaryTableDatabase = PrivilegedDatabase{} 148 var _ sql.CollatedDatabase = PrivilegedDatabase{} 149 var _ sql.ViewDatabase = PrivilegedDatabase{} 150 var _ fulltext.Database = PrivilegedDatabase{} 151 152 // NewPrivilegedDatabase returns a new PrivilegedDatabase. 153 func NewPrivilegedDatabase(grantTables *MySQLDb, db sql.Database) sql.Database { 154 return PrivilegedDatabase{ 155 grantTables: grantTables, 156 db: db, 157 } 158 } 159 160 // Name implements the interface sql.Database. 161 func (pdb PrivilegedDatabase) Name() string { 162 return pdb.db.Name() 163 } 164 165 // GetTableInsensitive implements the interface sql.Database. 166 func (pdb PrivilegedDatabase) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) { 167 checkName := pdb.db.Name() 168 if adb, ok := pdb.db.(sql.AliasedDatabase); ok { 169 checkName = adb.AliasedName() 170 } 171 172 privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) 173 dbSet := privSet.Database(checkName) 174 // If there are no usable privileges for this database then the table is inaccessible. 175 if privSet.Count() == 0 && !dbSet.HasPrivileges() { 176 return nil, false, sql.ErrDatabaseAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), checkName) 177 } 178 179 tblSet := dbSet.Table(tblName) 180 // If the user has no global static privileges, database-level privileges, or table-relevant privileges then the 181 // table is not accessible. 182 if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() { 183 return nil, false, sql.ErrTableAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), tblName) 184 } 185 return pdb.db.GetTableInsensitive(ctx, tblName) 186 } 187 188 // GetTableNames implements the interface sql.Database. 189 func (pdb PrivilegedDatabase) GetTableNames(ctx *sql.Context) ([]string, error) { 190 var tablesWithAccess []string 191 var err error 192 privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) 193 194 checkName := pdb.db.Name() 195 if adb, ok := pdb.db.(sql.AliasedDatabase); ok { 196 checkName = adb.AliasedName() 197 } 198 199 dbSet := privSet.Database(checkName) 200 // If there are no usable privileges for this database then no table is accessible. 201 privSetCount := privSet.Count() 202 if privSetCount == 0 && !dbSet.HasPrivileges() { 203 return nil, nil 204 } 205 206 tblNames, err := pdb.db.GetTableNames(ctx) 207 if err != nil { 208 return nil, err 209 } 210 dbSetCount := dbSet.Count() 211 for _, tblName := range tblNames { 212 // If the user has any global static privileges, database-level privileges, or table-relevant privileges then a 213 // table is accessible. 214 if privSetCount > 0 || dbSetCount > 0 || dbSet.Table(tblName).HasPrivileges() { 215 tablesWithAccess = append(tablesWithAccess, tblName) 216 } 217 } 218 return tablesWithAccess, nil 219 } 220 221 // GetTableInsensitiveAsOf returns a new sql.VersionedDatabase. 222 func (pdb PrivilegedDatabase) GetTableInsensitiveAsOf(ctx *sql.Context, tblName string, asOf interface{}) (sql.Table, bool, error) { 223 db, ok := pdb.db.(sql.VersionedDatabase) 224 if !ok { 225 return nil, false, sql.ErrAsOfNotSupported.New(pdb.db.Name()) 226 } 227 228 privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) 229 230 checkName := pdb.db.Name() 231 if adb, ok := pdb.db.(sql.AliasedDatabase); ok { 232 checkName = adb.AliasedName() 233 } 234 235 dbSet := privSet.Database(checkName) 236 // If there are no usable privileges for this database then the table is inaccessible. 237 if privSet.Count() == 0 && !dbSet.HasPrivileges() { 238 return nil, false, sql.ErrDatabaseAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), checkName) 239 } 240 241 tblSet := dbSet.Table(tblName) 242 // If the user has no global static privileges, database-level privileges, or table-relevant privileges then the 243 // table is not accessible. 244 if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() { 245 return nil, false, sql.ErrTableAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), tblName) 246 } 247 return db.GetTableInsensitiveAsOf(ctx, tblName, asOf) 248 } 249 250 // GetTableNamesAsOf returns a new sql.VersionedDatabase. 251 func (pdb PrivilegedDatabase) GetTableNamesAsOf(ctx *sql.Context, asOf interface{}) ([]string, error) { 252 db, ok := pdb.db.(sql.VersionedDatabase) 253 if !ok { 254 return nil, nil 255 } 256 257 var tablesWithAccess []string 258 var err error 259 privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) 260 261 checkName := pdb.db.Name() 262 if adb, ok := pdb.db.(sql.AliasedDatabase); ok { 263 checkName = adb.AliasedName() 264 } 265 266 dbSet := privSet.Database(checkName) 267 // If there are no usable privileges for this database then no table is accessible. 268 if privSet.Count() == 0 && !dbSet.HasPrivileges() { 269 return nil, nil 270 } 271 272 tblNames, err := db.GetTableNamesAsOf(ctx, asOf) 273 if err != nil { 274 return nil, err 275 } 276 privSetCount := privSet.Count() 277 dbSetCount := dbSet.Count() 278 for _, tblName := range tblNames { 279 // If the user has any global static privileges, database-level privileges, or table-relevant privileges then a 280 // table is accessible. 281 if privSetCount > 0 || dbSetCount > 0 && dbSet.Table(tblName).HasPrivileges() { 282 tablesWithAccess = append(tablesWithAccess, tblName) 283 } 284 } 285 286 return tablesWithAccess, nil 287 } 288 289 // CreateTable implements the interface sql.TableCreator. 290 func (pdb PrivilegedDatabase) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error { 291 if db, ok := pdb.db.(sql.TableCreator); ok { 292 return db.CreateTable(ctx, name, schema, collation, comment) 293 } 294 return sql.ErrCreateTableNotSupported.New(pdb.db.Name()) 295 } 296 297 // DropTable implements the interface sql.TableDropper. 298 func (pdb PrivilegedDatabase) DropTable(ctx *sql.Context, name string) error { 299 if db, ok := pdb.db.(sql.TableDropper); ok { 300 return db.DropTable(ctx, name) 301 } 302 return sql.ErrDropTableNotSupported.New(pdb.db.Name()) 303 } 304 305 // CreateFulltextTableNames implements the interface fulltext.Database. 306 func (pdb PrivilegedDatabase) CreateFulltextTableNames(ctx *sql.Context, parentTable string, parentIndexName string) (fulltext.IndexTableNames, error) { 307 if db, ok := pdb.db.(fulltext.Database); ok { 308 return db.CreateFulltextTableNames(ctx, parentTable, parentIndexName) 309 } 310 return fulltext.IndexTableNames{}, sql.ErrFullTextDatabaseNotSupported.New() 311 } 312 313 // RenameTable implements the interface sql.TableRenamer. 314 func (pdb PrivilegedDatabase) RenameTable(ctx *sql.Context, oldName, newName string) error { 315 if db, ok := pdb.db.(sql.TableRenamer); ok { 316 return db.RenameTable(ctx, oldName, newName) 317 } 318 return sql.ErrRenameTableNotSupported.New(pdb.db.Name()) 319 } 320 321 // GetTriggers implements the interface sql.TriggerDatabase. 322 func (pdb PrivilegedDatabase) GetTriggers(ctx *sql.Context) ([]sql.TriggerDefinition, error) { 323 if db, ok := pdb.db.(sql.TriggerDatabase); ok { 324 return db.GetTriggers(ctx) 325 } 326 return nil, sql.ErrTriggersNotSupported.New(pdb.db.Name()) 327 } 328 329 // CreateTrigger implements the interface sql.TriggerDatabase. 330 func (pdb PrivilegedDatabase) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinition) error { 331 if db, ok := pdb.db.(sql.TriggerDatabase); ok { 332 return db.CreateTrigger(ctx, definition) 333 } 334 return sql.ErrTriggersNotSupported.New(pdb.db.Name()) 335 } 336 337 // DropTrigger implements the interface sql.TriggerDatabase. 338 func (pdb PrivilegedDatabase) DropTrigger(ctx *sql.Context, name string) error { 339 if db, ok := pdb.db.(sql.TriggerDatabase); ok { 340 return db.DropTrigger(ctx, name) 341 } 342 return sql.ErrTriggersNotSupported.New(pdb.db.Name()) 343 } 344 345 // GetStoredProcedure implements the interface sql.StoredProcedureDatabase. 346 func (pdb PrivilegedDatabase) GetStoredProcedure(ctx *sql.Context, name string) (sql.StoredProcedureDetails, bool, error) { 347 if db, ok := pdb.db.(sql.StoredProcedureDatabase); ok { 348 return db.GetStoredProcedure(ctx, name) 349 } 350 return sql.StoredProcedureDetails{}, false, sql.ErrStoredProceduresNotSupported.New(pdb.db.Name()) 351 } 352 353 // GetStoredProcedures implements the interface sql.StoredProcedureDatabase. 354 func (pdb PrivilegedDatabase) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureDetails, error) { 355 if db, ok := pdb.db.(sql.StoredProcedureDatabase); ok { 356 return db.GetStoredProcedures(ctx) 357 } 358 return nil, sql.ErrStoredProceduresNotSupported.New(pdb.db.Name()) 359 } 360 361 // SaveStoredProcedure implements the interface sql.StoredProcedureDatabase. 362 func (pdb PrivilegedDatabase) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error { 363 if db, ok := pdb.db.(sql.StoredProcedureDatabase); ok { 364 return db.SaveStoredProcedure(ctx, spd) 365 } 366 return sql.ErrStoredProceduresNotSupported.New(pdb.db.Name()) 367 } 368 369 // DropStoredProcedure implements the interface sql.StoredProcedureDatabase. 370 func (pdb PrivilegedDatabase) DropStoredProcedure(ctx *sql.Context, name string) error { 371 if db, ok := pdb.db.(sql.StoredProcedureDatabase); ok { 372 return db.DropStoredProcedure(ctx, name) 373 } 374 return sql.ErrStoredProceduresNotSupported.New(pdb.db.Name()) 375 } 376 377 // GetEvent implements sql.EventDatabase 378 func (pdb PrivilegedDatabase) GetEvent(ctx *sql.Context, name string) (sql.EventDefinition, bool, error) { 379 if db, ok := pdb.db.(sql.EventDatabase); ok { 380 return db.GetEvent(ctx, name) 381 } 382 return sql.EventDefinition{}, false, sql.ErrEventsNotSupported.New(pdb.db.Name()) 383 } 384 385 // GetEvents implements sql.EventDatabase 386 func (pdb PrivilegedDatabase) GetEvents(ctx *sql.Context) ([]sql.EventDefinition, interface{}, error) { 387 if db, ok := pdb.db.(sql.EventDatabase); ok { 388 return db.GetEvents(ctx) 389 } 390 return nil, nil, sql.ErrEventsNotSupported.New(pdb.db.Name()) 391 } 392 393 // SaveEvent implements sql.EventDatabase 394 func (pdb PrivilegedDatabase) SaveEvent(ctx *sql.Context, ed sql.EventDefinition) (bool, error) { 395 if db, ok := pdb.db.(sql.EventDatabase); ok { 396 return db.SaveEvent(ctx, ed) 397 } 398 return false, sql.ErrEventsNotSupported.New(pdb.db.Name()) 399 } 400 401 // DropEvent implements sql.EventDatabase 402 func (pdb PrivilegedDatabase) DropEvent(ctx *sql.Context, name string) error { 403 if db, ok := pdb.db.(sql.EventDatabase); ok { 404 return db.DropEvent(ctx, name) 405 } 406 return sql.ErrEventsNotSupported.New(pdb.db.Name()) 407 } 408 409 // UpdateEvent implements sql.EventDatabase 410 func (pdb PrivilegedDatabase) UpdateEvent(ctx *sql.Context, originalName string, ed sql.EventDefinition) (bool, error) { 411 if db, ok := pdb.db.(sql.EventDatabase); ok { 412 return db.UpdateEvent(ctx, originalName, ed) 413 } 414 return false, sql.ErrEventsNotSupported.New(pdb.db.Name()) 415 } 416 417 // NeedsToReloadEvents implements sql.EventDatabase 418 func (pdb PrivilegedDatabase) NeedsToReloadEvents(ctx *sql.Context, token interface{}) (bool, error) { 419 if db, ok := pdb.db.(sql.EventDatabase); ok { 420 return db.NeedsToReloadEvents(ctx, token) 421 } 422 return false, sql.ErrEventsNotSupported.New(pdb.db.Name()) 423 } 424 425 func (pdb PrivilegedDatabase) UpdateLastExecuted(ctx *sql.Context, eventName string, lastExecuted time.Time) error { 426 if db, ok := pdb.db.(sql.EventDatabase); ok { 427 return db.UpdateLastExecuted(ctx, eventName, lastExecuted) 428 } 429 return sql.ErrEventsNotSupported.New(pdb.db.Name()) 430 } 431 432 // CreateView implements sql.ViewDatabase 433 func (pdb PrivilegedDatabase) CreateView(ctx *sql.Context, name string, selectStatement, createViewStmt string) error { 434 if db, ok := pdb.db.(sql.ViewDatabase); ok { 435 return db.CreateView(ctx, name, selectStatement, createViewStmt) 436 } 437 return sql.ErrViewsNotSupported.New(pdb.db.Name()) 438 } 439 440 // DropView implements sql.ViewDatabase 441 func (pdb PrivilegedDatabase) DropView(ctx *sql.Context, name string) error { 442 if db, ok := pdb.db.(sql.ViewDatabase); ok { 443 return db.DropView(ctx, name) 444 } 445 return sql.ErrViewsNotSupported.New(pdb.db.Name()) 446 } 447 448 // GetViewDefinition implements sql.ViewDatabase 449 func (pdb PrivilegedDatabase) GetViewDefinition(ctx *sql.Context, viewName string) (sql.ViewDefinition, bool, error) { 450 if db, ok := pdb.db.(sql.ViewDatabase); ok { 451 return db.GetViewDefinition(ctx, viewName) 452 } 453 return sql.ViewDefinition{}, false, sql.ErrViewsNotSupported.New(pdb.db.Name()) 454 } 455 456 // AllViews implements sql.ViewDatabase 457 func (pdb PrivilegedDatabase) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) { 458 if db, ok := pdb.db.(sql.ViewDatabase); ok { 459 return db.AllViews(ctx) 460 } 461 return nil, sql.ErrViewsNotSupported.New(pdb.db.Name()) 462 } 463 464 // CopyTableData implements the interface sql.TableCopierDatabase. 465 func (pdb PrivilegedDatabase) CopyTableData(ctx *sql.Context, sourceTable string, destinationTable string) (uint64, error) { 466 if db, ok := pdb.db.(sql.TableCopierDatabase); ok { 467 // Privilege checking is handled in the analyzer 468 return db.CopyTableData(ctx, sourceTable, destinationTable) 469 } 470 return 0, sql.ErrTableCopyingNotSupported.New() 471 } 472 473 // IsReadOnly implements the interface sql.ReadOnlyDatabase. 474 func (pdb PrivilegedDatabase) IsReadOnly() bool { 475 if db, ok := pdb.db.(sql.ReadOnlyDatabase); ok { 476 return db.IsReadOnly() 477 } 478 return false 479 } 480 481 // GetAllTemporaryTables implements the interface sql.TemporaryTableDatabase. 482 func (pdb PrivilegedDatabase) GetAllTemporaryTables(ctx *sql.Context) ([]sql.Table, error) { 483 if db, ok := pdb.db.(sql.TemporaryTableDatabase); ok { 484 return db.GetAllTemporaryTables(ctx) 485 } 486 // All current temp table checks skip if not implemented, same is iterating over an empty slice 487 return nil, nil 488 } 489 490 // GetCollation implements the interface sql.CollatedDatabase. 491 func (pdb PrivilegedDatabase) GetCollation(ctx *sql.Context) sql.CollationID { 492 if db, ok := pdb.db.(sql.CollatedDatabase); ok { 493 return db.GetCollation(ctx) 494 } 495 return sql.Collation_Default 496 } 497 498 // SetCollation implements the interface sql.CollatedDatabase. 499 func (pdb PrivilegedDatabase) SetCollation(ctx *sql.Context, collation sql.CollationID) error { 500 if db, ok := pdb.db.(sql.CollatedDatabase); ok { 501 return db.SetCollation(ctx, collation) 502 } 503 return sql.ErrDatabaseCollationsNotSupported.New(pdb.db.Name()) 504 } 505 506 // Unwrap returns the wrapped sql.Database. 507 func (pdb PrivilegedDatabase) Unwrap() sql.Database { 508 return pdb.db 509 } 510 511 // usernameFromCtx returns the username from the context, properly formatted for returned errors. 512 func (pdb PrivilegedDatabase) usernameFromCtx(ctx *sql.Context) string { 513 client := ctx.Session.Client() 514 return User{User: client.User, Host: client.Address}.UserHostToString("'") 515 }