github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/session.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dsess
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	sqltypes "github.com/dolthub/go-mysql-server/sql/types"
    28  	"github.com/shopspring/decimal"
    29  
    30  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    34  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    35  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    36  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
    37  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
    38  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    39  	"github.com/dolthub/dolt/go/libraries/utils/config"
    40  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    41  	"github.com/dolthub/dolt/go/store/hash"
    42  	"github.com/dolthub/dolt/go/store/types"
    43  )
    44  
    45  const (
    46  	DbRevisionDelimiter = "/"
    47  )
    48  
    49  var ErrSessionNotPersistable = errors.New("session is not persistable")
    50  
    51  // DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance
    52  type DoltSession struct {
    53  	sql.Session
    54  	username         string
    55  	email            string
    56  	dbStates         map[string]*DatabaseSessionState
    57  	dbCache          *DatabaseCache
    58  	provider         DoltDatabaseProvider
    59  	tempTables       map[string][]sql.Table
    60  	globalsConf      config.ReadWriteConfig
    61  	branchController *branch_control.Controller
    62  	statsProv        sql.StatsProvider
    63  	mu               *sync.Mutex
    64  	fs               filesys.Filesys
    65  
    66  	// If non-nil, this will be returned from ValidateSession.
    67  	// Used by sqle/cluster to put a session into a terminal err state.
    68  	validateErr error
    69  }
    70  
    71  var _ sql.Session = (*DoltSession)(nil)
    72  var _ sql.PersistableSession = (*DoltSession)(nil)
    73  var _ sql.TransactionSession = (*DoltSession)(nil)
    74  var _ branch_control.Context = (*DoltSession)(nil)
    75  
    76  // DefaultSession creates a DoltSession with default values
    77  func DefaultSession(pro DoltDatabaseProvider) *DoltSession {
    78  	return &DoltSession{
    79  		Session:          sql.NewBaseSession(),
    80  		username:         "",
    81  		email:            "",
    82  		dbStates:         make(map[string]*DatabaseSessionState),
    83  		dbCache:          newDatabaseCache(),
    84  		provider:         pro,
    85  		tempTables:       make(map[string][]sql.Table),
    86  		globalsConf:      config.NewMapConfig(make(map[string]string)),
    87  		branchController: branch_control.CreateDefaultController(context.TODO()), // Default sessions are fine with the default controller
    88  		mu:               &sync.Mutex{},
    89  		fs:               pro.FileSystem(),
    90  	}
    91  }
    92  
    93  // NewDoltSession creates a DoltSession object from a standard sql.Session and 0 or more Database objects.
    94  func NewDoltSession(
    95  	sqlSess *sql.BaseSession,
    96  	pro DoltDatabaseProvider,
    97  	conf config.ReadWriteConfig,
    98  	branchController *branch_control.Controller,
    99  	statsProvider sql.StatsProvider,
   100  ) (*DoltSession, error) {
   101  	username := conf.GetStringOrDefault(config.UserNameKey, "")
   102  	email := conf.GetStringOrDefault(config.UserEmailKey, "")
   103  	globals := config.NewPrefixConfig(conf, env.SqlServerGlobalsPrefix)
   104  
   105  	sess := &DoltSession{
   106  		Session:          sqlSess,
   107  		username:         username,
   108  		email:            email,
   109  		dbStates:         make(map[string]*DatabaseSessionState),
   110  		dbCache:          newDatabaseCache(),
   111  		provider:         pro,
   112  		tempTables:       make(map[string][]sql.Table),
   113  		globalsConf:      globals,
   114  		branchController: branchController,
   115  		statsProv:        statsProvider,
   116  		mu:               &sync.Mutex{},
   117  		fs:               pro.FileSystem(),
   118  	}
   119  
   120  	return sess, nil
   121  }
   122  
   123  // Provider returns the RevisionDatabaseProvider for this session.
   124  func (d *DoltSession) Provider() DoltDatabaseProvider {
   125  	return d.provider
   126  }
   127  
   128  // StatsProvider returns the sql.StatsProvider for this session.
   129  func (d *DoltSession) StatsProvider() sql.StatsProvider {
   130  	return d.statsProv
   131  }
   132  
   133  // DSessFromSess retrieves a dolt session from a standard sql.Session
   134  func DSessFromSess(sess sql.Session) *DoltSession {
   135  	return sess.(*DoltSession)
   136  }
   137  
   138  // lookupDbState is the private version of LookupDbState, returning a struct that has more information available than
   139  // the interface returned by the public method.
   140  func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*branchState, bool, error) {
   141  	dbName = strings.ToLower(dbName)
   142  
   143  	var baseName, rev string
   144  	baseName, rev = SplitRevisionDbName(dbName)
   145  
   146  	d.mu.Lock()
   147  	dbState, dbStateFound := d.dbStates[baseName]
   148  	d.mu.Unlock()
   149  
   150  	if dbStateFound {
   151  		// If we got an unqualified name, use the current working set head
   152  		if rev == "" {
   153  			rev = dbState.checkedOutRevSpec
   154  		}
   155  
   156  		branchState, ok := dbState.heads[strings.ToLower(rev)]
   157  
   158  		if ok {
   159  			if dbState.Err != nil {
   160  				return nil, false, dbState.Err
   161  			}
   162  
   163  			return branchState, ok, nil
   164  		}
   165  	}
   166  
   167  	// No state for this db / branch combination yet, look it up from the provider. We use the unqualified DB name (no
   168  	// branch) if the current DB has not yet been loaded into this session. It will resolve to that DB's default branch
   169  	// in that case.
   170  	revisionQualifiedName := dbName
   171  	if rev != "" {
   172  		revisionQualifiedName = RevisionDbName(baseName, rev)
   173  	}
   174  
   175  	database, ok, err := d.provider.SessionDatabase(ctx, revisionQualifiedName)
   176  	if err != nil {
   177  		return nil, false, err
   178  	}
   179  	if !ok {
   180  		return nil, false, nil
   181  	}
   182  
   183  	// Add the initial state to the session for future reuse
   184  	if err := d.addDB(ctx, database); err != nil {
   185  		return nil, false, err
   186  	}
   187  
   188  	d.mu.Lock()
   189  	dbState, dbStateFound = d.dbStates[baseName]
   190  	d.mu.Unlock()
   191  	if !dbStateFound {
   192  		// should be impossible
   193  		return nil, false, sql.ErrDatabaseNotFound.New(dbName)
   194  	}
   195  
   196  	return dbState.heads[strings.ToLower(database.Revision())], true, nil
   197  }
   198  
   199  // RevisionDbName returns the name of the revision db for the base name and revision string given
   200  func RevisionDbName(baseName string, rev string) string {
   201  	return baseName + DbRevisionDelimiter + rev
   202  }
   203  
   204  func SplitRevisionDbName(dbName string) (string, string) {
   205  	var baseName, rev string
   206  	parts := strings.SplitN(dbName, DbRevisionDelimiter, 2)
   207  	baseName = parts[0]
   208  	if len(parts) > 1 {
   209  		rev = parts[1]
   210  	}
   211  	return baseName, rev
   212  }
   213  
   214  // LookupDbState returns the session state for the database named. Unqualified database names, e.g. `mydb` get resolved
   215  // to the currently checked out HEAD, which could be a branch, a commit, a tag, etc. Revision-qualified database names,
   216  // e.g. `mydb/branch1` get resolved to the session state for the revision named.
   217  // A note on unqualified database names: unqualified names will resolve to a) the head last checked out with
   218  // `dolt_checkout`, or b) the database's default branch, if this session hasn't called `dolt_checkout` yet.
   219  // Also returns a bool indicating whether the database was found, and an error if one occurred.
   220  func (d *DoltSession) LookupDbState(ctx *sql.Context, dbName string) (SessionState, bool, error) {
   221  	s, ok, err := d.lookupDbState(ctx, dbName)
   222  	if err != nil {
   223  		return nil, false, err
   224  	}
   225  
   226  	return s, ok, nil
   227  }
   228  
   229  // RemoveDbState invalidates any cached db state in this session, for example, if a database is dropped.
   230  func (d *DoltSession) RemoveDbState(_ *sql.Context, dbName string) error {
   231  	d.mu.Lock()
   232  	defer d.mu.Unlock()
   233  	delete(d.dbStates, strings.ToLower(dbName))
   234  	// also clear out any db-level caches for this db
   235  	d.dbCache.Clear()
   236  	return nil
   237  }
   238  
   239  // RemoveBranchState removes the session state for a branch, for example, if a branch is deleted.
   240  func (d *DoltSession) RemoveBranchState(ctx *sql.Context, dbName string, branchName string) error {
   241  	baseName, _ := SplitRevisionDbName(dbName)
   242  
   243  	checkedOutState, ok, err := d.lookupDbState(ctx, baseName)
   244  	if err != nil {
   245  		return err
   246  	}
   247  	if !ok {
   248  		return sql.ErrDatabaseNotFound.New(baseName)
   249  	}
   250  
   251  	d.mu.Lock()
   252  	delete(checkedOutState.dbState.heads, strings.ToLower(branchName))
   253  	d.mu.Unlock()
   254  
   255  	db, ok := d.provider.BaseDatabase(ctx, baseName)
   256  	if !ok {
   257  		return sql.ErrDatabaseNotFound.New(baseName)
   258  	}
   259  
   260  	defaultHead, err := DefaultHead(baseName, db)
   261  	if err != nil {
   262  		return err
   263  	}
   264  
   265  	checkedOutState.dbState.checkedOutRevSpec = defaultHead
   266  
   267  	// also clear out any db-level caches for this db
   268  	d.dbCache.Clear()
   269  	return nil
   270  }
   271  
   272  // RenameBranchState replaces all references to a renamed branch with its new name
   273  func (d *DoltSession) RenameBranchState(ctx *sql.Context, dbName string, oldBranchName, newBranchName string) error {
   274  	baseName, _ := SplitRevisionDbName(dbName)
   275  
   276  	checkedOutState, ok, err := d.lookupDbState(ctx, baseName)
   277  	if err != nil {
   278  		return err
   279  	}
   280  	if !ok {
   281  		return sql.ErrDatabaseNotFound.New(baseName)
   282  	}
   283  
   284  	d.mu.Lock()
   285  	branch, ok := checkedOutState.dbState.heads[strings.ToLower(oldBranchName)]
   286  
   287  	if !ok {
   288  		// nothing to rename
   289  		d.mu.Unlock()
   290  		return nil
   291  	}
   292  
   293  	delete(checkedOutState.dbState.heads, strings.ToLower(oldBranchName))
   294  	branch.head = strings.ToLower(newBranchName)
   295  	checkedOutState.dbState.heads[strings.ToLower(newBranchName)] = branch
   296  
   297  	d.mu.Unlock()
   298  
   299  	// also clear out any db-level caches for this db
   300  	d.dbCache.Clear()
   301  	return nil
   302  }
   303  
   304  // SetValidateErr sets an error on this session to be returned from every call
   305  // to ValidateSession. This is effectively a way to disable a session.
   306  //
   307  // Used by sql/cluster logic to make sessions on a server which has
   308  // transitioned roles terminally error.
   309  func (d *DoltSession) SetValidateErr(err error) {
   310  	d.validateErr = err
   311  }
   312  
   313  // ValidateSession validates a working set if there are a valid sessionState with non-nil working set.
   314  // If there is no sessionState or its current working set not defined, then no need for validation,
   315  // so no error is returned.
   316  func (d *DoltSession) ValidateSession(ctx *sql.Context) error {
   317  	return d.validateErr
   318  }
   319  
   320  // StartTransaction refreshes the state of this session and starts a new transaction.
   321  func (d *DoltSession) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) {
   322  	// TODO: this is only necessary to support filter-branch, which needs to set a root directly and not have the
   323  	//  session state altered when a transaction begins
   324  	if TransactionsDisabled(ctx) {
   325  		return DisabledTransaction{}, nil
   326  	}
   327  
   328  	// New transaction, clear all session state
   329  	d.clear()
   330  
   331  	// Take a snapshot of the current noms root for every database under management
   332  	doltDatabases := d.provider.DoltDatabases()
   333  	txDbs := make([]SqlDatabase, 0, len(doltDatabases))
   334  	for _, db := range doltDatabases {
   335  		// TODO: this nil check is only necessary to support UserSpaceDatabase and clusterDatabase, come up with a better set of
   336  		//  interfaces to capture these capabilities
   337  		ddb := db.DbData().Ddb
   338  		if ddb != nil {
   339  			rrd, ok := db.(RemoteReadReplicaDatabase)
   340  			if ok && rrd.ValidReplicaState(ctx) {
   341  				err := rrd.PullFromRemote(ctx)
   342  				if err != nil && !IgnoreReplicationErrors() {
   343  					return nil, fmt.Errorf("replication error: %w", err)
   344  				} else if err != nil {
   345  					WarnReplicationError(ctx, err)
   346  				}
   347  			}
   348  
   349  			// TODO: this check is relatively expensive, we should cache this value when it changes instead of looking it
   350  			//  up on each transaction start
   351  			if _, v, ok := sql.SystemVariables.GetGlobal(ReadReplicaRemote); ok && v != "" {
   352  				err := ddb.Rebase(ctx)
   353  				if err != nil && !IgnoreReplicationErrors() {
   354  					return nil, err
   355  				} else if err != nil {
   356  					WarnReplicationError(ctx, err)
   357  				}
   358  			}
   359  
   360  			txDbs = append(txDbs, db)
   361  		}
   362  	}
   363  
   364  	tx, err := NewDoltTransaction(ctx, txDbs, tCharacteristic)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	// The engine sets the transaction after this call as well, but since we begin accessing data below, we need to set
   370  	// this now to avoid seeding the session state with stale data in some cases. The duplication is harmless since the
   371  	// code below cannot error. Additionally we clear any state that was cached by replication updates in the block above.
   372  	d.clear()
   373  	ctx.SetTransaction(tx)
   374  
   375  	// Set session vars for every DB in this session using their current branch head
   376  	for _, db := range doltDatabases {
   377  		// faulty settings can make it impossible to load particular DB branch states, so we ignore any errors in this
   378  		// loop and just decline to set the session vars. Throwing an error on transaction start in these cases makes it
   379  		// impossible for the user to correct any problems.
   380  		bs, ok, err := d.lookupDbState(ctx, db.Name())
   381  		if err != nil || !ok {
   382  			continue
   383  		}
   384  
   385  		_ = d.setDbSessionVars(ctx, bs, false)
   386  	}
   387  
   388  	return tx, nil
   389  }
   390  
   391  // clear clears all DB state for this session
   392  func (d *DoltSession) clear() {
   393  	d.mu.Lock()
   394  	defer d.mu.Unlock()
   395  
   396  	for _, dbState := range d.dbStates {
   397  		for head := range dbState.heads {
   398  			delete(dbState.heads, head)
   399  		}
   400  	}
   401  }
   402  
   403  func (d *DoltSession) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) {
   404  	dbData, _ := d.GetDbData(nil, dbName)
   405  
   406  	headSpec, _ := doltdb.NewCommitSpec("HEAD")
   407  	headRef, err := wsRef.ToHeadRef()
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  
   412  	optCmt, err := dbData.Ddb.Resolve(ctx, headSpec, headRef)
   413  	if err != nil {
   414  		return nil, err
   415  	}
   416  	headCommit, ok := optCmt.ToCommit()
   417  	if !ok {
   418  		return nil, doltdb.ErrGhostCommitEncountered
   419  	}
   420  
   421  	headRoot, err := headCommit.GetRootValue(ctx)
   422  	if err != nil {
   423  		return nil, err
   424  	}
   425  
   426  	return doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(headRoot).WithStagedRoot(headRoot), nil
   427  }
   428  
   429  // CommitTransaction commits the in-progress transaction. Depending on session settings, this may write only a new
   430  // working set, or may additionally create a new dolt commit for the current HEAD. If more than one branch head has
   431  // changes, the transaction is rejected.
   432  func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) (err error) {
   433  	// Any non-error path must set the ctx's transaction to nil even if no work was done, because the engine only clears
   434  	// out transaction state in some cases. Changes to only branch heads (creating a new branch, reset, etc.) have no
   435  	// changes to commit visible to the transaction logic, but they still need a new transaction on the next statement.
   436  	// See comment in |commitBranchState|
   437  	defer func() {
   438  		if err == nil {
   439  			ctx.SetTransaction(nil)
   440  		}
   441  	}()
   442  
   443  	if TransactionsDisabled(ctx) {
   444  		return nil
   445  	}
   446  
   447  	dirties := d.dirtyWorkingSets()
   448  	if len(dirties) == 0 {
   449  		return nil
   450  	}
   451  
   452  	if len(dirties) > 1 {
   453  		return ErrDirtyWorkingSets
   454  	}
   455  
   456  	performDoltCommitVar, err := d.Session.GetSessionVariable(ctx, DoltCommitOnTransactionCommit)
   457  	if err != nil {
   458  		return err
   459  	}
   460  
   461  	peformDoltCommitInt, ok := performDoltCommitVar.(int8)
   462  	if !ok {
   463  		return fmt.Errorf(fmt.Sprintf("Unexpected type for var %s: %T", DoltCommitOnTransactionCommit, performDoltCommitVar))
   464  	}
   465  
   466  	dirtyBranchState := dirties[0]
   467  	if peformDoltCommitInt == 1 {
   468  		// if the dirty working set doesn't belong to the currently checked out branch, that's an error
   469  		err = d.validateDoltCommit(ctx, dirtyBranchState)
   470  		if err != nil {
   471  			return err
   472  		}
   473  
   474  		message := "Transaction commit"
   475  		doltCommitMessageVar, err := d.Session.GetSessionVariable(ctx, DoltCommitOnTransactionCommitMessage)
   476  		if err != nil {
   477  			return err
   478  		}
   479  
   480  		doltCommitMessageString, ok := doltCommitMessageVar.(string)
   481  		if !ok && doltCommitMessageVar != nil {
   482  			return fmt.Errorf(fmt.Sprintf("Unexpected type for var %s: %T", DoltCommitOnTransactionCommitMessage, doltCommitMessageVar))
   483  		}
   484  
   485  		trimmedString := strings.TrimSpace(doltCommitMessageString)
   486  		if strings.TrimSpace(doltCommitMessageString) != "" {
   487  			message = trimmedString
   488  		}
   489  
   490  		var pendingCommit *doltdb.PendingCommit
   491  		pendingCommit, err = d.PendingCommitAllStaged(ctx, dirtyBranchState, actions.CommitStagedProps{
   492  			Message:    message,
   493  			Date:       ctx.QueryTime(),
   494  			AllowEmpty: false,
   495  			Force:      false,
   496  			Name:       d.Username(),
   497  			Email:      d.Email(),
   498  		})
   499  		if err != nil {
   500  			return err
   501  		}
   502  
   503  		// Nothing to stage, so fall back to CommitWorkingSet logic instead
   504  		if pendingCommit == nil {
   505  			return d.commitWorkingSet(ctx, dirtyBranchState, tx)
   506  		}
   507  
   508  		_, err = d.DoltCommit(ctx, ctx.GetCurrentDatabase(), tx, pendingCommit)
   509  		return err
   510  	} else {
   511  		return d.commitWorkingSet(ctx, dirtyBranchState, tx)
   512  	}
   513  }
   514  
   515  func (d *DoltSession) validateDoltCommit(ctx *sql.Context, dirtyBranchState *branchState) error {
   516  	currDb := ctx.GetCurrentDatabase()
   517  	if currDb == "" {
   518  		return fmt.Errorf("cannot dolt_commit with no database selected")
   519  	}
   520  	currDbBaseName, rev := SplitRevisionDbName(currDb)
   521  	dirtyDbBaseName := dirtyBranchState.dbState.dbName
   522  
   523  	if strings.ToLower(currDbBaseName) != strings.ToLower(dirtyDbBaseName) {
   524  		return fmt.Errorf("no changes to dolt_commit on database %s", currDbBaseName)
   525  	}
   526  
   527  	d.mu.Lock()
   528  	dbState, ok := d.dbStates[strings.ToLower(currDbBaseName)]
   529  	d.mu.Unlock()
   530  
   531  	if !ok {
   532  		return fmt.Errorf("no database state found for %s", currDbBaseName)
   533  	}
   534  
   535  	if rev == "" {
   536  		rev = dbState.checkedOutRevSpec
   537  	}
   538  
   539  	if strings.ToLower(rev) != strings.ToLower(dirtyBranchState.head) {
   540  		return fmt.Errorf("no changes to dolt_commit on branch %s", rev)
   541  	}
   542  
   543  	return nil
   544  }
   545  
   546  var ErrDirtyWorkingSets = errors.New("Cannot commit changes on more than one branch / database")
   547  
   548  // dirtyWorkingSets returns all dirty working sets for this session
   549  func (d *DoltSession) dirtyWorkingSets() []*branchState {
   550  	var dirtyStates []*branchState
   551  	for _, state := range d.dbStates {
   552  		for _, branchState := range state.heads {
   553  			if branchState.dirty {
   554  				dirtyStates = append(dirtyStates, branchState)
   555  			}
   556  		}
   557  	}
   558  
   559  	return dirtyStates
   560  }
   561  
   562  // CommitWorkingSet commits the working set for the transaction given, without creating a new dolt commit.
   563  // Clients should typically use CommitTransaction, which performs additional checks, instead of this method.
   564  func (d *DoltSession) CommitWorkingSet(ctx *sql.Context, dbName string, tx sql.Transaction) error {
   565  	commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) {
   566  		ws, err := dtx.Commit(ctx, workingSet, dbName)
   567  		return ws, nil, err
   568  	}
   569  
   570  	_, err := d.commitCurrentHead(ctx, dbName, tx, commitFunc)
   571  	return err
   572  }
   573  
   574  // commitWorkingSet commits the working set for the branch state given, without creating a new dolt commit.
   575  func (d *DoltSession) commitWorkingSet(ctx *sql.Context, branchState *branchState, tx sql.Transaction) error {
   576  	commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) {
   577  		ws, err := dtx.Commit(ctx, workingSet, branchState.RevisionDbName())
   578  		return ws, nil, err
   579  	}
   580  
   581  	_, err := d.commitBranchState(ctx, branchState, tx, commitFunc)
   582  	return err
   583  }
   584  
   585  // DoltCommit commits the working set and a new dolt commit with the properties given.
   586  // Clients should typically use CommitTransaction, which performs additional checks, instead of this method.
   587  func (d *DoltSession) DoltCommit(
   588  	ctx *sql.Context,
   589  	dbName string,
   590  	tx sql.Transaction,
   591  	commit *doltdb.PendingCommit,
   592  ) (*doltdb.Commit, error) {
   593  	commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) {
   594  		ws, commit, err := dtx.DoltCommit(
   595  			ctx,
   596  			workingSet.WithWorkingRoot(commit.Roots.Working).WithStagedRoot(commit.Roots.Staged),
   597  			commit,
   598  			dbName)
   599  		if err != nil {
   600  			return nil, nil, err
   601  		}
   602  
   603  		return ws, commit, err
   604  	}
   605  
   606  	return d.commitCurrentHead(ctx, dbName, tx, commitFunc)
   607  }
   608  
   609  // doCommitFunc is a function to write to the database, which involves updating the working set and potentially
   610  // updating HEAD with a new commit
   611  type doCommitFunc func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error)
   612  
   613  // commitBranchState performs a commit for the branch state given, using the doCommitFunc provided
   614  func (d *DoltSession) commitBranchState(
   615  	ctx *sql.Context,
   616  	branchState *branchState,
   617  	tx sql.Transaction,
   618  	commitFunc doCommitFunc,
   619  ) (*doltdb.Commit, error) {
   620  	dtx, ok := tx.(*DoltTransaction)
   621  	if !ok {
   622  		return nil, fmt.Errorf("expected a DoltTransaction")
   623  	}
   624  
   625  	_, newCommit, err := commitFunc(ctx, dtx, branchState.WorkingSet())
   626  	if err != nil {
   627  		return nil, err
   628  	}
   629  
   630  	// Anything that commits a transaction needs its current transaction state cleared so that the next statement starts
   631  	// a new transaction. This should in principle be done by the engine, but it currently only understands explicit
   632  	// COMMIT statements. Any other statements that commit a transaction, including stored procedures, needs to do this
   633  	// themselves.
   634  	ctx.SetTransaction(nil)
   635  	return newCommit, nil
   636  }
   637  
   638  // commitCurrentHead commits the current HEAD for the database given, using the doCommitFunc provided
   639  func (d *DoltSession) commitCurrentHead(ctx *sql.Context, dbName string, tx sql.Transaction, commitFunc doCommitFunc) (*doltdb.Commit, error) {
   640  	branchState, ok, err := d.lookupDbState(ctx, dbName)
   641  	if err != nil {
   642  		return nil, err
   643  	} else if !ok {
   644  		return nil, sql.ErrDatabaseNotFound.New(dbName)
   645  	}
   646  
   647  	return d.commitBranchState(ctx, branchState, tx, commitFunc)
   648  }
   649  
   650  // PendingCommitAllStaged returns a pending commit with all tables staged. Returns nil if there are no changes to stage.
   651  func (d *DoltSession) PendingCommitAllStaged(ctx *sql.Context, branchState *branchState, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) {
   652  	roots := branchState.roots()
   653  
   654  	var err error
   655  	roots, err = actions.StageAllTables(ctx, roots, true)
   656  	if err != nil {
   657  		return nil, err
   658  	}
   659  
   660  	return d.newPendingCommit(ctx, branchState, roots, props)
   661  }
   662  
   663  // NewPendingCommit returns a new |doltdb.PendingCommit| for the database named, using the roots given, adding any
   664  // merge parent from an in progress merge as appropriate. The session working set is not updated with these new roots,
   665  // but they are set in the returned |doltdb.PendingCommit|. If there are no changes staged, this method returns nil.
   666  func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) {
   667  	branchState, ok, err := d.lookupDbState(ctx, dbName)
   668  	if err != nil {
   669  		return nil, err
   670  	}
   671  	if !ok {
   672  		return nil, fmt.Errorf("session state for database %s not found", dbName)
   673  	}
   674  
   675  	return d.newPendingCommit(ctx, branchState, roots, props)
   676  }
   677  
   678  // newPendingCommit returns a new |doltdb.PendingCommit| for the database and head named by |branchState|
   679  // See NewPendingCommit
   680  func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchState, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) {
   681  	headCommit := branchState.headCommit
   682  	headHash, _ := headCommit.HashOf()
   683  
   684  	if branchState.WorkingSet() == nil {
   685  		return nil, doltdb.ErrOperationNotSupportedInDetachedHead
   686  	}
   687  
   688  	var mergeParentCommits []*doltdb.Commit
   689  	if branchState.WorkingSet().MergeCommitParents() {
   690  		mergeParentCommits = []*doltdb.Commit{branchState.WorkingSet().MergeState().Commit()}
   691  	} else if props.Amend {
   692  		numParentsHeadForAmend := headCommit.NumParents()
   693  		for i := 0; i < numParentsHeadForAmend; i++ {
   694  			optCmt, err := headCommit.GetParent(ctx, i)
   695  			if err != nil {
   696  				return nil, err
   697  			}
   698  			parentCommit, ok := optCmt.ToCommit()
   699  			if !ok {
   700  				return nil, doltdb.ErrGhostCommitEncountered
   701  			}
   702  
   703  			mergeParentCommits = append(mergeParentCommits, parentCommit)
   704  		}
   705  
   706  		// TODO: This is not the correct way to write this commit as an amend. While this commit is running
   707  		//  the branch head moves backwards and concurrency control here is not principled.
   708  		newRoots, err := actions.ResetSoftToRef(ctx, branchState.dbData, "HEAD~1")
   709  		if err != nil {
   710  			return nil, err
   711  		}
   712  
   713  		err = d.SetWorkingSet(ctx, ctx.GetCurrentDatabase(), branchState.WorkingSet().WithStagedRoot(newRoots.Staged))
   714  		if err != nil {
   715  			return nil, err
   716  		}
   717  
   718  		roots.Head = newRoots.Head
   719  	}
   720  
   721  	pendingCommit, err := actions.GetCommitStaged(ctx, roots, branchState.WorkingSet(), mergeParentCommits, branchState.dbData.Ddb, props)
   722  	if err != nil {
   723  		if props.Amend {
   724  			_, err = actions.ResetSoftToRef(ctx, branchState.dbData, headHash.String())
   725  			if err != nil {
   726  				return nil, err
   727  			}
   728  		}
   729  		if _, ok := err.(actions.NothingStaged); err != nil && !ok {
   730  			return nil, err
   731  		}
   732  	}
   733  
   734  	return pendingCommit, nil
   735  }
   736  
   737  // Rollback rolls the given transaction back
   738  func (d *DoltSession) Rollback(ctx *sql.Context, tx sql.Transaction) error {
   739  	// Nothing to do here, we just throw away all our work and let a new transaction begin next statement
   740  	d.clear()
   741  	return nil
   742  }
   743  
   744  // CreateSavepoint creates a new savepoint for this transaction with the name given. A previously created savepoint
   745  // with the same name will be overwritten.
   746  func (d *DoltSession) CreateSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error {
   747  	if TransactionsDisabled(ctx) {
   748  		return nil
   749  	}
   750  
   751  	dtx, ok := tx.(*DoltTransaction)
   752  	if !ok {
   753  		return fmt.Errorf("expected a DoltTransaction")
   754  	}
   755  
   756  	roots := make(map[string]doltdb.RootValue)
   757  	for _, db := range d.provider.DoltDatabases() {
   758  		branchState, ok, err := d.lookupDbState(ctx, db.Name())
   759  		if err != nil {
   760  			return err
   761  		}
   762  		if !ok {
   763  			return fmt.Errorf("session state for database %s not found", db.Name())
   764  		}
   765  		baseName, _ := SplitRevisionDbName(db.Name())
   766  		roots[strings.ToLower(baseName)] = branchState.WorkingSet().WorkingRoot()
   767  	}
   768  
   769  	dtx.CreateSavepoint(savepointName, roots)
   770  	return nil
   771  }
   772  
   773  // RollbackToSavepoint sets this session's root to the one saved in the savepoint name. It's an error if no savepoint
   774  // with that name exists.
   775  func (d *DoltSession) RollbackToSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error {
   776  	if TransactionsDisabled(ctx) {
   777  		return nil
   778  	}
   779  
   780  	dtx, ok := tx.(*DoltTransaction)
   781  	if !ok {
   782  		return fmt.Errorf("expected a DoltTransaction")
   783  	}
   784  
   785  	roots := dtx.RollbackToSavepoint(savepointName)
   786  	if roots == nil {
   787  		return sql.ErrSavepointDoesNotExist.New(savepointName)
   788  	}
   789  
   790  	for dbName, root := range roots {
   791  		err := d.SetWorkingRoot(ctx, dbName, root)
   792  		if err != nil {
   793  			return err
   794  		}
   795  	}
   796  
   797  	return nil
   798  }
   799  
   800  // ReleaseSavepoint removes the savepoint name from the transaction. It's an error if no savepoint with that name
   801  // exists.
   802  func (d *DoltSession) ReleaseSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error {
   803  	if TransactionsDisabled(ctx) {
   804  		return nil
   805  	}
   806  
   807  	dtx, ok := tx.(*DoltTransaction)
   808  	if !ok {
   809  		return fmt.Errorf("expected a DoltTransaction")
   810  	}
   811  
   812  	existed := dtx.ClearSavepoint(savepointName)
   813  	if !existed {
   814  		return sql.ErrSavepointDoesNotExist.New(savepointName)
   815  	}
   816  
   817  	return nil
   818  }
   819  
   820  // GetDoltDB returns the *DoltDB for a given database by name
   821  func (d *DoltSession) GetDoltDB(ctx *sql.Context, dbName string) (*doltdb.DoltDB, bool) {
   822  	branchState, ok, err := d.lookupDbState(ctx, dbName)
   823  	if err != nil {
   824  		return nil, false
   825  	}
   826  	if !ok {
   827  		return nil, false
   828  	}
   829  
   830  	return branchState.dbData.Ddb, true
   831  }
   832  
   833  func (d *DoltSession) GetDbData(ctx *sql.Context, dbName string) (env.DbData, bool) {
   834  	branchState, ok, err := d.lookupDbState(ctx, dbName)
   835  	if err != nil {
   836  		return env.DbData{}, false
   837  	}
   838  	if !ok {
   839  		return env.DbData{}, false
   840  	}
   841  
   842  	return branchState.dbData, true
   843  }
   844  
   845  // GetRoots returns the current roots for a given database associated with the session
   846  func (d *DoltSession) GetRoots(ctx *sql.Context, dbName string) (doltdb.Roots, bool) {
   847  	branchState, ok, err := d.lookupDbState(ctx, dbName)
   848  	if err != nil {
   849  		return doltdb.Roots{}, false
   850  	}
   851  	if !ok {
   852  		return doltdb.Roots{}, false
   853  	}
   854  
   855  	return branchState.roots(), true
   856  }
   857  
   858  // ResolveRootForRef returns the root value for the ref given, which refers to either a commit spec or is one of the
   859  // special identifiers |WORKING| or |STAGED|
   860  // Returns the root value associated with the identifier given, its commit time and its hash string. The hash string
   861  // for special identifiers |WORKING| or |STAGED| would be itself, 'WORKING' or 'STAGED', respectively.
   862  func (d *DoltSession) ResolveRootForRef(ctx *sql.Context, dbName, refStr string) (doltdb.RootValue, *types.Timestamp, string, error) {
   863  	if refStr == doltdb.Working || refStr == doltdb.Staged {
   864  		// TODO: get from working set / staged update time
   865  		now := types.Timestamp(time.Now())
   866  		// TODO: no current database
   867  		roots, _ := d.GetRoots(ctx, ctx.GetCurrentDatabase())
   868  		if refStr == doltdb.Working {
   869  			return roots.Working, &now, refStr, nil
   870  		} else if refStr == doltdb.Staged {
   871  			return roots.Staged, &now, refStr, nil
   872  		}
   873  	}
   874  
   875  	var root doltdb.RootValue
   876  	var commitTime *types.Timestamp
   877  	cs, err := doltdb.NewCommitSpec(refStr)
   878  	if err != nil {
   879  		return nil, nil, "", err
   880  	}
   881  
   882  	dbData, ok := d.GetDbData(ctx, dbName)
   883  	if !ok {
   884  		return nil, nil, "", sql.ErrDatabaseNotFound.New(dbName)
   885  	}
   886  
   887  	headRef, err := d.CWBHeadRef(ctx, dbName)
   888  	if err == doltdb.ErrOperationNotSupportedInDetachedHead {
   889  		// leave head ref nil, we may not need it (commit hash)
   890  	} else if err != nil {
   891  		return nil, nil, "", err
   892  	}
   893  
   894  	optCmt, err := dbData.Ddb.Resolve(ctx, cs, headRef)
   895  	if err != nil {
   896  		return nil, nil, "", err
   897  	}
   898  	cm, ok := optCmt.ToCommit()
   899  	if !ok {
   900  		return nil, nil, "", doltdb.ErrGhostCommitRuntimeFailure
   901  	}
   902  
   903  	root, err = cm.GetRootValue(ctx)
   904  	if err != nil {
   905  		return nil, nil, "", err
   906  	}
   907  
   908  	meta, err := cm.GetCommitMeta(ctx)
   909  	if err != nil {
   910  		return nil, nil, "", err
   911  	}
   912  
   913  	t := meta.Time()
   914  	commitTime = (*types.Timestamp)(&t)
   915  
   916  	commitHash, err := cm.HashOf()
   917  	if err != nil {
   918  		return nil, nil, "", err
   919  	}
   920  
   921  	return root, commitTime, commitHash.String(), nil
   922  }
   923  
   924  // SetWorkingRoot sets a new root value for the session for the database named. This is the primary mechanism by which data
   925  // changes are communicated to the engine and persisted back to disk. All data changes should be followed by a call to
   926  // update the session's root value via this method.
   927  // The dbName given should generally be a revision-qualified database name.
   928  // Data changes contained in the |newRoot| aren't persisted until this session is committed.
   929  func (d *DoltSession) SetWorkingRoot(ctx *sql.Context, dbName string, newRoot doltdb.RootValue) error {
   930  	branchState, _, err := d.lookupDbState(ctx, dbName)
   931  	if err != nil {
   932  		return err
   933  	}
   934  
   935  	if branchState.WorkingSet() == nil {
   936  		return doltdb.ErrOperationNotSupportedInDetachedHead
   937  	}
   938  
   939  	if rootsEqual(branchState.roots().Working, newRoot) {
   940  		return nil
   941  	}
   942  
   943  	if branchState.readOnly {
   944  		return fmt.Errorf("cannot set root on read-only session")
   945  	}
   946  	branchState.workingSet = branchState.WorkingSet().WithWorkingRoot(newRoot)
   947  
   948  	return d.SetWorkingSet(ctx, dbName, branchState.WorkingSet())
   949  }
   950  
   951  // SetRoots sets new roots for the session for the database named. Typically, clients should only set the working root,
   952  // via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions.
   953  // Unlike setting the working root, this method always marks the database state dirty.
   954  func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error {
   955  	sessionState, _, err := d.LookupDbState(ctx, dbName)
   956  	if err != nil {
   957  		return err
   958  	}
   959  
   960  	if sessionState.WorkingSet() == nil {
   961  		return doltdb.ErrOperationNotSupportedInDetachedHead
   962  	}
   963  
   964  	workingSet := sessionState.WorkingSet().WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged)
   965  	return d.SetWorkingSet(ctx, dbName, workingSet)
   966  }
   967  
   968  func (d *DoltSession) SetFileSystem(fs filesys.Filesys) {
   969  	d.fs = fs
   970  }
   971  
   972  func (d *DoltSession) GetFileSystem() filesys.Filesys {
   973  	return d.fs
   974  }
   975  
   976  // SetWorkingSet sets the working set for this session.
   977  func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb.WorkingSet) error {
   978  	if ws == nil {
   979  		panic("attempted to set a nil working set for the session")
   980  	}
   981  
   982  	branchState, _, err := d.lookupDbState(ctx, dbName)
   983  	if err != nil {
   984  		return err
   985  	}
   986  	if ws.Ref() != branchState.WorkingSet().Ref() {
   987  		return fmt.Errorf("must switch working sets with SwitchWorkingSet")
   988  	}
   989  	branchState.workingSet = ws
   990  
   991  	err = d.setDbSessionVars(ctx, branchState, true)
   992  	if err != nil {
   993  		return err
   994  	}
   995  
   996  	if writeSess := branchState.WriteSession(); writeSess != nil {
   997  		err = writeSess.SetWorkingSet(ctx, ws)
   998  		if err != nil {
   999  			return err
  1000  		}
  1001  	}
  1002  
  1003  	branchState.dirty = true
  1004  	return nil
  1005  }
  1006  
  1007  // SwitchWorkingSet switches to a new working set for this session. Unlike SetWorkingSet, this method expresses no
  1008  // intention to eventually persist any uncommitted changes. Rather, this method only changes the in memory state of
  1009  // this session. It's equivalent to starting a new session with the working set reference provided. If the current
  1010  // session is dirty, this method returns an error. Clients can only switch branches with a clean working set, and so
  1011  // must either commit or rollback any changes before attempting to switch working sets.
  1012  func (d *DoltSession) SwitchWorkingSet(
  1013  	ctx *sql.Context,
  1014  	dbName string,
  1015  	wsRef ref.WorkingSetRef,
  1016  ) error {
  1017  	headRef, err := wsRef.ToHeadRef()
  1018  	if err != nil {
  1019  		return err
  1020  	}
  1021  
  1022  	d.mu.Lock()
  1023  
  1024  	baseName, _ := SplitRevisionDbName(dbName)
  1025  	dbState, ok := d.dbStates[strings.ToLower(baseName)]
  1026  	if !ok {
  1027  		d.mu.Unlock()
  1028  		return sql.ErrDatabaseNotFound.New(dbName)
  1029  	}
  1030  	dbState.checkedOutRevSpec = headRef.GetPath()
  1031  
  1032  	d.mu.Unlock()
  1033  
  1034  	// bootstrap the db state as necessary
  1035  	branchState, ok, err := d.lookupDbState(ctx, baseName+DbRevisionDelimiter+headRef.GetPath())
  1036  	if err != nil {
  1037  		return err
  1038  	}
  1039  
  1040  	if !ok {
  1041  		return sql.ErrDatabaseNotFound.New(dbName)
  1042  	}
  1043  
  1044  	ctx.SetCurrentDatabase(baseName)
  1045  
  1046  	return d.setDbSessionVars(ctx, branchState, false)
  1047  }
  1048  
  1049  func (d *DoltSession) WorkingSet(ctx *sql.Context, dbName string) (*doltdb.WorkingSet, error) {
  1050  	// TODO: need to make sure we use a revision qualified DB name here
  1051  	sessionState, _, err := d.LookupDbState(ctx, dbName)
  1052  	if err != nil {
  1053  		return nil, err
  1054  	}
  1055  	if sessionState.WorkingSet() == nil {
  1056  		return nil, doltdb.ErrOperationNotSupportedInDetachedHead
  1057  	}
  1058  	return sessionState.WorkingSet(), nil
  1059  }
  1060  
  1061  // GetHeadCommit returns the parent commit of the current session.
  1062  func (d *DoltSession) GetHeadCommit(ctx *sql.Context, dbName string) (*doltdb.Commit, error) {
  1063  	branchState, ok, err := d.lookupDbState(ctx, dbName)
  1064  	if err != nil {
  1065  		return nil, err
  1066  	}
  1067  	if !ok {
  1068  		return nil, sql.ErrDatabaseNotFound.New(dbName)
  1069  	}
  1070  
  1071  	return branchState.headCommit, nil
  1072  }
  1073  
  1074  // SetSessionVariable is defined on sql.Session. We intercept it here to interpret the special semantics of the system
  1075  // vars that we define. Otherwise we pass it on to the base implementation.
  1076  func (d *DoltSession) SetSessionVariable(ctx *sql.Context, key string, value interface{}) error {
  1077  	if ok, db := IsHeadRefKey(key); ok {
  1078  		v, ok := value.(string)
  1079  		if !ok {
  1080  			return doltdb.ErrInvalidBranchOrHash
  1081  		}
  1082  		return d.setHeadRefSessionVar(ctx, db, v)
  1083  	}
  1084  	if IsReadOnlyVersionKey(key) {
  1085  		return sql.ErrSystemVariableReadOnly.New(key)
  1086  	}
  1087  
  1088  	if strings.ToLower(key) == "foreign_key_checks" {
  1089  		return d.setForeignKeyChecksSessionVar(ctx, key, value)
  1090  	}
  1091  
  1092  	return d.Session.SetSessionVariable(ctx, key, value)
  1093  }
  1094  
  1095  func (d *DoltSession) setHeadRefSessionVar(ctx *sql.Context, db, value string) error {
  1096  	headRef, err := ref.Parse(value)
  1097  	if err != nil {
  1098  		return err
  1099  	}
  1100  
  1101  	ws, err := ref.WorkingSetRefForHead(headRef)
  1102  	if err != nil {
  1103  		return err
  1104  	}
  1105  	err = d.SwitchWorkingSet(ctx, db, ws)
  1106  	if errors.Is(err, doltdb.ErrWorkingSetNotFound) {
  1107  		return fmt.Errorf("%w; %s: '%s'", doltdb.ErrBranchNotFound, err, value)
  1108  	}
  1109  	return err
  1110  }
  1111  
  1112  func (d *DoltSession) setForeignKeyChecksSessionVar(ctx *sql.Context, key string, value interface{}) error {
  1113  	d.mu.Lock()
  1114  	defer d.mu.Unlock()
  1115  
  1116  	convertedVal, _, err := sqltypes.Int64.Convert(value)
  1117  	if err != nil {
  1118  		return err
  1119  	}
  1120  	intVal := int64(0)
  1121  	if convertedVal != nil {
  1122  		intVal = convertedVal.(int64)
  1123  	}
  1124  
  1125  	if intVal == 0 {
  1126  		for _, dbState := range d.dbStates {
  1127  			for _, branchState := range dbState.heads {
  1128  				if ws := branchState.WriteSession(); ws != nil {
  1129  					opts := ws.GetOptions()
  1130  					opts.ForeignKeyChecksDisabled = true
  1131  					ws.SetOptions(opts)
  1132  				}
  1133  			}
  1134  		}
  1135  	} else if intVal == 1 {
  1136  		for _, dbState := range d.dbStates {
  1137  			for _, branchState := range dbState.heads {
  1138  				if ws := branchState.WriteSession(); ws != nil {
  1139  					opts := ws.GetOptions()
  1140  					opts.ForeignKeyChecksDisabled = false
  1141  					ws.SetOptions(opts)
  1142  				}
  1143  			}
  1144  		}
  1145  	} else {
  1146  		return sql.ErrInvalidSystemVariableValue.New("foreign_key_checks", intVal)
  1147  	}
  1148  
  1149  	return d.Session.SetSessionVariable(ctx, key, value)
  1150  }
  1151  
  1152  // addDB adds the database given to this session. This establishes a starting root value for this session, as well as
  1153  // other state tracking metadata.
  1154  func (d *DoltSession) addDB(ctx *sql.Context, db SqlDatabase) error {
  1155  	revisionQualifiedName := strings.ToLower(db.RevisionQualifiedName())
  1156  	baseName, _ := SplitRevisionDbName(revisionQualifiedName)
  1157  
  1158  	DefineSystemVariablesForDB(baseName)
  1159  
  1160  	tx, usingDoltTransaction := d.GetTransaction().(*DoltTransaction)
  1161  
  1162  	d.mu.Lock()
  1163  	defer d.mu.Unlock()
  1164  	sessionState, sessionStateExists := d.dbStates[baseName]
  1165  
  1166  	// Before computing initial state for the DB, check to see if we have it in the cache
  1167  	var dbState InitialDbState
  1168  	var dbStateCached bool
  1169  	if usingDoltTransaction {
  1170  		nomsRoot, ok := tx.GetInitialRoot(baseName)
  1171  		if ok && sessionStateExists {
  1172  			dbState, dbStateCached = d.dbCache.GetCachedInitialDbState(doltdb.DataCacheKey{Hash: nomsRoot}, revisionQualifiedName)
  1173  		}
  1174  	}
  1175  
  1176  	if !dbStateCached {
  1177  		var err error
  1178  		dbState, err = db.InitialDBState(ctx)
  1179  		if err != nil {
  1180  			return err
  1181  		}
  1182  	}
  1183  
  1184  	if !sessionStateExists {
  1185  		sessionState = newEmptyDatabaseSessionState()
  1186  		d.dbStates[baseName] = sessionState
  1187  
  1188  		var err error
  1189  		sessionState.tmpFileDir, err = dbState.DbData.Rsw.TempTableFilesDir()
  1190  		if err != nil {
  1191  			if errors.Is(err, env.ErrDoltRepositoryNotFound) {
  1192  				return env.ErrFailedToAccessDB.New(dbState.Db.Name())
  1193  			}
  1194  			return err
  1195  		}
  1196  
  1197  		sessionState.dbName = baseName
  1198  
  1199  		baseDb, ok := d.provider.BaseDatabase(ctx, baseName)
  1200  		if !ok {
  1201  			return fmt.Errorf("unable to find database %s, this is a bug", baseName)
  1202  		}
  1203  
  1204  		// The checkedOutRevSpec should be the checked out branch of the database if available, or the revision
  1205  		// string otherwise
  1206  		sessionState.checkedOutRevSpec, err = DefaultHead(baseName, baseDb)
  1207  		if err != nil {
  1208  			return err
  1209  		}
  1210  	}
  1211  
  1212  	if !dbStateCached && usingDoltTransaction {
  1213  		nomsRoot, ok := tx.GetInitialRoot(baseName)
  1214  		if ok {
  1215  			d.dbCache.CacheInitialDbState(doltdb.DataCacheKey{Hash: nomsRoot}, revisionQualifiedName, dbState)
  1216  		}
  1217  	}
  1218  
  1219  	branchState := sessionState.NewEmptyBranchState(db.Revision(), db.RevisionType())
  1220  
  1221  	// TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and
  1222  	//  the writer with one that errors out
  1223  	// TODO: this no longer gets called at session creation time, so the error handling below never occurs when a
  1224  	//  database is deleted out from under a running server
  1225  	branchState.dbData = dbState.DbData
  1226  	adapter := NewSessionStateAdapter(d, db.Name(), dbState.Remotes, dbState.Branches, dbState.Backups)
  1227  	branchState.dbData.Rsr = adapter
  1228  	branchState.dbData.Rsw = adapter
  1229  	branchState.readOnly = dbState.ReadOnly
  1230  
  1231  	// TODO: figure out how to cast this to dsqle.SqlDatabase without creating import cycles
  1232  	// Or better yet, get rid of EditOptions from the database, it's a session setting
  1233  	nbf := types.Format_Default
  1234  	if branchState.dbData.Ddb != nil {
  1235  		nbf = branchState.dbData.Ddb.Format()
  1236  	}
  1237  	editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions()
  1238  
  1239  	if dbState.Err != nil {
  1240  		sessionState.Err = dbState.Err
  1241  	} else if dbState.WorkingSet != nil {
  1242  		branchState.workingSet = dbState.WorkingSet
  1243  
  1244  		// TODO: this is pretty clunky, there is a silly dependency between InitialDbState and globalstate.StateProvider
  1245  		//  that's hard to express with the current types
  1246  		stateProvider, ok := db.(globalstate.GlobalStateProvider)
  1247  		if !ok {
  1248  			return fmt.Errorf("database does not contain global state store")
  1249  		}
  1250  		sessionState.globalState = stateProvider.GetGlobalState()
  1251  
  1252  		tracker, err := sessionState.globalState.AutoIncrementTracker(ctx)
  1253  		if err != nil {
  1254  			return err
  1255  		}
  1256  		branchState.writeSession = writer.NewWriteSession(nbf, branchState.WorkingSet(), tracker, editOpts)
  1257  	}
  1258  
  1259  	// WorkingSet is nil in the case of a read only, detached head DB
  1260  	if dbState.HeadCommit != nil {
  1261  		headRoot, err := dbState.HeadCommit.GetRootValue(ctx)
  1262  		if err != nil {
  1263  			return err
  1264  		}
  1265  		branchState.headRoot = headRoot
  1266  	} else if dbState.HeadRoot != nil {
  1267  		branchState.headRoot = dbState.HeadRoot
  1268  	}
  1269  
  1270  	branchState.headCommit = dbState.HeadCommit
  1271  	return nil
  1272  }
  1273  
  1274  func (d *DoltSession) DatabaseCache(ctx *sql.Context) *DatabaseCache {
  1275  	return d.dbCache
  1276  }
  1277  
  1278  func (d *DoltSession) AddTemporaryTable(ctx *sql.Context, db string, tbl sql.Table) {
  1279  	d.tempTables[strings.ToLower(db)] = append(d.tempTables[strings.ToLower(db)], tbl)
  1280  }
  1281  
  1282  func (d *DoltSession) DropTemporaryTable(ctx *sql.Context, db, name string) {
  1283  	tables := d.tempTables[strings.ToLower(db)]
  1284  	for i, tbl := range d.tempTables[strings.ToLower(db)] {
  1285  		if strings.ToLower(tbl.Name()) == strings.ToLower(name) {
  1286  			tables = append(tables[:i], tables[i+1:]...)
  1287  			break
  1288  		}
  1289  	}
  1290  	d.tempTables[strings.ToLower(db)] = tables
  1291  }
  1292  
  1293  func (d *DoltSession) GetTemporaryTable(ctx *sql.Context, db, name string) (sql.Table, bool) {
  1294  	for _, tbl := range d.tempTables[strings.ToLower(db)] {
  1295  		if strings.ToLower(tbl.Name()) == strings.ToLower(name) {
  1296  			return tbl, true
  1297  		}
  1298  	}
  1299  	return nil, false
  1300  }
  1301  
  1302  // GetAllTemporaryTables returns all temp tables for this session.
  1303  func (d *DoltSession) GetAllTemporaryTables(ctx *sql.Context, db string) ([]sql.Table, error) {
  1304  	return d.tempTables[strings.ToLower(db)], nil
  1305  }
  1306  
  1307  // CWBHeadRef returns the branch ref for this session HEAD for the database named
  1308  func (d *DoltSession) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, error) {
  1309  	branchState, ok, err := d.lookupDbState(ctx, dbName)
  1310  	if err != nil {
  1311  		return nil, err
  1312  	}
  1313  	if !ok {
  1314  		return nil, sql.ErrDatabaseNotFound.New(dbName)
  1315  	}
  1316  
  1317  	if branchState.revisionType != RevisionTypeBranch {
  1318  		return nil, doltdb.ErrOperationNotSupportedInDetachedHead
  1319  	}
  1320  
  1321  	return ref.NewBranchRef(branchState.head), nil
  1322  }
  1323  
  1324  // CurrentHead returns the current head for the db named, which must be unqualified. Used for bootstrap resolving the
  1325  // correct session head when a database name from the client is unqualified.
  1326  func (d *DoltSession) CurrentHead(ctx *sql.Context, dbName string) (string, bool, error) {
  1327  	baseName := strings.ToLower(dbName)
  1328  
  1329  	d.mu.Lock()
  1330  	dbState, ok := d.dbStates[baseName]
  1331  	d.mu.Unlock()
  1332  
  1333  	if ok {
  1334  		return dbState.checkedOutRevSpec, true, nil
  1335  	}
  1336  
  1337  	return "", false, nil
  1338  }
  1339  
  1340  func (d *DoltSession) Username() string {
  1341  	return d.username
  1342  }
  1343  
  1344  func (d *DoltSession) Email() string {
  1345  	return d.email
  1346  }
  1347  
  1348  // setDbSessionVars updates the three session vars that track the value of the session root hashes
  1349  func (d *DoltSession) setDbSessionVars(ctx *sql.Context, state *branchState, force bool) error {
  1350  	// This check is important even when we are forcing an update, because it updates the idea of staleness
  1351  	varsStale := d.dbSessionVarsStale(ctx, state)
  1352  	if !varsStale && !force {
  1353  		return nil
  1354  	}
  1355  
  1356  	baseName := state.dbState.dbName
  1357  
  1358  	// Different DBs have different requirements for what state is set, so we are maximally permissive on what's expected
  1359  	// in the state object here
  1360  	if state.WorkingSet() != nil {
  1361  		headRef, err := state.WorkingSet().Ref().ToHeadRef()
  1362  		if err != nil {
  1363  			return err
  1364  		}
  1365  
  1366  		err = d.Session.SetSessionVariable(ctx, HeadRefKey(baseName), headRef.String())
  1367  		if err != nil {
  1368  			return err
  1369  		}
  1370  	}
  1371  
  1372  	roots := state.roots()
  1373  
  1374  	if roots.Working != nil {
  1375  		h, err := roots.Working.HashOf()
  1376  		if err != nil {
  1377  			return err
  1378  		}
  1379  		err = d.Session.SetSessionVariable(ctx, WorkingKey(baseName), h.String())
  1380  		if err != nil {
  1381  			return err
  1382  		}
  1383  	}
  1384  
  1385  	if roots.Staged != nil {
  1386  		h, err := roots.Staged.HashOf()
  1387  		if err != nil {
  1388  			return err
  1389  		}
  1390  		err = d.Session.SetSessionVariable(ctx, StagedKey(baseName), h.String())
  1391  		if err != nil {
  1392  			return err
  1393  		}
  1394  	}
  1395  
  1396  	if state.headCommit != nil {
  1397  		h, err := state.headCommit.HashOf()
  1398  		if err != nil {
  1399  			return err
  1400  		}
  1401  		err = d.Session.SetSessionVariable(ctx, HeadKey(baseName), h.String())
  1402  		if err != nil {
  1403  			return err
  1404  		}
  1405  	}
  1406  
  1407  	return nil
  1408  }
  1409  
  1410  // dbSessionVarsStale returns whether the session vars for the database with the state provided need to be updated in
  1411  // the session
  1412  func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) bool {
  1413  	dtx, ok := ctx.GetTransaction().(*DoltTransaction)
  1414  	if !ok {
  1415  		return true
  1416  	}
  1417  
  1418  	return d.dbCache.CacheSessionVars(state, dtx)
  1419  }
  1420  
  1421  func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
  1422  	d.globalsConf = conf
  1423  	return &d
  1424  }
  1425  
  1426  // PersistGlobal implements sql.PersistableSession
  1427  func (d *DoltSession) PersistGlobal(sysVarName string, value interface{}) error {
  1428  	if d.globalsConf == nil {
  1429  		return ErrSessionNotPersistable
  1430  	}
  1431  
  1432  	sysVar, _, err := validatePersistableSysVar(sysVarName)
  1433  	if err != nil {
  1434  		return err
  1435  	}
  1436  
  1437  	d.mu.Lock()
  1438  	defer d.mu.Unlock()
  1439  	return setPersistedValue(d.globalsConf, sysVar.GetName(), value)
  1440  }
  1441  
  1442  // RemovePersistedGlobal implements sql.PersistableSession
  1443  func (d *DoltSession) RemovePersistedGlobal(sysVarName string) error {
  1444  	if d.globalsConf == nil {
  1445  		return ErrSessionNotPersistable
  1446  	}
  1447  
  1448  	sysVar, _, err := validatePersistableSysVar(sysVarName)
  1449  	if err != nil {
  1450  		return err
  1451  	}
  1452  
  1453  	d.mu.Lock()
  1454  	defer d.mu.Unlock()
  1455  	return d.globalsConf.Unset([]string{sysVar.GetName()})
  1456  }
  1457  
  1458  // RemoveAllPersistedGlobals implements sql.PersistableSession
  1459  func (d *DoltSession) RemoveAllPersistedGlobals() error {
  1460  	if d.globalsConf == nil {
  1461  		return ErrSessionNotPersistable
  1462  	}
  1463  
  1464  	allVars := make([]string, d.globalsConf.Size())
  1465  	i := 0
  1466  	d.globalsConf.Iter(func(k, v string) bool {
  1467  		allVars[i] = k
  1468  		i++
  1469  		return false
  1470  	})
  1471  
  1472  	d.mu.Lock()
  1473  	defer d.mu.Unlock()
  1474  	return d.globalsConf.Unset(allVars)
  1475  }
  1476  
  1477  // RemoveAllPersistedGlobals implements sql.PersistableSession
  1478  func (d *DoltSession) GetPersistedValue(k string) (interface{}, error) {
  1479  	if d.globalsConf == nil {
  1480  		return nil, ErrSessionNotPersistable
  1481  	}
  1482  
  1483  	return getPersistedValue(d.globalsConf, k)
  1484  }
  1485  
  1486  // SystemVariablesInConfig returns a list of System Variables associated with the session
  1487  func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
  1488  	if d.globalsConf == nil {
  1489  		return nil, ErrSessionNotPersistable
  1490  	}
  1491  	sysVars, _, err := SystemVariablesInConfig(d.globalsConf)
  1492  	if err != nil {
  1493  		return nil, err
  1494  	}
  1495  	return sysVars, nil
  1496  }
  1497  
  1498  // GetBranch implements the interface branch_control.Context.
  1499  func (d *DoltSession) GetBranch() (string, error) {
  1500  	// TODO: creating a new SQL context here is expensive
  1501  	ctx := sql.NewContext(context.Background(), sql.WithSession(d))
  1502  	currentDb := d.Session.GetCurrentDatabase()
  1503  
  1504  	// no branch if there's no current db
  1505  	if currentDb == "" {
  1506  		return "", nil
  1507  	}
  1508  
  1509  	branchState, _, err := d.LookupDbState(ctx, currentDb)
  1510  	if err != nil {
  1511  		return "", err
  1512  	}
  1513  
  1514  	if branchState.WorkingSet() != nil {
  1515  		branchRef, err := branchState.WorkingSet().Ref().ToHeadRef()
  1516  		if err != nil {
  1517  			return "", err
  1518  		}
  1519  		return branchRef.GetPath(), nil
  1520  	}
  1521  	// A nil working set probably means that we're not on a branch (like we may be on a commit), so we return an empty string
  1522  	return "", nil
  1523  }
  1524  
  1525  // GetUser implements the interface branch_control.Context.
  1526  func (d *DoltSession) GetUser() string {
  1527  	return d.Session.Client().User
  1528  }
  1529  
  1530  // GetHost implements the interface branch_control.Context.
  1531  func (d *DoltSession) GetHost() string {
  1532  	return d.Session.Client().Address
  1533  }
  1534  
  1535  // GetController implements the interface branch_control.Context.
  1536  func (d *DoltSession) GetController() *branch_control.Controller {
  1537  	return d.branchController
  1538  }
  1539  
  1540  // validatePersistedSysVar checks whether a system variable exists and is dynamic
  1541  func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) {
  1542  	sysVar, val, ok := sql.SystemVariables.GetGlobal(name)
  1543  	if !ok {
  1544  		return nil, nil, sql.ErrUnknownSystemVariable.New(name)
  1545  	}
  1546  	if sysVar.IsReadOnly() {
  1547  		return nil, nil, sql.ErrSystemVariableReadOnly.New(name)
  1548  	}
  1549  	return sysVar, val, nil
  1550  }
  1551  
  1552  // getPersistedValue reads and converts a config value to the associated MysqlSystemVariable type
  1553  func getPersistedValue(conf config.ReadableConfig, k string) (interface{}, error) {
  1554  	v, err := conf.GetString(k)
  1555  	if err != nil {
  1556  		return nil, err
  1557  	}
  1558  
  1559  	_, value, err := validatePersistableSysVar(k)
  1560  	if err != nil {
  1561  		return nil, err
  1562  	}
  1563  
  1564  	var res interface{}
  1565  	switch value.(type) {
  1566  	case int8:
  1567  		var tmp int64
  1568  		tmp, err = strconv.ParseInt(v, 10, 8)
  1569  		res = int8(tmp)
  1570  	case int, int16, int32, int64:
  1571  		res, err = strconv.ParseInt(v, 10, 64)
  1572  	case uint, uint8, uint16, uint32, uint64:
  1573  		res, err = strconv.ParseUint(v, 10, 64)
  1574  	case float32, float64:
  1575  		res, err = strconv.ParseFloat(v, 64)
  1576  	case bool:
  1577  		return nil, sql.ErrInvalidType.New(value)
  1578  	case string:
  1579  		return v, nil
  1580  	default:
  1581  		return nil, sql.ErrInvalidType.New(value)
  1582  	}
  1583  
  1584  	if err != nil {
  1585  		return nil, err
  1586  	}
  1587  
  1588  	return res, nil
  1589  }
  1590  
  1591  // setPersistedValue casts and persists a key value pair assuming thread safety
  1592  func setPersistedValue(conf config.WritableConfig, key string, value interface{}) error {
  1593  	switch v := value.(type) {
  1594  	case int:
  1595  		return config.SetInt(conf, key, int64(v))
  1596  	case int8:
  1597  		return config.SetInt(conf, key, int64(v))
  1598  	case int16:
  1599  		return config.SetInt(conf, key, int64(v))
  1600  	case int32:
  1601  		return config.SetInt(conf, key, int64(v))
  1602  	case int64:
  1603  		return config.SetInt(conf, key, v)
  1604  	case uint:
  1605  		return config.SetUint(conf, key, uint64(v))
  1606  	case uint8:
  1607  		return config.SetUint(conf, key, uint64(v))
  1608  	case uint16:
  1609  		return config.SetUint(conf, key, uint64(v))
  1610  	case uint32:
  1611  		return config.SetUint(conf, key, uint64(v))
  1612  	case uint64:
  1613  		return config.SetUint(conf, key, v)
  1614  	case float32:
  1615  		return config.SetFloat(conf, key, float64(v))
  1616  	case float64:
  1617  		return config.SetFloat(conf, key, v)
  1618  	case decimal.Decimal:
  1619  		f64, _ := v.Float64()
  1620  		return config.SetFloat(conf, key, f64)
  1621  	case string:
  1622  		return config.SetString(conf, key, v)
  1623  	case bool:
  1624  		if v {
  1625  			return config.SetInt(conf, key, 1)
  1626  		} else {
  1627  			return config.SetInt(conf, key, 0)
  1628  		}
  1629  	default:
  1630  		return sql.ErrInvalidType.New(v)
  1631  	}
  1632  }
  1633  
  1634  // SystemVariablesInConfig returns system variables from the persisted config
  1635  // and a list of persisted keys that have no corresponding definition in
  1636  // |sql.SystemVariables|.
  1637  func SystemVariablesInConfig(conf config.ReadableConfig) ([]sql.SystemVariable, []string, error) {
  1638  	allVars := make([]sql.SystemVariable, conf.Size())
  1639  	var missingKeys []string
  1640  	i := 0
  1641  	var err error
  1642  	var def interface{}
  1643  	conf.Iter(func(k, v string) bool {
  1644  		def, err = getPersistedValue(conf, k)
  1645  		if err != nil {
  1646  			if sql.ErrUnknownSystemVariable.Is(err) {
  1647  				err = nil
  1648  				missingKeys = append(missingKeys, k)
  1649  				return false
  1650  			}
  1651  			err = fmt.Errorf("key: '%s'; %w", k, err)
  1652  			return true
  1653  		}
  1654  		// getPersistedVal already checked for errors
  1655  		sysVar, _, _ := sql.SystemVariables.GetGlobal(k)
  1656  		sysVar.SetDefault(def)
  1657  		allVars[i] = sysVar
  1658  		i++
  1659  		return false
  1660  	})
  1661  	if err != nil {
  1662  		return nil, nil, err
  1663  	}
  1664  	return allVars, missingKeys, nil
  1665  }
  1666  
  1667  var initMu = sync.Mutex{}
  1668  
  1669  func InitPersistedSystemVars(dEnv *env.DoltEnv) error {
  1670  	initMu.Lock()
  1671  	defer initMu.Unlock()
  1672  
  1673  	var globals config.ReadWriteConfig
  1674  	if localConf, ok := dEnv.Config.GetConfig(env.LocalConfig); ok {
  1675  		globals = config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix)
  1676  	} else if globalConf, ok := dEnv.Config.GetConfig(env.GlobalConfig); ok {
  1677  		globals = config.NewPrefixConfig(globalConf, env.SqlServerGlobalsPrefix)
  1678  	} else {
  1679  		cli.Println("warning: no local or global Dolt configuration found; session is not persistable")
  1680  		globals = config.NewMapConfig(make(map[string]string))
  1681  	}
  1682  
  1683  	persistedGlobalVars, missingKeys, err := SystemVariablesInConfig(globals)
  1684  	if err != nil {
  1685  		return err
  1686  	}
  1687  	for _, k := range missingKeys {
  1688  		cli.Printf("warning: persisted system variable %s was not loaded since its definition does not exist.\n", k)
  1689  	}
  1690  	sql.SystemVariables.AddSystemVariables(persistedGlobalVars)
  1691  	return nil
  1692  }
  1693  
  1694  // TransactionRoot returns the noms root for the given database in the current transaction
  1695  func TransactionRoot(ctx *sql.Context, db SqlDatabase) (hash.Hash, error) {
  1696  	tx, ok := ctx.GetTransaction().(*DoltTransaction)
  1697  	// We don't have a real transaction in some cases (esp. PREPARE), in which case we need to use the tip of the data
  1698  	if !ok {
  1699  		return db.DbData().Ddb.NomsRoot(ctx)
  1700  	}
  1701  
  1702  	nomsRoot, ok := tx.GetInitialRoot(db.Name())
  1703  	if !ok {
  1704  		return hash.Hash{}, fmt.Errorf("could not resolve initial root for database %s", db.Name())
  1705  	}
  1706  
  1707  	return nomsRoot, nil
  1708  }
  1709  
  1710  // DefaultHead returns the head for the database given when one isn't specified
  1711  func DefaultHead(baseName string, db SqlDatabase) (string, error) {
  1712  	head := ""
  1713  
  1714  	// First check the global variable for the default branch
  1715  	_, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(baseName))
  1716  	if ok {
  1717  		head = val.(string)
  1718  		branchRef, err := ref.Parse(head)
  1719  		if err == nil {
  1720  			head = branchRef.GetPath()
  1721  		} else {
  1722  			head = ""
  1723  			// continue to below
  1724  		}
  1725  	}
  1726  
  1727  	// Fall back to the database's initially checked out branch
  1728  	if head == "" {
  1729  		rsr := db.DbData().Rsr
  1730  		if rsr != nil {
  1731  			headRef, err := rsr.CWBHeadRef()
  1732  			if err != nil {
  1733  				return "", err
  1734  			}
  1735  			head = headRef.GetPath()
  1736  		}
  1737  	}
  1738  
  1739  	if head == "" {
  1740  		head = db.Revision()
  1741  	}
  1742  
  1743  	return head, nil
  1744  }