github.com/dolthub/go-mysql-server@v0.18.0/sql/mysql_db/user.go (about)

     1  // Copyright 2021-2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql_db
    16  
    17  import (
    18  	"encoding/json"
    19  	"fmt"
    20  	"strings"
    21  	"time"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/in_mem_table"
    25  )
    26  
    27  // User represents a user from the user Grant Table.
    28  type User struct {
    29  	User                string
    30  	Host                string
    31  	PrivilegeSet        PrivilegeSet
    32  	Plugin              string
    33  	Password            string
    34  	PasswordLastChanged time.Time
    35  	Locked              bool
    36  	Attributes          *string
    37  	Identity            string
    38  	IsSuperUser         bool
    39  	//TODO: add the remaining fields
    40  
    41  	// IsRole is an additional field that states whether the User represents a role or user. In MySQL this must be a
    42  	// hidden column, therefore it's represented here as an additional field.
    43  	IsRole bool
    44  }
    45  
    46  func UserToRow(ctx *sql.Context, u *User) (sql.Row, error) {
    47  	row := make(sql.Row, len(userTblSchema))
    48  	var err error
    49  	for i, col := range userTblSchema {
    50  		row[i], err = col.Default.Eval(ctx, nil)
    51  		if err != nil {
    52  			panic(err) // Should never happen, schema is static
    53  		}
    54  	}
    55  	//TODO: once the remaining fields are added, fill those in as well
    56  	row[userTblColIndex_User] = u.User
    57  	row[userTblColIndex_Host] = u.Host
    58  	row[userTblColIndex_plugin] = u.Plugin
    59  	row[userTblColIndex_authentication_string] = u.Password
    60  	row[userTblColIndex_password_last_changed] = u.PasswordLastChanged
    61  	row[userTblColIndex_identity] = u.Identity
    62  	if u.Locked {
    63  		row[userTblColIndex_account_locked] = uint16(2)
    64  	}
    65  	if u.Attributes != nil {
    66  		row[userTblColIndex_User_attributes] = *u.Attributes
    67  	}
    68  	u.privSetToRow(ctx, row)
    69  	return row, nil
    70  }
    71  
    72  func UserFromRow(ctx *sql.Context, row sql.Row) (*User, error) {
    73  	if err := userTblSchema.CheckRow(row); err != nil {
    74  		return nil, err
    75  	}
    76  	//TODO: once the remaining fields are added, fill those in as well
    77  	var attributes *string
    78  	passwordLastChanged := time.Now().UTC()
    79  	if val, ok := row[userTblColIndex_User_attributes].(string); ok {
    80  		attributes = &val
    81  	}
    82  	if val, ok := row[userTblColIndex_password_last_changed].(time.Time); ok {
    83  		passwordLastChanged = val
    84  	}
    85  	return &User{
    86  		User:                row[userTblColIndex_User].(string),
    87  		Host:                row[userTblColIndex_Host].(string),
    88  		PrivilegeSet:        UserRowToPrivSet(ctx, row),
    89  		Plugin:              row[userTblColIndex_plugin].(string),
    90  		Password:            row[userTblColIndex_authentication_string].(string),
    91  		PasswordLastChanged: passwordLastChanged,
    92  		Locked:              row[userTblColIndex_account_locked].(uint16) == 2,
    93  		Attributes:          attributes,
    94  		Identity:            row[userTblColIndex_identity].(string),
    95  		IsRole:              false,
    96  	}, nil
    97  }
    98  
    99  func UserUpdateWithRow(ctx *sql.Context, row sql.Row, u *User) (*User, error) {
   100  	updatedUser, err := UserFromRow(ctx, row)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	updatedUser.IsRole = u.IsRole
   105  	return updatedUser, nil
   106  }
   107  
   108  var UserOps = in_mem_table.ValueOps[*User]{
   109  	ToRow:         UserToRow,
   110  	FromRow:       UserFromRow,
   111  	UpdateWithRow: UserUpdateWithRow,
   112  }
   113  
   114  func UserEquals(left, right *User) bool {
   115  	// IsRole is not tested for equality, as it is additional information
   116  	//TODO: once the remaining fields are added, fill those in as well
   117  	if left.User != right.User ||
   118  		left.Host != right.Host ||
   119  		left.Plugin != right.Plugin ||
   120  		left.Password != right.Password ||
   121  		left.Identity != right.Identity ||
   122  		!left.PasswordLastChanged.Equal(right.PasswordLastChanged) ||
   123  		left.Locked != right.Locked ||
   124  		!left.PrivilegeSet.Equals(right.PrivilegeSet) ||
   125  		left.Attributes == nil && right.Attributes != nil ||
   126  		left.Attributes != nil && right.Attributes == nil ||
   127  		(left.Attributes != nil && *left.Attributes != *right.Attributes) {
   128  		return false
   129  	}
   130  	return true
   131  }
   132  
   133  func UserCopy(u *User) *User {
   134  	uu := *u
   135  	uu.PrivilegeSet = NewPrivilegeSet()
   136  	uu.PrivilegeSet.UnionWith(u.PrivilegeSet)
   137  	return &uu
   138  }
   139  
   140  // FromJson implements the interface in_mem_table.Entry.
   141  func (u User) FromJson(ctx *sql.Context, jsonStr string) (*User, error) {
   142  	newUser := &User{}
   143  	if err := json.Unmarshal([]byte(jsonStr), newUser); err != nil {
   144  		return nil, err
   145  	}
   146  	return newUser, nil
   147  }
   148  
   149  // ToJson implements the interface in_mem_table.Entry.
   150  func (u *User) ToJson(ctx *sql.Context) (string, error) {
   151  	jsonData, err := json.Marshal(*u)
   152  	if err != nil {
   153  		return "", err
   154  	}
   155  	return string(jsonData), nil
   156  }
   157  
   158  // UserHostToString returns the user and host as a formatted string using the quotes given. Using the default root
   159  // account with the backtick as the quote, root@localhost would become `root`@`localhost`. Different quotes are used
   160  // in different places in MySQL. In addition, if the quote is used in a section as part of the name, it is escaped by
   161  // doubling the quote (which also mimics MySQL behavior).
   162  func (u User) UserHostToString(quote string) string {
   163  	replacement := quote + quote
   164  	user := strings.ReplaceAll(u.User, quote, replacement)
   165  	host := strings.ReplaceAll(u.Host, quote, replacement)
   166  	return fmt.Sprintf("%s%s%s@%s%s%s", quote, user, quote, quote, host, quote)
   167  }
   168  
   169  func UserRowToPrivSet(ctx *sql.Context, row sql.Row) PrivilegeSet {
   170  	privSet := NewPrivilegeSet()
   171  	for i, val := range row {
   172  		switch i {
   173  		case userTblColIndex_Select_priv:
   174  			if val.(uint16) == 2 {
   175  				privSet.AddGlobalStatic(sql.PrivilegeType_Select)
   176  			}
   177  		case userTblColIndex_Insert_priv:
   178  			if val.(uint16) == 2 {
   179  				privSet.AddGlobalStatic(sql.PrivilegeType_Insert)
   180  			}
   181  		case userTblColIndex_Update_priv:
   182  			if val.(uint16) == 2 {
   183  				privSet.AddGlobalStatic(sql.PrivilegeType_Update)
   184  			}
   185  		case userTblColIndex_Delete_priv:
   186  			if val.(uint16) == 2 {
   187  				privSet.AddGlobalStatic(sql.PrivilegeType_Delete)
   188  			}
   189  		case userTblColIndex_Create_priv:
   190  			if val.(uint16) == 2 {
   191  				privSet.AddGlobalStatic(sql.PrivilegeType_Create)
   192  			}
   193  		case userTblColIndex_Drop_priv:
   194  			if val.(uint16) == 2 {
   195  				privSet.AddGlobalStatic(sql.PrivilegeType_Drop)
   196  			}
   197  		case userTblColIndex_Reload_priv:
   198  			if val.(uint16) == 2 {
   199  				privSet.AddGlobalStatic(sql.PrivilegeType_Reload)
   200  			}
   201  		case userTblColIndex_Shutdown_priv:
   202  			if val.(uint16) == 2 {
   203  				privSet.AddGlobalStatic(sql.PrivilegeType_Shutdown)
   204  			}
   205  		case userTblColIndex_Process_priv:
   206  			if val.(uint16) == 2 {
   207  				privSet.AddGlobalStatic(sql.PrivilegeType_Process)
   208  			}
   209  		case userTblColIndex_File_priv:
   210  			if val.(uint16) == 2 {
   211  				privSet.AddGlobalStatic(sql.PrivilegeType_File)
   212  			}
   213  		case userTblColIndex_Grant_priv:
   214  			if val.(uint16) == 2 {
   215  				privSet.AddGlobalStatic(sql.PrivilegeType_GrantOption)
   216  			}
   217  		case userTblColIndex_References_priv:
   218  			if val.(uint16) == 2 {
   219  				privSet.AddGlobalStatic(sql.PrivilegeType_References)
   220  			}
   221  		case userTblColIndex_Index_priv:
   222  			if val.(uint16) == 2 {
   223  				privSet.AddGlobalStatic(sql.PrivilegeType_Index)
   224  			}
   225  		case userTblColIndex_Alter_priv:
   226  			if val.(uint16) == 2 {
   227  				privSet.AddGlobalStatic(sql.PrivilegeType_Alter)
   228  			}
   229  		case userTblColIndex_Show_db_priv:
   230  			if val.(uint16) == 2 {
   231  				privSet.AddGlobalStatic(sql.PrivilegeType_ShowDB)
   232  			}
   233  		case userTblColIndex_Super_priv:
   234  			if val.(uint16) == 2 {
   235  				privSet.AddGlobalStatic(sql.PrivilegeType_Super)
   236  			}
   237  		case userTblColIndex_Create_tmp_table_priv:
   238  			if val.(uint16) == 2 {
   239  				privSet.AddGlobalStatic(sql.PrivilegeType_CreateTempTable)
   240  			}
   241  		case userTblColIndex_Lock_tables_priv:
   242  			if val.(uint16) == 2 {
   243  				privSet.AddGlobalStatic(sql.PrivilegeType_LockTables)
   244  			}
   245  		case userTblColIndex_Execute_priv:
   246  			if val.(uint16) == 2 {
   247  				privSet.AddGlobalStatic(sql.PrivilegeType_Execute)
   248  			}
   249  		case userTblColIndex_Repl_slave_priv:
   250  			if val.(uint16) == 2 {
   251  				privSet.AddGlobalStatic(sql.PrivilegeType_ReplicationSlave)
   252  			}
   253  		case userTblColIndex_Repl_client_priv:
   254  			if val.(uint16) == 2 {
   255  				privSet.AddGlobalStatic(sql.PrivilegeType_ReplicationClient)
   256  			}
   257  		case userTblColIndex_Create_view_priv:
   258  			if val.(uint16) == 2 {
   259  				privSet.AddGlobalStatic(sql.PrivilegeType_CreateView)
   260  			}
   261  		case userTblColIndex_Show_view_priv:
   262  			if val.(uint16) == 2 {
   263  				privSet.AddGlobalStatic(sql.PrivilegeType_ShowView)
   264  			}
   265  		case userTblColIndex_Create_routine_priv:
   266  			if val.(uint16) == 2 {
   267  				privSet.AddGlobalStatic(sql.PrivilegeType_CreateRoutine)
   268  			}
   269  		case userTblColIndex_Alter_routine_priv:
   270  			if val.(uint16) == 2 {
   271  				privSet.AddGlobalStatic(sql.PrivilegeType_AlterRoutine)
   272  			}
   273  		case userTblColIndex_Create_user_priv:
   274  			if val.(uint16) == 2 {
   275  				privSet.AddGlobalStatic(sql.PrivilegeType_CreateUser)
   276  			}
   277  		case userTblColIndex_Event_priv:
   278  			if val.(uint16) == 2 {
   279  				privSet.AddGlobalStatic(sql.PrivilegeType_Event)
   280  			}
   281  		case userTblColIndex_Trigger_priv:
   282  			if val.(uint16) == 2 {
   283  				privSet.AddGlobalStatic(sql.PrivilegeType_Trigger)
   284  			}
   285  		case userTblColIndex_Create_tablespace_priv:
   286  			if val.(uint16) == 2 {
   287  				privSet.AddGlobalStatic(sql.PrivilegeType_CreateTablespace)
   288  			}
   289  		case userTblColIndex_Create_role_priv:
   290  			if val.(uint16) == 2 {
   291  				privSet.AddGlobalStatic(sql.PrivilegeType_CreateRole)
   292  			}
   293  		case userTblColIndex_Drop_role_priv:
   294  			if val.(uint16) == 2 {
   295  				privSet.AddGlobalStatic(sql.PrivilegeType_DropRole)
   296  			}
   297  		}
   298  	}
   299  	return privSet
   300  }
   301  
   302  // privSetToRow applies the this User's set of privileges to the given row. Only sets privileges that exist to "Y",
   303  // therefore any privileges that do not exist will have their default values.
   304  func (u *User) privSetToRow(ctx *sql.Context, row sql.Row) {
   305  	for _, priv := range u.PrivilegeSet.ToSlice() {
   306  		switch priv {
   307  		case sql.PrivilegeType_Select:
   308  			row[userTblColIndex_Select_priv] = uint16(2)
   309  		case sql.PrivilegeType_Insert:
   310  			row[userTblColIndex_Insert_priv] = uint16(2)
   311  		case sql.PrivilegeType_Update:
   312  			row[userTblColIndex_Update_priv] = uint16(2)
   313  		case sql.PrivilegeType_Delete:
   314  			row[userTblColIndex_Delete_priv] = uint16(2)
   315  		case sql.PrivilegeType_Create:
   316  			row[userTblColIndex_Create_priv] = uint16(2)
   317  		case sql.PrivilegeType_Drop:
   318  			row[userTblColIndex_Drop_priv] = uint16(2)
   319  		case sql.PrivilegeType_Reload:
   320  			row[userTblColIndex_Reload_priv] = uint16(2)
   321  		case sql.PrivilegeType_Shutdown:
   322  			row[userTblColIndex_Shutdown_priv] = uint16(2)
   323  		case sql.PrivilegeType_Process:
   324  			row[userTblColIndex_Process_priv] = uint16(2)
   325  		case sql.PrivilegeType_File:
   326  			row[userTblColIndex_File_priv] = uint16(2)
   327  		case sql.PrivilegeType_GrantOption:
   328  			row[userTblColIndex_Grant_priv] = uint16(2)
   329  		case sql.PrivilegeType_References:
   330  			row[userTblColIndex_References_priv] = uint16(2)
   331  		case sql.PrivilegeType_Index:
   332  			row[userTblColIndex_Index_priv] = uint16(2)
   333  		case sql.PrivilegeType_Alter:
   334  			row[userTblColIndex_Alter_priv] = uint16(2)
   335  		case sql.PrivilegeType_ShowDB:
   336  			row[userTblColIndex_Show_db_priv] = uint16(2)
   337  		case sql.PrivilegeType_Super:
   338  			row[userTblColIndex_Super_priv] = uint16(2)
   339  		case sql.PrivilegeType_CreateTempTable:
   340  			row[userTblColIndex_Create_tmp_table_priv] = uint16(2)
   341  		case sql.PrivilegeType_LockTables:
   342  			row[userTblColIndex_Lock_tables_priv] = uint16(2)
   343  		case sql.PrivilegeType_Execute:
   344  			row[userTblColIndex_Execute_priv] = uint16(2)
   345  		case sql.PrivilegeType_ReplicationSlave:
   346  			row[userTblColIndex_Repl_slave_priv] = uint16(2)
   347  		case sql.PrivilegeType_ReplicationClient:
   348  			row[userTblColIndex_Repl_client_priv] = uint16(2)
   349  		case sql.PrivilegeType_CreateView:
   350  			row[userTblColIndex_Create_view_priv] = uint16(2)
   351  		case sql.PrivilegeType_ShowView:
   352  			row[userTblColIndex_Show_view_priv] = uint16(2)
   353  		case sql.PrivilegeType_CreateRoutine:
   354  			row[userTblColIndex_Create_routine_priv] = uint16(2)
   355  		case sql.PrivilegeType_AlterRoutine:
   356  			row[userTblColIndex_Alter_routine_priv] = uint16(2)
   357  		case sql.PrivilegeType_CreateUser:
   358  			row[userTblColIndex_Create_user_priv] = uint16(2)
   359  		case sql.PrivilegeType_Event:
   360  			row[userTblColIndex_Event_priv] = uint16(2)
   361  		case sql.PrivilegeType_Trigger:
   362  			row[userTblColIndex_Trigger_priv] = uint16(2)
   363  		case sql.PrivilegeType_CreateTablespace:
   364  			row[userTblColIndex_Create_tablespace_priv] = uint16(2)
   365  		case sql.PrivilegeType_CreateRole:
   366  			row[userTblColIndex_Create_role_priv] = uint16(2)
   367  		case sql.PrivilegeType_DropRole:
   368  			row[userTblColIndex_Drop_role_priv] = uint16(2)
   369  		}
   370  	}
   371  }