github.com/dolthub/go-mysql-server@v0.18.0/sql/mysql_db/mysql_db.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql_db
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/sha1"
    20  	"encoding/hex"
    21  	"encoding/json"
    22  	"fmt"
    23  	"net"
    24  	"sort"
    25  	"strings"
    26  	"sync"
    27  	"sync/atomic"
    28  
    29  	flatbuffers "github.com/dolthub/flatbuffers/v23/go"
    30  	"github.com/dolthub/vitess/go/mysql"
    31  
    32  	"github.com/dolthub/go-mysql-server/sql"
    33  	"github.com/dolthub/go-mysql-server/sql/in_mem_table"
    34  	"github.com/dolthub/go-mysql-server/sql/mysql_db/serial"
    35  )
    36  
    37  // MySQLDbPersistence is used to determine the behavior of how certain tables in MySQLDb will be persisted.
    38  type MySQLDbPersistence interface {
    39  	Persist(ctx *sql.Context, data []byte) error
    40  }
    41  
    42  // NoopPersister is used when nothing in mysql db should be persisted
    43  type NoopPersister struct{}
    44  
    45  var _ MySQLDbPersistence = &NoopPersister{}
    46  
    47  // Persist implements the MySQLDbPersistence interface
    48  func (p *NoopPersister) Persist(ctx *sql.Context, data []byte) error {
    49  	return nil
    50  }
    51  
    52  type PlaintextAuthPlugin interface {
    53  	Authenticate(db *MySQLDb, user string, userEntry *User, pass string) (bool, error)
    54  }
    55  
    56  // MySQLDb are the collection of tables that are in the MySQL database
    57  type MySQLDb struct {
    58  	enabled atomic.Bool
    59  
    60  	user                *in_mem_table.IndexedSetTable[*User]
    61  	role_edges          *in_mem_table.IndexedSetTable[*RoleEdge]
    62  	replica_source_info *in_mem_table.IndexedSetTable[*ReplicaSourceInfo]
    63  
    64  	help_topic    *mysqlTable
    65  	help_keyword  *mysqlTable
    66  	help_category *mysqlTable
    67  	help_relation *mysqlTable
    68  
    69  	db            *in_mem_table.MultiIndexedSetTable[*User]
    70  	tables_priv   *in_mem_table.MultiIndexedSetTable[*User]
    71  	procs_priv    *in_mem_table.MultiIndexedSetTable[*User]
    72  	global_grants *in_mem_table.MultiIndexedSetTable[*User]
    73  
    74  	//TODO: add the rest of these tables
    75  	//columns_priv     *mysqlTable
    76  	//proxies_priv     *mysqlTable
    77  	//default_roles    *mysqlTable
    78  	//password_history *mysqlTable
    79  
    80  	persister MySQLDbPersistence
    81  	plugins   map[string]PlaintextAuthPlugin
    82  
    83  	lock          sync.RWMutex
    84  	updateCounter atomic.Uint64
    85  }
    86  
    87  var _ sql.Database = (*MySQLDb)(nil)
    88  var _ mysql.AuthServer = (*MySQLDb)(nil)
    89  
    90  // CreateEmptyMySQLDb returns a collection of MySQL Tables that do not contain any data.
    91  func CreateEmptyMySQLDb() *MySQLDb {
    92  	// original tables
    93  	mysqlDb := &MySQLDb{}
    94  
    95  	lock, rlock := &mysqlDb.lock, mysqlDb.lock.RLocker()
    96  
    97  	userSet, userTable := NewUserIndexedSetTable(lock, rlock)
    98  	mysqlDb.user = userTable
    99  	mysqlDb.role_edges = NewRoleEdgesIndexedSetTable(lock, rlock)
   100  	mysqlDb.replica_source_info = NewReplicaSourceInfoIndexedSetTable(lock, rlock)
   101  
   102  	// Help tables
   103  	mysqlDb.help_topic = newEmptyMySQLTable(
   104  		"help_topic",
   105  		helpTopicSchema,
   106  		mysqlDb)
   107  	mysqlDb.help_keyword = newEmptyMySQLTable(
   108  		"help_keyword",
   109  		helpKeywordSchema,
   110  		mysqlDb)
   111  	mysqlDb.help_category = newEmptyMySQLTable(
   112  		"help_category",
   113  		helpCategorySchema,
   114  		mysqlDb)
   115  	mysqlDb.help_relation = newEmptyMySQLTable(
   116  		"help_relation",
   117  		helpRelationSchema,
   118  		mysqlDb)
   119  
   120  	// multi tables
   121  	mysqlDb.db = NewUserDBIndexedSetTable(userSet, lock, rlock)
   122  	mysqlDb.tables_priv = NewUserTablesIndexedSetTable(userSet, lock, rlock)
   123  	mysqlDb.procs_priv = NewUserProcsIndexedSetTable(userSet, lock, rlock)
   124  	mysqlDb.global_grants = NewUserGlobalGrantsIndexedSetTable(userSet, lock, rlock)
   125  
   126  	// Start the counter at 1, all new sessions will start at zero so this forces an update for any new session
   127  	mysqlDb.updateCounter.Store(1)
   128  
   129  	return mysqlDb
   130  }
   131  
   132  type Reader struct {
   133  	users             in_mem_table.IndexedSet[*User]
   134  	roleEdges         in_mem_table.IndexedSet[*RoleEdge]
   135  	replicaSourceInfo in_mem_table.IndexedSet[*ReplicaSourceInfo]
   136  
   137  	close func()
   138  }
   139  
   140  type UserFetcher interface {
   141  	GetUser(u UserPrimaryKey) (res *User, ok bool)
   142  	GetUsersByUsername(username string) []*User
   143  }
   144  
   145  func (r *Reader) GetReplicaSourceInfo(k ReplicaSourceInfoPrimaryKey) (res *ReplicaSourceInfo, ok bool) {
   146  	sources := r.replicaSourceInfo.GetMany(ReplicaSourceInfoPrimaryKeyer{}, k)
   147  	if len(sources) > 1 {
   148  		panic("too many matching replica sources")
   149  	}
   150  	if len(sources) > 0 {
   151  		res = sources[0]
   152  		ok = true
   153  	}
   154  	return
   155  }
   156  
   157  func (r *Reader) GetUser(u UserPrimaryKey) (res *User, ok bool) {
   158  	users := r.users.GetMany(UserPrimaryKeyer{}, u)
   159  	if len(users) > 1 {
   160  		panic("too many matching users")
   161  	}
   162  	if len(users) > 0 {
   163  		res = users[0]
   164  		ok = true
   165  	}
   166  	return
   167  }
   168  
   169  func (r *Reader) GetUsersByUsername(username string) []*User {
   170  	return r.users.GetMany(UserSecondaryKeyer{}, UserSecondaryKey{
   171  		User: username,
   172  	})
   173  }
   174  
   175  func (r *Reader) GetToUserRoleEdges(key RoleEdgesToKey) []*RoleEdge {
   176  	return r.roleEdges.GetMany(RoleEdgeToKeyer{}, key)
   177  }
   178  
   179  func (r *Reader) VisitUsers(cb func(*User)) {
   180  	r.users.VisitEntries(cb)
   181  }
   182  
   183  func (r *Reader) VisitRoleEdges(cb func(*RoleEdge)) {
   184  	r.roleEdges.VisitEntries(cb)
   185  }
   186  
   187  func (r *Reader) VisitReplicaSourceInfos(cb func(*ReplicaSourceInfo)) {
   188  	r.replicaSourceInfo.VisitEntries(cb)
   189  }
   190  
   191  func (r *Reader) Close() {
   192  	if r.close != nil {
   193  		r.close()
   194  		r.close = nil
   195  	}
   196  }
   197  
   198  type Editor struct {
   199  	db     *MySQLDb
   200  	reader *Reader
   201  }
   202  
   203  func (ed *Editor) GetReplicaSourceInfo(k ReplicaSourceInfoPrimaryKey) (res *ReplicaSourceInfo, ok bool) {
   204  	sources := ed.reader.replicaSourceInfo.GetMany(ReplicaSourceInfoPrimaryKeyer{}, k)
   205  	if len(sources) > 1 {
   206  		panic("too many matching replica sources")
   207  	}
   208  	if len(sources) > 0 {
   209  		res = sources[0]
   210  		ok = true
   211  	}
   212  	return
   213  }
   214  
   215  func (ed *Editor) GetUsersByUsername(username string) []*User {
   216  	return ed.reader.GetUsersByUsername(username)
   217  }
   218  
   219  func (ed *Editor) GetUser(u UserPrimaryKey) (res *User, ok bool) {
   220  	return ed.reader.GetUser(u)
   221  }
   222  
   223  func (ed *Editor) GetToUserRoleEdges(key RoleEdgesToKey) []*RoleEdge {
   224  	return ed.reader.GetToUserRoleEdges(key)
   225  }
   226  
   227  func (ed *Editor) VisitUsers(cb func(*User)) {
   228  	ed.reader.VisitUsers(cb)
   229  }
   230  
   231  func (ed *Editor) VisitRoleEdges(cb func(*RoleEdge)) {
   232  	ed.reader.VisitRoleEdges(cb)
   233  }
   234  
   235  func (ed *Editor) VisitReplicaSourceInfos(cb func(*ReplicaSourceInfo)) {
   236  	ed.reader.VisitReplicaSourceInfos(cb)
   237  }
   238  
   239  func (ed *Editor) PutUser(u *User) {
   240  	if old, ok := ed.reader.users.Get(u); ok {
   241  		ed.reader.users.Remove(old)
   242  	}
   243  	ed.reader.users.Put(u)
   244  }
   245  
   246  func (ed *Editor) RemoveUser(pk UserPrimaryKey) {
   247  	ed.reader.users.RemoveMany(UserPrimaryKeyer{}, pk)
   248  }
   249  
   250  func (ed *Editor) PutRoleEdge(re *RoleEdge) {
   251  	if old, ok := ed.reader.roleEdges.Get(re); ok {
   252  		ed.reader.roleEdges.Remove(old)
   253  	}
   254  	ed.reader.roleEdges.Put(re)
   255  }
   256  
   257  func (ed *Editor) RemoveRoleEdge(pk RoleEdgesPrimaryKey) {
   258  	ed.reader.roleEdges.RemoveMany(RoleEdgePrimaryKeyer{}, pk)
   259  }
   260  
   261  func (ed *Editor) RemoveRoleEdgesFromKey(key RoleEdgesFromKey) {
   262  	ed.reader.roleEdges.RemoveMany(RoleEdgeFromKeyer{}, key)
   263  }
   264  
   265  func (ed *Editor) RemoveRoleEdgesToKey(key RoleEdgesToKey) {
   266  	ed.reader.roleEdges.RemoveMany(RoleEdgeToKeyer{}, key)
   267  }
   268  
   269  func (ed *Editor) RemoveReplicaSourceInfo(k ReplicaSourceInfoPrimaryKey) {
   270  	ed.reader.replicaSourceInfo.RemoveMany(ReplicaSourceInfoPrimaryKeyer{}, k)
   271  }
   272  
   273  func (ed *Editor) PutReplicaSourceInfo(rsi *ReplicaSourceInfo) {
   274  	if old, ok := ed.reader.replicaSourceInfo.Get(rsi); ok {
   275  		ed.reader.replicaSourceInfo.Remove(old)
   276  	}
   277  	ed.reader.replicaSourceInfo.Put(rsi)
   278  }
   279  
   280  func (ed *Editor) Close() {
   281  	ed.db.updateCounter.Add(1)
   282  	ed.reader.Close()
   283  	ed.db.lock.Unlock()
   284  }
   285  
   286  func (db *MySQLDb) unlockedReader() *Reader {
   287  	return &Reader{
   288  		db.user.Set(),
   289  		db.role_edges.Set(),
   290  		db.replica_source_info.Set(),
   291  		nil,
   292  	}
   293  }
   294  
   295  func (db *MySQLDb) Reader() *Reader {
   296  	db.lock.RLock()
   297  	return &Reader{
   298  		db.user.Set(),
   299  		db.role_edges.Set(),
   300  		db.replica_source_info.Set(),
   301  		func() {
   302  			db.lock.RUnlock()
   303  		},
   304  	}
   305  }
   306  
   307  func (db *MySQLDb) Editor() *Editor {
   308  	db.lock.Lock()
   309  	return &Editor{
   310  		db,
   311  		db.unlockedReader(),
   312  	}
   313  }
   314  
   315  func (db *MySQLDb) Enabled() bool {
   316  	return db.enabled.Load()
   317  }
   318  
   319  func (db *MySQLDb) SetEnabled(v bool) {
   320  	db.enabled.Store(v)
   321  }
   322  
   323  // LoadPrivilegeData adds the given data to the MySQL Tables. It does not remove any current data, but will overwrite any
   324  // pre-existing data. This has been deprecated in favor of LoadData.
   325  func (db *MySQLDb) LoadPrivilegeData(ctx *sql.Context, users []*User, roleConnections []*RoleEdge) error {
   326  	db.SetEnabled(true)
   327  
   328  	ed := db.Editor()
   329  	defer ed.Close()
   330  
   331  	for _, user := range users {
   332  		if user == nil {
   333  			continue
   334  		}
   335  		ed.PutUser(user)
   336  	}
   337  
   338  	for _, role := range roleConnections {
   339  		if role == nil {
   340  			continue
   341  		}
   342  		ed.PutRoleEdge(role)
   343  	}
   344  
   345  	return nil
   346  }
   347  
   348  // LoadData adds the given data to the MySQL Tables. It does not remove any current data, but will overwrite any
   349  // pre-existing data.
   350  func (db *MySQLDb) LoadData(ctx *sql.Context, buf []byte) (err error) {
   351  	// Do nothing if data file doesn't exist or is empty
   352  	if buf == nil || len(buf) == 0 {
   353  		return nil
   354  	}
   355  
   356  	type privDataJson struct {
   357  		Users []*User
   358  		Roles []*RoleEdge
   359  	}
   360  
   361  	// if it's a json file, read it; will be rewritten as flatbuffer later
   362  	data := &privDataJson{}
   363  	if err := json.Unmarshal(buf, data); err == nil {
   364  		return db.LoadPrivilegeData(ctx, data.Users, data.Roles)
   365  	}
   366  
   367  	// Indicate that mysql db exists
   368  	db.SetEnabled(true)
   369  
   370  	// Recover from panics
   371  	defer func() {
   372  		if recover() != nil {
   373  			err = fmt.Errorf("ill formatted privileges file")
   374  		}
   375  	}()
   376  
   377  	// Deserialize the flatbuffer
   378  	serialMySQLDb := serial.GetRootAsMySQLDb(buf, 0)
   379  
   380  	ed := db.Editor()
   381  	defer ed.Close()
   382  
   383  	// Fill in user table
   384  	for i := 0; i < serialMySQLDb.UserLength(); i++ {
   385  		serialUser := new(serial.User)
   386  		if !serialMySQLDb.User(serialUser, i) {
   387  			continue
   388  		}
   389  		user := LoadUser(serialUser)
   390  		ed.PutUser(user)
   391  	}
   392  
   393  	// Fill in Roles table
   394  	for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ {
   395  		serialRoleEdge := new(serial.RoleEdge)
   396  		if !serialMySQLDb.RoleEdges(serialRoleEdge, i) {
   397  			continue
   398  		}
   399  		role := LoadRoleEdge(serialRoleEdge)
   400  		ed.PutRoleEdge(role)
   401  	}
   402  
   403  	// Fill in the ReplicaSourceInfo table
   404  	for i := 0; i < serialMySQLDb.ReplicaSourceInfoLength(); i++ {
   405  		serialReplicaSourceInfo := new(serial.ReplicaSourceInfo)
   406  		if !serialMySQLDb.ReplicaSourceInfo(serialReplicaSourceInfo, i) {
   407  			continue
   408  		}
   409  		replicaSourceInfo := LoadReplicaSourceInfo(serialReplicaSourceInfo)
   410  		ed.PutReplicaSourceInfo(replicaSourceInfo)
   411  	}
   412  
   413  	// TODO: fill in other tables when they exist
   414  	return
   415  }
   416  
   417  // OverwriteUsersAndGrantData replaces the users and grant data served by this
   418  // MySQL DB instance with the data which is present in the provided byte
   419  // buffer, which is a persisted copy of a MySQLDb created with `Persist`. In
   420  // contrast to LoadData, it *does* remove current data in the database.
   421  //
   422  // This interface is appropriate for replication, when a replica needs to be
   423  // brought up to date with a primary server.
   424  //
   425  // This method does not support the legacy JSON serialization of users and
   426  // grant data. In contrast to most methods which operate with persisted users
   427  // and grants in *MySQLDb, this method _does_ restore persisted super users.
   428  func (db *MySQLDb) OverwriteUsersAndGrantData(ctx *sql.Context, ed *Editor, buf []byte) (err error) {
   429  	// Recover from panics
   430  	defer func() {
   431  		if recover() != nil {
   432  			err = fmt.Errorf("ill formatted privileges file")
   433  		}
   434  	}()
   435  
   436  	// Deserialize the flatbuffer
   437  	serialMySQLDb := serial.GetRootAsMySQLDb(buf, 0)
   438  
   439  	// In order to make certain we can read the entire serialized message,
   440  	// we load it fully into *User and *RoleEdge instances before we mutate
   441  	// our maps at all.
   442  	var users []*User
   443  	var edges []*RoleEdge
   444  
   445  	// Load all users
   446  	for i := 0; i < serialMySQLDb.UserLength(); i++ {
   447  		serialUser := new(serial.User)
   448  		if !serialMySQLDb.User(serialUser, i) {
   449  			continue
   450  		}
   451  		users = append(users, LoadUser(serialUser))
   452  	}
   453  	for i := 0; i < serialMySQLDb.SuperUserLength(); i++ {
   454  		serialUser := new(serial.User)
   455  		if !serialMySQLDb.SuperUser(serialUser, i) {
   456  			continue
   457  		}
   458  		user := LoadUser(serialUser)
   459  		user.IsSuperUser = true
   460  		users = append(users, user)
   461  	}
   462  
   463  	// Load all role edges
   464  	for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ {
   465  		serialRoleEdge := new(serial.RoleEdge)
   466  		if !serialMySQLDb.RoleEdges(serialRoleEdge, i) {
   467  			continue
   468  		}
   469  		edges = append(edges, LoadRoleEdge(serialRoleEdge))
   470  	}
   471  
   472  	ed.reader.users.Clear()
   473  	ed.reader.roleEdges.Clear()
   474  	for _, u := range users {
   475  		ed.PutUser(u)
   476  	}
   477  	for _, e := range edges {
   478  		ed.PutRoleEdge(e)
   479  	}
   480  
   481  	return
   482  }
   483  
   484  // SetPersister sets the custom persister to be used when the MySQL Db tables have been updated and need to be persisted.
   485  func (db *MySQLDb) SetPersister(persister MySQLDbPersistence) {
   486  	db.persister = persister
   487  }
   488  
   489  func (db *MySQLDb) SetPlugins(plugins map[string]PlaintextAuthPlugin) {
   490  	db.plugins = plugins
   491  }
   492  
   493  func (db *MySQLDb) VerifyPlugin(plugin string) error {
   494  	_, ok := db.plugins[plugin]
   495  	if ok {
   496  		return nil
   497  	}
   498  	return fmt.Errorf(`must provide authentication plugin for unsupported authentication format`)
   499  }
   500  
   501  // AddRootAccount adds the root account to the list of accounts.
   502  func (db *MySQLDb) AddRootAccount() {
   503  	ed := db.Editor()
   504  	defer ed.Close()
   505  	db.AddSuperUser(ed, "root", "localhost", "")
   506  }
   507  
   508  // AddSuperUser adds the given username and password to the list of accounts. This is a temporary function, which is
   509  // meant to replace the "auth.New..." functions while the remaining functions are added.
   510  func (db *MySQLDb) AddSuperUser(ed *Editor, username string, host string, password string) {
   511  	//TODO: remove this function and the called function
   512  	db.SetEnabled(true)
   513  	if len(password) > 0 {
   514  		hash := sha1.New()
   515  		hash.Write([]byte(password))
   516  		s1 := hash.Sum(nil)
   517  		hash.Reset()
   518  		hash.Write(s1)
   519  		s2 := hash.Sum(nil)
   520  		password = "*" + strings.ToUpper(hex.EncodeToString(s2))
   521  	}
   522  
   523  	if _, ok := ed.GetUser(UserPrimaryKey{
   524  		Host: host,
   525  		User: username,
   526  	}); !ok {
   527  		addSuperUser(ed, username, host, password)
   528  	}
   529  }
   530  
   531  // GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and
   532  // roles, roleSearch changes whether the search matches against user or role rules.
   533  func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSearch bool) *User {
   534  	//TODO: Determine what the localhost is on the machine, then handle the conversion between IP and localhost.
   535  	// For now, loopback addresses are treated as localhost.
   536  	//TODO: Determine how to match anonymous roles (roles with an empty user string), which differs from users
   537  	//TODO: Treat '%' as a proper wildcard for hostnames, allowing for regex-like matches.
   538  	// Hostnames representing an IP address that have a wildcard have additional restrictions on what may match
   539  	//TODO: Match non-existent users to the most relevant anonymous user if multiple exist (''@'localhost' vs ''@'%')
   540  	// It appears that ''@'localhost' can use the privileges set on ''@'%', which seems to be unique behavior.
   541  	// For example, 'abc'@'localhost' CANNOT use any privileges set on 'abc'@'%'.
   542  	// Unknown if this is special for ''@'%', or applies to any matching anonymous user.
   543  	//TODO: Hostnames representing IPs can use masks, such as 'abc'@'54.244.85.0/255.255.255.0'
   544  	//TODO: Allow for CIDR notation in hostnames
   545  	//TODO: Which user do we choose when multiple host names match (e.g. host name with most characters matched, etc.)
   546  
   547  	if "127.0.0.1" == host || "::1" == host {
   548  		host = "localhost"
   549  	}
   550  
   551  	if user, ok := fetcher.GetUser(UserPrimaryKey{
   552  		Host: host,
   553  		User: user,
   554  	}); ok {
   555  		return user
   556  	}
   557  
   558  	// First we check for matches on the same user, then we try the anonymous user
   559  	for _, targetUser := range []string{user, ""} {
   560  		users := fetcher.GetUsersByUsername(targetUser)
   561  		for _, user := range users {
   562  			//TODO: use the most specific match first, using "%" only if there isn't a more specific match
   563  			if host == user.Host ||
   564  				(host == "localhost" && user.Host == "::1") ||
   565  				(host == "localhost" && user.Host == "127.0.0.1") ||
   566  				(user.Host == "%" && (!roleSearch || host == "")) {
   567  				return user
   568  			}
   569  		}
   570  	}
   571  	return nil
   572  }
   573  
   574  // UserActivePrivilegeSet fetches the User, and returns their entire active privilege set. This takes into account the
   575  // active roles, which are set in the context, therefore the user is also pulled from the context.
   576  func (db *MySQLDb) UserActivePrivilegeSet(ctx *sql.Context) PrivilegeSet {
   577  	if privSet, counter := ctx.Session.GetPrivilegeSet(); db.updateCounter.Load() == counter {
   578  		// If the counters are equal, we can guarantee that the privilege set exists and is valid
   579  		return privSet.(PrivilegeSet)
   580  	}
   581  
   582  	rd := db.Reader()
   583  	defer rd.Close()
   584  
   585  	client := ctx.Session.Client()
   586  	user := db.GetUser(rd, client.User, client.Address, false)
   587  	if user == nil {
   588  		return NewPrivilegeSet()
   589  	}
   590  
   591  	privSet := user.PrivilegeSet.Copy()
   592  	roleEdgeEntries := rd.GetToUserRoleEdges(RoleEdgesToKey{
   593  		ToHost: user.Host,
   594  		ToUser: user.User,
   595  	})
   596  	//TODO: filter the active roles using the context, rather than using every granted roles
   597  	//TODO: System variable "activate_all_roles_on_login", if set, will set all roles as active upon logging in
   598  	for _, roleEdgeEntry := range roleEdgeEntries {
   599  		roleEdge := roleEdgeEntry
   600  		role := db.GetUser(rd, roleEdge.FromUser, roleEdge.FromHost, true)
   601  		if role != nil {
   602  			privSet.UnionWith(role.PrivilegeSet)
   603  		}
   604  	}
   605  
   606  	ctx.Session.SetPrivilegeSet(privSet, db.updateCounter.Load())
   607  	return privSet
   608  }
   609  
   610  // RoutineAdminCheck fetches the User from the context, and specifically evaluates, the permission check
   611  // assuming the operation is for a stored procedure or function. This allows us to have more fine grain control over
   612  // permissions for stored procedures (many of which are critical to Dolt). This method specifically checks exists
   613  // for the use of AdminOnly procedures which require more fine-grained access control. For procedures which are
   614  // not AdminOnly, then |UserHasPrivileges| should be used instead.
   615  func (db *MySQLDb) RoutineAdminCheck(ctx *sql.Context, operations ...sql.PrivilegedOperation) bool {
   616  	privSet := db.UserActivePrivilegeSet(ctx)
   617  
   618  	if privSet.Has(sql.PrivilegeType_Super) {
   619  		// Superpowers allow you to fly and look through walls, surely you can execute whatever you want.
   620  		return true
   621  	}
   622  
   623  	for _, operation := range operations {
   624  		for _, operationPriv := range operation.StaticPrivileges {
   625  			database := operation.Database
   626  			if database == "" {
   627  				database = ctx.GetCurrentDatabase()
   628  			}
   629  			dbSet := privSet.Database(database)
   630  			routineSet := dbSet.Routine(operation.Routine, operation.IsProcedure)
   631  			if routineSet.Has(operationPriv) {
   632  				continue
   633  			}
   634  
   635  			// User does not have permission to perform the operation.
   636  			return false
   637  		}
   638  	}
   639  	return true
   640  }
   641  
   642  // UserHasPrivileges fetches the User, and returns whether they have the desired privileges necessary to perform the
   643  // privileged operation(s). This takes into account the active roles, which are set in the context, therefore both
   644  // the user and the active roles are pulled from the context. This method is sufficient for all MySQL behaviors.
   645  // The one exception, currently, is for stored procedures and functions, which have a more fine-grained permission
   646  // due to Dolt's use of the AdminOnly flag in procedure definitions.
   647  //
   648  // This functions implements the global/database/table|routine hierarchy of permissions. If a user has Execute permissions
   649  // on the database, then they implicitly have that same permission on all tables and routines in that database. This
   650  // is how all MySQL permissions work.
   651  func (db *MySQLDb) UserHasPrivileges(ctx *sql.Context, operations ...sql.PrivilegedOperation) bool {
   652  	privSet := db.UserActivePrivilegeSet(ctx)
   653  	// Super users have all privileges, so if they have global super privs, then
   654  	// they have all dynamic privs and we don't need to check them.
   655  	if privSet.Has(sql.PrivilegeType_Super) {
   656  		return true
   657  	}
   658  
   659  	if !db.Enabled() {
   660  		return true
   661  	}
   662  	for _, operation := range operations {
   663  
   664  		for _, operationPriv := range operation.StaticPrivileges {
   665  			if privSet.Has(operationPriv) {
   666  				//TODO: Handle partial revokes
   667  				continue
   668  			}
   669  			database := operation.Database
   670  			if database == "" {
   671  				database = ctx.GetCurrentDatabase()
   672  			}
   673  			dbSet := privSet.Database(database)
   674  			if dbSet.Has(operationPriv) {
   675  				continue
   676  			}
   677  			tblSet := dbSet.Table(operation.Table)
   678  			if tblSet.Has(operationPriv) {
   679  				continue
   680  			}
   681  
   682  			// TODO: Complete the column check support.
   683  			// colSet := tblSet.Column(operation.Column)
   684  			// if colSet.Has(operationPriv) {
   685  			//  	continue
   686  			// }
   687  
   688  			routineSet := dbSet.Routine(operation.Routine, operation.IsProcedure)
   689  			if routineSet.Has(operationPriv) {
   690  				continue
   691  			}
   692  
   693  			// User does not have permission to perform the operation.
   694  			return false
   695  		}
   696  
   697  		for _, operationPriv := range operation.DynamicPrivileges {
   698  			if privSet.HasDynamic(operationPriv) {
   699  				continue
   700  			}
   701  
   702  			// Dynamic privileges are only allowed at a global scope, so no need to check
   703  			// for database, table, or column privileges.
   704  			return false
   705  		}
   706  	}
   707  	return true
   708  }
   709  
   710  // Name implements the interface sql.Database.
   711  func (db *MySQLDb) Name() string {
   712  	return "mysql"
   713  }
   714  
   715  // GetTableInsensitive implements the interface sql.Database.
   716  func (db *MySQLDb) GetTableInsensitive(_ *sql.Context, tblName string) (sql.Table, bool, error) {
   717  	switch strings.ToLower(tblName) {
   718  	case userTblName:
   719  		return db.user, true, nil
   720  	case roleEdgesTblName:
   721  		return db.role_edges, true, nil
   722  	case dbTblName:
   723  		return db.db, true, nil
   724  	case tablesPrivTblName:
   725  		return db.tables_priv, true, nil
   726  	case procsPrivTblName:
   727  		return db.procs_priv, true, nil
   728  	case replicaSourceInfoTblName:
   729  		return db.replica_source_info, true, nil
   730  	case helpTopicTableName:
   731  		return db.help_topic, true, nil
   732  	case helpKeywordTableName:
   733  		return db.help_keyword, true, nil
   734  	case helpCategoryTableName:
   735  		return db.help_category, true, nil
   736  	case helpRelationTableName:
   737  		return db.help_relation, true, nil
   738  	default:
   739  		return nil, false, nil
   740  	}
   741  }
   742  
   743  // GetTableNames implements the interface sql.Database.
   744  func (db *MySQLDb) GetTableNames(ctx *sql.Context) ([]string, error) {
   745  	return []string{
   746  		userTblName,
   747  		dbTblName,
   748  		tablesPrivTblName,
   749  		procsPrivTblName,
   750  		roleEdgesTblName,
   751  		replicaSourceInfoTblName,
   752  		helpTopicTableName,
   753  		helpKeywordTableName,
   754  		helpCategoryTableName,
   755  		helpRelationTableName,
   756  	}, nil
   757  }
   758  
   759  // AuthMethod implements the interface mysql.AuthServer.
   760  func (db *MySQLDb) AuthMethod(user, addr string) (string, error) {
   761  	if !db.Enabled() {
   762  		return "mysql_native_password", nil
   763  	}
   764  	var host string
   765  	// TODO : need to check for network type instead of addr string if it's unix socket network,
   766  	//  macOS passes empty addr, but ubuntu returns "@" as addr for `localhost`
   767  	if addr == "@" || addr == "" {
   768  		host = "localhost"
   769  	} else {
   770  		splitHost, _, err := net.SplitHostPort(addr)
   771  		if err != nil {
   772  			if err.(*net.AddrError).Err == "missing port in address" {
   773  				host = addr
   774  			} else {
   775  				return "", err
   776  			}
   777  		} else {
   778  			host = splitHost
   779  		}
   780  	}
   781  
   782  	rd := db.Reader()
   783  	defer rd.Close()
   784  
   785  	u := db.GetUser(rd, user, host, false)
   786  	if u == nil {
   787  		return "", mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "User not found '%v'", user)
   788  	}
   789  	if _, ok := db.plugins[u.Plugin]; ok {
   790  		return "mysql_clear_password", nil
   791  	}
   792  	return u.Plugin, nil
   793  }
   794  
   795  // Salt implements the interface mysql.AuthServer.
   796  func (db *MySQLDb) Salt() ([]byte, error) {
   797  	return mysql.NewSalt()
   798  }
   799  
   800  // ValidateHash implements the interface mysql.AuthServer. This is called when the method used is "mysql_native_password".
   801  func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, addr net.Addr) (mysql.Getter, error) {
   802  	var host string
   803  	var err error
   804  	if addr.Network() == "unix" {
   805  		host = "localhost"
   806  	} else {
   807  		host, _, err = net.SplitHostPort(addr.String())
   808  		if err != nil {
   809  			if err.(*net.AddrError).Err == "missing port in address" {
   810  				host = addr.String()
   811  			} else {
   812  				return nil, err
   813  			}
   814  		}
   815  	}
   816  
   817  	rd := db.Reader()
   818  	defer rd.Close()
   819  
   820  	if !db.Enabled() {
   821  		return MysqlConnectionUser{User: user, Host: host}, nil
   822  	}
   823  
   824  	userEntry := db.GetUser(rd, user, host, false)
   825  	if userEntry == nil || userEntry.Locked {
   826  		return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   827  	}
   828  	if len(userEntry.Password) > 0 {
   829  		if !validateMysqlNativePassword(authResponse, salt, userEntry.Password) {
   830  			return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   831  		}
   832  	} else if len(authResponse) > 0 { // password is nil or empty, therefore no password is set
   833  		// a password was given and the account has no password set, therefore access is denied
   834  		return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   835  	}
   836  
   837  	return MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil
   838  }
   839  
   840  // Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password".
   841  func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.Getter, error) {
   842  	var host string
   843  	var err error
   844  	if addr.Network() == "unix" {
   845  		host = "localhost"
   846  	} else {
   847  		host, _, err = net.SplitHostPort(addr.String())
   848  		if err != nil {
   849  			if err.(*net.AddrError).Err == "missing port in address" {
   850  				host = addr.String()
   851  			} else {
   852  				return nil, err
   853  			}
   854  		}
   855  	}
   856  
   857  	rd := db.Reader()
   858  	defer rd.Close()
   859  
   860  	connUser := MysqlConnectionUser{User: user, Host: host}
   861  	if !db.Enabled() {
   862  		return connUser, nil
   863  	}
   864  	userEntry := db.GetUser(rd, user, host, false)
   865  
   866  	if userEntry.Plugin != "" {
   867  		authplugin, ok := db.plugins[userEntry.Plugin]
   868  		if !ok {
   869  			return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'; auth plugin %s not registered with server", user, userEntry.Plugin)
   870  		}
   871  		pass, err := mysql.AuthServerReadPacketString(c)
   872  		if err != nil {
   873  			return nil, err
   874  		}
   875  		authed, err := authplugin.Authenticate(db, user, userEntry, pass)
   876  		if err != nil {
   877  			return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v': %v", user, err)
   878  		}
   879  		if !authed {
   880  			return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   881  		}
   882  		return connUser, nil
   883  	}
   884  	return nil, fmt.Errorf(`the only user login interface currently supported is "mysql_native_password"`)
   885  }
   886  
   887  // Persist passes along all changes to the integrator.
   888  //
   889  // This takes an Editor, instead of a Reader, since presumably we have just
   890  // done a write. In any case, it's nice to not ACK a write until it is
   891  // persisted, and the write lock which the Editor takes can help with not
   892  // making these changes visible until it is persisted as well.
   893  func (db *MySQLDb) Persist(ctx *sql.Context, ed *Editor) error {
   894  	// Extract all user entries from table, and sort
   895  	var users []*User
   896  	var superUsers []*User
   897  	ed.VisitUsers(func(u *User) {
   898  		if !u.IsSuperUser {
   899  			users = append(users, u)
   900  		} else {
   901  			superUsers = append(superUsers, u)
   902  		}
   903  	})
   904  	sort.Slice(users, func(i, j int) bool {
   905  		if users[i].Host == users[j].Host {
   906  			return users[i].User < users[j].User
   907  		}
   908  		return users[i].Host < users[j].Host
   909  	})
   910  	sort.Slice(superUsers, func(i, j int) bool {
   911  		if superUsers[i].Host == superUsers[j].Host {
   912  			return superUsers[i].User < superUsers[j].User
   913  		}
   914  		return superUsers[i].Host < superUsers[j].Host
   915  	})
   916  
   917  	// Extract all role entries from table, and sort
   918  	var roles []*RoleEdge
   919  	ed.VisitRoleEdges(func(v *RoleEdge) {
   920  		roles = append(roles, v)
   921  	})
   922  	sort.Slice(roles, func(i, j int) bool {
   923  		if roles[i].FromHost == roles[j].FromHost {
   924  			if roles[i].FromUser == roles[j].FromUser {
   925  				if roles[i].ToHost == roles[j].ToHost {
   926  					return roles[i].ToUser < roles[j].ToUser
   927  				}
   928  				return roles[i].ToHost < roles[j].ToHost
   929  			}
   930  			return roles[i].FromUser < roles[j].FromUser
   931  		}
   932  		return roles[i].FromHost < roles[j].FromHost
   933  	})
   934  
   935  	// Extract all replica source info entries from table, and sort
   936  	var replicaSourceInfos []*ReplicaSourceInfo
   937  	ed.VisitReplicaSourceInfos(func(v *ReplicaSourceInfo) {
   938  		replicaSourceInfos = append(replicaSourceInfos, v)
   939  	})
   940  	sort.Slice(replicaSourceInfos, func(i, j int) bool {
   941  		if replicaSourceInfos[i].Host == replicaSourceInfos[j].Host {
   942  			if replicaSourceInfos[i].Port == replicaSourceInfos[j].Port {
   943  				return replicaSourceInfos[i].User < replicaSourceInfos[j].User
   944  			}
   945  			return replicaSourceInfos[i].Port < replicaSourceInfos[j].Port
   946  		}
   947  		return replicaSourceInfos[i].Host < replicaSourceInfos[j].Host
   948  	})
   949  
   950  	// TODO: serialize other tables when they exist
   951  
   952  	// Create flatbuffer
   953  	b := flatbuffers.NewBuilder(0)
   954  	user := serializeUser(b, users)
   955  	roleEdge := serializeRoleEdge(b, roles)
   956  	replicaSourceInfo := serializeReplicaSourceInfo(b, replicaSourceInfos)
   957  	superUser := serializeUser(b, superUsers)
   958  
   959  	// Write MySQL DB
   960  	serial.MySQLDbStart(b)
   961  	serial.MySQLDbAddUser(b, user)
   962  	serial.MySQLDbAddRoleEdges(b, roleEdge)
   963  	serial.MySQLDbAddReplicaSourceInfo(b, replicaSourceInfo)
   964  	serial.MySQLDbAddSuperUser(b, superUser)
   965  	mysqlDbOffset := serial.MySQLDbEnd(b)
   966  
   967  	// Finish writing
   968  	b.Finish(mysqlDbOffset)
   969  
   970  	// Persist data
   971  	return db.persister.Persist(ctx, b.FinishedBytes())
   972  }
   973  
   974  // columnTemplate takes in a column as a template, and returns a new column with a different name based on the given
   975  // template.
   976  func columnTemplate(name string, source string, isPk bool, template *sql.Column) *sql.Column {
   977  	newCol := *template
   978  	if newCol.Default != nil {
   979  		newCol.Default = &(*newCol.Default)
   980  	}
   981  	newCol.Name = name
   982  	newCol.Source = source
   983  	newCol.PrimaryKey = isPk
   984  	return &newCol
   985  }
   986  
   987  // validateMysqlNativePassword was taken directly from vitess and validates the password hash for "mysql_native_password".
   988  func validateMysqlNativePassword(authResponse, salt []byte, mysqlNativePassword string) bool {
   989  	// SERVER: recv(authResponse)
   990  	// 		   hash_stage1=xor(authResponse, sha1(salt,hash))
   991  	// 		   candidate_hash2=sha1(hash_stage1)
   992  	// 		   check(candidate_hash2==hash)
   993  	if len(authResponse) == 0 || len(mysqlNativePassword) == 0 {
   994  		return false
   995  	}
   996  	if mysqlNativePassword[0] == '*' {
   997  		mysqlNativePassword = mysqlNativePassword[1:]
   998  	}
   999  
  1000  	hash, err := hex.DecodeString(mysqlNativePassword)
  1001  	if err != nil {
  1002  		return false
  1003  	}
  1004  
  1005  	// scramble = SHA1(salt+hash)
  1006  	crypt := sha1.New()
  1007  	crypt.Write(salt)
  1008  	crypt.Write(hash)
  1009  	scramble := crypt.Sum(nil)
  1010  
  1011  	// token = scramble XOR stage1Hash
  1012  	for i := range scramble {
  1013  		scramble[i] ^= authResponse[i]
  1014  	}
  1015  	stage1Hash := scramble
  1016  	crypt.Reset()
  1017  	crypt.Write(stage1Hash)
  1018  	candidateHash2 := crypt.Sum(nil)
  1019  
  1020  	return bytes.Equal(candidateHash2, hash)
  1021  }
  1022  
  1023  // mustDefault enforces that no error occurred when constructing the column default value.
  1024  func mustDefault(expr sql.Expression, outType sql.Type, representsLiteral bool, mayReturnNil bool) *sql.ColumnDefaultValue {
  1025  	colDef, err := sql.NewColumnDefaultValue(expr, outType, representsLiteral, !representsLiteral, mayReturnNil)
  1026  	if err != nil {
  1027  		panic(err)
  1028  	}
  1029  	return colDef
  1030  }
  1031  
  1032  type dummyPartition struct{}
  1033  
  1034  var _ sql.Partition = dummyPartition{}
  1035  
  1036  // Key implements the interface sql.Partition.
  1037  func (d dummyPartition) Key() []byte {
  1038  	return nil
  1039  }