github.com/dolthub/go-mysql-server@v0.18.0/sql/mysql_db/mysql_db.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql_db 16 17 import ( 18 "bytes" 19 "crypto/sha1" 20 "encoding/hex" 21 "encoding/json" 22 "fmt" 23 "net" 24 "sort" 25 "strings" 26 "sync" 27 "sync/atomic" 28 29 flatbuffers "github.com/dolthub/flatbuffers/v23/go" 30 "github.com/dolthub/vitess/go/mysql" 31 32 "github.com/dolthub/go-mysql-server/sql" 33 "github.com/dolthub/go-mysql-server/sql/in_mem_table" 34 "github.com/dolthub/go-mysql-server/sql/mysql_db/serial" 35 ) 36 37 // MySQLDbPersistence is used to determine the behavior of how certain tables in MySQLDb will be persisted. 38 type MySQLDbPersistence interface { 39 Persist(ctx *sql.Context, data []byte) error 40 } 41 42 // NoopPersister is used when nothing in mysql db should be persisted 43 type NoopPersister struct{} 44 45 var _ MySQLDbPersistence = &NoopPersister{} 46 47 // Persist implements the MySQLDbPersistence interface 48 func (p *NoopPersister) Persist(ctx *sql.Context, data []byte) error { 49 return nil 50 } 51 52 type PlaintextAuthPlugin interface { 53 Authenticate(db *MySQLDb, user string, userEntry *User, pass string) (bool, error) 54 } 55 56 // MySQLDb are the collection of tables that are in the MySQL database 57 type MySQLDb struct { 58 enabled atomic.Bool 59 60 user *in_mem_table.IndexedSetTable[*User] 61 role_edges *in_mem_table.IndexedSetTable[*RoleEdge] 62 replica_source_info *in_mem_table.IndexedSetTable[*ReplicaSourceInfo] 63 64 help_topic *mysqlTable 65 help_keyword *mysqlTable 66 help_category *mysqlTable 67 help_relation *mysqlTable 68 69 db *in_mem_table.MultiIndexedSetTable[*User] 70 tables_priv *in_mem_table.MultiIndexedSetTable[*User] 71 procs_priv *in_mem_table.MultiIndexedSetTable[*User] 72 global_grants *in_mem_table.MultiIndexedSetTable[*User] 73 74 //TODO: add the rest of these tables 75 //columns_priv *mysqlTable 76 //proxies_priv *mysqlTable 77 //default_roles *mysqlTable 78 //password_history *mysqlTable 79 80 persister MySQLDbPersistence 81 plugins map[string]PlaintextAuthPlugin 82 83 lock sync.RWMutex 84 updateCounter atomic.Uint64 85 } 86 87 var _ sql.Database = (*MySQLDb)(nil) 88 var _ mysql.AuthServer = (*MySQLDb)(nil) 89 90 // CreateEmptyMySQLDb returns a collection of MySQL Tables that do not contain any data. 91 func CreateEmptyMySQLDb() *MySQLDb { 92 // original tables 93 mysqlDb := &MySQLDb{} 94 95 lock, rlock := &mysqlDb.lock, mysqlDb.lock.RLocker() 96 97 userSet, userTable := NewUserIndexedSetTable(lock, rlock) 98 mysqlDb.user = userTable 99 mysqlDb.role_edges = NewRoleEdgesIndexedSetTable(lock, rlock) 100 mysqlDb.replica_source_info = NewReplicaSourceInfoIndexedSetTable(lock, rlock) 101 102 // Help tables 103 mysqlDb.help_topic = newEmptyMySQLTable( 104 "help_topic", 105 helpTopicSchema, 106 mysqlDb) 107 mysqlDb.help_keyword = newEmptyMySQLTable( 108 "help_keyword", 109 helpKeywordSchema, 110 mysqlDb) 111 mysqlDb.help_category = newEmptyMySQLTable( 112 "help_category", 113 helpCategorySchema, 114 mysqlDb) 115 mysqlDb.help_relation = newEmptyMySQLTable( 116 "help_relation", 117 helpRelationSchema, 118 mysqlDb) 119 120 // multi tables 121 mysqlDb.db = NewUserDBIndexedSetTable(userSet, lock, rlock) 122 mysqlDb.tables_priv = NewUserTablesIndexedSetTable(userSet, lock, rlock) 123 mysqlDb.procs_priv = NewUserProcsIndexedSetTable(userSet, lock, rlock) 124 mysqlDb.global_grants = NewUserGlobalGrantsIndexedSetTable(userSet, lock, rlock) 125 126 // Start the counter at 1, all new sessions will start at zero so this forces an update for any new session 127 mysqlDb.updateCounter.Store(1) 128 129 return mysqlDb 130 } 131 132 type Reader struct { 133 users in_mem_table.IndexedSet[*User] 134 roleEdges in_mem_table.IndexedSet[*RoleEdge] 135 replicaSourceInfo in_mem_table.IndexedSet[*ReplicaSourceInfo] 136 137 close func() 138 } 139 140 type UserFetcher interface { 141 GetUser(u UserPrimaryKey) (res *User, ok bool) 142 GetUsersByUsername(username string) []*User 143 } 144 145 func (r *Reader) GetReplicaSourceInfo(k ReplicaSourceInfoPrimaryKey) (res *ReplicaSourceInfo, ok bool) { 146 sources := r.replicaSourceInfo.GetMany(ReplicaSourceInfoPrimaryKeyer{}, k) 147 if len(sources) > 1 { 148 panic("too many matching replica sources") 149 } 150 if len(sources) > 0 { 151 res = sources[0] 152 ok = true 153 } 154 return 155 } 156 157 func (r *Reader) GetUser(u UserPrimaryKey) (res *User, ok bool) { 158 users := r.users.GetMany(UserPrimaryKeyer{}, u) 159 if len(users) > 1 { 160 panic("too many matching users") 161 } 162 if len(users) > 0 { 163 res = users[0] 164 ok = true 165 } 166 return 167 } 168 169 func (r *Reader) GetUsersByUsername(username string) []*User { 170 return r.users.GetMany(UserSecondaryKeyer{}, UserSecondaryKey{ 171 User: username, 172 }) 173 } 174 175 func (r *Reader) GetToUserRoleEdges(key RoleEdgesToKey) []*RoleEdge { 176 return r.roleEdges.GetMany(RoleEdgeToKeyer{}, key) 177 } 178 179 func (r *Reader) VisitUsers(cb func(*User)) { 180 r.users.VisitEntries(cb) 181 } 182 183 func (r *Reader) VisitRoleEdges(cb func(*RoleEdge)) { 184 r.roleEdges.VisitEntries(cb) 185 } 186 187 func (r *Reader) VisitReplicaSourceInfos(cb func(*ReplicaSourceInfo)) { 188 r.replicaSourceInfo.VisitEntries(cb) 189 } 190 191 func (r *Reader) Close() { 192 if r.close != nil { 193 r.close() 194 r.close = nil 195 } 196 } 197 198 type Editor struct { 199 db *MySQLDb 200 reader *Reader 201 } 202 203 func (ed *Editor) GetReplicaSourceInfo(k ReplicaSourceInfoPrimaryKey) (res *ReplicaSourceInfo, ok bool) { 204 sources := ed.reader.replicaSourceInfo.GetMany(ReplicaSourceInfoPrimaryKeyer{}, k) 205 if len(sources) > 1 { 206 panic("too many matching replica sources") 207 } 208 if len(sources) > 0 { 209 res = sources[0] 210 ok = true 211 } 212 return 213 } 214 215 func (ed *Editor) GetUsersByUsername(username string) []*User { 216 return ed.reader.GetUsersByUsername(username) 217 } 218 219 func (ed *Editor) GetUser(u UserPrimaryKey) (res *User, ok bool) { 220 return ed.reader.GetUser(u) 221 } 222 223 func (ed *Editor) GetToUserRoleEdges(key RoleEdgesToKey) []*RoleEdge { 224 return ed.reader.GetToUserRoleEdges(key) 225 } 226 227 func (ed *Editor) VisitUsers(cb func(*User)) { 228 ed.reader.VisitUsers(cb) 229 } 230 231 func (ed *Editor) VisitRoleEdges(cb func(*RoleEdge)) { 232 ed.reader.VisitRoleEdges(cb) 233 } 234 235 func (ed *Editor) VisitReplicaSourceInfos(cb func(*ReplicaSourceInfo)) { 236 ed.reader.VisitReplicaSourceInfos(cb) 237 } 238 239 func (ed *Editor) PutUser(u *User) { 240 if old, ok := ed.reader.users.Get(u); ok { 241 ed.reader.users.Remove(old) 242 } 243 ed.reader.users.Put(u) 244 } 245 246 func (ed *Editor) RemoveUser(pk UserPrimaryKey) { 247 ed.reader.users.RemoveMany(UserPrimaryKeyer{}, pk) 248 } 249 250 func (ed *Editor) PutRoleEdge(re *RoleEdge) { 251 if old, ok := ed.reader.roleEdges.Get(re); ok { 252 ed.reader.roleEdges.Remove(old) 253 } 254 ed.reader.roleEdges.Put(re) 255 } 256 257 func (ed *Editor) RemoveRoleEdge(pk RoleEdgesPrimaryKey) { 258 ed.reader.roleEdges.RemoveMany(RoleEdgePrimaryKeyer{}, pk) 259 } 260 261 func (ed *Editor) RemoveRoleEdgesFromKey(key RoleEdgesFromKey) { 262 ed.reader.roleEdges.RemoveMany(RoleEdgeFromKeyer{}, key) 263 } 264 265 func (ed *Editor) RemoveRoleEdgesToKey(key RoleEdgesToKey) { 266 ed.reader.roleEdges.RemoveMany(RoleEdgeToKeyer{}, key) 267 } 268 269 func (ed *Editor) RemoveReplicaSourceInfo(k ReplicaSourceInfoPrimaryKey) { 270 ed.reader.replicaSourceInfo.RemoveMany(ReplicaSourceInfoPrimaryKeyer{}, k) 271 } 272 273 func (ed *Editor) PutReplicaSourceInfo(rsi *ReplicaSourceInfo) { 274 if old, ok := ed.reader.replicaSourceInfo.Get(rsi); ok { 275 ed.reader.replicaSourceInfo.Remove(old) 276 } 277 ed.reader.replicaSourceInfo.Put(rsi) 278 } 279 280 func (ed *Editor) Close() { 281 ed.db.updateCounter.Add(1) 282 ed.reader.Close() 283 ed.db.lock.Unlock() 284 } 285 286 func (db *MySQLDb) unlockedReader() *Reader { 287 return &Reader{ 288 db.user.Set(), 289 db.role_edges.Set(), 290 db.replica_source_info.Set(), 291 nil, 292 } 293 } 294 295 func (db *MySQLDb) Reader() *Reader { 296 db.lock.RLock() 297 return &Reader{ 298 db.user.Set(), 299 db.role_edges.Set(), 300 db.replica_source_info.Set(), 301 func() { 302 db.lock.RUnlock() 303 }, 304 } 305 } 306 307 func (db *MySQLDb) Editor() *Editor { 308 db.lock.Lock() 309 return &Editor{ 310 db, 311 db.unlockedReader(), 312 } 313 } 314 315 func (db *MySQLDb) Enabled() bool { 316 return db.enabled.Load() 317 } 318 319 func (db *MySQLDb) SetEnabled(v bool) { 320 db.enabled.Store(v) 321 } 322 323 // LoadPrivilegeData adds the given data to the MySQL Tables. It does not remove any current data, but will overwrite any 324 // pre-existing data. This has been deprecated in favor of LoadData. 325 func (db *MySQLDb) LoadPrivilegeData(ctx *sql.Context, users []*User, roleConnections []*RoleEdge) error { 326 db.SetEnabled(true) 327 328 ed := db.Editor() 329 defer ed.Close() 330 331 for _, user := range users { 332 if user == nil { 333 continue 334 } 335 ed.PutUser(user) 336 } 337 338 for _, role := range roleConnections { 339 if role == nil { 340 continue 341 } 342 ed.PutRoleEdge(role) 343 } 344 345 return nil 346 } 347 348 // LoadData adds the given data to the MySQL Tables. It does not remove any current data, but will overwrite any 349 // pre-existing data. 350 func (db *MySQLDb) LoadData(ctx *sql.Context, buf []byte) (err error) { 351 // Do nothing if data file doesn't exist or is empty 352 if buf == nil || len(buf) == 0 { 353 return nil 354 } 355 356 type privDataJson struct { 357 Users []*User 358 Roles []*RoleEdge 359 } 360 361 // if it's a json file, read it; will be rewritten as flatbuffer later 362 data := &privDataJson{} 363 if err := json.Unmarshal(buf, data); err == nil { 364 return db.LoadPrivilegeData(ctx, data.Users, data.Roles) 365 } 366 367 // Indicate that mysql db exists 368 db.SetEnabled(true) 369 370 // Recover from panics 371 defer func() { 372 if recover() != nil { 373 err = fmt.Errorf("ill formatted privileges file") 374 } 375 }() 376 377 // Deserialize the flatbuffer 378 serialMySQLDb := serial.GetRootAsMySQLDb(buf, 0) 379 380 ed := db.Editor() 381 defer ed.Close() 382 383 // Fill in user table 384 for i := 0; i < serialMySQLDb.UserLength(); i++ { 385 serialUser := new(serial.User) 386 if !serialMySQLDb.User(serialUser, i) { 387 continue 388 } 389 user := LoadUser(serialUser) 390 ed.PutUser(user) 391 } 392 393 // Fill in Roles table 394 for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ { 395 serialRoleEdge := new(serial.RoleEdge) 396 if !serialMySQLDb.RoleEdges(serialRoleEdge, i) { 397 continue 398 } 399 role := LoadRoleEdge(serialRoleEdge) 400 ed.PutRoleEdge(role) 401 } 402 403 // Fill in the ReplicaSourceInfo table 404 for i := 0; i < serialMySQLDb.ReplicaSourceInfoLength(); i++ { 405 serialReplicaSourceInfo := new(serial.ReplicaSourceInfo) 406 if !serialMySQLDb.ReplicaSourceInfo(serialReplicaSourceInfo, i) { 407 continue 408 } 409 replicaSourceInfo := LoadReplicaSourceInfo(serialReplicaSourceInfo) 410 ed.PutReplicaSourceInfo(replicaSourceInfo) 411 } 412 413 // TODO: fill in other tables when they exist 414 return 415 } 416 417 // OverwriteUsersAndGrantData replaces the users and grant data served by this 418 // MySQL DB instance with the data which is present in the provided byte 419 // buffer, which is a persisted copy of a MySQLDb created with `Persist`. In 420 // contrast to LoadData, it *does* remove current data in the database. 421 // 422 // This interface is appropriate for replication, when a replica needs to be 423 // brought up to date with a primary server. 424 // 425 // This method does not support the legacy JSON serialization of users and 426 // grant data. In contrast to most methods which operate with persisted users 427 // and grants in *MySQLDb, this method _does_ restore persisted super users. 428 func (db *MySQLDb) OverwriteUsersAndGrantData(ctx *sql.Context, ed *Editor, buf []byte) (err error) { 429 // Recover from panics 430 defer func() { 431 if recover() != nil { 432 err = fmt.Errorf("ill formatted privileges file") 433 } 434 }() 435 436 // Deserialize the flatbuffer 437 serialMySQLDb := serial.GetRootAsMySQLDb(buf, 0) 438 439 // In order to make certain we can read the entire serialized message, 440 // we load it fully into *User and *RoleEdge instances before we mutate 441 // our maps at all. 442 var users []*User 443 var edges []*RoleEdge 444 445 // Load all users 446 for i := 0; i < serialMySQLDb.UserLength(); i++ { 447 serialUser := new(serial.User) 448 if !serialMySQLDb.User(serialUser, i) { 449 continue 450 } 451 users = append(users, LoadUser(serialUser)) 452 } 453 for i := 0; i < serialMySQLDb.SuperUserLength(); i++ { 454 serialUser := new(serial.User) 455 if !serialMySQLDb.SuperUser(serialUser, i) { 456 continue 457 } 458 user := LoadUser(serialUser) 459 user.IsSuperUser = true 460 users = append(users, user) 461 } 462 463 // Load all role edges 464 for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ { 465 serialRoleEdge := new(serial.RoleEdge) 466 if !serialMySQLDb.RoleEdges(serialRoleEdge, i) { 467 continue 468 } 469 edges = append(edges, LoadRoleEdge(serialRoleEdge)) 470 } 471 472 ed.reader.users.Clear() 473 ed.reader.roleEdges.Clear() 474 for _, u := range users { 475 ed.PutUser(u) 476 } 477 for _, e := range edges { 478 ed.PutRoleEdge(e) 479 } 480 481 return 482 } 483 484 // SetPersister sets the custom persister to be used when the MySQL Db tables have been updated and need to be persisted. 485 func (db *MySQLDb) SetPersister(persister MySQLDbPersistence) { 486 db.persister = persister 487 } 488 489 func (db *MySQLDb) SetPlugins(plugins map[string]PlaintextAuthPlugin) { 490 db.plugins = plugins 491 } 492 493 func (db *MySQLDb) VerifyPlugin(plugin string) error { 494 _, ok := db.plugins[plugin] 495 if ok { 496 return nil 497 } 498 return fmt.Errorf(`must provide authentication plugin for unsupported authentication format`) 499 } 500 501 // AddRootAccount adds the root account to the list of accounts. 502 func (db *MySQLDb) AddRootAccount() { 503 ed := db.Editor() 504 defer ed.Close() 505 db.AddSuperUser(ed, "root", "localhost", "") 506 } 507 508 // AddSuperUser adds the given username and password to the list of accounts. This is a temporary function, which is 509 // meant to replace the "auth.New..." functions while the remaining functions are added. 510 func (db *MySQLDb) AddSuperUser(ed *Editor, username string, host string, password string) { 511 //TODO: remove this function and the called function 512 db.SetEnabled(true) 513 if len(password) > 0 { 514 hash := sha1.New() 515 hash.Write([]byte(password)) 516 s1 := hash.Sum(nil) 517 hash.Reset() 518 hash.Write(s1) 519 s2 := hash.Sum(nil) 520 password = "*" + strings.ToUpper(hex.EncodeToString(s2)) 521 } 522 523 if _, ok := ed.GetUser(UserPrimaryKey{ 524 Host: host, 525 User: username, 526 }); !ok { 527 addSuperUser(ed, username, host, password) 528 } 529 } 530 531 // GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and 532 // roles, roleSearch changes whether the search matches against user or role rules. 533 func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSearch bool) *User { 534 //TODO: Determine what the localhost is on the machine, then handle the conversion between IP and localhost. 535 // For now, loopback addresses are treated as localhost. 536 //TODO: Determine how to match anonymous roles (roles with an empty user string), which differs from users 537 //TODO: Treat '%' as a proper wildcard for hostnames, allowing for regex-like matches. 538 // Hostnames representing an IP address that have a wildcard have additional restrictions on what may match 539 //TODO: Match non-existent users to the most relevant anonymous user if multiple exist (''@'localhost' vs ''@'%') 540 // It appears that ''@'localhost' can use the privileges set on ''@'%', which seems to be unique behavior. 541 // For example, 'abc'@'localhost' CANNOT use any privileges set on 'abc'@'%'. 542 // Unknown if this is special for ''@'%', or applies to any matching anonymous user. 543 //TODO: Hostnames representing IPs can use masks, such as 'abc'@'54.244.85.0/255.255.255.0' 544 //TODO: Allow for CIDR notation in hostnames 545 //TODO: Which user do we choose when multiple host names match (e.g. host name with most characters matched, etc.) 546 547 if "127.0.0.1" == host || "::1" == host { 548 host = "localhost" 549 } 550 551 if user, ok := fetcher.GetUser(UserPrimaryKey{ 552 Host: host, 553 User: user, 554 }); ok { 555 return user 556 } 557 558 // First we check for matches on the same user, then we try the anonymous user 559 for _, targetUser := range []string{user, ""} { 560 users := fetcher.GetUsersByUsername(targetUser) 561 for _, user := range users { 562 //TODO: use the most specific match first, using "%" only if there isn't a more specific match 563 if host == user.Host || 564 (host == "localhost" && user.Host == "::1") || 565 (host == "localhost" && user.Host == "127.0.0.1") || 566 (user.Host == "%" && (!roleSearch || host == "")) { 567 return user 568 } 569 } 570 } 571 return nil 572 } 573 574 // UserActivePrivilegeSet fetches the User, and returns their entire active privilege set. This takes into account the 575 // active roles, which are set in the context, therefore the user is also pulled from the context. 576 func (db *MySQLDb) UserActivePrivilegeSet(ctx *sql.Context) PrivilegeSet { 577 if privSet, counter := ctx.Session.GetPrivilegeSet(); db.updateCounter.Load() == counter { 578 // If the counters are equal, we can guarantee that the privilege set exists and is valid 579 return privSet.(PrivilegeSet) 580 } 581 582 rd := db.Reader() 583 defer rd.Close() 584 585 client := ctx.Session.Client() 586 user := db.GetUser(rd, client.User, client.Address, false) 587 if user == nil { 588 return NewPrivilegeSet() 589 } 590 591 privSet := user.PrivilegeSet.Copy() 592 roleEdgeEntries := rd.GetToUserRoleEdges(RoleEdgesToKey{ 593 ToHost: user.Host, 594 ToUser: user.User, 595 }) 596 //TODO: filter the active roles using the context, rather than using every granted roles 597 //TODO: System variable "activate_all_roles_on_login", if set, will set all roles as active upon logging in 598 for _, roleEdgeEntry := range roleEdgeEntries { 599 roleEdge := roleEdgeEntry 600 role := db.GetUser(rd, roleEdge.FromUser, roleEdge.FromHost, true) 601 if role != nil { 602 privSet.UnionWith(role.PrivilegeSet) 603 } 604 } 605 606 ctx.Session.SetPrivilegeSet(privSet, db.updateCounter.Load()) 607 return privSet 608 } 609 610 // RoutineAdminCheck fetches the User from the context, and specifically evaluates, the permission check 611 // assuming the operation is for a stored procedure or function. This allows us to have more fine grain control over 612 // permissions for stored procedures (many of which are critical to Dolt). This method specifically checks exists 613 // for the use of AdminOnly procedures which require more fine-grained access control. For procedures which are 614 // not AdminOnly, then |UserHasPrivileges| should be used instead. 615 func (db *MySQLDb) RoutineAdminCheck(ctx *sql.Context, operations ...sql.PrivilegedOperation) bool { 616 privSet := db.UserActivePrivilegeSet(ctx) 617 618 if privSet.Has(sql.PrivilegeType_Super) { 619 // Superpowers allow you to fly and look through walls, surely you can execute whatever you want. 620 return true 621 } 622 623 for _, operation := range operations { 624 for _, operationPriv := range operation.StaticPrivileges { 625 database := operation.Database 626 if database == "" { 627 database = ctx.GetCurrentDatabase() 628 } 629 dbSet := privSet.Database(database) 630 routineSet := dbSet.Routine(operation.Routine, operation.IsProcedure) 631 if routineSet.Has(operationPriv) { 632 continue 633 } 634 635 // User does not have permission to perform the operation. 636 return false 637 } 638 } 639 return true 640 } 641 642 // UserHasPrivileges fetches the User, and returns whether they have the desired privileges necessary to perform the 643 // privileged operation(s). This takes into account the active roles, which are set in the context, therefore both 644 // the user and the active roles are pulled from the context. This method is sufficient for all MySQL behaviors. 645 // The one exception, currently, is for stored procedures and functions, which have a more fine-grained permission 646 // due to Dolt's use of the AdminOnly flag in procedure definitions. 647 // 648 // This functions implements the global/database/table|routine hierarchy of permissions. If a user has Execute permissions 649 // on the database, then they implicitly have that same permission on all tables and routines in that database. This 650 // is how all MySQL permissions work. 651 func (db *MySQLDb) UserHasPrivileges(ctx *sql.Context, operations ...sql.PrivilegedOperation) bool { 652 privSet := db.UserActivePrivilegeSet(ctx) 653 // Super users have all privileges, so if they have global super privs, then 654 // they have all dynamic privs and we don't need to check them. 655 if privSet.Has(sql.PrivilegeType_Super) { 656 return true 657 } 658 659 if !db.Enabled() { 660 return true 661 } 662 for _, operation := range operations { 663 664 for _, operationPriv := range operation.StaticPrivileges { 665 if privSet.Has(operationPriv) { 666 //TODO: Handle partial revokes 667 continue 668 } 669 database := operation.Database 670 if database == "" { 671 database = ctx.GetCurrentDatabase() 672 } 673 dbSet := privSet.Database(database) 674 if dbSet.Has(operationPriv) { 675 continue 676 } 677 tblSet := dbSet.Table(operation.Table) 678 if tblSet.Has(operationPriv) { 679 continue 680 } 681 682 // TODO: Complete the column check support. 683 // colSet := tblSet.Column(operation.Column) 684 // if colSet.Has(operationPriv) { 685 // continue 686 // } 687 688 routineSet := dbSet.Routine(operation.Routine, operation.IsProcedure) 689 if routineSet.Has(operationPriv) { 690 continue 691 } 692 693 // User does not have permission to perform the operation. 694 return false 695 } 696 697 for _, operationPriv := range operation.DynamicPrivileges { 698 if privSet.HasDynamic(operationPriv) { 699 continue 700 } 701 702 // Dynamic privileges are only allowed at a global scope, so no need to check 703 // for database, table, or column privileges. 704 return false 705 } 706 } 707 return true 708 } 709 710 // Name implements the interface sql.Database. 711 func (db *MySQLDb) Name() string { 712 return "mysql" 713 } 714 715 // GetTableInsensitive implements the interface sql.Database. 716 func (db *MySQLDb) GetTableInsensitive(_ *sql.Context, tblName string) (sql.Table, bool, error) { 717 switch strings.ToLower(tblName) { 718 case userTblName: 719 return db.user, true, nil 720 case roleEdgesTblName: 721 return db.role_edges, true, nil 722 case dbTblName: 723 return db.db, true, nil 724 case tablesPrivTblName: 725 return db.tables_priv, true, nil 726 case procsPrivTblName: 727 return db.procs_priv, true, nil 728 case replicaSourceInfoTblName: 729 return db.replica_source_info, true, nil 730 case helpTopicTableName: 731 return db.help_topic, true, nil 732 case helpKeywordTableName: 733 return db.help_keyword, true, nil 734 case helpCategoryTableName: 735 return db.help_category, true, nil 736 case helpRelationTableName: 737 return db.help_relation, true, nil 738 default: 739 return nil, false, nil 740 } 741 } 742 743 // GetTableNames implements the interface sql.Database. 744 func (db *MySQLDb) GetTableNames(ctx *sql.Context) ([]string, error) { 745 return []string{ 746 userTblName, 747 dbTblName, 748 tablesPrivTblName, 749 procsPrivTblName, 750 roleEdgesTblName, 751 replicaSourceInfoTblName, 752 helpTopicTableName, 753 helpKeywordTableName, 754 helpCategoryTableName, 755 helpRelationTableName, 756 }, nil 757 } 758 759 // AuthMethod implements the interface mysql.AuthServer. 760 func (db *MySQLDb) AuthMethod(user, addr string) (string, error) { 761 if !db.Enabled() { 762 return "mysql_native_password", nil 763 } 764 var host string 765 // TODO : need to check for network type instead of addr string if it's unix socket network, 766 // macOS passes empty addr, but ubuntu returns "@" as addr for `localhost` 767 if addr == "@" || addr == "" { 768 host = "localhost" 769 } else { 770 splitHost, _, err := net.SplitHostPort(addr) 771 if err != nil { 772 if err.(*net.AddrError).Err == "missing port in address" { 773 host = addr 774 } else { 775 return "", err 776 } 777 } else { 778 host = splitHost 779 } 780 } 781 782 rd := db.Reader() 783 defer rd.Close() 784 785 u := db.GetUser(rd, user, host, false) 786 if u == nil { 787 return "", mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "User not found '%v'", user) 788 } 789 if _, ok := db.plugins[u.Plugin]; ok { 790 return "mysql_clear_password", nil 791 } 792 return u.Plugin, nil 793 } 794 795 // Salt implements the interface mysql.AuthServer. 796 func (db *MySQLDb) Salt() ([]byte, error) { 797 return mysql.NewSalt() 798 } 799 800 // ValidateHash implements the interface mysql.AuthServer. This is called when the method used is "mysql_native_password". 801 func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, addr net.Addr) (mysql.Getter, error) { 802 var host string 803 var err error 804 if addr.Network() == "unix" { 805 host = "localhost" 806 } else { 807 host, _, err = net.SplitHostPort(addr.String()) 808 if err != nil { 809 if err.(*net.AddrError).Err == "missing port in address" { 810 host = addr.String() 811 } else { 812 return nil, err 813 } 814 } 815 } 816 817 rd := db.Reader() 818 defer rd.Close() 819 820 if !db.Enabled() { 821 return MysqlConnectionUser{User: user, Host: host}, nil 822 } 823 824 userEntry := db.GetUser(rd, user, host, false) 825 if userEntry == nil || userEntry.Locked { 826 return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 827 } 828 if len(userEntry.Password) > 0 { 829 if !validateMysqlNativePassword(authResponse, salt, userEntry.Password) { 830 return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 831 } 832 } else if len(authResponse) > 0 { // password is nil or empty, therefore no password is set 833 // a password was given and the account has no password set, therefore access is denied 834 return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 835 } 836 837 return MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil 838 } 839 840 // Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password". 841 func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.Getter, error) { 842 var host string 843 var err error 844 if addr.Network() == "unix" { 845 host = "localhost" 846 } else { 847 host, _, err = net.SplitHostPort(addr.String()) 848 if err != nil { 849 if err.(*net.AddrError).Err == "missing port in address" { 850 host = addr.String() 851 } else { 852 return nil, err 853 } 854 } 855 } 856 857 rd := db.Reader() 858 defer rd.Close() 859 860 connUser := MysqlConnectionUser{User: user, Host: host} 861 if !db.Enabled() { 862 return connUser, nil 863 } 864 userEntry := db.GetUser(rd, user, host, false) 865 866 if userEntry.Plugin != "" { 867 authplugin, ok := db.plugins[userEntry.Plugin] 868 if !ok { 869 return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'; auth plugin %s not registered with server", user, userEntry.Plugin) 870 } 871 pass, err := mysql.AuthServerReadPacketString(c) 872 if err != nil { 873 return nil, err 874 } 875 authed, err := authplugin.Authenticate(db, user, userEntry, pass) 876 if err != nil { 877 return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v': %v", user, err) 878 } 879 if !authed { 880 return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 881 } 882 return connUser, nil 883 } 884 return nil, fmt.Errorf(`the only user login interface currently supported is "mysql_native_password"`) 885 } 886 887 // Persist passes along all changes to the integrator. 888 // 889 // This takes an Editor, instead of a Reader, since presumably we have just 890 // done a write. In any case, it's nice to not ACK a write until it is 891 // persisted, and the write lock which the Editor takes can help with not 892 // making these changes visible until it is persisted as well. 893 func (db *MySQLDb) Persist(ctx *sql.Context, ed *Editor) error { 894 // Extract all user entries from table, and sort 895 var users []*User 896 var superUsers []*User 897 ed.VisitUsers(func(u *User) { 898 if !u.IsSuperUser { 899 users = append(users, u) 900 } else { 901 superUsers = append(superUsers, u) 902 } 903 }) 904 sort.Slice(users, func(i, j int) bool { 905 if users[i].Host == users[j].Host { 906 return users[i].User < users[j].User 907 } 908 return users[i].Host < users[j].Host 909 }) 910 sort.Slice(superUsers, func(i, j int) bool { 911 if superUsers[i].Host == superUsers[j].Host { 912 return superUsers[i].User < superUsers[j].User 913 } 914 return superUsers[i].Host < superUsers[j].Host 915 }) 916 917 // Extract all role entries from table, and sort 918 var roles []*RoleEdge 919 ed.VisitRoleEdges(func(v *RoleEdge) { 920 roles = append(roles, v) 921 }) 922 sort.Slice(roles, func(i, j int) bool { 923 if roles[i].FromHost == roles[j].FromHost { 924 if roles[i].FromUser == roles[j].FromUser { 925 if roles[i].ToHost == roles[j].ToHost { 926 return roles[i].ToUser < roles[j].ToUser 927 } 928 return roles[i].ToHost < roles[j].ToHost 929 } 930 return roles[i].FromUser < roles[j].FromUser 931 } 932 return roles[i].FromHost < roles[j].FromHost 933 }) 934 935 // Extract all replica source info entries from table, and sort 936 var replicaSourceInfos []*ReplicaSourceInfo 937 ed.VisitReplicaSourceInfos(func(v *ReplicaSourceInfo) { 938 replicaSourceInfos = append(replicaSourceInfos, v) 939 }) 940 sort.Slice(replicaSourceInfos, func(i, j int) bool { 941 if replicaSourceInfos[i].Host == replicaSourceInfos[j].Host { 942 if replicaSourceInfos[i].Port == replicaSourceInfos[j].Port { 943 return replicaSourceInfos[i].User < replicaSourceInfos[j].User 944 } 945 return replicaSourceInfos[i].Port < replicaSourceInfos[j].Port 946 } 947 return replicaSourceInfos[i].Host < replicaSourceInfos[j].Host 948 }) 949 950 // TODO: serialize other tables when they exist 951 952 // Create flatbuffer 953 b := flatbuffers.NewBuilder(0) 954 user := serializeUser(b, users) 955 roleEdge := serializeRoleEdge(b, roles) 956 replicaSourceInfo := serializeReplicaSourceInfo(b, replicaSourceInfos) 957 superUser := serializeUser(b, superUsers) 958 959 // Write MySQL DB 960 serial.MySQLDbStart(b) 961 serial.MySQLDbAddUser(b, user) 962 serial.MySQLDbAddRoleEdges(b, roleEdge) 963 serial.MySQLDbAddReplicaSourceInfo(b, replicaSourceInfo) 964 serial.MySQLDbAddSuperUser(b, superUser) 965 mysqlDbOffset := serial.MySQLDbEnd(b) 966 967 // Finish writing 968 b.Finish(mysqlDbOffset) 969 970 // Persist data 971 return db.persister.Persist(ctx, b.FinishedBytes()) 972 } 973 974 // columnTemplate takes in a column as a template, and returns a new column with a different name based on the given 975 // template. 976 func columnTemplate(name string, source string, isPk bool, template *sql.Column) *sql.Column { 977 newCol := *template 978 if newCol.Default != nil { 979 newCol.Default = &(*newCol.Default) 980 } 981 newCol.Name = name 982 newCol.Source = source 983 newCol.PrimaryKey = isPk 984 return &newCol 985 } 986 987 // validateMysqlNativePassword was taken directly from vitess and validates the password hash for "mysql_native_password". 988 func validateMysqlNativePassword(authResponse, salt []byte, mysqlNativePassword string) bool { 989 // SERVER: recv(authResponse) 990 // hash_stage1=xor(authResponse, sha1(salt,hash)) 991 // candidate_hash2=sha1(hash_stage1) 992 // check(candidate_hash2==hash) 993 if len(authResponse) == 0 || len(mysqlNativePassword) == 0 { 994 return false 995 } 996 if mysqlNativePassword[0] == '*' { 997 mysqlNativePassword = mysqlNativePassword[1:] 998 } 999 1000 hash, err := hex.DecodeString(mysqlNativePassword) 1001 if err != nil { 1002 return false 1003 } 1004 1005 // scramble = SHA1(salt+hash) 1006 crypt := sha1.New() 1007 crypt.Write(salt) 1008 crypt.Write(hash) 1009 scramble := crypt.Sum(nil) 1010 1011 // token = scramble XOR stage1Hash 1012 for i := range scramble { 1013 scramble[i] ^= authResponse[i] 1014 } 1015 stage1Hash := scramble 1016 crypt.Reset() 1017 crypt.Write(stage1Hash) 1018 candidateHash2 := crypt.Sum(nil) 1019 1020 return bytes.Equal(candidateHash2, hash) 1021 } 1022 1023 // mustDefault enforces that no error occurred when constructing the column default value. 1024 func mustDefault(expr sql.Expression, outType sql.Type, representsLiteral bool, mayReturnNil bool) *sql.ColumnDefaultValue { 1025 colDef, err := sql.NewColumnDefaultValue(expr, outType, representsLiteral, !representsLiteral, mayReturnNil) 1026 if err != nil { 1027 panic(err) 1028 } 1029 return colDef 1030 } 1031 1032 type dummyPartition struct{} 1033 1034 var _ sql.Partition = dummyPartition{} 1035 1036 // Key implements the interface sql.Partition. 1037 func (d dummyPartition) Key() []byte { 1038 return nil 1039 }