github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/branch_control/branch_control.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package branch_control
    16  
    17  import (
    18  	"context"
    19  	goerrors "errors"
    20  	"fmt"
    21  	"os"
    22  	"sync/atomic"
    23  
    24  	flatbuffers "github.com/dolthub/flatbuffers/v23/go"
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"gopkg.in/src-d/go-errors.v1"
    27  
    28  	"github.com/dolthub/dolt/go/gen/fb/serial"
    29  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    30  )
    31  
    32  var (
    33  	ErrIncorrectPermissions  = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
    34  	ErrCannotCreateBranch    = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
    35  	ErrCannotDeleteBranch    = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
    36  	ErrExpressionsTooLong    = errors.NewKind("expressions are too long [%q, %q, %q, %q]")
    37  	ErrInsertingAccessRow    = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q, %q]")
    38  	ErrInsertingNamespaceRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q]")
    39  	ErrUpdatingRow           = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q, %q]")
    40  	ErrUpdatingToRow         = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q, %q] to the new branch expression [%q, %q]")
    41  	ErrDeletingRow           = errors.NewKind("`%s`@`%s` cannot delete the row [%q, %q, %q, %q]")
    42  	ErrMissingController     = errors.NewKind("a context has a non-nil session but is missing its branch controller")
    43  )
    44  
    45  // Context represents the interface that must be inherited from the context.
    46  type Context interface {
    47  	GetBranch() (string, error)
    48  	GetCurrentDatabase() string
    49  	GetUser() string
    50  	GetHost() string
    51  	GetPrivilegeSet() (sql.PrivilegeSet, uint64)
    52  	GetController() *Controller
    53  	GetFileSystem() filesys.Filesys
    54  }
    55  
    56  // Controller is the central hub for branch control functions. This is passed within a context.
    57  type Controller struct {
    58  	Access    *Access
    59  	Namespace *Namespace
    60  
    61  	Serialized atomic.Pointer[[]byte]
    62  
    63  	// A callback which we call when we successfully save new data.
    64  	// The new data will be available in |Serialized|.
    65  	SavedCallback func(context.Context)
    66  
    67  	branchControlFilePath string
    68  	doltConfigDirPath     string
    69  }
    70  
    71  // CreateDefaultController returns a default controller, which only has a single entry allowing all users to have write
    72  // permissions on all branches (only the super user has admin, if a super user has been set). This is equivalent to
    73  // passing empty strings to LoadData.
    74  func CreateDefaultController(ctx context.Context) *Controller {
    75  	controller, err := LoadData(ctx, "", "")
    76  	if err != nil {
    77  		panic(err) // should never happen
    78  	}
    79  	return controller
    80  }
    81  
    82  // LoadData loads the data from the given location and returns a controller. Returns the default controller if the
    83  // `branchControlFilePath` is empty.
    84  func LoadData(ctx context.Context, branchControlFilePath string, doltConfigDirPath string) (*Controller, error) {
    85  	accessTbl := newAccess()
    86  	controller := &Controller{
    87  		Access:                accessTbl,
    88  		Namespace:             newNamespace(accessTbl),
    89  		branchControlFilePath: branchControlFilePath,
    90  		doltConfigDirPath:     doltConfigDirPath,
    91  	}
    92  
    93  	// Do not attempt to load from an empty file path
    94  	if len(branchControlFilePath) == 0 {
    95  		// If the path is empty, then we should populate the controller with the default row to ensure normal (expected) operation
    96  		controller.Access.insertDefaultRow()
    97  		return controller, nil
    98  	}
    99  
   100  	data, err := os.ReadFile(branchControlFilePath)
   101  	if err != nil && !goerrors.Is(err, os.ErrNotExist) {
   102  		return nil, err
   103  	}
   104  
   105  	err = controller.LoadData(ctx, data /* isFirstLoad */, true)
   106  	if err != nil {
   107  		return nil, fmt.Errorf("failed to deserialize config at '%s': %w", branchControlFilePath, err)
   108  	}
   109  	return controller, nil
   110  }
   111  
   112  func (controller *Controller) LoadData(ctx context.Context, data []byte, isFirstLoad bool) error {
   113  	controller.Access.RWMutex.Lock()
   114  	defer controller.Access.RWMutex.Unlock()
   115  
   116  	// Nothing to load so we can return
   117  	if len(data) == 0 {
   118  		// As there is nothing to load, we should populate the controller with the default row to ensure normal (expected) operation
   119  		controller.Access.insertDefaultRow()
   120  		controller.Serialized.Store(&data)
   121  		if controller.SavedCallback != nil {
   122  			controller.SavedCallback(ctx)
   123  		}
   124  		return nil
   125  	}
   126  	// Load the tables
   127  	if serial.GetFileID(data) != serial.BranchControlFileID {
   128  		return fmt.Errorf("unable to deserialize branch controller, unknown file ID `%s`", serial.GetFileID(data))
   129  	}
   130  	bc, err := serial.TryGetRootAsBranchControl(data, serial.MessagePrefixSz)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	access, err := bc.TryAccessTbl(nil)
   135  	if err != nil {
   136  		return err
   137  	}
   138  	namespace, err := bc.TryNamespaceTbl(nil)
   139  	if err != nil {
   140  		return err
   141  	}
   142  
   143  	rollback := controller.Serialized.Load()
   144  
   145  	// TODO: Better concurrency control here. We see |Namespace| and
   146  	// |Access| in different views of the data here.
   147  
   148  	// The Deserialize functions acquire write locks, so we don't acquire them here
   149  	if err = controller.Access.Deserialize(access); err != nil {
   150  		// TODO: More principaled rollback. Hopefully this does not fail.
   151  		controller.LoadData(ctx, *rollback, isFirstLoad)
   152  		return err
   153  	}
   154  	if err = controller.Namespace.Deserialize(namespace); err != nil {
   155  		// TODO: More principaled rollback. Hopefully this does not fail.
   156  		controller.LoadData(ctx, *rollback, isFirstLoad)
   157  		return err
   158  	}
   159  
   160  	controller.Serialized.Store(&data)
   161  	if controller.SavedCallback != nil {
   162  		controller.SavedCallback(ctx)
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  // SaveData saves the data from the context's controller to the location pointed by it.
   169  func SaveData(ctx context.Context) error {
   170  	branchAwareSession := GetBranchAwareSession(ctx)
   171  	// A nil session means we're not in the SQL context, so we've got nothing to serialize
   172  	if branchAwareSession == nil {
   173  		return nil
   174  	}
   175  	controller := branchAwareSession.GetController()
   176  	// If there is no controller in the context, then we have nothing to serialize
   177  	if controller == nil {
   178  		return nil
   179  	}
   180  
   181  	return controller.SaveData(ctx, branchAwareSession.GetFileSystem())
   182  }
   183  
   184  func (controller *Controller) SaveData(ctx context.Context, fs filesys.Filesys) error {
   185  	// If we never set a save location then we just return
   186  	if len(controller.branchControlFilePath) == 0 {
   187  		return nil
   188  	}
   189  
   190  	// Create the doltcfg directory if it doesn't exist
   191  	if len(controller.doltConfigDirPath) != 0 {
   192  		if mkErr := fs.MkDirs(controller.doltConfigDirPath); mkErr != nil {
   193  			return mkErr
   194  		}
   195  	}
   196  
   197  	controller.Access.RWMutex.Lock()
   198  	defer controller.Access.RWMutex.Unlock()
   199  
   200  	b := flatbuffers.NewBuilder(1024)
   201  	// The Serialize functions acquire read locks, so we don't acquire them here
   202  	accessOffset := controller.Access.Serialize(b)
   203  	namespaceOffset := controller.Namespace.Serialize(b)
   204  	serial.BranchControlStart(b)
   205  	serial.BranchControlAddAccessTbl(b, accessOffset)
   206  	serial.BranchControlAddNamespaceTbl(b, namespaceOffset)
   207  	root := serial.BranchControlEnd(b)
   208  	// serial.FinishMessage() limits files to 2^24 bytes, so this works around it while maintaining read compatibility
   209  	b.Prep(1, flatbuffers.SizeInt32+4+serial.MessagePrefixSz)
   210  	b.FinishWithFileIdentifier(root, []byte(serial.BranchControlFileID))
   211  	data := b.Bytes[b.Head()-serial.MessagePrefixSz:]
   212  
   213  	err := fs.WriteFile(controller.branchControlFilePath, data, 0660)
   214  	if err != nil {
   215  		return err
   216  	}
   217  
   218  	controller.Serialized.Store(&data)
   219  	if controller.SavedCallback != nil {
   220  		controller.SavedCallback(ctx)
   221  	}
   222  	return nil
   223  }
   224  
   225  // CheckAccess returns whether the given context has the correct permissions on its selected branch. In general, SQL
   226  // statements will almost always return a *sql.Context, so any checks from the SQL path will correctly check for branch
   227  // permissions. However, not all CLI commands use *sql.Context, and therefore will not have any user associated with
   228  // the context. In these cases, CheckAccess will pass as we want to allow all local commands to ignore branch
   229  // permissions.
   230  func CheckAccess(ctx context.Context, flags Permissions) error {
   231  	branchAwareSession := GetBranchAwareSession(ctx)
   232  	// A nil session means we're not in the SQL context, so we allow all operations
   233  	if branchAwareSession == nil {
   234  		return nil
   235  	}
   236  	controller := branchAwareSession.GetController()
   237  	// Any context that has a non-nil session should always have a non-nil controller, so this is an error
   238  	if controller == nil {
   239  		return ErrMissingController.New()
   240  	}
   241  	controller.Access.RWMutex.RLock()
   242  	defer controller.Access.RWMutex.RUnlock()
   243  
   244  	user := branchAwareSession.GetUser()
   245  	host := branchAwareSession.GetHost()
   246  	database := branchAwareSession.GetCurrentDatabase()
   247  	branch, err := branchAwareSession.GetBranch()
   248  	if err != nil {
   249  		return err
   250  	}
   251  	// Get the permissions for the branch, user, and host combination
   252  	_, perms := controller.Access.Match(database, branch, user, host)
   253  	// If either the flags match or the user is an admin for this branch, then we allow access
   254  	if (perms&flags == flags) || (perms&Permissions_Admin == Permissions_Admin) {
   255  		return nil
   256  	}
   257  	return ErrIncorrectPermissions.New(user, host, branch)
   258  }
   259  
   260  // CanCreateBranch returns whether the given context can create a branch with the given name. In general, SQL statements
   261  // will almost always return a *sql.Context, so any checks from the SQL path will be able to validate a branch's name.
   262  // However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
   263  // these cases, CanCreateBranch will pass as we want to allow all local commands to freely create branches.
   264  func CanCreateBranch(ctx context.Context, branchName string) error {
   265  	branchAwareSession := GetBranchAwareSession(ctx)
   266  	// A nil session means we're not in the SQL context, so we allow the create operation
   267  	if branchAwareSession == nil {
   268  		return nil
   269  	}
   270  	controller := branchAwareSession.GetController()
   271  	// Any context that has a non-nil session should always have a non-nil controller, so this is an error
   272  	if controller == nil {
   273  		return ErrMissingController.New()
   274  	}
   275  	controller.Namespace.RWMutex.RLock()
   276  	defer controller.Namespace.RWMutex.RUnlock()
   277  
   278  	user := branchAwareSession.GetUser()
   279  	host := branchAwareSession.GetHost()
   280  	database := branchAwareSession.GetCurrentDatabase()
   281  	if controller.Namespace.CanCreate(database, branchName, user, host) {
   282  		return nil
   283  	}
   284  	return ErrCannotCreateBranch.New(user, host, branchName)
   285  }
   286  
   287  // CanDeleteBranch returns whether the given context can delete a branch with the given name. In general, SQL statements
   288  // will almost always return a *sql.Context, so any checks from the SQL path will be able to validate a branch's name.
   289  // However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
   290  // these cases, CanDeleteBranch will pass as we want to allow all local commands to freely delete branches.
   291  func CanDeleteBranch(ctx context.Context, branchName string) error {
   292  	branchAwareSession := GetBranchAwareSession(ctx)
   293  	// A nil session means we're not in the SQL context, so we allow the delete operation
   294  	if branchAwareSession == nil {
   295  		return nil
   296  	}
   297  	controller := branchAwareSession.GetController()
   298  	// Any context that has a non-nil session should always have a non-nil controller, so this is an error
   299  	if controller == nil {
   300  		return ErrMissingController.New()
   301  	}
   302  	controller.Access.RWMutex.RLock()
   303  	defer controller.Access.RWMutex.RUnlock()
   304  
   305  	user := branchAwareSession.GetUser()
   306  	host := branchAwareSession.GetHost()
   307  	database := branchAwareSession.GetCurrentDatabase()
   308  	// Get the permissions for the branch, user, and host combination
   309  	_, perms := controller.Access.Match(database, branchName, user, host)
   310  	// If the user has the write or admin flags, then we allow access
   311  	if (perms&Permissions_Write == Permissions_Write) || (perms&Permissions_Admin == Permissions_Admin) {
   312  		return nil
   313  	}
   314  	return ErrCannotDeleteBranch.New(user, host, branchName)
   315  }
   316  
   317  // AddAdminForContext adds an entry in the access table for the user represented by the given context. If the
   318  // context is missing some functionality that is needed to perform the addition, such as a user or the Controller, then
   319  // this simply returns.
   320  func AddAdminForContext(ctx context.Context, branchName string) error {
   321  	branchAwareSession := GetBranchAwareSession(ctx)
   322  	if branchAwareSession == nil {
   323  		return nil
   324  	}
   325  	controller := branchAwareSession.GetController()
   326  	if controller == nil {
   327  		return nil
   328  	}
   329  
   330  	user := branchAwareSession.GetUser()
   331  	host := branchAwareSession.GetHost()
   332  	database := branchAwareSession.GetCurrentDatabase()
   333  	// Check if we already have admin permissions for the given branch, as there's no need to do another insertion if so
   334  	controller.Access.RWMutex.RLock()
   335  	_, modPerms := controller.Access.Match(database, branchName, user, host)
   336  	controller.Access.RWMutex.RUnlock()
   337  	if modPerms&Permissions_Admin == Permissions_Admin {
   338  		return nil
   339  	}
   340  	controller.Access.RWMutex.Lock()
   341  	controller.Access.Insert(database, branchName, user, host, Permissions_Admin)
   342  	controller.Access.RWMutex.Unlock()
   343  	return SaveData(ctx)
   344  }
   345  
   346  // GetBranchAwareSession returns the session contained within the context. If the context does NOT contain a session,
   347  // then nil is returned.
   348  func GetBranchAwareSession(ctx context.Context) Context {
   349  	if sqlCtx, ok := ctx.(*sql.Context); ok {
   350  		if bas, ok := sqlCtx.Session.(Context); ok {
   351  			return bas
   352  		}
   353  	} else if bas, ok := ctx.(Context); ok {
   354  		return bas
   355  	}
   356  	return nil
   357  }
   358  
   359  // HasDatabasePrivileges returns whether the given context's user has the correct privileges to modify any table entries
   360  // that match the given database. The following are the required privileges:
   361  //
   362  // Global Space:   SUPER, GRANT
   363  // Global Space:   CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, EXECUTE, GRANT
   364  // Database Space: CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, EXECUTE, GRANT
   365  //
   366  // Any user that may grant SUPER is considered to be a super user. In addition, any user that may grant the suite of
   367  // alteration privileges is also considered a super user. The SUPER privilege does not exist at the database level, it
   368  // is a global privilege only.
   369  func HasDatabasePrivileges(ctx Context, database string) bool {
   370  	if ctx == nil {
   371  		return true
   372  	}
   373  	privSet, counter := ctx.GetPrivilegeSet()
   374  	if counter == 0 {
   375  		return false
   376  	}
   377  	hasSuper := privSet.Has(sql.PrivilegeType_Super, sql.PrivilegeType_GrantOption)
   378  	isGlobalAdmin := privSet.Has(sql.PrivilegeType_Create, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop,
   379  		sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_Delete, sql.PrivilegeType_Execute, sql.PrivilegeType_GrantOption)
   380  	isDatabaseAdmin := privSet.Database(database).Has(sql.PrivilegeType_Create, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop,
   381  		sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_Delete, sql.PrivilegeType_Execute, sql.PrivilegeType_GrantOption)
   382  	return hasSuper || isGlobalAdmin || isDatabaseAdmin
   383  }